"""Run statistical tests (BCa CI, sign-flip, Hedges' g, ROPE) on all checkpoint comparisons."""

import os
import sys
import numpy as np
import torch

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from src.stats import full_comparison_report, format_report

CKPT_DIR = os.path.join(os.path.dirname(__file__), "..", "checkpoints")

ROPE_LL = 0.005
ROPE_ACC = 0.01
ROPE_CORR = 0.05


def load_ckpt(name):
    path = os.path.join(CKPT_DIR, name)
    if not os.path.exists(path):
        return None
    return torch.load(path, weights_only=False)


def test_ablation_seeds():
    ck = load_ckpt("ablation_seeds.pt")
    if ck is None:
        print("[skipped] ablation_seeds.pt not found")
        return []

    results = ck["results"]
    seeds = ck["seeds"]
    k_values = ck["k_values"]
    reports = []

    for k in k_values:
        gru_ll = np.array([results[("gru", k, s)]["ll"] for s in seeds])
        mlp_ll = np.array([results[("mlp", k, s)]["ll"] for s in seeds])
        gru_acc = np.array([results[("gru", k, s)]["acc"] for s in seeds])
        mlp_acc = np.array([results[("mlp", k, s)]["acc"] for s in seeds])

        baseline_ll = np.mean(np.abs(np.concatenate([gru_ll, mlp_ll])))
        rope_ll = baseline_ll * 0.005

        reports.append(full_comparison_report(
            gru_ll, mlp_ll,
            rope_lo=-rope_ll, rope_hi=rope_ll,
            label=f"GRU vs MLP  k={k}  LL"))
        reports.append(full_comparison_report(
            gru_acc, mlp_acc,
            rope_lo=-ROPE_ACC, rope_hi=ROPE_ACC,
            label=f"GRU vs MLP  k={k}  Accuracy"))

    return reports


def test_arch_baselines():
    ck = load_ckpt("arch_baselines.pt")
    if ck is None:
        print("[skipped] arch_baselines.pt not found")
        return []

    results = ck["results"]
    seeds = ck["seeds"]
    others = ["lstm", "gtrxl", "mingru"]
    metrics = [
        ("ll",        ROPE_LL,   None),
        ("acc",       ROPE_ACC,  None),
        ("probe_acc", ROPE_ACC,  None),
        ("pc1_rho",   ROPE_CORR, None),
    ]
    reports = []

    for arch in others:
        for metric, rope_default, _ in metrics:
            gru_vals = np.array([results[("gru", s)][metric] for s in seeds])
            alt_vals = np.array([results[(arch, s)][metric] for s in seeds])

            if metric == "ll":
                baseline = np.mean(np.abs(np.concatenate([gru_vals, alt_vals])))
                rope = baseline * 0.005
            else:
                rope = rope_default

            reports.append(full_comparison_report(
                gru_vals, alt_vals,
                rope_lo=-rope, rope_hi=rope,
                label=f"GRU vs {arch.upper()}  {metric}"))

    return reports


def test_dim_sweep():
    ck = load_ckpt("dim_sweep_results.pt")
    if ck is None:
        print("[skipped] dim_sweep_results.pt not found")
        return []

    results = ck["results"]
    dims = sorted(ck["dims"])
    metrics = [
        ("ll",        ROPE_LL),
        ("probe_acc", ROPE_ACC),
        ("pc1_rho",   ROPE_CORR),
    ]
    reports = []

    for i in range(len(dims) - 1):
        d1, d2 = dims[i], dims[i + 1]
        for metric, rope_default in metrics:
            v1 = np.array(results[d1][metric])
            v2 = np.array(results[d2][metric])

            if metric == "ll":
                baseline = np.mean(np.abs(np.concatenate([v1, v2])))
                rope = baseline * 0.005
            else:
                rope = rope_default

            reports.append(full_comparison_report(
                v2, v1,  # larger dim - smaller dim
                rope_lo=-rope, rope_hi=rope,
                label=f"dim {d2} vs {d1}  {metric}"))

    return reports


def test_leave_one_out():
    ck = load_ckpt("leave_one_out.pt")
    if ck is None:
        print("[skipped] leave_one_out.pt not found")
        return []

    animals = ck["animals"]
    cross_r = ck["cross_results"]
    within_r = ck["within_results"]

    if len(animals) < 2:
        print(f"[skipped] leave_one_out: only {len(animals)} animal(s), need >= 2")
        return []

    metrics = [
        ("ll",        ROPE_LL),
        ("acc",       ROPE_ACC),
        ("probe_acc", ROPE_ACC),
    ]
    reports = []

    for metric, rope_default in metrics:
        cross_vals = np.array([cross_r[a][metric] for a in animals])
        within_vals = np.array([within_r[a][metric] for a in animals])

        if metric == "ll":
            baseline = np.mean(np.abs(np.concatenate([cross_vals, within_vals])))
            rope = baseline * 0.005
        else:
            rope = rope_default

        reports.append(full_comparison_report(
            cross_vals, within_vals,
            rope_lo=-rope, rope_hi=rope,
            label=f"cross vs within  {metric}"))

    return reports


def main():
    all_reports = {}

    sections = [
        ("Ablation Seeds (GRU vs MLP)", test_ablation_seeds),
        ("Architecture Baselines (GRU vs others)", test_arch_baselines),
        ("Dimension Sweep (adjacent pairs)", test_dim_sweep),
        ("Leave-One-Out (cross vs within)", test_leave_one_out),
    ]

    for title, fn in sections:
        print(f"\n{title} \n")
        reports = fn()
        all_reports[title] = reports
        for r in reports:
            print(format_report(r))
            print()

    out_path = os.path.join(CKPT_DIR, "statistical_tests.pt")
    torch.save(all_reports, out_path)
    print(f"Saved {sum(len(v) for v in all_reports.values())} comparisons "
          f"to {out_path}")


if __name__ == "__main__":
    main()
