RLHF Module

class chatlearn.models.base_module.BaseModule(name, args=None, replica_id=0)[source]

BaseModule is the base class for Base models.

Parameters:

name (str) – model name

property is_colocate
set_colocate(flag)[source]
property runtime_args

Return the arguments related to alignment training, the settings that are specified under the “runtime” section of the YAML configuration file.

property model_args

Return model arguments, such as those related to Megatron, should be specified in a separate configuration yaml file for the model being used.

property module_args

Return module arguments. module_args include num_gpu, gpu_per_process, model_config_file, etc.

property parameter_sync_frequency
init()[source]

Init env.

setup()[source]

Create model / optimizer / opt_param_scheduler / etc.

forward_step(data, iteration)[source]

Perform forward step for one batch.

Parameters:
  • data (dict) – data for forward_step

  • iteration (int) – local forward iteration

Returns:

A dict of results, where key is the string type, and the value is the tensor or a list, where the first dim of tensor or the len of list equals to batch size

Return type:

Dict

train_step(data, iteration)[source]

Perform train_step for one batch, including a list of micro-batches.

Parameters:
  • data ([Dict]) – A list of micro-batch for train_step, type of each micro-batch is dict

  • iteration (int) – local train iteration

eval_step(data)[source]

Perform eval_step for one batch

Parameters:

data (Dict) – Data for eval_step.

Returns:

A dict of results, where key is the string type, and the value is the tensor or a list, where the first dim of tensor or the len of list equals to batch size

Return type:

Dict

save_checkpoint(iteration)[source]

Save checkpoint given iteration.

Parameters:

iteration (int) – Current training iteration

put(key, data)[source]

Put the data to shared storage.

Parameters:
  • key (Str) – Use key to put.

  • data – data to save

get(key)[source]

Get data from shared storage using key

Parameters:

key (Str) – use key to get

before_episode()[source]

Operations before one episode.

after_episode()[source]

Operations after one episode.

build_dataset(train_prompts, is_eval=False)[source]

Build prompt dataset

Parameters:

train_prompts ([Str]) – A list of prompt string.

Returns:

Dataset with user-defined collate_fn

Return type:

torch.utils.data.Dataset

destroy_collective_group()[source]
is_last_rank()[source]

Is last rank.

property concat_params_dict
get_concat_params_dict()[source]
set_concat_params_dict(_concat_params_dict)[source]
property to_fix_act_ordering_dict
get_to_fix_act_ordering_dict()[source]
set_to_fix_act_ordering_dict(_to_fix_act_ordering_dict)[source]
property to_fix_qkv_ordering_dict
get_to_fix_qkv_ordering_dict()[source]
set_to_fix_qkv_ordering_dict(_to_fix_qkv_ordering_dict)[source]
property to_fix_qkv_ordering_func
get_to_fix_qkv_ordering_func()[source]
set_to_fix_qkv_ordering_func(_to_fix_qkv_ordering_func)[source]
reset_sync_parameters(trainable_param_names, pipe_stage=0)[source]
get_parameter_to_sync(name, pipe_stage)[source]
add_padding_config(key, padding_value=0.0, padding_type='right')[source]

Add spectial padding config for certain value.

Parameters:
  • key (str) – The key for data to be padded.

  • padding_value (float) – Padding value, default is 0.

  • padding_type (str) – Default right, can be right/left.

property resume_training

resume training from last checkpoint.

is_master_node()[source]

Whether this node is master node. :meta private:

set_src_parameter_model(src_model)[source]

src_model that sync parameter to current model :meta private:

property src_parameter_model

src_model that sync parameter to current model

offload_optimizer_states()[source]

offload optimizer states

onload_optimizer_states()[source]

onload optimizer states

offload_main_weights()[source]

offload main weights

onload_main_weights()[source]

onload main weights

offload_weights()[source]

offload weights

onload_weights()[source]

onload weights

free_grad_buffers()[source]

free grad buffers and related tensors

build_grad_buffers()[source]

build grad buffers and related tensors

onload()[source]
offload()[source]
property world_size
get_data_parallel_rank()[source]
get_data_parallel_size()[source]
class chatlearn.models.torch_module.TorchModule(*args, **kwargs)[source]

Bases: BaseModule

TorchModule is the class for Alignment Torch models.

Parameters:

name (str) – model name

is_last_rank()[source]

Is last rank.

property world_size
onload(to_onload_weights=None, to_build_grad_buffers=None, to_onload_main_weights=None, to_onload_optimizer_states=None)[source]
offload(to_offload_weights=None, to_free_grad_buffers=None, to_offload_main_weights=None, to_offload_optimizer_states=None)[source]
class chatlearn.models.megatron_module.MegatronModule(*args, **kwargs)[source]

Bases: TorchModule

MegatronModule is the class for Alignment Megatron models.

Parameters:

name (str) – model name

add_extra_args(parser)[source]

Add extra arguments for megatron.

Parameters:

parser (ArgumentParser) – Add extra arguments.

megatron_model()[source]
offload_optimizer_states()[source]

offload optimizer states

onload_optimizer_states()[source]

onload optimizer states

offload_main_weights()[source]

offload main weights

onload_main_weights()[source]

onload main weights

offload_weights()[source]

offload weights

onload_weights()[source]

onload weights

free_grad_buffers()[source]

free grad buffers and related tensors

build_grad_buffers()[source]

build grad buffers and related tensors