import json
import os
import re
from transformers import AutoTokenizer
import statistics
import numpy as np
from collections import defaultdict, Counter
import math

def extract_last_number(s):
    # if "</think>" not in s:
    #     return None
    text = re.sub(r"(\d),(\d)", r"\g<1>\g<2>", s) 
    res = re.findall(r"(\d+(\.\d+)?)", text)
    if len(res) > 0:
        num_str = res[-1][0]
        if "." in num_str:
            return float(num_str)
        else:
            return int(num_str)
    else:
        return None

import re

def remove_outliers_by_mode(data, max_gap=10):

    if not data:
        return []

    counter = Counter(data)
    mode_value, mode_cnt = counter.most_common(1)[0] 
    if mode_cnt > len(data) / 2:
        return [mode_value]
    min_acceptable_order = mode_value / max_gap
    max_acceptable_order = mode_value * max_gap

    cleaned_data = []
    for x in data:
        if min_acceptable_order <= x <= max_acceptable_order:
            cleaned_data.append(x)

    return cleaned_data


def extract_first_number(s):
    text = re.sub(r"(\d),(\d)", r"\1\2", s)

    match = re.search(r"(\d+(?:\.\d+)?)\s*/\s*(\d+(?:\.\d+)?)", text)
    if match:
        num = float(match.group(1))
        den = float(match.group(2))
        if den != 0:
            return num / den
        else:
            return float('inf') 

    res = re.findall(r"\d+(?:\.\d+)?", text)
    if res:
        num_str = res[0]
        return float(num_str) if '.' in num_str else int(num_str)
    
    return None

from scipy.stats import gaussian_kde
import numpy as np


year = 2024
model = ""
# model = ""
answer_dir = f"/path/to/LongCoT/AIME/{model}aime_{year}_answers"

answer_key = "answer" if year==2025 else "Answer"
if model == "" or "new" in model:
    answer_key = "CorrectAnswer"

origin_lengths = []
origin_answers = []
correct_answers = []
direct_answers = []
ids = []
for file in os.listdir(answer_dir):
    file_path = os.path.join(answer_dir, file)
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)  
    ids.append(data.get("id") or data.get("ID"))
    if model == "dpsk_":
        model_short_answer = extract_last_number(data["model_answer"])
    else:
        model_short_answer = extract_last_number(data['ModelAnswer'])
    origin_lengths.append(data['AnswerLength'])
    origin_answers.append(model_short_answer)

    correct_answers.append(extract_first_number(str(data[answer_key])))

    # direct answers
    if model == "":
        with open(f"/path/to/LongCoT/AIME/{model}aime_{year}_direct_answers/{file}", 'r', encoding='utf-8') as f:
            da = json.load(f)  
        direct_answers.append(da['direct_answers'])    
    else:
        direct_answers.append(data.get('direct_answers', [""]))

print("Ori LEN")
print(f"Origin Length: {sum(origin_lengths) / len(origin_lengths):.2f} ± {statistics.stdev(origin_lengths):.2f}")

origin_correct_cnt = 0
for origin, correct in zip(origin_answers, correct_answers):
    # print(origin, masked, correct)
    if str(origin) == str(correct):
        origin_correct_cnt += 1

print(origin_correct_cnt / 30)


# direct error
results = []
direct_correct_cnt = 0
high_error_group1 = []
for i in range(len(correct_answers)):
    model_ans = origin_answers[i]
    direct_answ = direct_answers[i]
    correct_ans = correct_answers[i]
    idx = ids[i]

    direct_ans = [extract_first_number(ans) for ans in direct_answ]
    direct_ans = [item for item in direct_ans if item!=None]

    direct_ans = remove_outliers_by_mode(direct_ans)
    for j in range(len(direct_ans)):
        if model_ans and (direct_ans[j] % model_ans == 0 or model_ans % direct_ans[j] == 0):
            direct_ans[j] = model_ans
    # print(direct_ans)

    direct_ans_int = defaultdict(int)
    for dir_ans in direct_ans:
        direct_ans_int[dir_ans] += 1
    if max(direct_ans_int, key=direct_ans_int.get) == int(correct_ans):
        direct_correct_cnt += 1

    if not model_ans:
        direct_error = 10000
    else:
        direct_error = sum(abs((dir_ans - int(correct_ans))) for dir_ans in direct_ans) / len(direct_ans)
    results.append({
        'direct_error': direct_error,
        'origin_length': origin_lengths[i],
        'correct_answer': correct_ans,
        # 'direct_answers': direct_ans,
        'id': idx,
        'model_ans': model_ans
    })


print(f"DIRECT CORRECT: {direct_correct_cnt / 30}")

sorted_by_error = sorted(results, key=lambda x: x['direct_error'])

half = len(sorted_by_error) // 2
low_error_group = sorted_by_error[:half]
high_error_group = sorted_by_error[half:] + high_error_group1
# high_error_group = [item for item in high_error_group if item['direct_error']!=10000]
print(low_error_group)
print(high_error_group)

def avg_length(group, length_key):
    lengths = [item[length_key] for item in group if item[length_key] is not None]
    return sum(lengths) / len(lengths) if lengths else 0

def std_length(group, length_key):
    lengths = [item[length_key] for item in group if item[length_key] is not None]
    return statistics.stdev(lengths) if len(lengths) > 1 else 0.0

print("\n=== 按模型误差分组，比较回答长度 ===")
low_avg_ori = avg_length(low_error_group, 'origin_length')
low_std_ori = std_length(low_error_group, 'origin_length')
high_avg_ori = avg_length(high_error_group, 'origin_length')
high_std_ori = std_length(high_error_group, 'origin_length')

print(f"低误差组平均 origin_length: {low_avg_ori:.2f} ± {low_std_ori:.2f}")
print(f"高误差组平均 origin_length: {high_avg_ori:.2f} ± {high_std_ori:.2f}")
print(f"R_delta: {(high_avg_ori-low_avg_ori)/low_avg_ori}")