# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import warnings
import torch
import os
from rm_lora import *
from datasets import load_dataset, concatenate_datasets
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
from typing import Dict
from trl import ModelConfig, RewardConfig
import random
import numpy as np
from peft import LoraConfig, get_peft_model, TaskType

IGNORE_INDEX = 0

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
tqdm.pandas()
def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    set_seed(seed)
    torch.use_deterministic_algorithms(True)
seed_torch(0)

def rank0_print(rank, *args, **kwargs):
    """Print, but only on rank 0."""
    if rank == 0:
        print(*args, **kwargs)
        
if __name__ == "__main__":
    parser = HfArgumentParser((RewardConfig, ModelConfig))
    parser.add_argument('--pair_json_path', type=str, default="/home/v-xinyuguan/teamdrive/teamdrive/xy/xy/mcts/0924/qwen2SFT__ORM_filter.json")
    parser.add_argument('--test_pair_json_path', type=str, default=None)
    parser.add_argument('--metrics_path', type=str, default=None)
    parser.add_argument('--linear_tpye', type=str, default="single")
    parser.add_argument('--attn_impl', type=str, default="eager")
    config, model_config, remain_args = parser.parse_args_into_dataclasses()

    config.save_only_model = True
    config.load_best_model_at_end = False
    
    config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
    config.pair_json_path = remain_args.pair_json_path

    from accelerate import Accelerator
    accelerator = Accelerator()
    rank = accelerator.process_index
    print(rank)
    
    print(rank)
    
    # 1. 加载 Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path, 
        trust_remote_code=True, 
        use_fast=True,
        padding_side="right",
        split_special_tokens=False,
    )

    # 2. 加载 Base Model
    rank0_print(rank, "Loading Base Model...")
    base_model = AutoModelForCausalLM.from_pretrained(
        model_config.model_name_or_path, 
        trust_remote_code=True,
        torch_dtype=torch.bfloat16 if remain_args.attn_impl == "flash_attention_2" else torch.float32,
        attn_implementation=remain_args.attn_impl,
        use_cache=False,
    )
    
    # 3. 处理 Special Tokens (这一步必须在 LoRA 之前做，或者在 LoRA 中设置 modules_to_save)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    tokenizer.add_special_tokens(
        {
            "additional_special_tokens": ['<code>', '<end_of_step>', '<end_of_code>', '<output>', '<end_of_output>', '<answer>', '<end_of_answer>', '<|user|>', '<|assistant|>', '<refine>', '<end_of_refine>', '\n<|assistant|>']
        },
        replace_additional_special_tokens=False,
    )
    base_model.resize_token_embeddings(len(tokenizer))

    # ==================== 新增：LoRA配置 ====================
    rank0_print(rank, "Applying LoRA...")
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM, 
        inference_mode=False, 
        r=64,             # Rank, 建议 8, 16, 64。数据量大可以用 64
        lora_alpha=128,   # Alpha 建议是 r 的 2倍
        lora_dropout=0.05,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], # Qwen 全量 Linear
        # 关键点：因为你 resize 了 embedding，而且把 embedding 设为了可训，
        # 为了让新加的 token 对应的 embedding 能被训练，需要把 embed_tokens 加入 modules_to_save
        modules_to_save=["embed_tokens"] 
    )
    
    # 将 Base Model 转换为 Peft Model
    base_model = get_peft_model(base_model, peft_config)
    
    # 打印可训练参数量 (这一步非常重要，用于确认 LoRA 是否生效)
    if rank == 0:
        base_model.print_trainable_parameters()
    # ========================================================

    # 4. 把加了 LoRA 的模型包裹进 RewardModel
    # RewardModelWithValueHead 里的 self.v_head 会是新初始化的全连接层，默认就是可训练的
    # 所以最终结构是：Base(冻结) + LoRA(可训) + Embed(可训) + ValueHead(可训)
    model = RewardModelWithValueHead(pretrained_model=base_model, linear_tpye=remain_args.linear_tpye)

    # model.config 等配置处理保持不变
    if model.config.pad_token_id is None:
        model.config.pad_token_id = model.config.eos_token_id
        
    model.to(torch.bfloat16) # 确保整个模型精度正确

    ################
    # Dataset
    ################
    if remain_args.pair_json_path is not None:
        raw_datasets = load_dataset('json', data_files=remain_args.pair_json_path, writer_batch_size=100)
        raw_datasets['train'] = raw_datasets['train'].shuffle(seed=42)
        if remain_args.test_pair_json_path is not None:
            raw_datasets['test'] = load_dataset('json', data_files=remain_args.test_pair_json_path)['train']
        else:
            raw_datasets['train'], raw_datasets['test'] = raw_datasets['train'].train_test_split(test_size=0.05, seed=42).values()
            rank0_print(rank, 'trainset size:', len(raw_datasets['train']), 'testset size:', len(raw_datasets['test']))

    remove_columns = ['prompt', 'neg', 'pos', 'neg_count', 'pos_count']
    remove_columns = [col for col in remove_columns if col in raw_datasets['train'].column_names]
    partial_func = partial(preprocess_value_dataset, tokenizer=tokenizer, max_length=config.max_length)
    raw_datasets = raw_datasets.map(
        partial_func,
        batched=True,
        num_proc=16,
        remove_columns=remove_columns
    ) 
    rank0_print(rank, 'after preprocess', raw_datasets['train'].column_names)
    train_dataset = raw_datasets["train"]
    eval_dataset = raw_datasets["test"]
    rank0_print(rank, 'After filtering, trainset size:', len(train_dataset), 'testset size:', len(eval_dataset))

    ################
    # Training
    ################
    rank0_print(rank, config)
    rank0_print(rank, train_dataset.column_names)
    config.save_safetensors = False
    trainer = RMTrainer(
        model=model,
        tokenizer=tokenizer,
        args=config,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=PairwiseDataCollatorWithPadding(
            tokenizer=tokenizer,
            max_length=config.max_length,
            padding='max_length'
            ),
        compute_metrics=ComputeAccuracy()
    )

    trainer.train()
    trainer.save_model(config.output_dir)
    trainer.save_state()
    metrics = trainer.evaluate()
    trainer.log_metrics("eval", metrics)
    rank0_print(rank, metrics)
    import json
    import os
    if remain_args.metrics_path:
        os.makedirs(os.path.dirname(remain_args.metrics_path), exist_ok=True)
        with open(remain_args.metrics_path, 'w') as f:
            metrics['test_pair_json_path'] = remain_args.test_pair_json_path
            metrics['model_name_or_path'] = model_config.model_name_or_path
            json.dump(metrics, f)
