"""Marginal gain / stopping statistics used for Fig. 7.4-style plots."""
from __future__ import annotations
import numpy as np
import pandas as pd

def compute_delta(
    curve: pd.DataFrame,
    group_cols: list[str],
    c_col: str = "probe_c",
    v_col: str = "V_opt",
) -> pd.DataFrame:
    """Compute discrete marginal gain: Δ(c) = V(c) - V(next c)."""
    df = curve.sort_values(group_cols + [c_col]).copy()
    df["_v_next"] = df.groupby(group_cols, dropna=False)[v_col].shift(-1)
    df["_c_next"] = df.groupby(group_cols, dropna=False)[c_col].shift(-1)
    df["delta"] = (df[v_col] - df["_v_next"]).clip(lower=0.0)
    return df

def normalize_delta(
    df: pd.DataFrame,
    group_cols: list[str],
    method: str = "max",
    eps: float = 1e-12,
) -> pd.DataFrame:
    """Normalize delta per group for comparability across tasks.

    method:
      - 'max': divide by max delta within group
      - 'total': divide by sum of deltas (approx total reducible variance)
    """
    out = df.copy()
    if method == "max":
        denom = out.groupby(group_cols, dropna=False)["delta"].transform(lambda s: np.nanmax(s.to_numpy()))
    elif method == "total":
        denom = out.groupby(group_cols, dropna=False)["delta"].transform(lambda s: np.nansum(s.to_numpy()))
    else:
        raise ValueError(f"Unknown normalization method: {method}")
    denom = denom.replace(0.0, np.nan)
    out["delta_norm"] = out["delta"] / (denom + eps)
    return out
