Source code for chatlearn.utils.arguments

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""arguments from command or yaml."""

import argparse
import ast
import os
from typing import List

import yaml

from chatlearn.utils.constant import LORA_LAYER, RAY_PG_STRATEGY, PARAM_SYNC_COMM_TYPE
from chatlearn.utils.logger import logger
from chatlearn.utils.utils import get_attributes


def get_path(fn, folder):
    if not fn.startswith("/") and not fn.startswith(folder):
        fn = os.path.join(folder, fn)
    assert os.path.exists(fn), f'{fn} not exists'
    return fn


def convert_type(data):
    try:
        return ast.literal_eval(data)
    except Exception:
        return data


def parse_value(value):
    if isinstance(value, dict):
        return {k: parse_value(v) for k, v in value.items()}

    if isinstance(value, str):
        if value.strip().startswith("${"):
            # ${env_name:default_value}
            placeholder = value.replace("${", "")[:-1]
            placeholder = placeholder.split(":")
            env_name = placeholder[0]
            if env_name in os.environ:
                value = convert_type(os.environ[env_name])
            else:
                if len(placeholder) > 1:
                    value = convert_type(placeholder[1])
                else:
                    logger.warning(f"cannot find value for {env_name}, set to None")
                    value = None
    return value


def update_dict(src, dst):
    # do not overwrite
    for k, v in src.items():
        if k not in dst:
            dst[k] = v
        else:
            if isinstance(v, dict) and isinstance(dst[k], dict):
                update_dict(v, dst[k])


def parse_args_from_yaml(config_file, config_dir):
    with open(config_file, 'r', encoding='utf-8') as stream:
        config_vars = yaml.load(stream, Loader=yaml.SafeLoader)
        # empty yaml file
        if config_vars is None:
            return {}
        config_vars = {key: parse_value(value) for key, value in config_vars.items()}
        if 'includes' in config_vars:
            includes_vars = {}
            # iterate in reverse order, so the next include overwrite the prev
            for base in reversed(config_vars["includes"]):
                base_path = get_path(base, config_dir)
                base_config = parse_args_from_yaml(base_path, config_dir)
                update_dict(base_config, includes_vars)
            update_dict(includes_vars, config_vars)
        return config_vars


def parse_args():
    """Parse all arguments."""
    parser = argparse.ArgumentParser(description='ChatLearn Arguments',
                                     allow_abbrev=False)

    parser.add_argument("-c", "--config",
                        required=False,
                        help="where to load YAML configuration",
                        metavar="FILE")

    args, _ = parser.parse_known_args()

    if args.config:
        config_dir = os.path.dirname(args.config)
        args_yaml = parse_args_from_yaml(args.config, config_dir)
    else:
        config_dir = None
        args_yaml = None
    config = Config(args_yaml, config_dir)
    return config


class BaseConfig:
    """Base class includes some common format functions."""

    def __init__(self):
        self._finalize = True

    def __str__(self):
        members = [attr for attr in dir(self) \
                   if not callable(getattr(self, attr)) and not attr.startswith("__")]
        ser_str = self.__class__.__name__ + " {\n"
        for key in members:
            if key.startswith('_'):
                continue
            attr = getattr(self, key)
            attr = '"{}"'.format(attr) if isinstance(attr, str) else attr
            ser_str += "    %s = %s,\n" % (key, attr)
        ser_str += "}"

        return ser_str

    def __repr__(self):
        return self.__str__()

    def validate(self):
        pass


class SubConfig(BaseConfig):
    """Sub Config"""
    _is_changed = False

    def __setattr__(self, name, value):
        if not name.startswith("_") and getattr(self, name) != value:
            self._is_changed = True
        super().__setattr__(name, value)

    def is_changed(self):
        return self._is_changed


[docs] class LoraConfig(SubConfig): """Config for lora""" #: enable lora, default False. enable_lora: bool = False #: The "name_scope" parameter is used to specify a particular module to be converted to its LoRA. #: By default, it is set to None, which means there is no restriction on the module and any module #: can be converted using the "lora_layer" parameter. However, if "name_scope" is set to a specific #: value (e.g., "encoder"), only the modules whose name_scope contains the value "encoder" will be converted to LoRA. part_module_name: str = None #: The rank value of the LoRA, which is the r dimension of the A/B matrix. lora_dim: int = 8 #: The LoRA dropout ratio refers to whether dropout computation is inserted in the forward pass #: of the LoRA layer. By default, the dropout ratio is set to 0.0. lora_dropout: float = 0.0 #: When adding the values of the LoRA A and B matrices to the original weight matrix, #: the scaling value is set as "W = W + A * B * lora_scaling". By default, the scaling value #: is set to 1.0. lora_scaling: float = 1.0 #: The layer class names involved in LoRA training in the model, separated by commas. lora_layer: str = LORA_LAYER #: LoRA training is enabled only in the ColumnParallelLinear layer of the MHA QKV module. column_only_qkv: bool = False
[docs] class BatchGenerationConfig(SubConfig): """Config for batch generation ranking and memory-efficiency.""" #: [optional] sort prompts by length each episode. ranking: bool = False #: [optional] min prompt length in the first stage of batch generation. min_prompt_length: int = 0
[docs] class ModelConfig(BaseConfig): """Config for model.""" #: [legacy] number of GPU used for one model, default 0. num_device: int = 0 #: [required] number of GPU used for one model, default 0, same as num_device num_gpu: int = 0 #: [required] number of GPU used for one model, default 0 num_cpu: int = 0 #: [optional] gpu per process, e.g., for PyTorch DDP, Megatron, DeepSpeed, `gpu_per_process` is set to 1 gpu_per_process: int = None #: [optional] cpu per process cpu_per_process: int = None #: [optional] number of module replica, #: for gpu model, num_replica = num_gpu // (TP * PP * DP), #: for cpu model, num_replica = num_cpu // cpu_per_process num_replica: int = 1 #: [required] whether model is trainable trainable: bool = False #: [optional] tensor model parallel size tensor_model_parallel_size: int = None #: [optional] pipeline model parallel size pipeline_model_parallel_size: int = None #: [optional] zero size zero_size: int = None #: [optional] config file for model model_config_file: str = "" config_dir: str = "" #: [optional] model type, e.g., Torch/Tensorflow, etc model_type: str = "" #: [optional] placeholder for other args args_dict: dict = None #: [optional] generation batch size, will overwrite generation batch size in RuntimeConfig generation_batch_size: int = -1 #: lora config lora: LoraConfig = None #: batch generation config batch_generation: BatchGenerationConfig = None #: offload optimizer states offload_optimizer_states = False #: parameter sync frequency sync_frequency = 1 #: offload weights offload_weights = False #: free grad buffers free_grad_buffers = False #: overall switch for offload optimizer states/weights and free grad buffers free_memory = False def __init__(self): super().__init__() self.args_dict = {} self.lora = LoraConfig() self.batch_generation = BatchGenerationConfig() def __str__(self): members = [attr for attr in dir(self) \ if not callable(getattr(self, attr)) and not attr.startswith("__")] ser_str = self.__class__.__name__ + " {\n" for key in members: if key.startswith('_'): continue attr = getattr(self, key) if key in ["lora", "batch_generation"]: if not attr.is_changed(): continue attr = '"{}"'.format(attr) if isinstance(attr, str) else attr ser_str += " %s = %s,\n" % (key, attr) ser_str += "}" return ser_str
[docs] class RuntimeConfig(BaseConfig): """training related configs.""" #: [required] number of episodes. One episode includes a inference and training loop. num_episode: int = 5000 #: [required] number of samples per episode. sample_per_episode: int = 1000 #: [optional] number of training epoch per episode. default set to 1. num_training_epoch: int = 1 #: [required] generation(inference) batch size. generation_batch_size: int = 2 #: [required] training micro batch size. train_micro_batch_size: int = 2 #: [required] training global batch size. train_global_batch_size: int = None #: [required] save checkpoint per `save_episode_interval` episodes. save_episode_interval: int = None #: [optional] log time and memory per `log_interval` iterations. log_interval: int = 1 #: [required]: data_path for dataset data_path: str = None #: [optional]: colocate models into the same device colocation: List[str] = [] #: [optional]: eval every N episode, if 0, will not eval eval_episode_interval: int = 0 #: [optional]: enable resume training when data checkpoint is set enable_resume_training: bool = True #: [optional]: checkpoint for dataloader data_checkpoint_path: str = None #: [optional]: max data checkpoint nums max_data_ckpt_nums: int = None #: [optional]: load data checkpoint from iteration load_data_checkpoint_iteration: int = None #: [optional]: stream_data_loader type, ["fixed", "dynamic"] stream_data_loader_type: str = "fixed" #: private debug: bool = False #: enable nsys nvtx nsys: bool = False #: profiler dir profiler_dir: str = None #: coalesce parameters in model sync coalesce_param: bool = True #: coalesce_buffer size in mb coalesced_buffer_mb: int = 100 #: concurrent parameter sync concurrent_comm: bool = True #: parameter sync communication type, broadcast/p2p param_sync_comm_type: str = PARAM_SYNC_COMM_TYPE.BROADCAST.value #: parameter sync max workers param_sync_max_workers: int = None #: max number of relay episodes, if `max_relay_episode` is set to -1, then relay all episodes #: if `max_relay_episode` is set to 0, then relay is disabled max_relay_episode: int = 0 #: relay after n episodes relay_episode_offset: int = 0 #: consumed samples consumed_samples: int = 0 #: concurrent model setup concurrent_setup: bool = False #: bucket size in the memory manager to reduce peak memory bucket_size_mb_in_memory_manager: int = 1024 #: free collective group after parameter synchronization and rebuild before next synchronization free_sync_collective_group: bool = False #: [optional] cpu only model schedule policy, PACK or SPREAD #: PACK: All provided bundles are packed onto a single node on a best-effort basis. #: SPREAD: Each bundle is spread onto separate nodes on a best-effort basis. cpu_schedule_strategy: str = RAY_PG_STRATEGY.SPREAD.value #: exp name for each run exp_name: str = "CHATLEARN" #: output dir output_dir: str = "./" def __init__(self): super().__init__() self._args_dict = {}
[docs] def get(self, key): """ Get other config by key. Args ---- key: str key to get config """ if key not in self._args_dict: logger.warning(f"{key} not found in RuntimeConfig") else: return self._args_dict[key]
def validate(self): """ :meta private: """ for key in self._args_dict: if key == "save_interval": raise Exception("save_interval is deprecated, please use save_episode_interval to save checkpoints")
[docs] class RuntimeEnvConfig(BaseConfig): """Runtime env config, you can refer https://docs.ray.io/en/latest/ray-core/handling-dependencies.html for more information.""" #: pip install packages pip: List[str] = [] #: python modules py_modules: List[str] = [] #: working directory working_dir: str = os.getcwd() #: platform, e.g., DLC platform: str = "" #: excludes files from packaging excludes: List[str] = [] def __init__(self): super().__init__() self._args_dict = {}
[docs] def get(self, key): """ Get other config by key Args ---- key: str Key to get config. """ if key not in self._args_dict: logger.warning(f"{key} not found in RuntimeConfig") else: return self._args_dict[key]
class Config(BaseConfig): """A class to manage chatlearn configuration. Args ---- param_dict: dict dict format of parameters.""" def __init__(self, param_dict=None, config_dir=None): super().__init__() self._finalize = False self.models = {} self.env_args = RuntimeEnvConfig() self.runtime_args = RuntimeConfig() self.config_dir = config_dir self._active_module_args = None self.initialized = False if param_dict: self._parse_params(param_dict) self._validate_params() # remove later, just for compatibility self.rlhf_args = self.runtime_args self._finalize = True def _parse_params(self, param_dict): """Parse params from param_dict.""" def set_param(user_args, config_cls, instance): for attribute, default_value in get_attributes(config_cls): if attribute in user_args: value = user_args[attribute] if attribute == "colocation": colocation_list = [] for group in value: colocation_list.append(group.replace(' ', '').split(',')) value = colocation_list else: value = default_value original_value = getattr(instance, attribute) if original_value is not None: assert isinstance(original_value, type(value)), \ f"{instance}.{attribute} should be type of {type(original_value)} but got {type(value)}" setattr(instance, attribute, value) for user_attribute in user_args: if not hasattr(config_cls, user_attribute): if hasattr(instance, "_args_dict"): getattr(instance, "_args_dict")[user_attribute] = user_args[user_attribute] else: raise RuntimeError(f"attribute {user_attribute} not defined in {config_cls.__name__}") instance.validate() for model_name, model_args in param_dict["models"].items(): model_config = ModelConfig() model_config.config_dir = self.config_dir for user_attribute, user_value in model_args.items(): if hasattr(ModelConfig, user_attribute): original_value = getattr(ModelConfig, user_attribute) if 'num_device' == user_attribute: logger.warning("num_device is deprecated, please use num_gpu instead") if 'num_gpu' not in model_args.keys(): setattr(model_config, "num_gpu", user_value) else: logger.warning("both num_device and num_gpu are set, use num_gpu") continue if 'lora' == user_attribute: set_param(user_value, LoraConfig, model_config.lora) user_value = model_config.lora elif "batch_generation" == user_attribute: set_param(user_value, BatchGenerationConfig, model_config.batch_generation) user_value = model_config.batch_generation if original_value is not None: assert isinstance(user_value, type(original_value)), \ f"ModelConfig.{user_attribute} should be type of {type(original_value)} but got {type(user_value)} ({user_value})" setattr(model_config, user_attribute, user_value) else: logger.warning(f"unknown argument {user_attribute}") self.models[model_name] = model_config if model_config.model_config_file: model_config.model_config_file = get_path(model_config.model_config_file, self.config_dir) model_config.args_dict = parse_args_from_yaml(model_config.model_config_file, self.config_dir) if "runtime" in param_dict: set_param(param_dict["runtime"], RuntimeConfig, self.runtime_args) elif "rlhf" in param_dict: logger.warning("rlhf is deprecated, please use runtime as section name") set_param(param_dict["rlhf"], RuntimeConfig, self.runtime_args) if "runtime_env" in param_dict: set_param(param_dict["runtime_env"], RuntimeEnvConfig, self.env_args) def _get_and_check_type(value, default_value, key): # To be noticed: all str type values should in lower case. if isinstance(value, str): value = value.lower() if default_value is None: return value if not isinstance(value, type(default_value)): raise ValueError("%s type error, expected: %s." \ % (key, type(default_value))) return value def _validate_params(self): if self.runtime_args.train_global_batch_size is None: self.runtime_args.train_global_batch_size = self.runtime_args.train_micro_batch_size assert self.runtime_args.train_global_batch_size % self.runtime_args.train_micro_batch_size == 0, \ f"train_global_batch_size should be times of train_micro_batch_size," \ f"but got {self.runtime_args.train_global_batch_size}/{self.runtime_args.train_micro_batch_size}" assert self.runtime_args.stream_data_loader_type.lower() in ["fixed", "dynamic"] assert self.runtime_args.cpu_schedule_strategy in [strategy.value for strategy in RAY_PG_STRATEGY] assert self.runtime_args.param_sync_comm_type in list(PARAM_SYNC_COMM_TYPE) for model_name, model_args in self.models.items(): if model_args.num_gpu >= 1: if model_args.gpu_per_process is None: model_args.gpu_per_process = 1 else: assert model_args.gpu_per_process <= model_args.num_gpu, \ f"{model_name}: gpu_per_process: {model_args.gpu_per_process}, num_cpu: {model_args.num_gpu}" elif model_args.num_cpu >= 1: if model_args.cpu_per_process is None: model_args.cpu_per_process = 1 else: assert model_args.cpu_per_process <= model_args.num_cpu, \ f"{model_name}: cpu_per_process: {model_args.cpu_per_process}, num_cpu: {model_args.num_cpu}" if model_args.generation_batch_size is None or model_args.generation_batch_size <= 0: if self.runtime_args.generation_batch_size: model_args.generation_batch_size = self.runtime_args.generation_batch_size for key in ["pipeline_model_parallel_size", "tensor_model_parallel_size", "zero_size"]: if model_args.args_dict.get(key) is not None: setattr(model_args, key, model_args.args_dict.get(key)) assert getattr(model_args, key) >= 1 elif getattr(model_args, key) is None: setattr(model_args, key, 1) if model_args.tensor_model_parallel_size > 1 or model_args.pipeline_model_parallel_size > 1: assert model_args.zero_size == 1 or model_args.zero_size is None assert model_args.num_gpu % ( model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size) == 0, \ "num_gpu must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size " \ f"for {model_name} model, but got num_gpu = {model_args.num_gpu}" \ f"tensor_model_parallel_size = {model_args.tensor_model_parallel_size}, and " \ f"pipeline_model_parallel_size = {model_args.pipeline_model_parallel_size}." assert model_args.num_gpu > 0 or model_args.num_cpu > 0, \ f"{model_name} num_gpu: {model_args.num_gpu}, num_cpu: {model_args.num_cpu}, at least one of them should be set" if model_args.num_gpu >= 1: if model_args.zero_size > 1: assert model_args.num_gpu % model_args.zero_size == 0 model_args.num_replica = model_args.num_gpu // model_args.zero_size else: model_args.num_replica = model_args.num_gpu // ( model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size) elif model_args.num_cpu >= 1: model_args.num_replica = model_args.num_cpu // model_args.cpu_per_process assert model_args.num_replica * model_args.generation_batch_size <= self.runtime_args.sample_per_episode, \ f"num_replica * batch_size {model_args.num_replica}*{model_args.generation_batch_size} " + \ f"should be less than sample_per_episode {self.runtime_args.sample_per_episode}" if model_args.batch_generation.min_prompt_length: logger.info(f"Enable batch generation: \ min_prompt_length = {model_args.batch_generation.min_prompt_length}") if model_args.free_memory: model_args.offload_weights = True if model_args.trainable: model_args.free_grad_buffers = True model_args.offload_optimizer_states = True if self.runtime_args.colocation and len(self.runtime_args.colocation) > 0: model_set = set() for colocate_models in self.runtime_args.colocation: for model_name in colocate_models: assert model_name not in model_set, f"Model {model_name} should only appear once in colocation group" model_set.add(model_name) if self.runtime_args.exp_name not in self.runtime_args.output_dir: self.runtime_args.output_dir = f"{self.runtime_args.output_dir}/{self.runtime_args.exp_name}" logger.info(f"Env Config: \n{self.env_args}") logger.info(f"Runtime Config: \n{self.runtime_args}") for name, model_args in self.models.items(): logger.info(f"Model({name}) Config: \n{model_args}") @property def active_module_args(self): return self._active_module_args @active_module_args.setter def active_module_args(self, config): self._active_module_args = config