import os
import json
import re

import statistics
import numpy as np

def extract_answer(s):
    valid_options = {'A', 'B', 'C', 'D'}
    result = [ch for ch in s if ch in valid_options]
    return ''.join(result)

def correctness(text, correct):
    if "【" in text:
        ans = extract_answer(text.split("【")[-1])
        return ans == "".join(correct)
    return False

files = [
    "path/to/LongCoT/Knowlogic/mask_results/time-cn.json",
    "path/to/LongCoT/Knowlogic/mask_results/nature-cn.json",
    "path/to/LongCoT/Knowlogic/mask_results/social-cn.json",
    "path/to/LongCoT/Knowlogic/mask_results/space-cn.json",
    "path/to/LongCoT/Knowlogic/mask_results/space+nature-cn.json",
]

origin_lengths = []
masked_lengths = []
avg_masked_position = []
origin_correct = masked_correct = 0

long = [0, 0]

for path in files:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    for item in data:
        if item.get("masked_length", None) == None:
            continue

        origin_lengths.append(item['model_answer_length'])
        masked_lengths.append(item['masked_length'])
        avg_masked_position.append(item['mask_position'])

        correct_answer = item['correct_answer']
        if origin_lengths[-1] <= 10000 and correctness(item['model_answer'], correct_answer):
            origin_correct += 1
        if masked_lengths[-1] <= 10000 and correctness(item['masked_answer'], correct_answer):
            masked_correct += 1


print("Ori LEN")
print(f"Origin Length: {sum(origin_lengths) / len(origin_lengths):.2f} ± {statistics.stdev(origin_lengths):.2f}")
print(f"Masked Length: {sum(masked_lengths) / len(masked_lengths):.2f} ± {statistics.stdev(masked_lengths):.2f}")

print("MASK POS")
print(f"Avg Masked Position: {sum(avg_masked_position) / len(avg_masked_position):.2f} ± {statistics.stdev(avg_masked_position):.2f}")

print(origin_correct / len(origin_lengths))
print(masked_correct / len(masked_lengths))
print(long)