import json
import os
import re
from transformers import AutoTokenizer
import statistics
import numpy as np
from collections import defaultdict


def extract_last_number(s):
    text = re.sub(r"(\d),(\d)", r"\g<1>\g<2>", s) 
    res = re.findall(r"(\d+(\.\d+)?)", text)
    if len(res) > 0:
        num_str = res[-1][0]
        if "." in num_str:
            return float(num_str)
        else:
            return int(num_str)
    else:
        return 0.0

def extract_first_number(s):
    text = re.sub(r"(\d),(\d)", r"\g<1>\g<2>", s)
    res = re.findall(r"(\d+(?:\.\d+)?)", text)
    if res:
        num_str = res[0]
        return float(num_str) if '.' in num_str else int(num_str)
    else:
        return 0.0

model_path = "path/to/model/DeepSeek-R1-Distill-Qwen-14B"
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side='left'
)

year = 2025
answer_dir = f"path/to/LongCoT/AIME/aime_{year}_answers_masked"
direct_answer_dir = f"path/to/LongCoT/AIME/aime_{year}_direct_answers"

origin_lengths = []
masked_lengths = []
origin_answers = []
masked_answers = []
correct_answers = []
avg_masked_position = []
direct_answers = []

for file in os.listdir(answer_dir):
    file_path = os.path.join(answer_dir, file)
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)        
    model_short_answer = extract_last_number(data['ModelAnswer'])

    # avg masked position
    splitted_answer = data['ModelAnswer'].split("Wait,")
    thinking = ""
    for ans in splitted_answer:
        thinking += ans
        if ans == splitted_answer[-1]:
            thinking = splitted_answer[0]
            break
        if str(model_short_answer) in ans[-20:]:
            break
        thinking += "Wait,"
    if thinking != splitted_answer[0]:
        pos = len(tokenizer.encode(thinking, add_special_tokens=False))
        avg_masked_position.append(pos)

    # other calculations
    length = data["MaskedLength"]
    answer = data['MaskedAnswer']
    origin_lengths.append(data['AnswerLength'])
    masked_lengths.append(length)
    if origin_lengths[-1] == 20000:
        origin_answers.append(None)
    else:
        origin_answers.append(extract_last_number(data['ModelAnswer']))

    if masked_lengths[-1] == 20000:
        masked_answers.append(None)
    else:
        masked_answers.append(extract_last_number(answer))

    correct_answers.append(extract_last_number(str(data['CorrectAnswer'])))

    # direct answers
    file_path = os.path.join(direct_answer_dir, file)
    with open(file_path, 'r', encoding='utf-8') as f:
        direct = json.load(f)['direct_answers']

    direct_answers.append(direct)

print("Ori LEN")
print(f"Origin Length: {sum(origin_lengths) / len(origin_lengths):.2f} ± {statistics.stdev(origin_lengths):.2f}")
print(f"Masked Length: {sum(masked_lengths) / len(masked_lengths):.2f} ± {statistics.stdev(masked_lengths):.2f}")

print("MASK POS")
print(f"Avg Masked Position: {sum(avg_masked_position) / len(avg_masked_position):.2f} ± {statistics.stdev(avg_masked_position):.2f}")

# for ori, mas in zip(origin_lengths, masked_lengths):
#     print(ori, mas)

origin_correct_cnt = 0
masked_correct_cnt = 0
for origin, masked, correct in zip(origin_answers, masked_answers, correct_answers):
    # print(origin, masked, correct)
    if str(origin) == str(correct):
        origin_correct_cnt += 1
    if str(masked) == str(correct):
        masked_correct_cnt += 1

print(origin_correct_cnt / 30)
print(masked_correct_cnt / 30)


# direct error
results = []
direct_correct_cnt = 0
for i in range(len(correct_answers)):
    model_ans = origin_answers[i]
    direct_ans = direct_answers[i]
    correct_ans = correct_answers[i]

    direct_ans = [extract_first_number(ans) for ans in direct_ans if len(ans.split("\n"))>1]
    direct_ans_int = defaultdict(int)
    for dir_ans in direct_ans:
        direct_ans_int[dir_ans] += 1
    direct_error = sum(abs(dir_ans - int(correct_ans)) for dir_ans in direct_ans_int) / len(direct_ans_int)

    results.append({
        'direct_error': direct_error,
        'origin_length': origin_lengths[i],
        'correct_answer': correct_ans,
        'masked_length': masked_lengths[i],
    })

    if max(direct_ans_int, key=direct_ans_int.get) == int(correct_ans):
        direct_correct_cnt += 1

print(f"DIRECT CORRECT: {direct_correct_cnt / 30}")

sorted_by_error = sorted(results, key=lambda x: x['direct_error'])

half = len(sorted_by_error) // 2
low_error_group = sorted_by_error[:half]
high_error_group = sorted_by_error[half:]
print(low_error_group)
print(high_error_group)

def avg_length(group, length_key):
    lengths = [item[length_key] for item in group if item[length_key] is not None]
    return sum(lengths) / len(lengths) if lengths else 0

def std_length(group, length_key):
    lengths = [item[length_key] for item in group if item[length_key] is not None]
    return statistics.stdev(lengths) if len(lengths) > 1 else 0.0

print("\n=== 按模型误差分组，比较回答长度 ===")
low_avg_ori = avg_length(low_error_group, 'origin_length')
low_std_ori = std_length(low_error_group, 'origin_length')
high_avg_ori = avg_length(high_error_group, 'origin_length')
high_std_ori = std_length(high_error_group, 'origin_length')

low_avg_mask = avg_length(low_error_group, 'masked_length')
low_std_mask = std_length(low_error_group, 'masked_length')
high_avg_mask = avg_length(high_error_group, 'masked_length')
high_std_mask = std_length(high_error_group, 'masked_length')

print(f"低误差组（前 {half} 题）平均 origin_length: {low_avg_ori:.2f} ± {low_std_ori:.2f}")
print(f"高误差组（后 {len(sorted_by_error) - half} 题）平均 origin_length: {high_avg_ori:.2f} ± {high_std_ori:.2f}")

print(f"低误差组平均 masked_length: {low_avg_mask:.2f} ± {low_std_mask:.2f}")
print(f"高误差组平均 masked_length: {high_avg_mask:.2f} ± {high_std_mask:.2f}")