import json
import os
import re
import statistics
import math
from typing import Dict, List, Tuple

_CATEGORY_RE = re.compile(r"(.+?)_\d+$")   # captures everything before the trailing _<number>

def compute_accuracy(filepath):
    """Compute the accuracy of model answers in the given JSON file."""
    with open(filepath, 'r') as f:
        data = json.load(f)
    
    total_samples = 0
    correct_count = 0
    
    # Assuming the JSON structure is a dictionary where each key represents a sample.
    for sample_key, sample in data.items():
        total_samples += 1
        if sample.get("model_answer").lower() == sample.get("correct_answer").lower():
            correct_count += 1
            
    # Calculate accuracy (handle division by zero if no samples exist)
    accuracy = correct_count / total_samples if total_samples > 0 else 0
    return accuracy

def summarize(scores: List[float]) -> Tuple[float, float]:
    """Return (mean, stdev) – stdev is 0.0 if only one score is present."""
    print(f"Number of scores: {len(scores)}")
    avg_acc = statistics.mean(scores)
    # compute std. dev
    std_dev = statistics.stdev(scores)
    std_error = std_dev / math.sqrt(len(scores))
    return avg_acc, std_error

def main() -> None:
    folderpaths = [
        "subgoals_full_experiments/gemini_2.0",
        "subgoals_full_experiments/gemini_2.5",
        "subgoals_full_experiments/o4_mini",
    ]

    results: Dict[str, Dict[str, Tuple[float, float]]] = {}

    for model_folder in folderpaths:
        category_scores: Dict[str, List[float]] = {}

        for sub in os.listdir(model_folder):
            sub_path = os.path.join(model_folder, sub)
            if not os.path.isdir(sub_path):
                continue

            m = _CATEGORY_RE.match(sub)
            if not m:
                continue  # skip anything that doesn't follow the *_<number> pattern

            category = m.group(1)
            acc = compute_accuracy(sub_path + '/results.json')
            category_scores.setdefault(category, []).append(acc)

        # compute mean / std‑dev per category
        results[model_folder] = {
            cat: summarize(scores) for cat, scores in category_scores.items()
        }

    # ─────── print nicely ───────
    for model_folder, cat_dict in results.items():
        print(f"\nResults for {os.path.basename(model_folder)}:")
        for cat, (mean_acc, std_acc) in sorted(cat_dict.items()):
            print(f"  {cat:<30} {mean_acc*100:.2f}% ± {std_acc*100:.2f}%")

if __name__ == "__main__":
    main()