Engine¶
- class chatlearn.DPOEngine(reference: BaseModule, policy_trainer: BaseModule)[source]¶
DPO Engine.
- class chatlearn.OnlineDPOEngine(policy: BaseModule, reference: BaseModule, reward: BaseModule, policy_trainer: BaseModule)[source]¶
Online DPO Engine.
- class chatlearn.RLHFEngine(policy: BaseModule, reference: BaseModule, reward: BaseModule, value: BaseModule, policy_trainer: BaseModule, value_trainer: BaseModule)[source]¶
- class chatlearn.EvalEngine(eval_flow=None, evaluator=None)[source]¶
Evaluation Engine
- set_dataset(dataset)[source]¶
Set prompt dataset.
- Parameters:
dataset (list) – a list of prompt string
- class chatlearn.Evaluator(model_flow)[source]¶
Evaluator.
- Parameters:
models ([BaseModule]) – models to evaluate
args (RuntimeConfig) – default to None
- set_post_process_func(post_process_func)[source]¶
Set post process function for model evaluation results.
- Parameters:
post_process_func – This function accept two arguments. 1. results: a list of evaluation results 2. eval_info: a dict meta that contains “train_iteration” and “episode_iteration”
- eval(cur_iter=None, train_iteration=None)[source]¶
Evaluating.
- Parameters:
cur_iter (int) – current iteration.
train_iteration (int) – current training iteration.
- set_dataset(dataset)¶