import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
from tqdm import tqdm
import os


def generate_output(tokenizer, model, input_texts, temperature=None, do_sample=False):
    device = next(model.parameters()).device
    batch_outputs = []
    
    all_lengths = []
    for input_text in input_texts:
        inputs = tokenizer(input_text, return_tensors="pt").to(device)

        outputs = model.generate(
            **inputs,
            max_new_tokens=7,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=do_sample,
            temperature=temperature if do_sample else None,
            num_return_sequences=16,
        )

        generated_texts = []
        input_len = inputs['input_ids'].shape[1]
        for out in outputs:
            output = out[input_len:]
            result = tokenizer.decode(output, skip_special_tokens=True)
            generated_texts.append(result)

        batch_outputs.append(generated_texts)

    return batch_outputs, all_lengths


model_path = "path/to/model/DeepSeek-R1-Distill-Qwen-14B"
print(f"正在加载模型: {model_path} ...")

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side='left'
)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True,
    local_files_only=True
)

device = "cuda" if torch.cuda.is_available() else "cpu"

results_dir = "path/to/LongCoT/CharCount/results/qwen_zh_results"
output_dir = "path/to/LongCoT/CharCount/qwen_zh_results_with_direct"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

result_files = [f for f in os.listdir(results_dir) if f.startswith("results") and f.endswith(".json")]
def get_checkpoint_number(folder):
    try:
        return int(folder.split('.')[0].split("results")[-1])
    except ValueError:
        return 0
result_files.sort(key=get_checkpoint_number)

start_index = 0  
end_index = 85

formats = [
    "<｜User｜>{origin_text}直接用一个数字回答问题。<｜Assistant｜><think>嗯，用户问我{question}\n让我直接回答他，不要有思考过程。</think>\n答案是：",
    "<｜User｜>{origin_text}直接用一个数字回答问题。<｜Assistant｜><think>嗯，用户的问题是{question}\n我现在直接进行回答。我不应该输出思考过程。</think>\n答案是：",
    "<｜User｜>{origin_text}直接用一个数字回答问题。<｜Assistant｜><think>嗯，用户想知道{question}\n我将会直接用一个数字回答问题，不需要思考。</think>\n答案是：",
    "<｜User｜>{origin_text}直接用一个数字回答问题。<｜Assistant｜><think>嗯，用户问{question}\n我不应该思考，我直接回答该问题。</think>\n答案是："
]

temperatures = [0.5]
batch_size = 4

for file_name in tqdm(result_files[start_index:end_index], desc="Processing files"):
    results_file_path = os.path.join(results_dir, file_name)
    print(f"正在读取文件: {results_file_path}")
    if os.path.exists(f"path/to/LongCoT/CharCount/qwen_zh_results_with_direct/{file_name}"):
        print("RESULT ALLREADY EXISTS !!! SKIPPED !!!")
        continue
    
    try:
        with open(results_file_path, 'r', encoding='utf-8') as f:
            results = json.load(f)
    except FileNotFoundError:
        print(f"文件 {results_file_path} 未找到，跳过。")
        continue

    num_batches = (len(results) + batch_size - 1) // batch_size
    for batch_idx in tqdm(range(num_batches), desc="Processing batches"):
        start = batch_idx * batch_size
        end = min(start + batch_size, len(results))
        batch_results = results[start:end]

        for sample_type in ["greedy"] + temperatures:
            batch_input_texts = []
            for result in batch_results:
                question = result["question"]
                origin_text = question.split("直接用一个阿拉伯数字回答问题。")[0]
                for input_format in formats:
                    input_text = input_format.format(origin_text=origin_text, question=origin_text)
                    batch_input_texts.append(input_text)

            if sample_type == "greedy":
                outputs, _ = generate_output(tokenizer, model, batch_input_texts, do_sample=False)
            else:
                outputs, _ = generate_output(tokenizer, model, batch_input_texts, temperature=sample_type, do_sample=True)

            for i, result in enumerate(batch_results):
                if "direct_answers" not in result:
                    result["direct_answers"] = []
                start_idx = i * len(formats)
                end_idx = start_idx + len(formats)
                result["direct_answers"].extend(outputs[start_idx:end_idx])

        for result in batch_results:
            if "direct_answer" in result:
                del result["direct_answer"]

    output_file_path = os.path.join(output_dir, file_name)
    with open(output_file_path, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=4)
    