"""
compute_algorithmic_stats.py
============================
Summarise *confidence-evaluation* results (method 4) and print a model-wide
ASCII leaderboard ranked by overall **success-likelihood percentage** plus
per-question rankings.

Only minimal edits from the original binary-evaluation script.
"""
from __future__ import annotations

import argparse
import ast
import json
from collections import Counter
from pathlib import Path
from typing import List, Dict

# ---------------------------------------------------------------------------
# Constant: full dataset list for Gemini-2.0-Flash
# ---------------------------------------------------------------------------
DEFAULT_DATASETS_GEMINI: list[str] = [
    "clutrr",
    "clevr",
    "leaf",
    "omnimath-1",
    "bbeh_word_sorting",
    "bbeh_dyck_languages",
    "bbeh_object_counting",
    "bbeh_object_properties",
    "bbeh_boardgame_qa",
    "bbeh_boolean_expressions",
    "bbeh_zebra_puzzles",
    "bbeh_buggy_tables",
    "bbeh_spatial_reasoning",
    "bbeh_multistep_arithmetic",
    "bbeh_geometric_shapes",
    "bbeh_shuffled_objects",
    "bbeh_temporal_sequence",
    "bbeh_disambiguation_qa",
    "bbeh_causal_understanding",
    "bbeh_time_arithmetic",
    "bbeh_web_of_lies",
    "bbeh_sarc_triples",
    "bbeh_nycc",
    "bbeh_sportqa",
    "bbeh_linguini",
    "bbeh_movie_recommendation",
    "bbeh_hyperbaton"
]

# 9 reflection items + overall success likelihood
QUESTIONS = [
    "Input / output clear?",
    "Know correct algorithm?",
    "Syntax confidence?",
    "Execution confidence?",
    "Functional correctness?",
    "Overall code confidence?",
    "Decomposition ability?",
    "Complexity manageable?",
    "Mental execution accurate?",
    "Overall success likelihood",
]
NUM_Q = len(QUESTIONS)           # 10
FINAL_IDX = NUM_Q - 1            # 9

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _safe_parse_entry(entry):
    """
    Return length-10 list[float] with values in [0,1] or None on failure.
    Accepts list or string literal representation.
    """
    if isinstance(entry, str):
        try:
            entry = ast.literal_eval(entry)
        except Exception:
            return None
    if not isinstance(entry, list) or len(entry) != NUM_Q:
        return None
    out = []
    for v in entry:
        try:
            fv = float(v)
        except (TypeError, ValueError):
            return None
        if not (0.0 <= fv <= 1.0):
            return None
        out.append(fv)
    return out


def _per_question_means(records: List[List[float]]):
    return [sum(row[i] for row in records) / len(records) for i in range(NUM_Q)]


def _compute_stats(records):
    """
    For compatibility with the old log format we still compute a
    pseudo-binary mode by thresholding probabilities at 0.5.
    """
    stats = []
    for i in range(NUM_Q):
        col = [r[i] for r in records]
        mean = sum(col) / len(col)
        # convert to 0/1 for counts using 0.5 threshold
        bin_col = [1 if v >= 0.5 else 0 for v in col]
        c = Counter(bin_col)
        mode = "tie" if c[0] == c[1] else (1 if c[1] > c[0] else 0)
        stats.append((mean, mode, c[1], c[0]))
    return stats

# ---------------------------------------------------------------------------
# Dataset processing
# ---------------------------------------------------------------------------

def process_dataset(base: Path, model: str, ds: str):
    # Method-4 outputs
    path = base / model / ds / "algorithmic_eval_results_4.json"
    if not path.is_file():
        print(f"[WARN] {ds}: file not found")
        return None
    try:
        data = json.loads(path.read_text("utf-8"))
    except json.JSONDecodeError as e:
        print(f"[ERR ] {ds}: json decode {e}")
        return None
    recs = [_safe_parse_entry(r) for r in data]
    valid = [r for r in recs if r]
    skipped = len(data) - len(valid)
    if not valid:
        print(f"[WARN] {ds}: no valid rows")
        return None
    means = _per_question_means(valid)
    stats = _compute_stats(valid)
    out_log = path.with_name("algorithmic_eval_results_4_stats.log")
    with out_log.open("w", encoding="utf-8") as f:
        for i, (mean, mode, c1, c0) in enumerate(stats):
            f.write(
                f"Q{i+1}: {QUESTIONS[i]}\n"
                f"  mean = {mean:.3f}\n"
                f"  mode = {mode}\n"
                f"  ≥0.5 = {c1}\n"
                f"  <0.5 = {c0}\n\n"
            )
    print(f"[OK  ] {ds:25} | success {means[FINAL_IDX]:6.2%} | valid {len(valid):3d} skip {skipped:3d}")
    return {
        "dataset": ds,
        "alg%": means[FINAL_IDX],   # keep key name for downstream scripts
        "means": means,
        "valid": len(valid),
        "skip": skipped,
    }

# ---------------------------------------------------------------------------
# Table printer
# ---------------------------------------------------------------------------

def print_leaderboard(rows: List[Dict]):
    if not rows:
        print("No results.")
        return
    # widths
    name_w = max(len(r["dataset"]) for r in rows)
    cols = [
        ("rank", 4),
        ("dataset", name_w),
        ("alg%", 7),
        *[(f"Q{i+1}", 10) for i in range(NUM_Q - 0)],  # 10 columns
        ("valid", 6),
        ("skip", 5),
    ]
    header = " | ".join(k.upper().ljust(w) for k, w in cols)
    print("\n" + header)
    print("-" * len(header))

    for r in rows:
        parts = []
        for key, width in cols:
            if key.startswith("Q"):
                parts.append(str(r[key]).rjust(width))
            elif key == "dataset":
                parts.append(r[key].ljust(width))
            elif key == "alg%":
                parts.append(f"{r['alg%']:6.2%}".rjust(width))
            else:
                parts.append(str(r[key]).rjust(width))
        print(" | ".join(parts))

# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", default="gemini-2.0-flash")
    ap.add_argument("datasets", nargs="*")
    ap.add_argument("--base_dir", default="logs")
    args = ap.parse_args()

    base = Path(args.base_dir)
    if not base.is_dir():
        raise SystemExit(f"Base dir {base} missing")

    if args.datasets:
        datasets = args.datasets
    elif args.model == "gemini-2.0-flash":
        datasets = DEFAULT_DATASETS_GEMINI
    else:
        datasets = [d.name for d in (base / args.model).iterdir() if d.is_dir()]

    results: List[Dict] = []
    for ds in datasets:
        res = process_dataset(base, args.model, ds)
        if res:
            results.append(res)

    # compute per-question rankings (Q1–Q10)
    for q in range(NUM_Q):
        ranked = sorted(results, key=lambda r: r["means"][q], reverse=True)
        for rank, row in enumerate(ranked, start=1):
            pct = row["means"][q] * 100
            row[f"Q{q+1}"] = f"{rank} ({pct:.0f}%)"

    # final sort by overall success likelihood
    results.sort(key=lambda r: r["alg%"], reverse=True)
    for idx, row in enumerate(results, start=1):
        row["rank"] = idx

    print_leaderboard(results)


if __name__ == "__main__":
    main()
