import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
from tqdm import tqdm
import os
import random


def generate_output(tokenizer, model, input_texts):
    device = next(model.parameters()).device
    inputs = tokenizer(
        input_texts,
        return_tensors="pt",
        padding=True
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=4096, pad_token_id=tokenizer.eos_token_id)

    all_outputs = []
    all_lengths = []
    for i, out in enumerate(outputs):
        inp = inputs['input_ids'][i]
        output = out[len(inp):]
        result = tokenizer.decode(output, skip_special_tokens=True)
        length = len(tokenizer.encode(result, add_special_tokens=False))
        all_outputs.append(result)
        all_lengths.append(length)
    return all_outputs, all_lengths

def extract_last_digit(text):
    chinese_digits = {
        "零": 0, "一": 1, "二": 2, "三": 3, "四": 4,
        "五": 5, "六": 6, "七": 7, "八": 8, "九": 9,
        "两": 2, "仨": 3,
    }

    arabic_numbers = [(int(match.group()), match.start()) for match in re.finditer(r'\d+', text)]
    chinese_numbers = [(chinese_digits[char], i) for i, char in enumerate(text) if char in chinese_digits]

    all_numbers_with_positions = arabic_numbers + chinese_numbers
    all_numbers_with_positions.sort(key=lambda x: x[1])
    result = [num for num, pos in all_numbers_with_positions]
    return result[-1] if result else None


def extract_first_ans(text):
    model_answer = extract_last_digit(text)
    answer_map = {
        1: ["one", 'only', 'first'],
        2: ["two", "second"],
        3: ["three", 'third'],
        4: ["four", 'fourth'],
        5: ["five", 'fifth'],
        6: ["six", 'sixth'],
    }
    if model_answer not in answer_map.keys():
        return None
    
    answer_split = text.split("Wait")
    thinkings = ""
    for ans in answer_split:
        thinkings += ans
        found = False
        for kw in answer_map[model_answer]:
            if kw in thinkings:
                found = True
                break
        if found:
            break
        thinkings += "Wait"

    return thinkings


model_path = "path/to/model/DeepSeek-R1-Distill-Qwen-14B"
print(f"正在加载模型: {model_path} ...")

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side='left'
)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True,
    local_files_only=True
)

all_data = []
for file in os.listdir("path/to/LongCoT/CharCount/results/qwen_zh_results_with_direct"):
    with open(f"path/to/LongCoT/CharCount/results/qwen_zh_results_with_direct/{file}", "r", encoding='utf-8') as f:
        data = json.load(f)
    all_data.extend(data)

words = random.sample(all_data, 1000)


batch_size = 4
input_format = "<｜User｜>Answer directly with an Arabic number.<｜Assistant｜><think>"

num_batches = len(words) // batch_size + (1 if len(words) % batch_size != 0 else 0)

for batch_idx in tqdm(range(num_batches), desc="Processing batches"):
    start_idx = batch_idx * batch_size
    end_idx = min(start_idx + batch_size, len(words))
    batch_words = words[start_idx:end_idx]

    input_texts = []
    correct_answers = []
    anses = []
    for word in batch_words:
        del word['direct_answers']
        ans = extract_first_ans(word['model_answer'])
        anses.append(ans)
        input_text = input_format + ans
        input_texts.append(input_text)
        correct_answers.append(ans)

    output_strs, output_length = generate_output(tokenizer, model, input_texts)

    for i in range(len(batch_words)):
        batch_words[i]['masked_answer'] = output_strs[i]
        batch_words[i]['masked_length'] = output_length[i] + len(tokenizer.encode(anses[i], add_special_tokens=False))

    with open(f"path/to/LongCoT/CharCount/test/en_masked.json", 'w', encoding='utf-8') as f:
        json.dump(words, f, ensure_ascii=False, indent=4)
