import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
from tqdm import tqdm
import os


def generate_output(tokenizer, model, input_texts):
    device = next(model.parameters()).device
    inputs = tokenizer(
        input_texts,
        return_tensors="pt",
        padding=True,
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs, 
            max_new_tokens=10000, 
            pad_token_id=tokenizer.eos_token_id, 
            do_sample=False
        )

    all_outputs = []
    all_lengths = []
    for i, out in enumerate(outputs):
        inp = inputs['input_ids'][i]
        output = out[len(inp):]
        result = tokenizer.decode(output, skip_special_tokens=True)
        length = len(tokenizer.encode(result, add_special_tokens=False))
        all_outputs.append(result)
        all_lengths.append(length)
    return all_outputs, all_lengths


model_path = "path/to/model/DeepSeek-R1-Distill-Qwen-14B"
print(f"正在加载模型: {model_path} ...")

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side='left'
)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True,
    local_files_only=True
)

file_paths = [
    # "nature-cn.json",
    # "social-cn.json",
    "space-cn.json",
    # "space+nature-cn.json",
    # "time-cn.json",
]

def split_answer(text):
    splitted = text.split("\n\n")
    res = []
    for i in range(len(splitted)):
        sp = splitted[i]
        res.append(splitted[i])
        if ("所以" in sp or "因此" in sp) \
            and ("答案" in sp or "选项" in sp) and i+1 < len(splitted) \
            and ("或者" in splitted[i+1] or "不过" in splitted[i+1] or "但是" in splitted[i+1] or "可是" in splitted[i+1]) \
            :
            return "\n\n".join(res)

    return ""

batch_size = 4
for file in file_paths:
    print(file)
    with open(os.path.join("path/to/LongCoT/Knowlogic/direct_results", file), 'r', encoding='utf-8') as f:
        data = json.load(f)

    # process data
    num_batches = len(data) // batch_size + (1 if len(data) % batch_size != 0 else 0)
    for batch_idx in tqdm(range(num_batches), desc="Processing batches"):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(data))
        batch_data = data[start_idx:end_idx]

        input_texts = []
        correct_answers = []
        items = []
        anses = []
        for item in batch_data:
            del item['direct_answers']
            if item.get("masked_length", None) != None:
                continue

            ans = split_answer(item['model_answer'])
            if ans != "":
                input_text = f"<｜User｜>选项：{item['question'].split('\n\n选项：')[-1]}<｜Assistant｜><think>" + ans
                input_texts.append(input_text)
                items.append(item)
                anses.append(ans)
            else:
                item['masked_answer'] = item['model_answer']
                item['masked_length'] = item['model_answer_length']
                item['mask_position'] = item['model_answer_length']

        if len(input_texts) == 0:
            continue

        output_strs, output_length = generate_output(tokenizer, model, input_texts)

        for i in range(len(items)):
            items[i]['masked_answer'] = output_strs[i]
            items[i]['masked_length'] = output_length[i] + len(tokenizer.encode(anses[i], add_special_tokens=False))
            items[i]['masked_length'] = min(10000, items[i]['masked_length'])
            items[i]['mask_position'] = len(tokenizer.encode(anses[i], add_special_tokens=False))

        with open(os.path.join("path/to/LongCoT/Knowlogic/mask_results", file), 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=4)
    