import os
from pathlib import Path

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'  # use 4 GPUs
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['HF_HOME'] = '/data/huggingface'

from tqdm import tqdm
from argparse import Namespace
from vllm import LLM, SamplingParams
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.parser import *


class InferenceVLLM:
    def __init__(self, args):
        self.args = args

        # 加载 tokenizer（本地离线）
        self.tokenizer = AutoTokenizer.from_pretrained(
            args.model_name_or_path,
            cache_dir=args.model_cache_dir,
            padding_side="left",
        )

        # 初始化 LLM（本地离线）
        model_dir = snapshot_download(args.model_name_or_path, cache_dir=args.model_cache_dir)
        self.llm = LLM(
            model=model_dir,
            gpu_memory_utilization=args.gpu_memory_utilization,
            tensor_parallel_size=2,
        )

    def process_prompt(self, question, step=None):
        messages = [
            {"role": "system", "content": self.args.system_prompt},
            {"role": "user", "content": self.args.question_format.format(question=question)}
        ]
        prompt_text = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True, enable_thinking=True
        )
        if step is not None:
            prompt_text += step
            prompt_text += ' </think> The final answer is: \\boxed{'
        return prompt_text

    def inference_batch(self, prompts, **kwargs):
        sampling_params = SamplingParams(**kwargs)
        outputs = self.llm.generate(prompts, sampling_params)
        return outputs


class Dataset:
    def __init__(self, filename):
        self.filename = filename
        self.data = self.load_data()

    def load_data(self):
        # load jsonl
        data = []
        with open(self.filename, 'r') as file:
            for line in file:
                data.append(json.loads(line))

        # add index
        if 'idx' not in data[0]:
            data = [{'idx': i, **example} for i, example in enumerate(data)]
        data = sorted(data, key=lambda x: x['idx'])

        # change key name
        if 'question' not in data[0]:
            for key in ["problem", "Question", "input"]:
                if key in data[0]:
                    for example in data:
                        example['question'] = example.pop(key)
                    break
        return data

    def get_elements(self, key):
        return [str(example[key]).strip() for example in self.data]

    def __len__(self):
        return len(self.data)

    def save_all_outputs(self, outputs, tokenizer):
        for i, output in enumerate(outputs):
            self.data[i]['generated_response'] = [o.text for o in output.outputs]
            self.data[i]['generated_finished_reason'] = [o.finish_reason for o in output.outputs]
            self.data[i]['generated_answer'] = [extract_answer(text, '--') for text in self.data[i]['generated_response']]
            self.data[i]['generated_tokens'] = [len(o.token_ids) for o in output.outputs]
            self.data[i]['step_response'] = [text.split('\n\n') for text in self.data[i]['generated_response']]
            self.data[i]['step_tokens'] = [[len(tokenizer.tokenize(step)) for step in step_list] for step_list in self.data[i]['step_response']]
            self.data[i]['step_answer'] = [[""] * len(sample) for sample in self.data[i]['step_response']]  # 先初始化，之后再填充
            self.data[i]['step_answer_response'] = [[""] * len(sample) for sample in self.data[i]['step_response']]  # 初始化为每个步骤的空字符串列表
            # self.data[i]['step_answer_prob'] = [0 for _ in range(len(self.data[i]['step_response']))]  # 先初始化，之后再填充
            self.data[i]['step_answer_tokens'] = [[0] * len(sample) for sample in self.data[i]['step_response']]  # 初始化为每个步骤的0
            # self.data[i]['step_is_correct'] = [None for _ in range(len(self.data[i]['step_response']))]  # 先初始化，之后再填充

    def construct_step_wise_prompt(self):
        prompts = []
        for example in self.data:
            if 'step_response' not in example:
                continue
            for n, sample in enumerate(example['step_response']):
                for i, _ in enumerate(sample):
                    temp = {}
                    temp['idx'] = example['idx']
                    temp['n_idx'] = n
                    temp['step_idx'] = i
                    temp['question'] = example['question']
                    step = '\n\n'.join(sample[:i+1])
                    temp['step'] = step.strip()
                    prompts.append(temp)
        return prompts

    def save_step_wise_outputs(self, outputs, constructed_step_wise_prompts):
        for i, output in enumerate(outputs):
            question_idx = constructed_step_wise_prompts[i]['idx']
            sample_idx = constructed_step_wise_prompts[i]['n_idx']
            step_idx = constructed_step_wise_prompts[i]['step_idx']
            step_answer = extract_answer(output.outputs[0].text, '--')
            # step_is_correct = check_is_correct(step_answer, self.data[question_idx]['answer'])
            self.data[question_idx]['step_answer_response'][sample_idx][step_idx] = output.outputs[0].text
            self.data[question_idx]['step_answer'][sample_idx][step_idx] = step_answer
            # self.data[question_idx]['step_answer_prob'][step_idx] = output.outputs[0].cumu
            self.data[question_idx]['step_answer_tokens'][sample_idx][step_idx] = len(output.outputs[0].token_ids)

            # self.data[question_idx]['step_is_correct'][step_idx] = step_is_correct

    def save_data(self, output_file):
        with open(output_file, 'w') as file:
            for example in self.data:
                file.write(json.dumps(example) + '\n')

    def load_checkpoint(self, checkpoint_file):
        with open(checkpoint_file, 'r') as file:
            self.data = [json.loads(line) for line in file]


if __name__ == '__main__':
    args = Namespace(
        model_tag='QwQ-32B',
        model_name_or_path='Qwen/QwQ-32B',  # Qwen/QwQ-32B; deepseek-ai/DeepSeek-R1-Distill-Qwen-32B; deepseek-ai/DeepSeek-R1-Distill-Llama-8B
        model_cache_dir='/data/huggingface',
        system_prompt="You are a helpful and harmless assistant. You are Qwen developed by Alibaba. "  # You are Qwen developed by Alibaba. ;You are Deepseek developed by Deepseek. ;
                      "You should think step-by-step and put your final answer within \\boxed{}. ",
        question_format="{question}",
        data_set='gpqa',
        max_tokens=32768,
        temperature=0.6,
        n=10,
        top_p=0.9,  # only for QwQ32B
        top_k=20,  # only for QwQ32B
        gpu_memory_utilization=0.9,  # for LLM
        step_batch_size=10000,  # for step-wise inference
    )

    inference = InferenceVLLM(args)

    for dataset in ['aime', 'gpqa', 'math', 'minerva', 'olympiadbench']:
        print(dataset)
        args.data_set = dataset

        data_dir = Path('/home/project/Reasoning/data/datasets') / args.data_set / 'test.jsonl'
        result_dir = Path('/data/project/Reasoning/results') / args.model_tag / args.data_set
        # 自动创建目录，包括父目录
        result_dir.mkdir(parents=True, exist_ok=True)

        dataset = Dataset(data_dir)

        ###################################################
        # # generate all responses
        main_results_path = result_dir / f'{args.data_set}_main_results.jsonl'
        if main_results_path.exists():
            print(f"Results already exist at {main_results_path}. Skipping generation.")
            dataset.load_checkpoint(main_results_path)
        else:
            questions = dataset.get_elements('question')
            prompts = [inference.process_prompt(question) for question in questions]
            generated_outputs = inference.inference_batch(
                prompts,
                n=args.n,
                max_tokens=args.max_tokens,
                temperature=args.temperature,
                top_p=args.top_p,
            )
            dataset.save_all_outputs(generated_outputs, inference.tokenizer)
            dataset.save_data(main_results_path)

        #########################################################
        # # generate step-wise responses
        step_results_path = result_dir / f'{args.data_set}_step_results.jsonl'
        if step_results_path.exists():
            print(f"Step-wise results already exist at {step_results_path}. Skipping generation.")
            dataset.load_checkpoint(step_results_path)
        else:
            constructed_step_wise_prompts = dataset.construct_step_wise_prompt()
            step_wise_prompts = [inference.process_prompt(prompt['question'], prompt['step']) for prompt in constructed_step_wise_prompts]
            print(f"Generating step-wise responses for {len(step_wise_prompts)} prompts...")
            for i in tqdm(range(0, len(step_wise_prompts), args.step_batch_size)):
                step_wise_batch = step_wise_prompts[i:i + args.step_batch_size]
                generated_outputs = inference.inference_batch(
                    step_wise_batch,
                    max_tokens=10,  # Step-wise responses are usually short
                    stop=['}',],
                )
                dataset.save_step_wise_outputs(generated_outputs, constructed_step_wise_prompts[i:i + args.step_batch_size])
            dataset.save_data(step_results_path)

