import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
from tqdm import tqdm
import os
import random

def generate_output(tokenizer, model, input_texts):
    device = next(model.parameters()).device

    input_messages = [
        [
            {"role": "user", "content": prompt}
        ]
        for prompt in input_texts
    ]

    text = tokenizer.apply_chat_template(
        input_messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding=True
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=20000, pad_token_id=tokenizer.eos_token_id)

    all_outputs = []
    all_lengths = []
    for i, out in enumerate(outputs):
        inp = inputs['input_ids'][i]
        output = out[len(inp):]
        result = tokenizer.decode(output, skip_special_tokens=True)
        length = len(tokenizer.encode(result, add_special_tokens=False))
        all_outputs.append(result)
        all_lengths.append(length)
    return all_outputs, all_lengths


def generate_direct_output(tokenizer, model, input_texts, temperature=None, do_sample=False):
    device = next(model.parameters()).device
    inputs = tokenizer(
        input_texts,
        return_tensors="pt",
        padding=True
    ).to(device)

    if do_sample:
        outputs = model.generate(
            **inputs,
            max_new_tokens=10,
            pad_token_id=tokenizer.eos_token_id,
            temperature=temperature
        )
    else:
        outputs = model.generate(
            **inputs,
            max_new_tokens=10,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=False
        )

    all_outputs = []
    all_lengths = []
    for i, out in enumerate(outputs):
        inp = inputs['input_ids'][i]
        output = out[len(inp):]
        result = tokenizer.decode(output, skip_special_tokens=True)
        length = len(output)
        all_outputs.append(result)
        all_lengths.append(length)
    return all_outputs, all_lengths

model_path = "path/to/model/QwQ-32B"
print(f"正在加载模型: {model_path} ...")

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side='left'
)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True,
    local_files_only=True
)

input_directorys = [
    "path/to/LongCoT/AIME/aime2025-I.jsonl",
    # "path/to/LongCoT/AIME/aime2025-II.jsonl"
]
output_dir = "path/to/LongCoT/AIME/qwq_aime_2025_answers"
os.makedirs(output_dir, exist_ok=True)

formats = [
    "<|im_start|>user\n{origin_text}<|im_end|>\n<|im_start|>assistant\n<think>\nLet me answer him without thinking more.</think>\nAnswer: ",
    "<|im_start|>user\n{origin_text}<|im_end|>\n<|im_start|>assistant\n<think>\nI will answer directly. I won't output any thinking process.</think>\nAnswer: ",
    "<|im_start|>user\n{origin_text}<|im_end|>\n<|im_start|>assistant\n<think>\nI will answer with a single number.</think>\nThe answer is: ",
    "<|im_start|>user\n{origin_text}<|im_end|>\n<|im_start|>assistant\n<think>\nI should not think, but should answer directly.</think>\nThe answer is: "
]

temperatures = [0.5, 1.0, 2.0]

for input_directory in input_directorys:
    idx = 0
    with open(input_directory, 'r', encoding='utf-8') as infile:
        for line in tqdm(infile):
            idx += 1
            data = json.loads(line.strip())
            question_id = input_directory.split("/")[-1].split(".")[0] + f"-{idx}"
            problem = data.get("question")
            print("SOLVING", question_id)
            if f"{question_id}.json" in os.listdir(output_dir):
                continue

            # greedy answer
            output_strs, output_length = generate_output(tokenizer, model, [problem])
            data["ModelAnswer"] = output_strs[0]
            data["AnswerLength"] = output_length[0]
            data["direct_answers"] = []

            # direct answer
            for sample_type in ["greedy"] + temperatures:
                batch_input_texts = []
                for input_format in formats:
                    input_text = input_format.format(origin_text=problem)
                    batch_input_texts.append(input_text)

                if sample_type == "greedy":
                    outputs, _ = generate_direct_output(tokenizer, model, batch_input_texts, do_sample=False)
                else:
                    outputs, _ = generate_direct_output(tokenizer, model, batch_input_texts, temperature=sample_type, do_sample=True)
                data["direct_answers"].extend(outputs)

            with open(os.path.join(output_dir, f"{question_id}.json"), 'w', encoding='utf-8') as f:
                json.dump(data, f, ensure_ascii=False, indent=4)

            print(question_id, "DONE!!!")

