import sys, os

# os.environ["VLLM_ATTENTION_BACKEND"] = "DIFFERENTIAL_FLASH_ATTN"

from transformers import AutoTokenizer
import datasets
datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory='.': True
import torch


from vllm import LLM, SamplingParams

from tqdm import tqdm
import numpy as np
import json
import random
from dataclasses import dataclass, field
from transformers import HfArgumentParser,set_seed
from typing import Optional
from vllm.lora.request import LoRARequest
import os
os.environ["VLLM_USE_V1"] = "0"

from accelerate import Accelerator
import torch.distributed as dist
from load_judge_dataset import load_benchmark_data
from load_prompt import load_all_prompt
# 214 token qwen 2.5

hybrid_thinking_models = ['Qwen3-14B','Qwen3-8B','Qwen3-4B','Qwen3-1.7B','Qwen3-0.6B','Qwen3-32B', 
                          'Phi-4-reasoning-plus']


@dataclass
class ScriptArguments:
    """The arguments for the DPO training script."""
    output_dir: Optional[str] = field(default="result", metadata={"help": "the location of the output file"},)
    model_path: Optional[str] = field(default="", metadata={"help": "the location of the output file"},)
    seed: Optional[int] = field(default=42, metadata={"help": "the local index of the agent"},)
    local_index: Optional[int] = field(default=0, metadata={"help": "the local index of the agent"},)
    dataset_name: Optional[str] = field(default="ultrafeedback", metadata={"help": "the dataset name to load"},)
    use_tensor_parallel: Optional[bool] = field(default=False, metadata={"help": "whether use tensor parallel for model"},)
    mode: Optional[str] = field(default="instruct", metadata={"help": "instruct or reasoning"},)
    num_gen: Optional[int] = field(default=1, metadata={"help": "number of generations"},)

def check_correct(flip, gens):
    corrects = []
    for gen in gens:
        if "[[A]]" in gen or "[A]" in gen or '\\boxed{Assistant 1}' in gen:
            if flip==1:
                corrects.append(0)
            else:
                corrects.append(1)
        elif "[[B]]" in gen or "[B]" in gen or '\\boxed{Assistant 2}' in gen:
            if flip==1:
                corrects.append(1)
            else:
                corrects.append(0)
        else:
            corrects.append(0)
    return corrects
    
@torch.no_grad()
def generate_response_vllm(dataset, model, model_name, tokenizer, sys_prompt, user_template, mode, n=1):
    with torch.inference_mode():
        enable_thinking = False
        if mode == 'reasoning':
            enable_thinking = True
            temp = 0.6
            max_tokens = 16384
        else:
            enable_thinking = False
            temp = 0.6
            max_tokens = 4096

        sampling_params = SamplingParams(temperature=temp, top_p=1.0, n=n, max_tokens=max_tokens, stop_token_ids=[tokenizer.eos_token_id])

        all_prompt = dataset['prompt'].tolist()
        all_chosen = dataset['chosen'].tolist()
        all_rejected = dataset['rejected'].tolist()
        chat_prompts = []
        all_flips = []
    
        for i in range(len(all_chosen)):
            chosen = all_chosen[i]
            rejected = all_rejected[i]
            prompt = all_prompt[i]

            # 为每个生成生成独立的flip
            flips_per_sample = []

            # Generate n different prompts for this sample with different flips
            for gen_idx in range(n):
                sample = np.random.binomial(n=1, p=0.5, size=1)[0]
                if sample==1:
                    answer_a = chosen
                    answer_b = rejected
                    flips_per_sample.append(0)
                elif sample==0:
                    answer_a = rejected
                    answer_b = chosen
                    flips_per_sample.append(1)
                else:
                    raise ValueError("sample should be 0 or 1")

                user_content = user_template.format(question=prompt, answer_a=answer_a, answer_b=answer_b)

                prompt_message = [{"role": "system", "content": sys_prompt},
                                  {"role": "user", "content": user_content}]

                if model_name in hybrid_thinking_models:
                    chat_prompts.append(tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking))
                else:
                    chat_prompts.append(tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True))
            
            all_flips.append(flips_per_sample)

        # Generate all responses at once
        responses = model.generate(chat_prompts, sampling_params)
        
        judges = []
        checks = []
        tokens = []
        flips = []
        
        # Reorganize results: each original sample corresponds to n generations
        for i in range(len(all_chosen)):
            judges_per_sample = []
            checks_per_sample = []
            tokens_per_sample = []
            
            for gen_idx in range(n):
                response_idx = i * n + gen_idx
                output = responses[response_idx]
                
                # Get the generated text and token count
                gen_text = output.outputs[0].text.strip()
                token_count = len(output.outputs[0].token_ids)
                flip = all_flips[i][gen_idx]
                
                # Check if the generation is correct
                check = check_correct(flip, [gen_text])[0]
                
                judges_per_sample.append(gen_text)
                checks_per_sample.append(check)
                tokens_per_sample.append(token_count)
            
            judges.append(judges_per_sample)
            checks.append(checks_per_sample)
            tokens.append(tokens_per_sample)
            flips.append(all_flips[i])
        
        dataset['judge'] = judges
        dataset['flip'] = flips
        dataset['check'] = checks
        dataset['num_tokens'] = tokens
        return dataset
            



    

if __name__ == "__main__":
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]

    output_dir = script_args.output_dir
    local_index = script_args.local_index
    set_seed(script_args.seed)
    model_name_or_path = script_args.model_path
    model_name = model_name_or_path.split("/")[-1]
    mode = script_args.mode
    n = script_args.num_gen

    sys_prompt_override = None  # only for Nemotron
    original_model_name = model_name_or_path.split("/")[-1]
    
    if model_name_or_path=="microsoft/Phi-4-mini":
        if mode == 'reasoning':
            model_name_or_path = "microsoft/Phi-4-mini-flash-reasoning"
        elif mode == 'instruct':
            model_name_or_path = "microsoft/Phi-4-mini-instruct"
    
    elif model_name_or_path=="microsoft/phi-4":
        if mode == 'reasoning':
            model_name_or_path = "microsoft/phi-4"
        elif mode == 'instruct':
            model_name_or_path = "microsoft/Phi-4-mini-instruct"

    elif model_name_or_path=="Qwen/Qwen3-4B-2507":
        if mode == 'reasoning':
            model_name_or_path = "Qwen/Qwen3-4B-Thinking-2507"
        elif mode == 'instruct':
            model_name_or_path = "Qwen/Qwen3-4B-Instruct-2507"
    
    elif model_name_or_path == "nvidia/NVIDIA-Nemotron-Nano-12B-v2" or model_name_or_path == "nvidia/NVIDIA-Nemotron-Nano-9B-v2":
        sys_prompt_override = "/think" if mode=="reasoning" else "/no_think"



    elif model_name_or_path=="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B":
        if mode == 'reasoning':
            model_name_or_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
        elif mode == 'instruct':
            model_name_or_path = "Qwen/Qwen2.5-14B-Instruct"

    elif model_name_or_path=="deepseek-ai/DeepSeek-R1-Distill-Llama-8B":
            if mode == 'reasoning':
                model_name_or_path = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
            elif mode == 'instruct':
                model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct"


    actual_model_name = original_model_name.split("/")[-1]
    
    llm_kwargs = dict(
        model=model_name_or_path,
        tokenizer=model_name_or_path,
        gpu_memory_utilization=0.9,
        swap_space=16,
        tensor_parallel_size=torch.cuda.device_count(),
        dtype="bfloat16",
        trust_remote_code=True,
        max_model_len=32768,
    )
    
    if model_name_or_path == "microsoft/Phi-4-reasoning-plus":
        if mode == "reasoning":
            llm_kwargs["enable_reasoning"] = True
            llm_kwargs["reasoning_parser"] = "deepseek_r1"
        else:
            llm_kwargs["enable_reasoning"] = False

    print(model_name_or_path)

    ref_model = LLM(**llm_kwargs)
        
    dataset_name = script_args.dataset_name
    use_tensor_parallel = script_args.use_tensor_parallel

    sys_prompt, user_template = load_all_prompt(model_name_or_path, mode)

    # 对 Nemotron 的特殊处理
    if sys_prompt_override is not None:
        # 把原来的 sys_prompt 内容移到 user_template 开头
        user_template = sys_prompt + "\n\n" + user_template
        sys_prompt = sys_prompt_override  # 替换为 /think 或 /no_think

    # for nemotron
    if sys_prompt_override is not None:
        sys_prompt = sys_prompt_override

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

    accelerator = Accelerator()


    world_size = int(os.getenv("WORLD_SIZE", "4"))

    ### for train dataset
    ds = load_benchmark_data(dataset_name) #.select(range(5000))  # for debug use a smaller dataset
    ds = ds.shuffle(seed=42)

    ds = ds.select(range(min(20000, len(ds))))

    data_size = len(ds["chosen"])
    share = int(data_size / world_size) + 1

    os.makedirs(f'{output_dir}/{model_name}_{mode}/{dataset_name}', exist_ok=True)
    output_path = f'{output_dir}/{model_name}_{mode}/{dataset_name}/generations'

    if not use_tensor_parallel:
        world_size = int(os.getenv("WORLD_SIZE", "4"))  # 默认4个进程
        share = int(data_size / world_size) + 1
        
        start_idx = local_index * share
        end_idx = min((local_index + 1) * share, data_size)
        ds = ds.select(np.arange(start_idx, end_idx))
        
        print(f'Process {local_index}/{world_size}: processing samples [{start_idx}:{end_idx}], size: {len(ds)}/{data_size}')
        
        df = ds.to_pandas()
        gen_judge_df = generate_response_vllm(df, ref_model, actual_model_name, tokenizer, sys_prompt, user_template, mode, n=n)
        gen_judge_df = datasets.Dataset.from_pandas(gen_judge_df).to_list()
        
        with open(output_path + str(local_index) + '.json', "w", encoding="utf-8") as f:
            json.dump(gen_judge_df, f, indent=4, ensure_ascii=False)
        
        print(f'Process {local_index}: Saved to {output_path}{local_index}.json')
    
    else:
        print(f'Using tensor parallel mode, processing all {data_size} samples')
    
        df = ds.to_pandas()
        gen_judge_df = generate_response_vllm(df, ref_model, actual_model_name, tokenizer, sys_prompt, user_template, mode, n=n)
        gen_judge_df = datasets.Dataset.from_pandas(gen_judge_df).to_list()
        
        with open(output_path + '.json', "w", encoding="utf-8") as f:
            json.dump(gen_judge_df, f, indent=4, ensure_ascii=False)
        
        print(f'Saved to {output_path}.json')