from deepspeed import DeepSpeedConfig
from typing import Optional

# from trlx.utils.loading import get_orchestrator, get_pipeline, get_trainer
from .sampler import PPOSampler
from .pipelines import PPOPipeline, LMPipeline
from .trainer import PPOTrainer
from .reward import Reward

def train(
    args,
    tokenizer,
    reward_fn=None,
    teacher_model=None,
    prompt_data: Optional[str] = None,
    eval_prompt_data: Optional[str] = None,
    lm_data: Optional[str] = None,
    eval_lm_data: Optional[str] = None,
    ds_config: Optional[DeepSpeedConfig] = None,
):
    # 定义模型，优化器及损失函数
    trainer = PPOTrainer(
        args=args,
        tokenizer=tokenizer,
        reward_fn=reward_fn,
        ds_config=ds_config,
    )
    trainer.set_teacher_model(teacher_model)
    # 定义训练数据
    ppo_pipeline = PPOPipeline(
        args, tokenizer, "train", prompt_data, num=args.train_num
    )

    sampler = PPOSampler(

        args, trainer, ppo_pipeline, chunk_size=args.chunk_size
    )
    sampler.run_sample(args.num_rollouts_per_device)

    # 定义验证数据
    eval_ppo_pipeline = PPOPipeline(
        args, trainer.tokenizer, "valid", eval_prompt_data, fix_prompts=True, num=args.dev_num
    )
    trainer.add_eval_pipeline(eval_ppo_pipeline)

    # 学生模型在蒸馏之前，在某个数据集上预训练，蒸馏时为防止学生模型出现遗忘，因此在该数据集上也要继续训练，用LM指代该数据集
    lm_pipeline = LMPipeline(
        args, trainer.tokenizer, "train", lm_data, num=args.train_num) if lm_data is not None else None
    eval_lm_pipeline = LMPipeline(
        args, trainer.tokenizer, "valid", eval_lm_data, num=args.dev_num) if eval_lm_data is not None else None

    trainer.add_lm_pipeline(lm_pipeline, eval_lm_pipeline)

    trainer.train()
    return trainer
