import os

import torch

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'  # use 4 GPUs
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['HF_HOME'] = '/data/huggingface'

import json
from tqdm import tqdm
from argparse import Namespace
from vllm import LLM, SamplingParams
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.parser import *
from utils.math_normalization import *
from utils.grader import *
from src.early_stop_cot import EarlyStopCoT


class InferenceVLLM:
    def __init__(self, args):
        self.args = args

        # 加载 tokenizer（本地离线）
        self.tokenizer = AutoTokenizer.from_pretrained(
            args.model_name_or_path,
            cache_dir=args.model_cache_dir,
            padding_side="left",
        )

        # 初始化 LLM（本地离线）
        model_dir = snapshot_download(args.model_name_or_path, cache_dir=args.model_cache_dir)
        self.llm = LLM(
            model=model_dir,
            gpu_memory_utilization=args.gpu_memory_utilization,
            tensor_parallel_size=2,
        )

    def process_prompt(self, question, step=None):

        messages = [
            {"role": "system", "content": self.args.system_prompt},
            {"role": "user", "content": self.args.question_format.format(question=question)}
        ]
        prompt_text = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True, enable_thinking=True
        )
        if step is not None:
            prompt_text += step
            prompt_text += ' </think> The final answer is: \\boxed{'
        return prompt_text

    def inference_batch(self, prompts, **kwargs):
        sampling_params = SamplingParams(**kwargs)
        outputs = self.llm.generate(prompts, sampling_params)
        return outputs


def process_data(data, sample_idx):
    prompts = []
    for example in data:
        if 'step_response' not in example:
            continue
        sample = example['step_response'][sample_idx]
        temp = {}
        temp['idx'] = example['idx']
        temp['n_idx'] = sample_idx
        temp['step_idx'] = len(sample) - 1
        temp['question'] = example['question']
        step = '\n\n'.join(sample[:len(sample)])
        temp['step'] = step.strip()
        prompts.append(temp)
    return prompts


def save_results(outputs, prompts, output_file):
    data = []
    for i, output in enumerate(outputs):
        temp = {}
        temp['idx'] = prompts[i]['idx']
        temp['n_idx'] = prompts[i]['n_idx']
        temp['step_idx'] = prompts[i]['step_idx']
        temp['generated_response'] = [o.text for o in output.outputs]
        temp['generated_finished_reason'] = [o.finish_reason for o in output.outputs]
        temp['generated_tokens'] = [len(o.token_ids) for o in output.outputs]
        data.append(temp)
    with open(output_file, 'w') as file:
        for example in data:
            file.write(json.dumps(example) + '\n')
        print(f"Saved results to {output_file}")


if __name__ == '__main__':
    from pathlib import Path

    args = Namespace(
        model_tag='QwQ-32B',
        model_name_or_path='Qwen/QwQ-32B',  # Qwen/QwQ-32B; deepseek-ai/DeepSeek-R1-Distill-Qwen-32B; deepseek-ai/DeepSeek-R1-Distill-Llama-8B
        model_cache_dir='/data/huggingface',
        system_prompt="You are a helpful and harmless assistant. You are Qwen developed by Alibaba. "  # You are Qwen developed by Alibaba. ;You are Deepseek developed by Deepseek. ;
                      "You should think step-by-step and put your final answer within \\boxed{}. ",
        question_format="{question}",
        data_set='gpqa',
        max_tokens=32768,
        temperature=0.6,  # 0.7
        n=10,
        top_p=0.95,  # 1.0
        top_k=20,  # 不存在
        gpu_memory_utilization=0.9,  # for LLM
        step_batch_size=10000,  # for step-wise inference
    )

    sample_idx = 0  # 0-9
    final_sample_num = 10
    rootpath = '/data/project/Reasoning/results/'
    model_tag_lis = ['DeepSeek-R1-Distill-Llama-8B', 'QwQ-32B', 'Qwen3-8B']
    model_name_lis = ['deepseek-ai/DeepSeek-R1-Distill-Llama-8B', 'Qwen/QwQ-32B', 'Qwen/Qwen3-8B']
    dataset_lis = ['aime', 'gpqa', 'math', 'minerva', 'olympiadbench']

    for model_tag, model_name in zip(model_tag_lis, model_name_lis):
        args.model_tag = model_tag
        args.model_name_or_path = model_name
        inference = InferenceVLLM(args)

        for dataset in dataset_lis:
            args.data_set = dataset

            path = Path(rootpath) / model_tag / dataset / f'{dataset}_step_results_processed.jsonl'
            early_stop_cot = EarlyStopCoT(path)

            prompts = process_data(early_stop_cot.data, sample_idx)
            processed_prompts = [inference.process_prompt(prompt['question'], prompt['step']) for prompt in prompts]
            generated_outputs = inference.inference_batch(
                processed_prompts,
                n=final_sample_num,
                max_tokens=20,
                temperature=args.temperature,
                top_p=args.top_p,
            )
            output_dir = Path('/data/project/Reasoning/results') / args.model_tag / args.data_set / 'final_answer_converge.jsonl'
            save_results(generated_outputs, prompts, output_dir)
