Source code for chatlearn.models.torch_module

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Torch module"""

import gc
import os
import ray
import torch
import torch.distributed as dist
from chatlearn.utils.logger import log_rank_0, debug_rank_0

from chatlearn.utils.utils import get_full_proc_memory_info
from .base_module import BaseModule

[docs] class TorchModule(BaseModule): """TorchModule is the class for Alignment Torch models. Args ---- name : str model name """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def model_setup(self): """ :meta private: """ super().model_setup() if self.runtime_args.profiler_dir is not None and self.replica_id == 0: self.profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], schedule=torch.profiler.schedule( wait=1, warmup=1, active=1, repeat=1), profile_memory=False, record_shapes=False, with_stack=False, with_flops=False, on_trace_ready=torch.profiler.tensorboard_trace_handler(self.runtime_args.profiler_dir) ) self.profiler.start() def get_visible_gpus(self): """ :meta private: """ return ray.get_gpu_ids() def set_env(self, args): """ :meta private: """ for key in ['RANK', 'MASTER_ADDR', 'MASTER_PORT', 'WORLD_SIZE', 'LOCAL_RANK']: assert key in args, f"{key} is not set for TorchModule" os.environ[key] = str(args[key]) self._rank = int(os.environ['RANK']) return 1 def get_dist_env(self): """ :meta private: """ envs = {} for key in ['RANK', 'MASTER_ADDR', 'MASTER_PORT', 'WORLD_SIZE', 'LOCAL_RANK']: envs[key] = os.environ[key] return envs def peak_memory(self): """ :meta private: """ self._peak_memory = max(self._peak_memory, torch.cuda.max_memory_allocated() / (1024 ** 3)) return self._peak_memory def empty_cache(self): """ :meta private: """ if not self.timers("empty_cache").started_: self.timers("empty_cache").start() peak_mem = torch.cuda.max_memory_allocated() / (1024 ** 3) debug_rank_0(f"{self.name} replica: {self.replica_id}, before empty cache, peak mem: {peak_mem:.2f} GiB", self._logger) torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() peak_mem = torch.cuda.max_memory_allocated() / (1024 ** 3) debug_rank_0(f"{self.name} replica: {self.replica_id}, after empty cache, peak mem: {peak_mem:.2f} GiB", self._logger) self.timers("empty_cache").stop() def check_param_exists(self, names): """ check if the given names exists in current model :meta private: """ not_exists = [] for name in names: if not self.exist_parameter(name): not_exists.append(name) if not_exists: log_rank_0(f"parameters not exists: {not_exists} in model {self.name}", self._logger) return False return True
[docs] def is_last_rank(self): """ Is last rank. """ if dist.is_initialized(): return dist.get_rank() == (dist.get_world_size() - 1) return True
@property def world_size(self): return dist.get_world_size() def _get_if_not_none(self, to_set, default): if not default: return False if to_set is not None: return to_set return default
[docs] def onload(self, to_onload_weights=None, to_build_grad_buffers=None, to_onload_main_weights=None, to_onload_optimizer_states=None): if not self.is_colocate: return to_onload_weights = self._get_if_not_none(to_onload_weights, self.module_args.offload_weights) to_build_grad_buffers = self._get_if_not_none(to_build_grad_buffers, self.module_args.free_grad_buffers) to_onload_main_weights = self._get_if_not_none(to_onload_main_weights, self.module_args.offload_weights) to_onload_optimizer_states = self._get_if_not_none(to_onload_optimizer_states, self.module_args.offload_optimizer_states) if to_onload_weights or to_build_grad_buffers or to_onload_main_weights or to_onload_optimizer_states: log_rank_0(get_full_proc_memory_info('Before onload'), self._logger) torch.cuda.synchronize() timer = self.timers(f'{self.name}_free_memory') if not timer.started_: timer.start() torch.distributed.barrier() if to_onload_weights: self.onload_weights() if self.trainable: if to_build_grad_buffers: self.build_grad_buffers() if to_onload_main_weights: self.onload_main_weights() if to_onload_optimizer_states: self.onload_optimizer_states() torch.distributed.barrier() torch.cuda.synchronize() torch.cuda.empty_cache() gc.collect() timer.stop() log_rank_0(get_full_proc_memory_info('After onload'), self._logger)
[docs] def offload(self, to_offload_weights=None, to_free_grad_buffers=None, to_offload_main_weights=None, to_offload_optimizer_states=None): # The first time of calling `offload_weights` and `offload_main_weights` has a higher peak memory. # So `free_grad_buffers` is called first to free memory, and `offload_weights` is called afterward # to make more space for `offload_main_weights`. if not self.is_colocate: return to_offload_weights = self._get_if_not_none(to_offload_weights, self.module_args.offload_weights) to_offload_main_weights = self._get_if_not_none(to_offload_main_weights, self.module_args.offload_weights) to_free_grad_buffers = self._get_if_not_none(to_free_grad_buffers, self.module_args.free_grad_buffers) to_offload_optimizer_states = self._get_if_not_none(to_offload_optimizer_states, self.module_args.offload_optimizer_states) if to_free_grad_buffers or to_offload_weights or to_offload_optimizer_states or to_offload_main_weights: log_rank_0(get_full_proc_memory_info('Before offload'), self._logger) torch.cuda.synchronize() timer = self.timers(f'{self.name}_free_memory') if not timer.started_: timer.start() torch.distributed.barrier() if self.trainable: if to_free_grad_buffers: self.free_grad_buffers() if to_offload_main_weights: self.offload_main_weights() if to_offload_optimizer_states: self.offload_optimizer_states() if to_offload_weights: self.offload_weights() torch.distributed.barrier() torch.cuda.synchronize() torch.cuda.empty_cache() gc.collect() timer.stop() log_rank_0(get_full_proc_memory_info('After offload'), self._logger)