Source code for chatlearn.models.base_module

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""base module"""

from collections import defaultdict
from functools import reduce
from itertools import cycle
import math
import operator
import os
import torch

import ray
import ray.util.collective as col
from ray.util.collective.collective_group.base_collective_group import BaseGroup
from ray.util.collective.collective_group.nccl_collective_group import NCCLGroup
from torch.utils.data import DataLoader
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from chatlearn.data.sampler import SingleDataSampler, EpisodeDataSampler
from chatlearn.checkpoint.checkpoint_manager import CheckpointManager
from chatlearn.utils import future
from chatlearn.utils.dist_utils import bucket_tensors, coalesced_comm_dense
from chatlearn.utils.global_vars import get_args
from chatlearn.utils.global_vars import set_global_variables
from chatlearn.utils.logger import log_rank_0, debug_rank_0, setup_logger
from chatlearn.utils.timer import Timers
from chatlearn.utils.utils import get_host_addr
from chatlearn.launcher import dlc_utils


[docs] class BaseModule: """BaseModule is the base class for Base models. Args ---- name : str model name """ def __init__(self, name, args=None, replica_id=0): self.name = name if args is None: global_args = get_args() else: global_args = args set_global_variables(args) self.global_args = global_args args = global_args.models[name] self.total_gpu = args.num_gpu self.total_cpu = args.num_cpu self.gpu_per_process = args.gpu_per_process self.trainable = args.trainable self._runtime_args = self.global_args.runtime_args self._module_args = args self.replica_id = replica_id self.config_dir = args.config_dir self._is_colocate = False if self.total_gpu > 0: self._num_gpu_per_replica = args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.zero_size assert self._num_gpu_per_replica <= self.total_gpu assert self.total_gpu % self._num_gpu_per_replica == 0 if not self.trainable: self._num_replica = args.num_gpu // self._num_gpu_per_replica else: # For trainable models, perform the DP inside DistActor self._num_replica = 1 self._num_gpu_per_replica = self.total_gpu else: self._num_gpu_per_replica = 0 self._num_replica = args.num_replica assert self._num_replica >= 1 self._param_ranks = None self._named_parameters = None self._param_to_name = None self._parameters = None self._coalesced_parameters = None self.error_signal = None self._rank = None self._world_size = None self._group_names = [] self._dataloader = None self._eval_dataloader = None self._kl_coef = None self._padding_config = {} self._storage = None self._timers = None self._data_iter = None self._eval_data_iter = None self.call_funcs = [] self.trainable_funcs = [] self._data_ckpt_manager = None self._peak_memory = 0 self._parameters_to_sync = defaultdict(list) self._concat_params_dict = None self._to_fix_act_ordering_dict = None self._to_fix_qkv_ordering_dict = None self._to_fix_qkv_ordering_func = None # current compute iteration self._iteration = 0 self._train_iteration = 0 self.enable_lora = self._module_args.lora.enable_lora self._finalized = False self._resume_training = False self._address = dlc_utils.get_addr() if dlc_utils.in_dlc_env() else get_host_addr() self._is_master_node = os.environ.get("RANK", '0') == '0' self._logger = setup_logger(model_name=self.name, ip_addr=self._address) self._dummy_output = None self._dummy_inputs = [] # parameter sync from src_model self._src_parameter_model = None self.profiler = None @property def is_colocate(self): return self._is_colocate
[docs] def set_colocate(self, flag): self._is_colocate = flag
def finalize(self): """ finalize the class, any change from user after finalize will not work. :meta private: """ self._finalized = True def _assert_not_finalized(self): """ :meta private: """ assert not self._finalized, f"{self} is finalized, any change to the class should happen before finalize." @property def runtime_args(self): """ Return the arguments related to alignment training, the settings that are specified under the "runtime" section of the YAML configuration file. """ return self._runtime_args @property def model_args(self): """ Return model arguments, such as those related to Megatron, should be specified in a separate configuration yaml file for the model being used. """ return self._module_args.args_dict @property def module_args(self): """ Return module arguments. module_args include `num_gpu`, `gpu_per_process`, `model_config_file`, etc. """ return self._module_args @property def parameter_sync_frequency(self): return self.module_args.sync_frequency def set_env(self, args): """ set system env, private :meta private: """ def set_error_signal(self, error_signal): """ signal for handling errors :meta private: """ self.error_signal = error_signal def error(self, error_msg=None): """ :meta private: """ future.wait(self.error_signal.set.remote(error_msg))
[docs] def init(self): """ Init env. """
[docs] def setup(self): """ Create model / optimizer / opt_param_scheduler / etc. """
@property def data_ckpt_manager(self): """ :meta private: """ if self.runtime_args.data_checkpoint_path is not None: assert self._data_ckpt_manager is not None return self._data_ckpt_manager def model_setup(self): """ :meta private: """ self.global_args.active_module_args = self._module_args if self.runtime_args.data_checkpoint_path is not None: self._data_ckpt_manager = CheckpointManager(self, self.runtime_args.data_checkpoint_path, self.runtime_args.max_data_ckpt_nums, self.runtime_args.load_data_checkpoint_iteration) if self.runtime_args.enable_resume_training: meta = self._data_ckpt_manager.resume() if meta: self._resume_training = self.runtime_args.consumed_samples > 0 start_episode = meta["episode"] + 1 self._iteration = start_episode * math.ceil(self.runtime_args.sample_per_episode / \ self._num_replica / self.module_args.generation_batch_size) log_rank_0(f"{self.name} resume training {self._resume_training}: set start iteration to {self._iteration}", self._logger) self.setup()
[docs] def forward_step(self, data, iteration): """ Perform forward step for one batch. Args ---- data : dict data for forward_step iteration : int local forward iteration Returns ------- Dict A dict of results, where key is the string type, and the value is the tensor or a list, where the first dim of tensor or the len of list equals to batch size """
[docs] def train_step(self, data, iteration): """ Perform train_step for one batch, including a list of micro-batches. Args ---- data : [Dict] A list of micro-batch for train_step, type of each micro-batch is dict iteration : int local train iteration """
[docs] def eval_step(self, data): """ Perform eval_step for one batch Args ---- data: Dict Data for eval_step. Returns ------- Dict A dict of results, where key is the string type, and the value is the tensor or a list, where the first dim of tensor or the len of list equals to batch size """
[docs] def save_checkpoint(self, iteration): """ Save checkpoint given iteration. Args ---- iteration: int Current training iteration """
def save_data_checkpoint(self, replica_id, iteration, episode_id): """ Save checkpoint for dataloader. :meta private: """ if self.data_ckpt_manager is not None: consumed_samples = self.runtime_args.consumed_samples self.data_ckpt_manager.save_checkpoint(replica_id, iteration, episode_id, consumed_samples)
[docs] def put(self, key, data): """ Put the data to shared storage. Args ---- key: Str Use key to put. data data to save """ self._storage.put.remote(key, data)
[docs] def get(self, key): """ Get data from shared storage using key Args ---- key: Str use key to get """ ref = self._storage.get.remote(key) return future.get(ref)
def validate(self): """ :meta private: """
[docs] def before_episode(self): """ Operations before one episode. """
[docs] def after_episode(self): """ Operations after one episode. """
[docs] def build_dataset(self, train_prompts, is_eval=False): """ Build prompt dataset Args ---- train_prompts: [Str] A list of prompt string. Returns ------- torch.utils.data.Dataset Dataset with user-defined collate_fn """
def _build_dataloader(self, data, batch_size, dynamic_batch_size_flag=False, is_eval=False): """ build and set the dataloader for the model Args: data: a list of string is_eval: set to `True` to build a dataloader for evaluation (default: `False`) :meta private: """ dataset = self.build_dataset(data, is_eval) # pylint: disable=assignment-from-no-return consumed_samples = 0 if not is_eval: if self.data_ckpt_manager is not None: consumed_samples = self.runtime_args.consumed_samples collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None dataloader = self.build_dataloader(dataset, batch_size=batch_size, collate_fn=collate_fn, is_eval=is_eval, dynamic_batch_size_flag=dynamic_batch_size_flag, consumed_samples=consumed_samples) if is_eval: self._eval_dataloader = dataloader self._eval_data_iter = iter(self._eval_dataloader) else: self._data_iter = iter(dataloader) self._data_iter = cycle(self._data_iter) self._dataloader = dataloader def build_dataloader(self, dataset, batch_size, collate_fn=None, is_eval=False, dynamic_batch_size_flag=False, consumed_samples=0): """ build the dataloader for the model Args: dataset: a torch.utils.data.Dataset object batch_size: how many samples per batch to load collate_fn: set when loading from an map-style dataset (defulat: `None`) is_eval: set to `True` to build a dataloader for evaluation (default: `False`) consumed_samples: consumed samples :meta private: """ log_rank_0(f"Creating DataLoader... consumed_samples: {consumed_samples}", self._logger) if is_eval: batch_sampler = SingleDataSampler(total_samples=len(dataset), consumed_samples=0, micro_batch_size=batch_size, data_parallel_rank=self.replica_id, data_parallel_size=self._num_replica, dynamic_batch_size_flag=dynamic_batch_size_flag) else: batch_sampler = EpisodeDataSampler(total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=batch_size, data_parallel_rank=self.replica_id, data_parallel_size=self._num_replica, sample_per_episode=self.runtime_args.sample_per_episode) return DataLoader( dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, pin_memory=True ) def reset_eval_data_iter(self): """ :meta private: """ if self._eval_dataloader is not None: self._eval_data_iter = iter(self._eval_dataloader) def next_batch(self, is_eval=False): """ :meta private: """ if is_eval: return next(self._eval_data_iter) else: return next(self._data_iter) @property def num_replica(self): """ :meta private: """ return self._num_replica @property def num_gpu_per_replica(self): """ :meta private: """ return self._num_gpu_per_replica def setup_collective_group(self, rank, world_size, backend, group_name): """ :meta private: """ self._group_names.append(group_name) self._world_size = world_size col.init_collective_group( world_size, rank, backend=backend, group_name=group_name) def _destroy_collective_group(self, group_name): """ :meta private: """ from ray.util.collective.collective import _group_mgr # pylint: disable=import-outside-toplevel rank = col.get_rank(group_name) saved_group: BaseGroup = _group_mgr.get_group_by_name(group_name) saved_comm_keys = [] if isinstance(saved_group, (NCCLGroup, )): saved_comm_keys = list(saved_group._dev_comm_map.keys()) try: col.destroy_collective_group(group_name) except Exception as e: self._logger.warning(f"_destroy_collective_group {group_name} {e}") if isinstance(saved_group, (NCCLGroup, )): for comm_key in saved_comm_keys: group_key = saved_group._generate_group_key(comm_key) from ray.util.collective.const import get_store_name # pylint: disable=import-outside-toplevel store_name = get_store_name(group_key) try: store = ray.get_actor(store_name) if rank == 0: raise RuntimeError(f'{store_name} in group {group_name} should be killed on rank {rank}.') self._logger.debug(f'Kill {store_name} in group {group_name} on rank {rank}') ray.kill(store) except ValueError: ...
[docs] def destroy_collective_group(self): for group_name in self._group_names: self._destroy_collective_group(group_name) self._group_names = []
def get_local_param_ranks(self): """ :meta private: """ def fuse_lora_layer(self): """ :meta private: """ from chatlearn.models.megatron.lora import fuse_lora_layer # pylint: disable=import-outside-toplevel fuse_lora_layer(self.model) def unfuse_lora_layer(self): """ :meta private: """ from chatlearn.models.megatron.lora import unfuse_lora_layer # pylint: disable=import-outside-toplevel unfuse_lora_layer(self.model) @property def rank(self): """ :meta private: """ return self._rank def get_rank(self): """ :meta private: """ return self.rank
[docs] def is_last_rank(self): """ Is last rank. """ return True
@property def parameters(self): """ :meta private: """ if self._parameters is None: if not isinstance(self.model, list): model = [self.model] else: model = self.model self._parameters = [] for partition in model: for item in partition.parameters(): self._parameters.append(item) return self._parameters @property def named_parameters(self): """ :meta private: """ if self._named_parameters is None: if not isinstance(self.model, list): model = [self.model] else: model = self.model self._named_parameters = {} for partition in model: for item in partition.named_parameters(): self._named_parameters[item[0]] = item[1] return self._named_parameters @property def param_to_name(self): """ :meta private: """ if self._param_to_name is None: if not isinstance(self.model, list): model = [self.model] else: model = self.model self._param_to_name = {} for partition in model: for item in partition.named_parameters(): self._param_to_name[item[1]] = item[0] return self._param_to_name @property def concat_params_dict(self): return self._concat_params_dict
[docs] def get_concat_params_dict(self): return self._concat_params_dict
[docs] def set_concat_params_dict(self, _concat_params_dict): self._concat_params_dict = _concat_params_dict
@property def to_fix_act_ordering_dict(self): return self._to_fix_act_ordering_dict
[docs] def get_to_fix_act_ordering_dict(self): return self._to_fix_act_ordering_dict
[docs] def set_to_fix_act_ordering_dict(self, _to_fix_act_ordering_dict): self._to_fix_act_ordering_dict = _to_fix_act_ordering_dict
@property def to_fix_qkv_ordering_dict(self): return self._to_fix_qkv_ordering_dict
[docs] def get_to_fix_qkv_ordering_dict(self): return self._to_fix_qkv_ordering_dict
[docs] def set_to_fix_qkv_ordering_dict(self, _to_fix_qkv_ordering_dict): self._to_fix_qkv_ordering_dict = _to_fix_qkv_ordering_dict
@property def to_fix_qkv_ordering_func(self): return self._to_fix_qkv_ordering_func
[docs] def get_to_fix_qkv_ordering_func(self): return self._to_fix_qkv_ordering_func
[docs] def set_to_fix_qkv_ordering_func(self, _to_fix_qkv_ordering_func): self._to_fix_qkv_ordering_func = _to_fix_qkv_ordering_func
def _set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_to_sync=None): # pylint: disable=too-many-nested-blocks if parameters_to_sync is None: parameters_to_sync = defaultdict(list) concat = [] set_sync_param_flag = False if self.concat_params_dict is not None: if isinstance(self.concat_params_dict, dict): assert "modules" in self.concat_params_dict assert "dim" in self.concat_params_dict assert isinstance(self.concat_params_dict["modules"], list) concat_modules_list = self.concat_params_dict["modules"] concat_dim = self.concat_params_dict["dim"] else: raise RuntimeError( f"Expect concat_params_dict in {self} to be a dict or None, while {self.concat_params_dict}.") if self.to_fix_act_ordering_dict is not None: if isinstance(self.to_fix_act_ordering_dict, dict): assert "modules" in self.to_fix_act_ordering_dict assert "dim" in self.to_fix_act_ordering_dict assert isinstance(self.to_fix_act_ordering_dict["modules"], list) to_fix_act_ordering_list = self.to_fix_act_ordering_dict["modules"] fix_dim = self.to_fix_act_ordering_dict["dim"] else: raise RuntimeError( f"Expect to_fix_act_ordering_dict in {self} to be a dict or None, while {self.to_fix_act_ordering_dict}.") if self.to_fix_qkv_ordering_dict is not None: if isinstance(self.to_fix_qkv_ordering_dict, dict): assert "modules" in self.to_fix_qkv_ordering_dict assert "layer_re" in self.to_fix_qkv_ordering_dict assert isinstance(self.to_fix_qkv_ordering_dict["modules"], list) to_fix_modules_list = self.to_fix_qkv_ordering_dict["modules"] layer_re = self.to_fix_qkv_ordering_dict["layer_re"] else: raise RuntimeError( f"Expect to_fix_qkv_ordering_dict in {self} to be a dict or None, while {self.to_fix_qkv_ordering_dict}.") for name in trainable_param_names: if self.concat_params_dict is None and self.to_fix_act_ordering_dict is None: set_sync_param_flag = True _params_to_sync = self.named_parameters[name] else: need_concat_or_fix = False if self.concat_params_dict is not None: if any([ele in name for ele in concat_modules_list]): # pylint: disable=use-a-generator concat.append(self.named_parameters[name]) need_concat_or_fix = True if len(concat) == len(concat_modules_list): set_sync_param_flag = True _params_to_sync = torch.cat(concat, dim=concat_dim) if self.to_fix_act_ordering_dict is not None: if any([ele in name for ele in to_fix_act_ordering_list]): # pylint: disable=use-a-generator val = self.named_parameters[name] offset = val.shape[0] // 2 w1 = val[:offset, :] w2 = val[offset:, :] need_concat_or_fix = True set_sync_param_flag = True _params_to_sync = torch.cat([w2, w1], dim=fix_dim) if not need_concat_or_fix: set_sync_param_flag = True _params_to_sync = self.named_parameters[name] if not set_sync_param_flag: continue if self.to_fix_qkv_ordering_dict is not None: from chatlearn.utils.vllm_utils import split_attn_state # pylint: disable=import-outside-toplevel m = layer_re.match(name) if m is not None: op_name = m.group(2) if op_name in to_fix_modules_list: checkpoint_version = 3.0 tp_size = self.module_args.args_dict["tensor_model_parallel_size"] heads = self.module_args.args_dict["num_attention_heads"] // tp_size hidden_size_per_head = self.module_args.args_dict["hidden_size"] // self.module_args.args_dict[ "num_attention_heads"] if self._to_fix_qkv_ordering_func is split_attn_state: _num_query_groups = self.module_args.args_dict["num_query_groups"] // tp_size \ if self.module_args.args_dict["group_query_attention"] else heads _params_to_sync = self._to_fix_qkv_ordering_func( _params_to_sync, heads, _num_query_groups, hidden_size_per_head, self.module_args.args_dict["hidden_size"]) else: input_shape = _params_to_sync.size() shape = (heads, hidden_size_per_head, 3) + input_shape[1:] division = reduce(operator.mul, shape, 1) num_elements = _params_to_sync.numel() if num_elements == division: # model with gqa dont need to fix qkv ordering. weight_or_bias = m.group(3) _params_to_sync = self._to_fix_qkv_ordering_func( _params_to_sync, checkpoint_version, 3, heads, hidden_size_per_head ) if weight_or_bias == "weight": _params_to_sync = _params_to_sync.contiguous() concat = [] set_sync_param_flag = False parameters_to_sync[pipe_stage].append((name, _params_to_sync)) return parameters_to_sync def set_sync_parameters(self, trainable_param_names, pipe_stage=0): """ :meta private: """ if pipe_stage not in self._parameters_to_sync or len(self._parameters_to_sync[pipe_stage]) == 0: # pylint: disable=too-many-nested-blocks self._set_sync_parameters(trainable_param_names, pipe_stage, self._parameters_to_sync)
[docs] def reset_sync_parameters(self, trainable_param_names, pipe_stage=0): self._parameters_to_sync[pipe_stage] = [] self._set_sync_parameters(trainable_param_names, pipe_stage, self._parameters_to_sync)
def get_parameter_names(self, requires_grad=True): """ :meta private: """ param_to_name = self.param_to_name if requires_grad: return [param_to_name[param] for param in self.parameters if param.requires_grad] else: return [param_to_name[param] for param in self.parameters] def get_parameter(self, name): """ :meta private: """ if name not in self.named_parameters: raise Exception(f"parameter {name} not exits") return self.named_parameters[name]
[docs] def get_parameter_to_sync(self, name, pipe_stage): for name0, param in self._parameters_to_sync[pipe_stage]: if name0 == name: return param.cpu()
def exist_parameter(self, name): """ :meta private: """ return name in self.named_parameters def parameter_shape(self, name): """ :meta private: """ return self.get_parameter(name).shape def send_recv_parameter(self, name, rank, group_name, func, pipe_stage=0): """ :meta private: """ if self.runtime_args.coalesce_param: assert name is None tensors = [param.data for _, param in self._parameters_to_sync[pipe_stage]] dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) for bucket in dense_buckets: tensor_changed = func is col.recv coalesced_comm_dense(bucket, func, extra_args=(rank, group_name), tensor_changed=tensor_changed) for param in sparse_bucket: func(param, rank, group_name) else: tensor = self.get_parameter(name) func(tensor, rank, group_name) def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): """ :meta private: """ tensors = [param.data for _, param in self._parameters_to_sync[pipe_stage]] assert len(tensors) > 0 dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) tensor_changed = rank != src_rank for bucket in dense_buckets: coalesced_comm_dense(bucket, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) for param in sparse_bucket: col.broadcast(param, src_rank, group_name) def send_parameter(self, name, dst_rank, group_name, pipe_stage=0): """ :meta private: """ self.send_recv_parameter(name, dst_rank, group_name, col.send, pipe_stage) def recv_parameter(self, name, src_rank, group_name, pipe_stage=0): """ :meta private: """ self.send_recv_parameter(name, src_rank, group_name, col.recv, pipe_stage) def ray_put_parameter(self, name, group_name, pipe_stage=0): """ :meta private: """ name2ref = {} if self.runtime_args.coalesce_param: assert name is None tensors = [param.data for _, param in self._parameters_to_sync[pipe_stage]] dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) debug_rank_0(f"{self.name} Put dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) for bucket_id, bucket in enumerate(dense_buckets): flat_tensors = _flatten_dense_tensors(bucket) flat_tensors_ref = ray.put(flat_tensors) name2ref[group_name + ":dense_bucket_" + str(bucket_id)] = flat_tensors_ref for param_id, param in enumerate(sparse_bucket): param_ref = ray.put(param) name2ref[group_name + ":sparse_bucket_" + str(param_id)] = param_ref else: tensor = self.get_parameter(name) tensor_ref = ray.put(tensor) name2ref[group_name + ":" + name] = tensor_ref return name2ref def ray_get_parameter(self, name, group_name, name2ref, pipe_stage=0): """ :meta private: """ if self.runtime_args.coalesce_param: assert name is None tensors = [param.data for _, param in self._parameters_to_sync[pipe_stage]] dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) debug_rank_0(f"{self.name} Get dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) for bucket_id, bucket in enumerate(dense_buckets): put_ref = name2ref[group_name + ":dense_bucket_" + str(bucket_id)] flat_tensors = ray.get(put_ref) for tensor, synced in zip( bucket, _unflatten_dense_tensors(flat_tensors, bucket)): tensor.copy_(synced) for param_id, param in enumerate(sparse_bucket): put_ref = name2ref[group_name + ":sparse_bucket_" + str(param_id)] param.copy_(ray.get(put_ref)) else: tensor = self.get_parameter(name) put_ref = name2ref[group_name + ":" + name] tensor.copy_(ray.get(put_ref)) def pipeline_model_parallel_size(self): """ :meta private: """ return self.module_args.pipeline_model_parallel_size def tensor_model_parallel_size(self): """ :meta private: """ return self.module_args.tensor_model_parallel_size def num_layers(self): """ :meta private: """ def set_storage(self, storage): """ :meta private: """ self._storage = storage def timers(self, name): """ :meta private: """ if self._timers is None: self._timers = Timers() return self._timers(name) def timer_summary(self, e2e_cost=None): """ :meta private: """ if self._timers: return self._timers.log(e2e_cost=e2e_cost)
[docs] def add_padding_config(self, key, padding_value=0.0, padding_type="right"): """ Add spectial padding config for certain value. Args ---- key: str The key for data to be padded. padding_value: float Padding value, default is 0. padding_type: str Default right, can be right/left. """ self._padding_config[key] = {"padding_value": padding_value, "padding_type": padding_type}
def padding_config(self): """ :meta private: """ return self._padding_config def peak_memory(self): """ :meta private: """ return 0.0 @property def resume_training(self): """ resume training from last checkpoint. """ return self._resume_training def get_address(self): """ Get node address :meta private: """ return self._address
[docs] def is_master_node(self): """ Whether this node is master node. :meta private: """ return self._is_master_node
[docs] def set_src_parameter_model(self, src_model): """ src_model that sync parameter to current model :meta private: """ self._src_parameter_model = src_model
@property def src_parameter_model(self): """ src_model that sync parameter to current model """ return self._src_parameter_model
[docs] def offload_optimizer_states(self): """ offload optimizer states """
[docs] def onload_optimizer_states(self): """ onload optimizer states """
[docs] def offload_main_weights(self): """ offload main weights """
[docs] def onload_main_weights(self): """ onload main weights """
[docs] def offload_weights(self): """ offload weights """
[docs] def onload_weights(self): """ onload weights """
[docs] def free_grad_buffers(self): """ free grad buffers and related tensors """
[docs] def build_grad_buffers(self): """ build grad buffers and related tensors """
[docs] def onload(self): pass
[docs] def offload(self): pass
@property def world_size(self): pass @property def data_parallel_size(self): """ data parallel size :meta private: """ @property def data_parallel_rank(self): """ data parallel rank :meta private: """ def empty_cache(self): """ :meta private: """
[docs] def get_data_parallel_rank(self): return self.data_parallel_rank
[docs] def get_data_parallel_size(self): return self.data_parallel_size