"""Plotting utilities (Matplotlib, PDF outputs)."""
from __future__ import annotations
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def save_pdf(fig, out_path):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, bbox_inches="tight")
    plt.close(fig)

def fig_population_decay_by_regime(curve: pd.DataFrame, out_path, regime_col: str="regime", c_col: str="probe_c"):
    """Fig 7.1 style: population mean V_opt(c) (log-log) by regime."""
    fig = plt.figure(figsize=(6.2, 3.6))
    ax = plt.gca()

    for reg, g in curve.groupby(regime_col):
        # mean over datasets at each c
        gg = g.groupby(c_col)["V_opt"].agg(["mean", "std", "count"]).reset_index()
        gg["se"] = gg["std"] / np.sqrt(np.maximum(gg["count"], 1))
        ax.plot(gg[c_col], gg["mean"], label=str(reg))
        ax.fill_between(gg[c_col], np.maximum(gg["mean"]-gg["se"], 1e-12), gg["mean"]+gg["se"], alpha=0.2)

    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel("Probe compute $c$ (steps)")
    ax.set_ylabel(r"$\widehat{\mathcal{V}}_{\mathrm{opt}}(c)$")
    ax.legend(frameon=False)
    save_pdf(fig, out_path)

def fig_phase_diagram(points: pd.DataFrame, out_path, x="L_int", y="alpha_hat", hue="regime"):
    fig = plt.figure(figsize=(6.2, 3.8))
    ax = plt.gca()
    for reg, g in points.groupby(hue):
        ax.scatter(g[x], g[y], s=18, alpha=0.8, label=str(reg))
    ax.set_xlabel(r"$\widehat{\mathcal{L}}_{\mathrm{int}}$")
    ax.set_ylabel(r"$\widehat{\alpha}$")
    ax.legend(frameon=False, loc="best")
    save_pdf(fig, out_path)

def fig_efficiency_frontier(marginal: pd.DataFrame, out_path, regime_col="regime", c_col="probe_c"):
    """Fig 7.4 style: normalized marginal gain Delta(c) by regime."""
    fig = plt.figure(figsize=(6.2, 3.6))
    ax = plt.gca()
    for reg, g in marginal.groupby(regime_col):
        ax.plot(g[c_col], g["delta_norm"], label=str(reg))
    ax.set_xscale("log")
    ax.set_xlabel("Probe compute $c$ (steps)")
    ax.set_ylabel("Normalized marginal gain $\Delta(c)$")
    ax.legend(frameon=False)
    save_pdf(fig, out_path)

def fig_plateau_by_regime(curve: pd.DataFrame, out_path, regime_col="regime", c_col="probe_c"):
    fig = plt.figure(figsize=(6.2, 3.6))
    ax = plt.gca()
    for reg, g in curve.groupby(regime_col):
        gg = g.groupby(c_col)["L_hat"].agg(["mean","std","count"]).reset_index()
        gg["se"] = gg["std"] / np.sqrt(np.maximum(gg["count"], 1))
        ax.plot(gg[c_col], gg["mean"], label=str(reg))
        ax.fill_between(gg[c_col], gg["mean"]-gg["se"], gg["mean"]+gg["se"], alpha=0.2)
    ax.set_xscale("log")
    ax.set_xlabel("Probe compute $c$ (steps)")
    ax.set_ylabel(r"$\widehat{\mathcal{L}}(c)$")
    ax.legend(frameon=False)
    save_pdf(fig, out_path)

def fig_alpha_distribution(points: pd.DataFrame, out_path, alpha_col="alpha_hat", regime_col="regime"):
    fig = plt.figure(figsize=(6.2, 3.2))
    ax = plt.gca()
    regs = list(points[regime_col].dropna().unique())
    for reg in regs:
        vals = points.loc[points[regime_col]==reg, alpha_col].dropna().to_numpy()
        ax.hist(vals, bins=20, alpha=0.45, label=str(reg))
    ax.set_xlabel(r"Estimated decay exponent $\widehat{\alpha}$")
    ax.set_ylabel("Count")
    ax.legend(frameon=False)
    save_pdf(fig, out_path)

def fig_ablation_a1(ab: pd.DataFrame, out_path, regime_col="regime"):
    """Expected columns: regime, estimator, mse_mean, mse_se"""
    fig = plt.figure(figsize=(6.2, 3.2))
    ax = plt.gca()
    regs = list(ab[regime_col].unique())
    ests = list(ab["estimator"].unique())
    x = np.arange(len(regs))
    width = 0.25 if len(ests)==3 else 0.8/len(ests)
    for i, est in enumerate(ests):
        g = ab[ab["estimator"]==est].set_index(regime_col).loc[regs].reset_index()
        ax.bar(x + (i-(len(ests)-1)/2)*width, g["mse_mean"], width=width, yerr=g["mse_se"], capsize=2, label=est)
    ax.set_xticks(x)
    ax.set_xticklabels([str(r) for r in regs])
    ax.set_ylabel("CV MSE")
    ax.legend(frameon=False)
    save_pdf(fig, out_path)

def fig_hard_cases_on_phase(hard: pd.DataFrame, all_points: pd.DataFrame, out_path, x="L_int", y="alpha_hat"):
    fig = plt.figure(figsize=(6.2, 3.8))
    ax = plt.gca()
    ax.scatter(all_points[x], all_points[y], s=12, alpha=0.25, label="all tasks")
    ax.scatter(hard[x], hard[y], s=28, alpha=0.9, label="hard cases")
    ax.set_xlabel(r"$\widehat{\mathcal{L}}_{\mathrm{int}}$")
    ax.set_ylabel(r"$\widehat{\alpha}$")
    ax.legend(frameon=False)
    save_pdf(fig, out_path)
