import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
from tqdm import tqdm
import os
import random

def generate_output(tokenizer, model, input_texts):
    device = next(model.parameters()).device

    input_messages = [
        [
            {"role": "user", "content": prompt}
        ]
        for prompt in input_texts
    ]

    text = tokenizer.apply_chat_template(
        input_messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding=True
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=4096, pad_token_id=tokenizer.eos_token_id)

    all_outputs = []
    all_lengths = []
    for i, out in enumerate(outputs):
        inp = inputs['input_ids'][i]
        output = out[len(inp):]
        result = tokenizer.decode(output, skip_special_tokens=True)
        # 计算每个解码后输出的实际长度
        length = len(tokenizer.encode(result, add_special_tokens=False))
        all_outputs.append(result)
        all_lengths.append(length)
    return all_outputs, all_lengths


model_path = "path/to/model/QwQ-32B"
print(f"正在加载模型: {model_path} ...")

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side='left'
)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True,
    local_files_only=True
)

# 71-86
# 106-119
# 139-149
# 157-170
base_idx = 43
batch_size = 4
for idx in range(base_idx, base_idx + 100):
    print(f"path/to/LongCoT/CharCount/words/word{idx}.json")
    if os.path.exists(f"path/to/LongCoT/CharCount/results/QwQ_zh_results/results{idx}.json"):
        print("RESULT ALLREADY EXISTS !!!")

        # continue

        with open(f"path/to/LongCoT/CharCount/results/QwQ_zh_results/results{idx}.json", 'r', encoding='utf-8') as f:
            results = json.load(f)
            allready_words = [item['question'].split("这个单词里面")[0] for item in results]

    with open(f"path/to/LongCoT/CharCount/words/word{idx}.json", 'r', encoding='utf-8') as f:
        words = json.load(f)

    words = [word for word in words if word[0] not in allready_words]
    origin_text_format = "{word}这个单词里面有几个字母{ch}？直接用一个阿拉伯数字回答问题。"

    num_batches = len(words) // batch_size + (1 if len(words) % batch_size != 0 else 0)

    for batch_idx in tqdm(range(num_batches), desc="Processing batches"):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(words))
        batch_words = words[start_idx:end_idx]

        input_texts = []
        correct_answers = []
        for word in batch_words:
            input_text = origin_text_format.format(word=word[0], ch=word[1])
            ans = word[2]
            input_texts.append(input_text)
            correct_answers.append(ans)

        output_strs, output_length = generate_output(tokenizer, model, input_texts)

        for i in range(len(batch_words)):
            result = {
                "question": origin_text_format.format(word=batch_words[i][0], ch=batch_words[i][1]),
                "correct_answer": correct_answers[i],
                "model_answer": output_strs[i],
                "model_answer_length": output_length[i],
            }
            results.append(result)

    with open(f"path/to/LongCoT/CharCount/results/QwQ_zh_results/results{idx}.json", 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=4)