from __future__ import annotations
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
import argparse
import json
from pathlib import Path

import numpy as np
import pandas as pd

from ftpredict.dataio import load_synthetic
from ftpredict.column_map import SYN_RUNS, SYN_CURVES
from ftpredict.estimation import estimate_L_from_runs, add_intrinsic_and_vopt
from ftpredict.powerlaw import fit_power_law_by_group
from ftpredict.repro import get_env_metadata
from ftpredict import plots

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_dir", type=str, default="data")
    ap.add_argument("--out_dir", type=str, default="figures")
    args = ap.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    data = load_synthetic(args.data_dir)

    # Prefer released synthetic_risk_curves.csv if available (already aggregated)
    curves = data.risk_curves.copy()
    if {"task_id","probe_c","L_hat"}.issubset(curves.columns):
        curve = curves.rename(columns={"task_id":"task_id","probe_c":"probe_c","L_hat":"L_hat"})
        curve = curve.sort_values(["task_id","probe_c"])
        # enforce monotonicity per task
        # reuse add_intrinsic_and_vopt; requires L_hat in curve already
        curve = add_intrinsic_and_vopt(curve, group_cols=["task_id"], c_col="probe_c")
    else:
        # fall back to run_logs
        runs = data.run_logs.copy()
        curve = estimate_L_from_runs(
            runs,
            group_cols=[SYN_RUNS.task],
            c_col=SYN_RUNS.c,
            squared_error_col=SYN_RUNS.squared_error,
            outcome_col=SYN_RUNS.outcome,
            enforce_monotone=True,
        )
        curve = add_intrinsic_and_vopt(curve, group_cols=[SYN_RUNS.task], c_col=SYN_RUNS.c)

    fits = fit_power_law_by_group(curve, group_cols=["task_id"], c_col="probe_c", v_col="V_opt")
    points = fits.merge(curve.groupby("task_id", as_index=False)["L_int"].first(), on="task_id", how="left")

    # Synthetic: Figure S1 (decay curves by true regime if present)
    if "regime_true" in curves.columns:
        curve = curve.merge(curves[["task_id","regime_true"]].drop_duplicates(), on="task_id", how="left")
        plots.fig_population_decay_by_regime(curve.rename(columns={"regime_true":"regime"}), out_dir/"Fig_S1_synth_decay_by_regime.pdf", regime_col="regime", c_col="probe_c")

    # Synthetic: Figure S4 (phase diagram, colored by true regime if present)
    if "regime_true" in curves.columns:
        pts = points.merge(curves[["task_id","regime_true"]].drop_duplicates(), on="task_id", how="left").rename(columns={"regime_true":"regime"})
        plots.fig_phase_diagram(pts.rename(columns={"alpha_hat":"alpha_hat","L_int":"L_int"}), out_dir/"Fig_S4_synth_phase_diagram.pdf", x="L_int", y="alpha_hat", hue="regime")
    else:
        plots.fig_phase_diagram(points, out_dir/"Fig_S4_synth_phase_diagram.pdf", x="L_int", y="alpha_hat", hue=None)

    manifest = {
        "generated_at": str(pd.Timestamp.utcnow()),
        "inputs": {
            "synthetic_task_specs.csv": "synthetic_task_specs.csv",
            "synthetic_run_logs.csv": "synthetic_run_logs.csv",
            "synthetic_risk_curves.csv": "synthetic_risk_curves.csv",
            "synthetic_schema.md": "synthetic_schema.md",
        },
        "env": get_env_metadata(),
        "outputs": [
            "Fig_S1_synth_decay_by_regime.pdf",
            "Fig_S4_synth_phase_diagram.pdf",
        ],
    }
    (out_dir/"manifest_synth.json").write_text(json.dumps(manifest, indent=2))

if __name__ == "__main__":
    main()