import warnings
import torch
import os
# 引用新的类
from rm_logits import RewardModelWithProb, preprocess_value_dataset, RMTrainer, PairwiseDataCollatorWithPadding, ComputeAccuracy, IGNORE_INDEX
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
from trl import ModelConfig, RewardConfig
import random
import numpy as np
from functools import partial

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
tqdm.pandas()
def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    set_seed(seed)
    torch.use_deterministic_algorithms(True)
seed_torch(0)

def rank0_print(rank, *args, **kwargs):
    if rank == 0:
        print(*args, **kwargs)
        
if __name__ == "__main__":
    parser = HfArgumentParser((RewardConfig, ModelConfig))
    parser.add_argument('--pair_json_path', type=str, default="/home/v-xinyuguan/teamdrive/teamdrive/xy/xy/mcts/0924/qwen2SFT__ORM_filter.json")
    parser.add_argument('--test_pair_json_path', type=str, default=None)
    parser.add_argument('--metrics_path', type=str, default=None)
    parser.add_argument('--linear_tpye', type=str, default="single") # 保留参数兼容
    parser.add_argument('--attn_impl', type=str, default="eager")
    config, model_config, remain_args = parser.parse_args_into_dataclasses()

    config.save_only_model = True
    config.load_best_model_at_end = False
    config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
    config.pair_json_path = remain_args.pair_json_path

    from accelerate import Accelerator
    accelerator = Accelerator()
    rank = accelerator.process_index
    print(rank)
    
    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path, 
        trust_remote_code=True, 
        use_fast=True,
        padding_side="right",
        split_special_tokens=False,
    )
    
    model = AutoModelForCausalLM.from_pretrained(
        model_config.model_name_or_path, 
        trust_remote_code=True,
        torch_dtype=torch.bfloat16 if remain_args.attn_impl == "flash_attention_2" else torch.float32,
        attn_implementation=remain_args.attn_impl,
        use_cache=False,
    )

    # === 使用新的 Reward Model Wrapper ===
    # 不需要传 token id，因为逻辑在 rm.py 里用 CrossEntropy 处理了
    model = RewardModelWithProb(pretrained_model=model)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if model.config.pad_token_id is None:
        model.config.pad_token_id = model.config.eos_token_id
        
    tokenizer.add_special_tokens(
        {
            "additional_special_tokens": ['<code>', '<end_of_step>', '<end_of_code>', '<output>', '<end_of_output>', '<answer>', '<end_of_answer>', '<|user|>', '<|assistant|>', '<refine>', '<end_of_refine>', '\n<|assistant|>']
        },
        replace_additional_special_tokens=False,
    )
    model.pretrained_model.resize_token_embeddings(len(tokenizer))
    model.to(torch.bfloat16)

    ################
    # Dataset
    ################
    if remain_args.pair_json_path is not None:
        raw_datasets = load_dataset('json', data_files=remain_args.pair_json_path, writer_batch_size=100)
        raw_datasets['train'] = raw_datasets['train'].shuffle(seed=42)
        if remain_args.test_pair_json_path is not None:
            raw_datasets['test'] = load_dataset('json', data_files=remain_args.test_pair_json_path)['train']
        else:
            raw_datasets['train'], raw_datasets['test'] = raw_datasets['train'].train_test_split(test_size=0.05, seed=42).values()
    
    remove_columns = ['prompt', 'neg', 'pos', 'neg_count', 'pos_count']
    remove_columns = [col for col in remove_columns if col in raw_datasets['train'].column_names]
    
    partial_func = partial(preprocess_value_dataset, tokenizer=tokenizer, max_length=config.max_length)
    
    raw_datasets = raw_datasets.map(
        partial_func,
        batched=True,
        num_proc=16,
        remove_columns=remove_columns
    ) 
    
    train_dataset = raw_datasets["train"]
    eval_dataset = raw_datasets["test"]
    rank0_print(rank, 'After filtering, trainset size:', len(train_dataset), 'testset size:', len(eval_dataset))

    ################
    # Training
    ################
    trainer = RMTrainer(
        model=model,
        tokenizer=tokenizer,
        args=config,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=PairwiseDataCollatorWithPadding(
            tokenizer=tokenizer,
            max_length=config.max_length,
            padding='max_length'
            ),
        compute_metrics=ComputeAccuracy()
    )

    trainer.train()
    trainer.save_model(config.output_dir)