#!/usr/bin/env python3
"""
lomo_generalization_metrics.py
------------------------------
Compute Leave-One-Model-Out (LOMO) generalization metrics for Overton coverage, and
also output a per-model delta table comparing **human adjusted coverage** vs **LOMO-substituted adjusted coverage**.

Inputs:
  --data       : Path to CSV for human data (same as prediction/baselines). If unset, load from Hugging Face.
  --human_csv  : Deprecated alias for --data; use --data.
  --source     : HF split when loading human from HF: full (default), modelslant, or prism. Also used for preds_csv and save_deltas naming.
  --preds_csv  : CSV with judge predictions (same keys). With --source modelslant/prism, _source is appended to default path.
  --pred_col   : name of prediction column in preds_csv (e.g., gemini_fr+fs_avg)
  --cluster_col: name of cluster column in human_csv (default: cluster_kmeans)
  --tau        : coverage threshold (default: 4.0)
  --save_deltas: optional CSV path to save a per-model table with human adj_coverage,
                 LOMO adj_coverage (for that model's fold), and the delta (LOMO - human).

What it computes:
  • Human baseline: OLS (LPM) with question FEs, cluster-robust SEs; adjusted coverage = average prediction
    first within question, then across questions (equal weight per question).
  • LOMO folds: for each target model, substitute that model’s ratings with predictions, recompute adjusted
    coverage, refit OLS, and compare to the human baseline.
  • Generalization metrics per fold (rank correlations, coef correlations/MAE, direction agreement, target sig replication).
  • A **delta table**: one row per model, showing human adj_coverage vs that model's LOMO adj_coverage and their difference.

Note: This script focuses on metrics + an optional delta CSV. It prints summaries to stdout.
"""

import sys
from pathlib import Path
_src = Path(__file__).resolve().parent.parent
if str(_src) not in sys.path:
    sys.path.insert(0, str(_src))
from load_dataset import get_overtonbench_data
import helper_functions

import numpy as np
import pandas as pd
from scipy.stats import spearmanr, kendalltau, pearsonr
import statsmodels.formula.api as smf


def build_coverage(df: pd.DataFrame, rating_col: str, cluster_col: str, tau: float) -> pd.DataFrame:
    # Mean rating per (q, m, cluster), then threshold to covered
    gm = (df.groupby(["question_id", "model", cluster_col], as_index=False)[rating_col]
            .mean()
            .rename(columns={rating_col: "mean_rating"}))
    gm["covered"] = (gm["mean_rating"] >= tau).astype(int)
    return gm


def ols_with_qfe(clustered: pd.DataFrame) -> pd.DataFrame:
    # OLS with question fixed-effects, cluster-robust SEs by question
    model = smf.ols("covered ~ 0 + C(model) + C(question_id)", data=clustered)
    res = model.fit(cov_type="cluster", cov_kwds={"groups": clustered["question_id"]})

    # Adjusted coverage (equal weight per question): mean within question, then across questions
    clustered = clustered.copy()
    clustered["pred"] = res.predict(clustered)
    q_model = (clustered.groupby(["question_id", "model"], as_index=False)["pred"]
                        .mean()
                        .rename(columns={"pred": "pred_q"}))
    adj_cov = (q_model.groupby("model", as_index=False)["pred_q"]
                      .mean()
                      .rename(columns={"pred_q": "adj_coverage"}))

    # Model-effects vs. grand mean (linear contrasts)
    params = res.params.copy()
    idx = list(params.index)
    model_params = [nm for nm in idx if nm.startswith("C(model)")]
    if not model_params:
        raise RuntimeError("No model effects found in regression.")
    M = len(model_params)

    rows = []
    for nm in model_params:
        L = np.zeros((len(idx),), dtype=float)
        L[idx.index(nm)] = 1.0
        for jn in model_params:
            L[idx.index(jn)] -= 1.0 / M
        tt = res.t_test(L)
        m_name = nm.split("[T.", 1)[1][:-1] if "[T." in nm else nm.split("[", 1)[1][:-1]
        rows.append({
            "model": m_name,
            "coef_vs_grand_mean": float(tt.effect.item()),
            "se": float(tt.sd.item()),
            "t": float(tt.tvalue.item()),
            "p": float(tt.pvalue.item()),
        })
    infer = pd.DataFrame(rows)
    return adj_cov.merge(infer, on="model", how="outer").sort_values("model").reset_index(drop=True)


def fold_metrics(human_tab: pd.DataFrame, fold_tab: pd.DataFrame, target: str) -> dict:
    # Join human vs fold tables
    comp = fold_tab.merge(
        human_tab.rename(columns={
            "adj_coverage": "adj_coverage_human",
            "coef_vs_grand_mean": "coef_vs_grand_mean_human",
            "p": "p_human"
        })[["model", "adj_coverage_human", "coef_vs_grand_mean_human", "p_human"]],
        on="model", how="left"
    ).sort_values("model")

    # Rank correlations (adjusted coverage)
    rho, _ = spearmanr(comp["adj_coverage_human"], comp["adj_coverage"])
    tau, _ = kendalltau(comp["adj_coverage_human"], comp["adj_coverage"])

    # Coef vector agreement
    r, _ = pearsonr(comp["coef_vs_grand_mean_human"], comp["coef_vs_grand_mean"])
    mae = (comp["coef_vs_grand_mean_human"] - comp["coef_vs_grand_mean"]).abs().mean()
    sign_h = np.sign(comp["coef_vs_grand_mean_human"]).astype(int)
    sign_f = np.sign(comp["coef_vs_grand_mean"]).astype(int)
    dir_agree = float((sign_h == sign_f).mean() * 100.0)

    # Target significance replication
    row_t = comp[comp["model"] == target].iloc[0]
    sig_h = (row_t["p_human"] < 0.05) and (np.sign(row_t["coef_vs_grand_mean_human"]) != 0)
    sig_f = (row_t["p"] < 0.05) and (np.sign(row_t["coef_vs_grand_mean"]) != 0)
    same_sign = np.sign(row_t["coef_vs_grand_mean_human"]) == np.sign(row_t["coef_vs_grand_mean"])
    target_sig_rep = bool(sig_h and sig_f and same_sign)

    return {
        "fold_target": target,
        "spearman_rho": float(rho),
        "kendall_tau": float(tau),
        "pearson_r_coef": float(r),
        "mae_coef": float(mae),
        "direction_agreement_pct": dir_agree,
        "target_sig_replication": target_sig_rep
    }


def main():
    import argparse
    ap = argparse.ArgumentParser(description="LOMO generalization metrics for Overton coverage + delta table.")
    ap.add_argument("--human_csv", default=None, help="Path to human data CSV; default = load from Hugging Face (deprecated: prefer --data).")
    ap.add_argument("--data", default=None,
                    help="Path to CSV to use instead of Hugging Face (overrides .env DATASET). Same as prediction/baselines; use for human data.")
    _default_preds_csv = "outputs/predictions/gemini_all_rows_fr+fs.csv"
    ap.add_argument("--preds_csv", default=_default_preds_csv,
                    help="Path to judge predictions CSV. When --source is modelslant/prism and this is left at default, _source is appended to the path. If your CSV has a different prediction column name, pass --pred_col too.")
    ap.add_argument("--pred_col", default="gemini_fr+fs_avg",
                    help="Name of the prediction column in preds_csv (default: gemini_fr+fs_avg). Must match the column in your CSV when using a custom --preds_csv.")
    ap.add_argument("--cluster_col", default="cluster_kmeans")
    ap.add_argument("--tau", type=float, default=4.0)
    ap.add_argument("--save_deltas", default=None, help="Optional output CSV for per-model LOMO deltas.")
    ap.add_argument("--source", default=None,
                    help="Question source split when loading from HF: full (default), modelslant, or prism. Used for human load and for preds_csv/save_deltas naming.")
    args = ap.parse_args()

    source = (args.source or "full").strip().lower() if args.source else "full"
    path = args.data or args.human_csv

    def _path_with_source(path_str: str) -> str:
        if source not in ("modelslant", "prism"):
            return path_str
        if "." in path_str:
            base, ext = path_str.rsplit(".", 1)
            return f"{base}_{source}.{ext}"
        return f"{path_str}_{source}"

    helper_functions.set_data_options(path=path, source_split=args.source)
    human = get_overtonbench_data(path=path, source_split=args.source or "full")
    # Only append _source to preds path when using the default path; explicit --preds_csv is used as-is
    preds_path = _path_with_source(args.preds_csv) if (source in ("modelslant", "prism") and args.preds_csv == _default_preds_csv) else args.preds_csv
    preds = pd.read_csv(preds_path)
    merged = human.merge(
        preds[["user", "question_id", "model", args.pred_col]],
        on=["user", "question_id", "model"],
        how="inner"
    )

    # Human baseline table
    human_cov = build_coverage(merged, "representation_rating", args.cluster_col, args.tau)
    human_tab = ols_with_qfe(human_cov)  # has adj_coverage + model-effect stats

    models = sorted(merged["model"].unique().tolist())
    per_fold_metrics = []
    delta_rows = []

    for target in models:
        df_fold = merged.copy()
        df_fold["rating_lomo"] = np.where(df_fold["model"] == target,
                                          df_fold[args.pred_col],
                                          df_fold["representation_rating"])
        fold_cov = build_coverage(df_fold, "rating_lomo", args.cluster_col, args.tau)
        fold_tab = ols_with_qfe(fold_cov)

        # Metrics
        per_fold_metrics.append(fold_metrics(human_tab, fold_tab, target))

        # Delta table row: target model's LOMO adj_coverage minus human
        lomo_val = float(fold_tab.loc[fold_tab["model"] == target, "adj_coverage"].values[0])
        human_val = float(human_tab.loc[human_tab["model"] == target, "adj_coverage"].values[0])
        delta_rows.append({
            "model": target,
            "adj_coverage_human": human_val,
            "adj_coverage_lomo": lomo_val,
            "delta": lomo_val - human_val
        })

    per_fold_df = pd.DataFrame(per_fold_metrics)

    # Aggregated means across folds
    agg = {
        "spearman_rho_mean": per_fold_df["spearman_rho"].mean(),
        "kendall_tau_mean": per_fold_df["kendall_tau"].mean(),
        "pearson_r_coef_mean": per_fold_df["pearson_r_coef"].mean(),
        "mae_coef_mean": per_fold_df["mae_coef"].mean(),
        "direction_agreement_pct_mean": per_fold_df["direction_agreement_pct"].mean(),
        "target_sig_replication_rate_pct": per_fold_df["target_sig_replication"].mean() * 100.0
    }

    print("=== LOMO Generalization Metrics ===")
    print(f"tau={args.tau} | cluster_col={args.cluster_col} | pred_col={args.pred_col}")
    print("\nPer-fold metrics:")
    print(per_fold_df.to_string(index=False, justify='center', float_format=lambda x: f'{x:.6f}'))
    print("\nAggregated means across folds:")
    for k, v in agg.items():
        if "pct" in k:
            print(f"  {k}: {v:.2f}%")
        else:
            print(f"  {k}: {v:.6f}")

    # Delta table
    delta_df = pd.DataFrame(delta_rows).sort_values("adj_coverage_human", ascending=False).reset_index(drop=True)
    print("\nPer-model LOMO deltas (adj_coverage_lomo - adj_coverage_human):")
    print(delta_df.to_string(index=False, justify='center', float_format=lambda x: f'{x:.3f}'))

    if args.save_deltas:
        out_path = _path_with_source(args.save_deltas)
        Path(out_path).parent.mkdir(parents=True, exist_ok=True)
        delta_df.to_csv(out_path, index=False)
        print(f"\n[OK] Saved delta table to {out_path}")


if __name__ == "__main__":
    main()
