import torch
import os
import json
import torch.distributed as dist
from accelerate import init_empty_weights

from transformers import (
    AutoModelForCausalLM,
    AutoConfig,)

from arguments import get_args
from utils import print_args, initialize, load_parallel, get_tokenizer, parallel_model_map

from minillm import train, Reward

from peft import PeftModel
from thop import profile
from pytorch_model_summary import summary
import torch_npu
from torch_npu.contrib import transfer_to_npu
import sys
sys.path.append('/usr/local/Ascend/ascend-toolkit/latest/python/site-packages')
sys.path.append('/usr/local/Ascend/ascend-toolkit/latest/opp/built-in/op_impl/ai_core/tbe')
import warnings
warnings.filterwarnings("ignore")
import torch_npu
from torch_npu.npu import amp
from torch_npu.contrib import transfer_to_npu

def get_teacher_model(args, device):
    config = AutoConfig.from_pretrained(args.teacher_model_path)  # pretrained_model_name_or_path可以是一个字符串，表示在huggingface中存在的模型；或者是文件夹路径，该文件夹包含使用save_pretrained()保存的模型权重和模型配置
    if args.model_parallel:
        config.is_model_parallel = True
        with init_empty_weights():  # 不消耗内存初始化模型，防止出现随机初始化超大模型时内存占用高
            if args.model_type == "qwen":
                model = parallel_model_map[args.model_type](config).to(torch.bfloat16)
            else:
                # model.half()表示开启半精度，减少GPU占用，需要放在model.cuda()之前
                # 模型改为半精度之后，输入也要改成半精度
                model = parallel_model_map[args.model_type](config).half()  # model = ParallelGPT2LMHeadModel(config).half()
        load_parallel(model, args.teacher_model_path)
        model = model.to(device)
    else:
        config.is_model_parallel = False
        model = AutoModelForCausalLM.from_pretrained(
            args.teacher_model_path, 
            config=config, 
            device_map={"": device}, 
            torch_dtype=torch.float16 if args.model_type != "qwen" else torch.bfloat16
        )

        if args.peft is not None:
            if args.peft == "lora":
                assert args.teacher_peft_path is not None
                model = PeftModel.from_pretrained(model, args.peft_path)
            else:
                raise NotImplementedError
        else:
            if dist.get_rank() == 0:
                # print(summary(model, torch.randn(1, 16, 16).long().to(device), show_input=False, show_hierarchical=False))
                print(' > number of parameters: {}'.format(
                    sum([p.nelement() for p in model.parameters()])), flush=True)

    model.eval()

    return model


def main():
    local_rank = int(os.environ["LOCAL_RANK"])
    torch_npu.npu.set_device(local_rank) 
    
    args = get_args()
    initialize(args)

    device = torch.cuda.current_device()  # 返回当前正在使用的device
    
    os.makedirs(args.save, exist_ok=True)  # args.save: 保存模型训练的文件
    if dist.get_rank() == 0:  # dist.get_rank(): 返回当前进程的排名，rank=0表示主进程 distribute的优势见https://zhuanlan.zhihu.com/p/76638962 基本概念见https://blog.csdn.net/hxxjxw/article/details/119606518
        print_args(args)
        with open(os.path.join(args.save, "args.json"), "w") as f:
            json.dump(vars(args), f)
            
    with open(args.deepspeed_config, "r") as f:
        ds_config = json.load(f)

    ds_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps  # 2
    ds_config["train_micro_batch_size_per_gpu"] = args.batch_size  # 2
    # 梯度截断的教程https://zhuanlan.zhihu.com/p/557949443
    ds_config["gradient_clipping"] = args.clip_grad  # 1.0 网络参数梯度的范数上线
    ds_config["steps_per_print"] = 10000000
    
    args.fp32 = not ds_config["fp16"]["enabled"]
    args.deepspeed_config = None
    
    if args.teacher_model_type is None:
        args.teacher_model_type = args.model_type  # ！！！这里的teacher_model_type是GPT2
    
    teacher_model = get_teacher_model(args, device)
    tokenizer = get_tokenizer(args)
    
    reward = Reward(args, tokenizer, teacher_model)
    
    train(
        args=args,
        tokenizer=tokenizer,
        reward_fn=reward.reward_fn,
        teacher_model=teacher_model,
        ds_config=ds_config,
        prompt_data=args.prompt_data_dir,
        eval_prompt_data=args.prompt_data_dir,
        lm_data=args.lm_data_dir,
        eval_lm_data=args.lm_data_dir,
    )


if __name__ == "__main__":
    main()