import os
import json
import re
from transformers import AutoTokenizer
import statistics
import numpy as np

model_path = "path/to/model/DeepSeek-R1-Distill-Qwen-14B"
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side='left'
)

def extract_last_digit(text):
    chinese_digits = {
        "零": 0, "一": 1, "二": 2, "三": 3, "四": 4,
        "五": 5, "六": 6, "七": 7, "八": 8, "九": 9,
        "两": 2, "仨": 3,
    }

    arabic_numbers = [(int(match.group()), match.start()) for match in re.finditer(r'\d+', text)]
    chinese_numbers = [(chinese_digits[char], i) for i, char in enumerate(text) if char in chinese_digits]

    all_numbers_with_positions = arabic_numbers + chinese_numbers
    all_numbers_with_positions.sort(key=lambda x: x[1])
    result = [num for num, pos in all_numbers_with_positions]
    if len(result) != 0:
        return result[-1]
    return None

def longest_common_prefix(str1, str2):
    min_len = min(len(str1), len(str2))
    i = 0
    while i < min_len and str1[i] == str2[i]:
        i += 1
    return str1[:i]

origin_length = []
masked_length = []
origin_correct = 0
masked_correct = 0
avg_masked_position = []


path = "path/to/LongCoT/CharCount/test/zh_masked.json"
with open(path, 'r', encoding='utf-8') as f:
    data = json.load(f)
for item in data:
    model_answer = extract_last_digit(item['model_answer'])
    masked_answer = extract_last_digit(item['masked_answer'])
    correct_answer = item['correct_answer']

    if item['model_answer_length'] <= 4000:
        origin_length_.append(item['model_answer_length'])
    origin_length.append(item['model_answer_length'])

    prefix = longest_common_prefix(item['model_answer'], item['masked_answer'])
    avg_masked_position.append(
        len(tokenizer.encode(prefix, add_special_tokens=False))
    )

    masked_len = item['masked_length']
    masked_length.append(masked_len)

    if model_answer == correct_answer:
        origin_correct += 1
    if masked_answer == correct_answer:
        masked_correct += 1


print("Ori LEN")
print(f"Origin Length: {sum(origin_length) / len(origin_length):.2f} ± {statistics.stdev(origin_length):.2f}")
print(f"Masked Length: {sum(masked_length) / len(masked_length):.2f} ± {statistics.stdev(masked_length):.2f}")

print("MASK POS")
print(f"Avg Masked Position: {sum(avg_masked_position) / len(avg_masked_position):.2f} ± {statistics.stdev(avg_masked_position):.2f}")

print(origin_correct)
print(masked_correct)
