import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import os
from tqdm import tqdm
import re

def extract_last_number(s):
    text = re.sub(r"(\d),(\d)", r"\g<1>\g<2>", s) 
    res = re.findall(r"(\d+(\.\d+)?)", text)
    if len(res) > 0:
        num_str = res[-1][0]
        if "." in num_str:
            return float(num_str)
        else:
            return int(num_str)
    else:
        return 0.0


def generate_output(tokenizer, model, input_texts):
    device = next(model.parameters()).device
    inputs = tokenizer(
        input_texts,
        return_tensors="pt",
        padding=True
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs, 
            max_new_tokens=20000, 
            pad_token_id=tokenizer.eos_token_id, 
            do_sample=False
        )

    all_outputs = []
    all_lengths = []
    for i, out in enumerate(outputs):
        inp = inputs['input_ids'][i]
        output = out[len(inp):]
        result = tokenizer.decode(output, skip_special_tokens=True)
        length = len(output)
        all_outputs.append(result)
        all_lengths.append(length)
    return all_outputs, all_lengths

files = [
    # "2024-I-1.json",
    # "2024-I-2.json",
    # "2024-I-3.json",
    # "2024-I-4.json",
    # "2024-I-5.json",
    # "2024-I-6.json",
    # "2024-I-7.json",
    # "2024-I-8.json",
    # "2024-I-9.json",
    # "2024-I-10.json",

    # "2024-I-11.json",
    # "2024-I-12.json",
    # "2024-I-13.json",
    # "2024-I-14.json",
    # "2024-I-15.json",

    # "2024-II-1.json",
    # "2024-II-2.json",
    # "2024-II-3.json",
    # "2024-II-4.json",
    # "2024-II-5.json",

    "2024-II-6.json",
    "2024-II-7.json",
    "2024-II-8.json",
    "2024-II-9.json",
    "2024-II-10.json",
    "2024-II-11.json",
    "2024-II-12.json",
    "2024-II-13.json",
    "2024-II-14.json",
    "2024-II-15.json",
]


def process_jsonl(input_directory, output_dir, tokenizer, model):
    for file in files:
        input_file = os.path.join(input_directory, file)    
        with open(input_file, 'r', encoding='utf-8') as f:
            data = json.load(f)

            question_id = data.get("ID")
            answer = data.get("ModelAnswer")
            correct_answer = data.get("CorrectAnswer")
            model_short_answer = extract_last_number(answer)

            splitted_answer = answer.split("Wait,")
            thinking = ""
            for ans in splitted_answer:
                thinking += ans
                if ans == splitted_answer[-1]:
                    thinking = splitted_answer[0]
                    break
                if str(model_short_answer) in ans[-10:]:
                    print(f"Short answer {model_short_answer} in thinking.")
                    break
                thinking += "Wait,"
            thinking = splitted_answer[0]
            input_text = f"<｜User｜><｜Assistant｜><think>{thinking}"

            model_answers, lengths = generate_output(tokenizer, model, [input_text])
            masked_answer = model_answers[0]
            masked_length = min(lengths[0] + len(tokenizer.encode(thinking, add_special_tokens=False)), 20000)

            data["MaskedAnswer"] = masked_answer
            data["MaskedLength"] = masked_length

            output_path = os.path.join(output_dir, f"{question_id}.json")
            with open(output_path, 'w', encoding='utf-8') as outfile:
                json.dump(data, outfile, ensure_ascii=False, indent=4)

            print(f"已处理问题 ID: {question_id}，保存至: {output_path}")

model_path = "path/to/model/DeepSeek-R1-Distill-Qwen-14B"
print(f"正在加载模型: {model_path} ...")

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side='left'
)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True,
    local_files_only=True
)

input_directory = "path/to/LongCoT/AIME/aime_2024_answers"
output_dir = "path/to/LongCoT/AIME/aime_2024_answers_masked"
process_jsonl(input_directory, output_dir, tokenizer, model)

print(f"处理完成！所有结果已保存到目录: {output_dir}")