import json
import os
import re
import statistics
import numpy as np
from tqdm import tqdm
from collections import defaultdict

def extract_last_number(s):
    text = re.sub(r"(\d),(\d)", r"\g<1>\g<2>", s)  # 处理逗号分隔的数字
    res = re.findall(r"(\d+(\.\d+)?)", text)
    if len(res) > 0:
        num_str = res[-1][0]
        if "." in num_str:
            return float(num_str)
        else:
            return int(num_str)
    else:
        return 0.0

def extract_first_number(s):
    text = re.sub(r"(\d),(\d)", r"\g<1>\g<2>", s)
    res = re.findall(r"(\d+(?:\.\d+)?)", text)
    if res:
        num_str = res[0]
        return float(num_str) if '.' in num_str else int(num_str)
    else:
        return 0.0

answer_dir = f"path/to/LongCoT/CharCount/results/QwQ_zh_results_with_more_direct"

origin_lengths = []
origin_answers = []
correct_answers = []
direct_answers = []

for file in tqdm(os.listdir(answer_dir)):
    # print(file)
    file_path = os.path.join(answer_dir, file)
    with open(file_path, 'r', encoding='utf-8') as f:
        datas = json.load(f)  
    for data in datas:
        origin_lengths.append(data['model_answer_length'])
        if origin_lengths[-1] >= 4000:
            origin_answers.append(None)
        else:
            origin_answers.append(extract_last_number(data['model_answer']))

        correct_answers.append(data["correct_answer"])
        direct_answers.append(data['direct_answers'])


origin_correct_cnt = 0
for origin, correct in zip(origin_answers, correct_answers):
    if origin == correct:
        origin_correct_cnt += 1

# direct error
results = []
direct_correct_cnt = 0
for i in range(len(correct_answers)):
    model_ans = origin_answers[i]
    direct_ans = direct_answers[i]
    correct_ans = correct_answers[i]

    direct_ans = [extract_first_number(ans) for ans in direct_ans if extract_first_number(ans)!=None]
    direct_ans_int = defaultdict(int)
    for dir_ans in direct_ans:
        direct_ans_int[dir_ans] += 1
    if max(direct_ans_int, key=direct_ans_int.get) == int(correct_ans):
        direct_correct_cnt += 1

    direct_error = sum(abs(int(dir_ans) - int(correct_ans)) for dir_ans in direct_ans) / len(direct_ans)

    results.append({
        'direct_error': direct_error,
        'origin_length': origin_lengths[i],
        'correct_answer': correct_ans
    })

sorted_by_error = sorted(results, key=lambda x: x['direct_error'])

half = len(sorted_by_error) // 3
low_error_group = sorted_by_error[:half]
high_error_group = sorted_by_error[-half:]
# print(low_error_group)
# print(high_error_group)

groups = low_error_group + high_error_group
origin_lengths = [x['origin_length'] for x in groups]
print("Ori LEN")
print(f"Origin Length: {sum(origin_lengths) / len(origin_lengths):.2f} ± {statistics.stdev(origin_lengths):.2f}")
print("Correct")
print(origin_correct_cnt / len(correct_answers))
print("Direct Correct")
print(direct_correct_cnt / len(correct_answers))


def avg_length(group, length_key):
    lengths = [item[length_key] for item in group if item[length_key] is not None]
    return sum(lengths) / len(lengths) if lengths else 0

def std_length(group, length_key):
    lengths = [item[length_key] for item in group if item[length_key] is not None]
    return statistics.stdev(lengths) if len(lengths) > 1 else 0.0

print("\n=== 按模型误差分组，比较回答长度 ===")
low_avg_ori = avg_length(low_error_group, 'origin_length')
low_std_ori = std_length(low_error_group, 'origin_length')
high_avg_ori = avg_length(high_error_group, 'origin_length')
high_std_ori = std_length(high_error_group, 'origin_length')

print(f"低误差组（前 {half} 题）平均 origin_length: {low_avg_ori:.2f} ± {low_std_ori:.2f}")
print(f"高误差组（后 {len(sorted_by_error) - half} 题）平均 origin_length: {high_avg_ori:.2f} ± {high_std_ori:.2f}")
