"""End-to-end estimators used in Section 7 / Appendix F.

This module implements:
  - Risk curve estimation \hat{L}(c)
  - Intrinsic floor \hat{L}_int := \hat{L}(c_max)
  - Reducible optimization variance \hat{V}_opt(c) := \hat{L}(c) - \hat{L}_int
  - Optional isotonic regression to enforce monotone decrease in c
  - Power-law fit to obtain \hat{\alpha}

The estimators operate on either:
  (A) run-level logs (seed-level rows) via squared_error, or
  (B) outcome variance across seeds when squared_error is unavailable.
"""
from __future__ import annotations
import numpy as np
import pandas as pd
from sklearn.isotonic import IsotonicRegression
from typing import Optional, Tuple

from .powerlaw import fit_power_law_by_group

def isotonic_monotone_decreasing(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    ir = IsotonicRegression(increasing=False, out_of_bounds="clip")
    return ir.fit_transform(x, y)

def estimate_L_from_runs(
    runs: pd.DataFrame,
    group_cols: list[str],
    c_col: str = "probe_c",
    squared_error_col: str = "squared_error",
    outcome_col: str = "R_true",
    enforce_monotone: bool = True,
) -> pd.DataFrame:
    """Estimate \hat{L}(c) per group.

    If squared_error_col exists, use mean squared error directly.
    Otherwise, use sample variance of outcomes across runs/seeds.
    """
    df = runs.copy()
    if squared_error_col in df.columns:
        agg = (
            df.groupby(group_cols + [c_col], dropna=False)[squared_error_col]
            .agg(L_hat="mean", n="count")
            .reset_index()
            .sort_values(group_cols + [c_col])
        )
    else:
        if outcome_col not in df.columns:
            raise ValueError(f"Need either '{squared_error_col}' or '{outcome_col}' in runs.")
        agg = (
            df.groupby(group_cols + [c_col], dropna=False)[outcome_col]
            .agg(R_var=lambda x: float(np.nanvar(x.to_numpy(), ddof=1)) if len(x)>1 else 0.0,
                 n="count",
                 R_mean="mean")
            .reset_index()
            .sort_values(group_cols + [c_col])
        )
        agg["L_hat"] = agg["R_var"]

    if enforce_monotone:
        out = []
        for _, g in agg.groupby(group_cols, dropna=False):
            x = g[c_col].to_numpy()
            y = g["L_hat"].to_numpy()
            g2 = g.copy()
            g2["L_hat"] = isotonic_monotone_decreasing(x, y)
            out.append(g2)
        agg = pd.concat(out, ignore_index=True).sort_values(group_cols + [c_col])
    return agg

def add_intrinsic_and_vopt(curve: pd.DataFrame, group_cols: list[str], c_col: str="probe_c") -> pd.DataFrame:
    curve = curve.copy()
    cmax = curve.groupby(group_cols, dropna=False)[c_col].max().reset_index().rename(columns={c_col:"c_max"})
    curve = curve.merge(cmax, on=group_cols, how="left")
    lint = curve.loc[curve[c_col] == curve["c_max"], group_cols + ["L_hat"]].rename(columns={"L_hat":"L_int"})
    curve = curve.merge(lint, on=group_cols, how="left")
    curve["V_opt"] = (curve["L_hat"] - curve["L_int"]).clip(lower=0.0)
    return curve

def estimate_alpha(points_curve: pd.DataFrame, group_cols: list[str]) -> pd.DataFrame:
    return fit_power_law_by_group(points_curve, group_cols=group_cols)
