import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
from tqdm import tqdm
import os
import random

def generate_output(tokenizer, model, input_texts):
    device = next(model.parameters()).device

    input_messages = [
        [
            {"role": "user", "content": prompt}
        ]
        for prompt in input_texts
    ]

    text = tokenizer.apply_chat_template(
        input_messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding=True
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs, 
            max_new_tokens=10000, 
            pad_token_id=tokenizer.eos_token_id,
            do_sample=False 
            )
            
    all_outputs = []
    all_lengths = []
    for i, out in enumerate(outputs):
        inp = inputs['input_ids'][i]
        output = out[len(inp):]
        result = tokenizer.decode(output, skip_special_tokens=True)
        length = len(tokenizer.encode(result, add_special_tokens=False))
        all_outputs.append(result)
        all_lengths.append(length)
    return all_outputs, all_lengths


def generate_direct_output(tokenizer, model, input_texts, temperature=None, do_sample=False):
    device = next(model.parameters()).device
    batch_outputs = []
    
    all_lengths = []
    for input_text in input_texts:
        inputs = tokenizer(input_text, return_tensors="pt").to(device)

        outputs = model.generate(
            **inputs,
            max_new_tokens=7,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=do_sample,
            temperature=temperature if do_sample else None,
            num_return_sequences=16,
        )

        generated_texts = []
        input_len = inputs['input_ids'].shape[1]
        for out in outputs:
            output = out[input_len:]
            result = tokenizer.decode(output, skip_special_tokens=True)
            generated_texts.append(result)

        batch_outputs.append(generated_texts)

    return batch_outputs, all_lengths

model_path = "path/to/model/QwQ-32B"
print(f"正在加载模型: {model_path} ...")

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side='left'
)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True,
    local_files_only=True
)

input_directory = "path/to/LongCoT/Knowlogic/finaldata"
output_dir = "path/to/LongCoT/Knowlogic/qwq_results"
os.makedirs(output_dir, exist_ok=True)

file_paths = [
    # "nature-cn.json",
    # "social-cn.json",
    # "space-cn.json",
    "space+nature-cn.json",
    # "time-cn.json",
]

formats = [
    "<|im_start|>user\n{origin_text}<|im_end|>\n<|im_start|>assistant\n<think>\n让我直接回答他，不要有思考过程。</think>\n答案是：【",
    "<|im_start|>user\n{origin_text}<|im_end|>\n<|im_start|>assistant\n<think>\n我现在直接进行回答。我不应该输出思考过程。</think>\n答案是：【",
    "<|im_start|>user\n{origin_text}<|im_end|>\n<|im_start|>assistant\n<think>\n我将会直接用一个数字回答问题，不需要思考。</think>\n答案是：【",
    "<|im_start|>user\n{origin_text}<|im_end|>\n<|im_start|>assistant\n<think>\n我不应该思考，我直接回答该问题。</think>\n答案是：【",
]
temperatures = [0.5, 1.0, 2.0]

batch_size = 4
for file in file_paths:
    print(file)
    with open(os.path.join(input_directory, file), 'r', encoding='utf-8') as f:
        data = json.load(f)

    num_batches = len(data) // batch_size + (1 if len(data) % batch_size != 0 else 0)

    for batch_idx in tqdm(range(num_batches), desc="Processing batches"):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(data))
        batch_data = data[start_idx:end_idx]

        input_texts = []
        correct_answers = []
        for item in batch_data:
            input_text = item['question']
            input_texts.append(input_text)

        # greedy answer
        output_strs, output_length = generate_output(tokenizer, model, input_texts)
        for i in range(len(batch_data)):
            batch_data[i]['model_answer'] = output_strs[i]
            batch_data[i]['model_answer_length'] = output_length[i]

        # direct answer
        for sample_type in ["greedy"] + temperatures:
            batch_input_texts = []
            for result in batch_data:
                question = result["question"]
                for input_format in formats:
                    input_text = input_format.format(origin_text=question)
                    batch_input_texts.append(input_text)

            if sample_type == "greedy":
                outputs, _ = generate_direct_output(tokenizer, model, batch_input_texts, do_sample=False)
            else:
                outputs, _ = generate_direct_output(tokenizer, model, batch_input_texts, temperature=sample_type, do_sample=True)

            for i, d in enumerate(batch_data):
                if "direct_answers" not in d:
                    d["direct_answers"] = []
                start_idx = i * len(formats)
                end_idx = start_idx + len(formats)
                d["direct_answers"].extend(outputs[start_idx:end_idx])

        with open(os.path.join(output_dir, file), 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=4)


