from __future__ import annotations
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
import argparse
import json
from pathlib import Path

import numpy as np
import pandas as pd
import yaml

from ftpredict.dataio import load_real
from ftpredict.column_map import REAL_RUNS
from ftpredict.estimation import estimate_L_from_runs, add_intrinsic_and_vopt
from ftpredict.powerlaw import fit_power_law_by_group
from ftpredict.regimes import assign_regime
from ftpredict.marginal_gain import compute_delta, normalize_delta
from ftpredict.repro import get_env_metadata
from ftpredict import plots

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_dir", type=str, default="data")
    ap.add_argument("--out_dir", type=str, default="figures")
    ap.add_argument("--config_regimes", type=str, default="configs/regime_thresholds.yaml")
    ap.add_argument("--probe_c_for_ablation", type=int, default=200)
    ap.add_argument("--hard_quantile", type=float, default=0.95, help="top quantile of L_int to mark as hard cases")
    ap.add_argument("--delta_norm", type=str, default="max", choices=["max","total"])
    args = ap.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    cfg = yaml.safe_load(Path(args.config_regimes).read_text())
    data = load_real(args.data_dir)

    runs = data.runs.copy()

    # ---------------------------
    # (7.2) Risk decomposition curves: L_hat(c), L_int, V_opt(c)
    # ---------------------------
    required = {REAL_RUNS.dataset, REAL_RUNS.c}
    missing = required - set(runs.columns)
    if missing:
        raise ValueError(f"{REAL_RUNS.dataset}/{REAL_RUNS.c} missing in metadataset_Risk.csv: {missing}")

    curve = estimate_L_from_runs(
        runs,
        group_cols=[REAL_RUNS.dataset],
        c_col=REAL_RUNS.c,
        squared_error_col=REAL_RUNS.squared_error,
        outcome_col=REAL_RUNS.outcome,
        enforce_monotone=True,
    )
    curve = add_intrinsic_and_vopt(curve, group_cols=[REAL_RUNS.dataset], c_col=REAL_RUNS.c)

    # ---------------------------
    # (7.2) Power-law fit per dataset to obtain alpha_hat
    # ---------------------------
    fits = fit_power_law_by_group(curve, group_cols=[REAL_RUNS.dataset], c_col=REAL_RUNS.c, v_col="V_opt")
    points = fits.merge(
        curve.groupby(REAL_RUNS.dataset, as_index=False)["L_int"].first(),
        on=REAL_RUNS.dataset,
        how="left",
    )

    # Regime assignment (post hoc; visualization only)
    points["regime"], thr = assign_regime(points.rename(columns={REAL_RUNS.dataset:"dataset_name"}), cfg)
    # restore dataset col name if needed
    points = points.rename(columns={"dataset_name": REAL_RUNS.dataset})

    # attach regime to curve
    curve = curve.merge(points[[REAL_RUNS.dataset, "regime"]], on=REAL_RUNS.dataset, how="left")

    # ---------------------------
    # Fig 7.1: population decay curves by regime (all tasks, not cherry-picked)
    # ---------------------------
    plots.fig_population_decay_by_regime(
        curve=curve, out_path=out_dir/"Fig_7_1_population_decay_full.pdf", regime_col="regime", c_col=REAL_RUNS.c
    )

    # ---------------------------
    # Fig 7.3: phase diagram (L_int, alpha_hat)
    # ---------------------------
    plots.fig_phase_diagram(
        points=points.rename(columns={"alpha_hat":"alpha_hat", "L_int":"L_int"}),
        out_path=out_dir/"Fig_7_3_phase_diagram.pdf",
        x="L_int",
        y="alpha_hat",
        hue="regime",
    )

    # ---------------------------
    # Fig 7.4: normalized marginal gain (efficiency frontier proxy)
    # ---------------------------
    d = compute_delta(curve, group_cols=[REAL_RUNS.dataset], c_col=REAL_RUNS.c, v_col="V_opt")
    d = normalize_delta(d, group_cols=[REAL_RUNS.dataset], method=args.delta_norm)
    d = d.dropna(subset=["delta_norm", "_c_next"])
    marg_reg = d.groupby(["regime", REAL_RUNS.c], as_index=False)["delta_norm"].mean()
    plots.fig_efficiency_frontier(marg_reg, out_dir/"Fig_7_4_efficiency_frontier.pdf", regime_col="regime", c_col=REAL_RUNS.c)

    # ---------------------------
    # Appendix figures
    # ---------------------------
    plots.fig_plateau_by_regime(curve, out_dir/"Fig_7_2_L_plateau_by_regime.pdf", regime_col="regime", c_col=REAL_RUNS.c)
    plots.fig_alpha_distribution(points.rename(columns={"alpha_hat":"alpha_hat"}), out_dir/"Fig_7_2b_alpha_distribution.pdf")

    # Ablation A1:
    # If explicit estimator outputs exist, replace these proxies accordingly.
    df_c = runs[runs[REAL_RUNS.c] == args.probe_c_for_ablation].copy()
    if {REAL_RUNS.outcome, REAL_RUNS.pred}.issubset(df_c.columns):
        df_c["mse_hybrid"] = (df_c[REAL_RUNS.outcome] - df_c[REAL_RUNS.pred])**2
    elif REAL_RUNS.squared_error in df_c.columns:
        df_c["mse_hybrid"] = df_c[REAL_RUNS.squared_error]
    else:
        raise ValueError("Need (R_true,R_pred) or squared_error for ablation.")

    # Static-only proxy: per-dataset mean predictor
    if REAL_RUNS.outcome in df_c.columns:
        static_pred = df_c.groupby(REAL_RUNS.dataset)[REAL_RUNS.outcome].transform("mean")
        df_c["mse_static"] = (df_c[REAL_RUNS.outcome] - static_pred)**2
    else:
        df_c["mse_static"] = df_c["mse_hybrid"]

    # Dynamic-only proxy: per-(dataset,seed) mean at this c if seed exists
    if REAL_RUNS.seed in df_c.columns and REAL_RUNS.outcome in df_c.columns:
        dyn_pred = df_c.groupby([REAL_RUNS.dataset, REAL_RUNS.seed])[REAL_RUNS.outcome].transform("mean")
        df_c["mse_dynamic"] = (df_c[REAL_RUNS.outcome] - dyn_pred)**2
    else:
        df_c["mse_dynamic"] = df_c["mse_hybrid"]

    df_c = df_c.merge(points[[REAL_RUNS.dataset, "regime"]], on=REAL_RUNS.dataset, how="left")
    ab_rows = []
    for reg, g in df_c.groupby("regime"):
        for est, col in [("static-only","mse_static"),("dynamic-only","mse_dynamic"),("hybrid","mse_hybrid")]:
            vals = g[col].to_numpy()
            mse_mean = float(np.nanmean(vals))
            mse_se = float(np.nanstd(vals) / np.sqrt(np.sum(np.isfinite(vals))))
            ab_rows.append({"regime": reg, "estimator": est, "mse_mean": mse_mean, "mse_se": mse_se})
    ab = pd.DataFrame(ab_rows)
    plots.fig_ablation_a1(ab, out_dir/"Fig_7_4_ablation_A1_static_dynamic_hybrid.pdf")

    # Hard cases on phase diagram (top quantile by L_int)
    thr_hard = float(points["L_int"].quantile(args.hard_quantile))
    hard = points[points["L_int"] >= thr_hard].copy()
    plots.fig_hard_cases_on_phase(hard, points, out_dir/"Fig_7_8_hard_cases_on_phase.pdf")

    # ---------------------------
    # Write a reproducibility manifest
    # ---------------------------
    manifest = {
        "generated_at": str(pd.Timestamp.utcnow()),
        "inputs": {
            "metadataset_Risk.csv": "metadataset_Risk.csv",
            "risk_curve_by_dataset.csv": "risk_curve_by_dataset.csv",
        },
        "regime_thresholds": thr,
        "delta_norm": args.delta_norm,
        "probe_c_for_ablation": args.probe_c_for_ablation,
        "hard_quantile": args.hard_quantile,
        "env": get_env_metadata(),
        "outputs": [
            "Fig_7_1_population_decay_full.pdf",
            "Fig_7_3_phase_diagram.pdf",
            "Fig_7_4_efficiency_frontier.pdf",
            "Fig_7_2_L_plateau_by_regime.pdf",
            "Fig_7_2b_alpha_distribution.pdf",
            "Fig_7_4_ablation_A1_static_dynamic_hybrid.pdf",
            "Fig_7_8_hard_cases_on_phase.pdf",
        ],
    }
    (out_dir/"manifest_real.json").write_text(json.dumps(manifest, indent=2))

if __name__ == "__main__":
    main()