import os
from dataclasses import dataclass, field
from typing import Dict, Optional

import transformers
from datasets import load_dataset, load_from_disk, concatenate_datasets
from torch import distributed as dist
from trl import PRMConfig, PRMTrainer
import torch

import random
import numpy as np

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="Qwen/Qwen2-0.5B")


@dataclass
class DataArguments:
    train_data_path: str = field(default="trl-lib/math_shepherd")
    eval_data_path: str = field(default=None)
    lazy_preprocess: bool = False
    only_vul: bool = False
    resample: bool = False


@dataclass
class TrainingArguments(PRMConfig):
    cache_dir: Optional[str] = field(default=None)
    max_length: int = field(default=128000)
    max_completion_length: int = field(default=8000)
    fix_llm: bool = field(default=False)


def safe_save_model_for_hf_trainer(
        trainer: transformers.Trainer, 
        output_dir: str,
    ):
    """Collects the state dict and dump to disk."""
    state_dict = trainer.model.state_dict()

    if trainer.args.should_save and trainer.args.local_rank == 0:
        trainer._save(output_dir, state_dict=state_dict)
        


def make_supervised_data_module(data_args) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    assert data_args.train_data_path is not None
    # if 'bigvul' in data_args.train_data_path or 'precise' in data_args.train_data_path or 'sven' in data_args.train_data_path or 'primevul' in data_args.train_data_path:
    train_dataset = load_from_disk(data_args.train_data_path)['train']
    eval_dataset = load_from_disk(data_args.train_data_path)['test']
    resample = False
    if resample:
        # Step 1: Filter data with label 0
        minority_data = train_dataset.filter(lambda x: 0 in x['labels'])
        # Step 2: Replicate the minority data
        replicate_ratio = 2
        replicated_datasets = [minority_data] * replicate_ratio
        # Step 3: Concatenate original dataset with replicated minority data
        balanced_dataset = concatenate_datasets([train_dataset] + replicated_datasets)
        balanced_dataset = balanced_dataset.shuffle(seed=42)
    else:
        balanced_dataset = train_dataset
        
    only_vul = False
    if only_vul:
        balanced_dataset = balanced_dataset.filter(lambda x: 0 in x['labels'])
    
    # else:
    #     train_dataset = load_dataset(data_args.train_data_path, split="train")
    #     eval_dataset = load_dataset(data_args.train_data_path, split="test")

    return dict(
        train_dataset=balanced_dataset, 
        eval_dataset=eval_dataset, 
    )

def get_rng_state():
    state = {
        "py_random": random.getstate(),
        "np_random": np.random.get_state(),
        "torch_cpu": torch.get_rng_state(),
    }
    if torch.cuda.is_available():
        state["torch_cuda_all"] = torch.cuda.get_rng_state_all()
    return state

def set_rng_state(state):
    random.setstate(state["py_random"])
    np.random.set_state(state["np_random"])
    torch.set_rng_state(state["torch_cpu"])
    if torch.cuda.is_available() and "torch_cuda_all" in state:
        torch.cuda.set_rng_state_all(state["torch_cuda_all"])
        
def load(path, model, optimizer,
             scheduler=None, scaler=None, map_location="cpu"):
        ckpt = torch.load(path, map_location=map_location)
        (model.module if hasattr(model, "module") else model).load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        if scheduler is not None and "scheduler" in ckpt:
            scheduler.load_state_dict(ckpt["scheduler"])
        if scaler is not None and "scaler" in ckpt:
            scaler.load_state_dict(ckpt["scaler"])
        set_rng_state(ckpt["rng"])
        epoch        = ckpt["epoch"]
        global_step  = ckpt["global_step"]
        step_in_epoch= ckpt.get("step_in_epoch", 0)
        extra        = ckpt.get("extra", {})
        return epoch, global_step, step_in_epoch, extra
    
def train():
    os.environ["WANDB_PROJECT"]="PRM_Math_Shepherd"
    
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    
    (
        model_args,
        data_args,
        training_args,
        _
    ) = parser.parse_args_into_dataclasses(return_remaining_strings=True)

    # Set random seeds
    torch.manual_seed(training_args.seed)
    np.random.seed(training_args.seed)
    random.seed(training_args.seed)

    # Load model and tokenizer
    model = transformers.AutoModelForTokenClassification.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        trust_remote_code=True,
        use_cache = False,
        torch_dtype=torch.bfloat16,  # 🔥 关键！直接用bf16加载
        low_cpu_mem_usage=True,       # 🔥 降低加载时内存峰值
        device_map=None,              # DeepSpeed会自己处理设备分配
    )
    

    # freeze llm except last layer if needed
    if training_args.fix_llm:
        model.model.requires_grad_(False)
                
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        use_fast=False,
        trust_remote_code=True,
    )

    # Load data
    data_module = make_supervised_data_module(data_args=data_args)

    # Start trainner
    print('training_args', training_args, 'model_args', model_args, 'data_args', data_args)
    trainer = PRMTrainer(
        model=model, 
        processing_class=tokenizer, 
        args=training_args, 
        **data_module
    )

    trainer.train()
    trainer.save_state()
    
    # model.push_to_hub(training_args.output_dir)

    safe_save_model_for_hf_trainer(
        trainer=trainer, 
        output_dir=training_args.output_dir
    )

    dist.destroy_process_group()


if __name__ == "__main__":
    train()