import json
import os
import re
import statistics
import numpy as np
from tqdm import tqdm
from collections import defaultdict


def extract_abcd_options(s):
    valid_options = {'A', 'B', 'C', 'D'}
    result = [ch for ch in s if ch in valid_options]
    return ''.join(result)

answer_dirs = [
    # "path/to/LongCoT/Knowlogic/qwq_results",

    # "path/to/LongCoT/Knowlogic/dpsk_results"

    "path/to/LongCoT/Knowlogic/qwen_results"
]

origin_lengths = []
origin_answers = []
correct_answers = []
direct_answers = []

for answer_dir in answer_dirs:
    for file in tqdm(os.listdir(answer_dir)):
        if not file.endswith("json"):
            continue
        print(file)
        file_path = os.path.join(answer_dir, file)
        with open(file_path, 'r', encoding='utf-8') as f:
            datas = json.load(f)
        for data in datas:
            origin_lengths.append(data['model_answer_length'])
            if len(data['model_answer'].split("</think>")) == 1:
                origin_answers.append("")
            else:
                origin_answers.append(extract_abcd_options(data['model_answer'].split("</think>")[-1]))

            if data.get("answer", None):
                correct_answers.append("".join(data["answer"]))
            else:
                correct_answers.append("".join(data["correct_answer"]))
            direct_answers.append(data['direct_answers'])

print("Ori LEN")
print(f"Origin Length: {sum(origin_lengths) / len(origin_lengths):.2f} ± {statistics.stdev(origin_lengths):.2f}")

origin_correct_cnt = 0
for origin, correct in zip(origin_answers, correct_answers):
    if origin == correct:
        origin_correct_cnt += 1

print("Correct")
print(origin_correct_cnt / len(correct_answers))


# direct error
results = []
direct_correct_cnt = 0
for i in range(len(correct_answers)):
    model_ans = origin_answers[i]
    direct_ans = direct_answers[i]
    correct_ans = correct_answers[i]
    # print(direct_ans)
    direct_ans = [extract_abcd_options(ans.split("】")[0]) for ans in direct_ans]
    # print(direct_ans)
    # exit(0)
    direct_ans_int = defaultdict(int)
    for dir_ans in direct_ans:
        direct_ans_int[dir_ans] += 1
    if max(direct_ans_int, key=direct_ans_int.get) == correct_ans:
        direct_correct_cnt += 1

    if model_ans == "":
        continue
    direct_error = 16 - direct_ans_int[model_ans]
    results.append({
        'direct_error': direct_error,
        'origin_length': origin_lengths[i],
        'correct_answer': correct_ans
    })

print("Direct Correct")
print(direct_correct_cnt / len(correct_answers))


sorted_by_error = sorted(results, key=lambda x: x['direct_error'])

half = len(sorted_by_error) // 2
low_error_group = sorted_by_error[:half]
high_error_group = sorted_by_error[-half:]
# print(low_error_group)
# print(high_error_group)
print(sorted_by_error[0]['direct_error'])
print(sorted_by_error[half]['direct_error'])
print(sorted_by_error[-1]['direct_error'])

def avg_length(group, length_key):
    lengths = [item[length_key] for item in group if item[length_key] is not None]
    return sum(lengths) / len(lengths) if lengths else 0

def std_length(group, length_key):
    lengths = [item[length_key] for item in group if item[length_key] is not None]
    return statistics.stdev(lengths) if len(lengths) > 1 else 0.0

print("\n=== 按模型误差分组，比较回答长度 ===")
low_avg_ori = avg_length(low_error_group, 'origin_length')
low_std_ori = std_length(low_error_group, 'origin_length')
high_avg_ori = avg_length(high_error_group, 'origin_length')
high_std_ori = std_length(high_error_group, 'origin_length')

print(f"低误差组（前 {len(low_error_group)} 题）平均 origin_length: {low_avg_ori:.2f} ± {low_std_ori:.2f}")
print(f"高误差组（后 {len(high_error_group)} 题）平均 origin_length: {high_avg_ori:.2f} ± {high_std_ori:.2f}")
