# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Megatron module"""
import inspect
try:
from chatlearn.utils.megatron_import_helper import get_args
from chatlearn.utils.megatron_import_helper import mpu
from chatlearn.utils.megatron_import_helper import initialize_megatron
from chatlearn.utils.megatron_import_helper import save_checkpoint_and_time
from chatlearn.utils.megatron_import_helper import set_jit_fusion_options
from chatlearn.utils.megatron_utils import initialize_megatron as chatlearn_initialize_megatron
from chatlearn.utils.megatron_utils import build_pipeline_layer_name_mapping
from chatlearn.models.megatron.memory_manager import create_trainer_memory_manager, InferenceMemoryManager
except ImportError:
mpu = None
from .torch_module import TorchModule
# pylint: disable=import-outside-toplevel
[docs]
class MegatronModule(TorchModule):
"""MegatronModule is the class for Alignment Megatron models.
Args
----
name : str
model name
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if mpu is None:
print("Cannot import megatron, please set megatron python path first.")
if not self.trainable:
# inference only
if self.model_args.get("micro_batch_size") != self.module_args.generation_batch_size:
self._logger.info(f"{self.name} Overwrite micro_batch_size with generation_batch_size {self.module_args.generation_batch_size}")
self.model_args["micro_batch_size"] = self.module_args.generation_batch_size
else:
self.model_args["micro_batch_size"] = self.runtime_args.train_micro_batch_size
self.model_args["global_batch_size"] = self.runtime_args.train_global_batch_size
if self.model_args.get("micro_batch_size") != self.runtime_args.train_micro_batch_size:
self._logger.info(f"{self.name} Overwrite micro_batch_size with train_micro_batch_size {self.module_args.train_micro_batch_size}")
if self.model_args.get("global_batch_size") != self.runtime_args.train_global_batch_size:
self._logger.info(f"{self.name} Overwrite global_batch_size with train_global_batch_size {self.module_args.train_global_batch_size}")
if not self.model_args.get("tensorboard_dir") and self.runtime_args.output_dir is not None:
self.model_args['tensorboard_dir'] = f"{self.runtime_args.output_dir}/tensorboard"
def init(self):
"""
:meta private:
"""
if "args_dict" in inspect.getfullargspec(initialize_megatron).args:
initialize_func = initialize_megatron
else:
initialize_func = chatlearn_initialize_megatron
initialize_func(extra_args_provider=self.add_extra_args,
ignore_unknown_args=True,
args_dict=self.model_args)
if self.trainable:
# slow down if set jit fusion for inference model
set_jit_fusion_options()
def model_setup(self):
"""
:meta private:
"""
super().model_setup()
# TODO: we may need to let setup return model, optimizer and opt_param_scheduler
if self.trainable:
assert hasattr(self, "model")
assert hasattr(self, "optimizer")
assert hasattr(self, "opt_param_scheduler")
if self.module_args.offload_weights or self.module_args.free_grad_buffers or self.module_args.offload_optimizer_states:
self._memory_manager = create_trainer_memory_manager(
self.megatron_model(),
self.optimizer,
self.megatron_args.use_distributed_optimizer,
self.megatron_args.accumulate_allreduce_grads_in_fp32,
self.megatron_args.params_dtype,
self.runtime_args.bucket_size_mb_in_memory_manager,
)
self.offload()
else:
assert hasattr(self, "model")
self.model.eval()
if self.module_args.offload_weights:
self._memory_manager = InferenceMemoryManager(
self.megatron_model(),
self.runtime_args.bucket_size_mb_in_memory_manager,
)
self.offload()
@property
def megatron_args(self):
"""
:meta private:
"""
return get_args()
def pipeline_model_parallel_size(self):
"""
get pipeline_model_parallel_size
:meta private:
"""
return self.megatron_args.pipeline_model_parallel_size
def tensor_model_parallel_size(self):
"""
get tensor_model_parallel_size
:meta private:
"""
return self.megatron_args.tensor_model_parallel_size
@property
def data_parallel_size(self):
"""
:meta private:
"""
return mpu.get_data_parallel_world_size()
@property
def data_parallel_rank(self):
"""
:meta private:
"""
return mpu.get_data_parallel_rank()
def pipeline_parallel_rank(self):
"""
:meta private:
"""
return mpu.get_pipeline_model_parallel_rank()
def tensor_parallel_rank(self):
"""
:meta private:
"""
return mpu.get_tensor_model_parallel_rank()
def num_layers(self):
"""
:meta private:
"""
return self.megatron_args.num_layers
[docs]
def megatron_model(self):
if isinstance(self.model, list):
assert len(self.model) == 1
model = self.model[0]
else:
model = self.model
return model
def build_pipeline_layer_name_mapping(self, num_target_pipe_stage, target_pipe_rank, requires_grad=True):
"""
build name mapping from src model to tgt model
Args:
num_target_pipe_stage: number of pipeline stage in target model
target_pipe_rank: target model pipeline rank
requires_grad: whether the returned layer requires_grad, as we only need to sync parameters that have changed
:meta private:
"""
src_layers_per_stage = self.num_layers() // self.pipeline_model_parallel_size()
dst_layers_per_stage = self.num_layers() // num_target_pipe_stage
assert dst_layers_per_stage % src_layers_per_stage == 0, \
"We assume pipeline stage of target model is smaller than src model, and is divisible by src model"
mapping_interval = dst_layers_per_stage // src_layers_per_stage
src_rank = mpu.get_pipeline_model_parallel_rank()
self._logger.debug(f"build mapping for rank {src_rank} =========")
model = self.megatron_model()
is_tgt_last_stage = target_pipe_rank == num_target_pipe_stage - 1 and target_pipe_rank != 0
name_mapping = build_pipeline_layer_name_mapping(src_layers_per_stage, src_rank, mapping_interval,
is_tgt_last_stage, model, requires_grad)
return name_mapping
def get_local_param_ranks(self):
"""
:meta private:
"""
data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS)
return data_parallel_global_ranks, mpu.get_data_parallel_rank()
def save_checkpoint(self, iteration):
"""
save checkpoint at `iteration`
:param iteration: save iteration
:meta private:
"""
if self.enable_lora:
self.fuse_lora_layer()
save_checkpoint_and_time(iteration, self.model, self.optimizer,
self.opt_param_scheduler)
if self.enable_lora:
self.unfuse_lora_layer()
[docs]
def offload_optimizer_states(self):
"""
offload optimizer states
"""
if self.module_args.offload_optimizer_states:
self._memory_manager.offload_optimizer_states()
[docs]
def onload_optimizer_states(self):
"""
onload optimizer states
"""
if self.module_args.offload_optimizer_states:
self._memory_manager.onload_optimizer_states()
[docs]
def offload_main_weights(self):
"""
offload main weights
"""
if self.module_args.offload_weights:
self._memory_manager.offload_main_weights()
[docs]
def onload_main_weights(self):
"""
onload main weights
"""
if self.module_args.offload_weights:
self._memory_manager.onload_main_weights()
[docs]
def offload_weights(self):
"""
offload weights
"""
if self.module_args.offload_weights:
self._memory_manager.offload_weights()
[docs]
def onload_weights(self):
"""
onload weights
"""
if self.module_args.offload_weights:
self._memory_manager.onload_weights()
[docs]
def free_grad_buffers(self):
"""
free grad buffers and related tensors
"""
if self.module_args.free_grad_buffers:
self._memory_manager.free_grad_buffers()
[docs]
def build_grad_buffers(self):
"""
build grad buffers and related tensors
"""
if self.module_args.free_grad_buffers:
self._memory_manager.build_grad_buffers()