"""Estimation routines used throughout Section 7.

We operate in terms of squared error (risk) curves:
  L_hat(c) = mean over runs/seeds of squared_error at probe_c = c

and define:
  L_int := L_hat(c_max)
  V_opt(c) := max(L_hat(c) - L_int, 0)

Optionally apply isotonic regression to enforce monotonicity of L_hat(c).
"""
from __future__ import annotations
import numpy as np
import pandas as pd
from sklearn.isotonic import IsotonicRegression

def isotonic_monotone_decreasing(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    # Enforce y to be non-increasing in x (bigger c -> smaller or equal risk)
    ir = IsotonicRegression(increasing=False, out_of_bounds="clip")
    y_fit = ir.fit_transform(x, y)
    return y_fit

def compute_risk_curve_from_runs(
    df: pd.DataFrame,
    group_cols: list[str],
    c_col: str = "probe_c",
    se_col: str = "squared_error",
    enforce_monotone: bool = True,
) -> pd.DataFrame:
    """Aggregate run-level logs into risk curves.

    Returns a dataframe with columns:
      group_cols + [probe_c, L_hat, n]
    """
    agg = (
        df.groupby(group_cols + [c_col], dropna=False)[se_col]
        .agg(L_hat="mean", n="count")
        .reset_index()
        .sort_values(group_cols + [c_col])
    )
    if enforce_monotone:
        # apply isotonic per group
        out = []
        for _, g in agg.groupby(group_cols, dropna=False):
            x = g[c_col].to_numpy()
            y = g["L_hat"].to_numpy()
            g = g.copy()
            g["L_hat"] = isotonic_monotone_decreasing(x, y)
            out.append(g)
        agg = pd.concat(out, ignore_index=True).sort_values(group_cols + [c_col])
    return agg

def add_intrinsic_and_vopt(curve: pd.DataFrame, group_cols: list[str], c_col: str="probe_c") -> pd.DataFrame:
    """Adds L_int (per group) and V_opt(c)=L_hat-L_int."""
    curve = curve.copy()
    # c_max per group
    cmax = curve.groupby(group_cols)[c_col].max().reset_index().rename(columns={c_col:"c_max"})
    curve = curve.merge(cmax, on=group_cols, how="left")
    lint = curve.loc[curve[c_col] == curve["c_max"], group_cols + ["L_hat"]].rename(columns={"L_hat":"L_int"})
    curve = curve.merge(lint, on=group_cols, how="left")
    curve["V_opt"] = (curve["L_hat"] - curve["L_int"]).clip(lower=0.0)
    return curve
