RLHF Module¶
- class chatlearn.models.base_module.BaseModule(name, args=None, replica_id=0)[source]¶
BaseModule is the base class for Base models.
- Parameters:
name (str) – model name
- property is_colocate¶
- property runtime_args¶
Return the arguments related to alignment training, the settings that are specified under the “runtime” section of the YAML configuration file.
- property model_args¶
Return model arguments, such as those related to Megatron, should be specified in a separate configuration yaml file for the model being used.
- property module_args¶
Return module arguments. module_args include num_gpu, gpu_per_process, model_config_file, etc.
- property parameter_sync_frequency¶
- forward_step(data, iteration)[source]¶
Perform forward step for one batch.
- Parameters:
data (dict) – data for forward_step
iteration (int) – local forward iteration
- Returns:
A dict of results, where key is the string type, and the value is the tensor or a list, where the first dim of tensor or the len of list equals to batch size
- Return type:
Dict
- train_step(data, iteration)[source]¶
Perform train_step for one batch, including a list of micro-batches.
- Parameters:
data ([Dict]) – A list of micro-batch for train_step, type of each micro-batch is dict
iteration (int) – local train iteration
- eval_step(data)[source]¶
Perform eval_step for one batch
- Parameters:
data (Dict) – Data for eval_step.
- Returns:
A dict of results, where key is the string type, and the value is the tensor or a list, where the first dim of tensor or the len of list equals to batch size
- Return type:
Dict
- save_checkpoint(iteration)[source]¶
Save checkpoint given iteration.
- Parameters:
iteration (int) – Current training iteration
- put(key, data)[source]¶
Put the data to shared storage.
- Parameters:
key (Str) – Use key to put.
data – data to save
- build_dataset(train_prompts, is_eval=False)[source]¶
Build prompt dataset
- Parameters:
train_prompts ([Str]) – A list of prompt string.
- Returns:
Dataset with user-defined collate_fn
- Return type:
torch.utils.data.Dataset
- property concat_params_dict¶
- property to_fix_act_ordering_dict¶
- property to_fix_qkv_ordering_dict¶
- property to_fix_qkv_ordering_func¶
- add_padding_config(key, padding_value=0.0, padding_type='right')[source]¶
Add spectial padding config for certain value.
- Parameters:
key (str) – The key for data to be padded.
padding_value (float) – Padding value, default is 0.
padding_type (str) – Default right, can be right/left.
- property resume_training¶
resume training from last checkpoint.
- set_src_parameter_model(src_model)[source]¶
src_model that sync parameter to current model :meta private:
- property src_parameter_model¶
src_model that sync parameter to current model
- property world_size¶
- class chatlearn.models.torch_module.TorchModule(*args, **kwargs)[source]¶
Bases:
BaseModule
TorchModule is the class for Alignment Torch models.
- Parameters:
name (str) – model name
- property world_size¶
- class chatlearn.models.megatron_module.MegatronModule(*args, **kwargs)[source]¶
Bases:
TorchModule
MegatronModule is the class for Alignment Megatron models.
- Parameters:
name (str) – model name