Engine

class chatlearn.DPOEngine(reference: BaseModule, policy_trainer: BaseModule)[source]

DPO Engine.

class chatlearn.OnlineDPOEngine(policy: BaseModule, reference: BaseModule, reward: BaseModule, policy_trainer: BaseModule)[source]

Online DPO Engine.

class chatlearn.RLHFEngine(policy: BaseModule, reference: BaseModule, reward: BaseModule, value: BaseModule, policy_trainer: BaseModule, value_trainer: BaseModule)[source]
class chatlearn.EvalEngine(eval_flow=None, evaluator=None)[source]

Evaluation Engine

set_dataset(dataset)[source]

Set prompt dataset.

Parameters:

dataset (list) – a list of prompt string

set_post_process_func(post_process_func)[source]

Set post process function.

Parameters:

post_process_func – This function accept two arguments. 1. results: a list of evaluation results 2. eval_info: a dict meta that contains “train_iteration” and “episode_iteration”

eval(cur_iter=None, train_iteration=None)[source]

Start evaluating.

class chatlearn.Evaluator(model_flow)[source]

Evaluator.

Parameters:
set_post_process_func(post_process_func)[source]

Set post process function for model evaluation results.

Parameters:

post_process_func – This function accept two arguments. 1. results: a list of evaluation results 2. eval_info: a dict meta that contains “train_iteration” and “episode_iteration”

eval(cur_iter=None, train_iteration=None)[source]

Evaluating.

Parameters:
  • cur_iter (int) – current iteration.

  • train_iteration (int) – current training iteration.

set_dataset(dataset)