
#!/usr/bin/env python3
"""
Analyze TopPR sweeps to find the setup where ΔPR = |P[0]-P[1]| + |R[0]-R[1]|
is (i) **maximum for normal models** and simultaneously (ii) **minimal on Min-Max models**.

"Normal models" here means Generation == "Normal" and Train != "Min-Max".
"Min-Max model" here means Generation == "Normal" and Train == "Min-Max".

You can compute ΔPR using either (Ps, Rs) or (Ps_dino, Rs_dino).

Usage examples:
  python analyze_toppr_delta.py --base-dir evals-test/toppr --metric dino
  python analyze_toppr_delta.py --base-dir evals-test/toppr --metric plain --by-model
"""

import argparse
import json
import math
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd


# ------------------------- helpers to parse naming -------------------------
def infer_model(path: str) -> str:
    if "_vp" in path:
        return "EDM-VP"
    if "_ve" in path:
        return "EDM-VE"
    return "EDM-?"


def infer_train(path: str) -> str:
    if "weak" in path:
        return "RW"
    if "minmax" in path:
        return "Min-Max"
    return "pretrained"


def infer_generation(path: str) -> str:
    if "_mgo" in path:
        return "RS-MGO"
    if "_ego" in path:
        return "RS-EGO"
    return "Normal"


def parse_toppr_dirname(dirname: str) -> Optional[Dict[str, object]]:
    """
    Expected: {model}_toppr_{pcaTag}_a{alpha}_rp{true|false}_l2{true|false}_n{N}_r{reps}
    Where pcaTag is either "pcaFalse" or "pcaTrue_d{dim}".
    """
    if "_toppr_" not in dirname:
        return None

    try:
        model, rest = dirname.split("_toppr_", 1)
        parts = rest.split("_")
        # pca tag can include an underscore (e.g., pcaTrue_d256), so handle it specially
        if parts[0].startswith("pcaFalse"):
            pca = False
            pca_dim = None
            idx = 1
        elif parts[0].startswith("pcaTrue"):
            pca = True
            # Accept pcaTrue or pcaTrue_d{dim}
            if parts[0].startswith("pcaTrue_d"):
                pca_dim = int(parts[0].split("pcaTrue_d")[1])
                idx = 1
            else:
                # if no dim found, default None
                pca_dim = None
                idx = 1
        else:
            # unexpected pca tag
            return None

        # Remaining tokens should be like a{alpha}, rp{true|false}, l2{true|false}, n{N}, r{reps}
        tokens = parts[idx:]
        kv = {}
        for tok in tokens:
            if tok.startswith("a"):
                kv["alpha"] = float(tok[1:])
            elif tok.startswith("rp"):
                rp_val = tok[2:]
                kv["randproj"] = True if rp_val == "true" else False
            elif tok.startswith("l2"):
                l2_val = tok[2:]
                kv["l2norm"] = True if l2_val == "true" else False
            elif tok.startswith("n"):
                kv["n"] = int(tok[1:])
            elif tok.startswith("r"):
                kv["repeats"] = int(tok[1:])

        return {
            "root": model,
            "pca": pca,
            "pca_dim": pca_dim,
            "alpha": kv.get("alpha"),
            "randproj": kv.get("randproj"),
            "l2norm": kv.get("l2norm"),
            "n": kv.get("n"),
            "repeats": kv.get("repeats"),
        }
    except Exception:
        return None


def abs_delta_10(arr: List[float]) -> float:
    """Return |arr[1]-arr[0]| or NaN if not enough data."""
    try:
        if arr is None or len(arr) < 2:
            return float("nan")
        return float(abs(arr[1] - arr[0]))
    except Exception:
        return float("nan")


def compute_delta_pr_from_metrics(metrics: Dict[str, object], metric_kind: str) -> float:
    """
    metric_kind: 'plain' => use 'P' and 'R' arrays
                 'dino'  => use 'P_dino' and 'R_dino' arrays
    """
    key_P = "P" if metric_kind == "plain" else "P_dino"
    key_R = "R" if metric_kind == "plain" else "R_dino"

    # Find how many label-specific entries there are:
    num_labels = 0
    for k in metrics.keys():
        if k.startswith("fid_dino-"):
            try:
                # expect e.g. 'fid_dino-0', 'fid_dino-1', ...
                int(k.split("-")[1])
                num_labels += 1
            except Exception:
                pass

    if num_labels <= 1:
        return float("nan")

    P = [metrics.get(f"{key_P}-{i}", float("nan")) for i in range(num_labels)]
    R = [metrics.get(f"{key_R}-{i}", float("nan")) for i in range(num_labels)]
    dP = abs_delta_10(P)
    dR = abs_delta_10(R)
    if math.isnan(dP) or math.isnan(dR):
        return float("nan")
    return float(dP + dR)


def load_results_jsonl(path: Path) -> Optional[Dict[str, object]]:
    try:
        with path.open("r") as f:
            return json.load(f)
    except FileNotFoundError:
        return None
    except Exception:
        return None


def collect_rows(base_dir: Path, metric_kind: str) -> pd.DataFrame:
    """
    Walk the TopPR sweep outputs under base_dir and build a dataframe with ΔPR per run.
    """
    rows = []
    for dirpath, dirnames, filenames in os.walk(base_dir):
        # Only process leaf directories that contain results_eval.jsonl
        if "results_eval.jsonl" in filenames:
            d = Path(dirpath)
            run_name = d.name
            meta = parse_toppr_dirname(run_name)
            if not meta:
                continue

            metrics_path = d / "results_eval.jsonl"
            metrics = load_results_jsonl(metrics_path)
            if not metrics:
                continue

            delta_pr = compute_delta_pr_from_metrics(metrics, metric_kind=metric_kind)

            root = meta["root"]
            rows.append(
                {
                    "Run": run_name,
                    "Root": root,
                    "Model": infer_model(root),
                    "Train": infer_train(root),
                    "Generation": infer_generation(root),
                    "pca": meta["pca"],
                    "pca_dim": meta["pca_dim"],
                    "alpha": meta["alpha"],
                    "randproj": meta["randproj"],
                    "l2norm": meta["l2norm"],
                    "n": meta["n"],
                    "repeats": meta["repeats"],
                    "DeltaPR": delta_pr,
                    "ResultsPath": str(metrics_path),
                }
            )
    df = pd.DataFrame(rows)
    if not df.empty:
        # Enforce categorical ordering for readability
        df["Model"] = pd.Categorical(df["Model"], categories=["EDM-VP", "EDM-VE"], ordered=True)
        df["Train"] = pd.Categorical(df["Train"], categories=["pretrained", "RW", "Min-Max"], ordered=True)
        df["Generation"] = pd.Categorical(df["Generation"], categories=["Normal", "RS-MGO", "RS-EGO"], ordered=True)
        df = df.sort_values(["Model", "Train", "Generation", "alpha", "randproj", "pca", "pca_dim", "n", "repeats"], ignore_index=True)
    return df


def summarize_by_hparams(df: pd.DataFrame) -> pd.DataFrame:
    """
    Group by hyperparameters (NOT by model/training), then compute:
      - mean ΔPR over Normal (non-Min-Max) runs
      - mean ΔPR over Min-Max runs
      - score = normal_mean - minmax_mean
    Only Generation == "Normal" is considered in both sets.
    """
    if df.empty:
        return df

    # Filters
    mask_normal = (df["Generation"] == "Normal") & (df["Train"] != "Min-Max")
    mask_minmax = (df["Generation"] == "Normal") & (df["Train"] == "Min-Max")

    # Grouping keys (the TopPR hyperparams only)
    keys = ["pca", "pca_dim", "alpha", "randproj", "l2norm", "n", "repeats"]

    # Compute means for each set independently, then merge
    normal_means = (
        df[mask_normal]
        .groupby(keys, dropna=False, as_index=False)["DeltaPR"]
        .mean()
        .rename(columns={"DeltaPR": "DeltaPR_NormalMean"})
    )
    minmax_means = (
        df[mask_minmax]
        .groupby(keys, dropna=False, as_index=False)["DeltaPR"]
        .mean()
        .rename(columns={"DeltaPR": "DeltaPR_MinMaxMean"})
    )

    out = pd.merge(normal_means, minmax_means, on=keys, how="outer")
    out["DeltaPR_NormalMean"] = out["DeltaPR_NormalMean"].astype(float)
    out["DeltaPR_MinMaxMean"] = out["DeltaPR_MinMaxMean"].astype(float)
    out["Score"] = out["DeltaPR_NormalMean"] - out["DeltaPR_MinMaxMean"]
    out = out.sort_values("Score", ascending=False, ignore_index=True)
    return out


def summarize_by_hparams_and_model(df: pd.DataFrame) -> pd.DataFrame:
    """
    Same as summarize_by_hparams, but produce a separate score per Model (EDM-VP vs EDM-VE).
    """
    if df.empty:
        return df

    records = []
    for model, sub in df.groupby("Model"):
        s = summarize_by_hparams(sub)
        s = s.copy()
        s["Model"] = model
        records.append(s)
    if not records:
        return pd.DataFrame()
    out = pd.concat(records, ignore_index=True)
    # Within each model, sort by Score desc for readability
    out = out.sort_values(["Model", "Score"], ascending=[True, False], ignore_index=True)
    return out


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--base-dir", type=str, default="evals-test/toppr", help="Directory containing TopPR sweep runs.")
    ap.add_argument("--metric", type=str, choices=["plain", "dino"], default="dino",
                    help="Which P/R to use when computing ΔPR: 'plain' uses (P, R), 'dino' uses (P_dino, R_dino).")
    ap.add_argument("--by-model", action="store_true", help="Also show the best setup per model (EDM-VP, EDM-VE).")
    ap.add_argument("--topk", type=int, default=10, help="Show top-K setups.")
    args = ap.parse_args()

    base_dir = Path(args.base_dir)
    if not base_dir.exists():
        print(f"[ERROR] Base directory does not exist: {base_dir}")
        return

    df = collect_rows(base_dir, metric_kind=args.metric)
    if df.empty:
        print("[WARN] No runs found (or no results_eval.jsonl files).")
        return

    # Round for display (keep full precision in ranking already computed).
    df_preview = df.head(5).copy()
    print("Example rows:")
    print(df_preview.to_string(index=False))

    print("\n=== Aggregated across all models (maximize Normal mean, minimize Min-Max mean) ===")
    agg = summarize_by_hparams(df)
    if agg.empty:
        print("No aggregations available.")
        return

    # Display top-K
    topk = min(args.topk, len(agg))
    cols = ["pca", "pca_dim", "alpha", "randproj", "l2norm", "n", "repeats",
            "DeltaPR_NormalMean", "DeltaPR_MinMaxMean", "Score"]
    print(agg[cols].head(topk).to_string(index=False))

    if args.by-model:
        print("\n=== Per-model breakdown (EDM-VP / EDM-VE) ===")
        agg_m = summarize_by_hparams_and_model(df)
        if not agg_m.empty:
            # Show top-K per model
            for model, sub in agg_m.groupby("Model"):
                print(f"\n-- {model} --")
                print(sub[cols].head(topk).to_string(index=False))

    # Also write CSVs for convenience
    out_dir = base_dir / "_analysis"
    out_dir.mkdir(parents=True, exist_ok=True)
    df.to_csv(out_dir / "runs_deltapr.csv", index=False)
    agg.to_csv(out_dir / "best_setups_overall.csv", index=False)
    if args.by-model:
        agg_m.to_csv(out_dir / "best_setups_per_model.csv", index=False)

    print(f"\nSaved: {out_dir/'runs_deltapr.csv'}")
    print(f"Saved: {out_dir/'best_setups_overall.csv'}")
    if args.by-model:
        print(f"Saved: {out_dir/'best_setups_per_model.csv'}")


if __name__ == "__main__":
    main()
