"""Bootstrap utilities for uncertainty bands.

We bootstrap at the *unit of independence* (e.g., seeds within a task/dataset).
"""
from __future__ import annotations
import numpy as np
import pandas as pd
from typing import Callable, Tuple

def bootstrap_ci(
    values: np.ndarray,
    stat_fn: Callable[[np.ndarray], float] = np.mean,
    n_boot: int = 1000,
    ci: float = 0.95,
    seed: int = 0,
) -> Tuple[float, float]:
    rng = np.random.default_rng(seed)
    values = values[np.isfinite(values)]
    if len(values) == 0:
        return (float("nan"), float("nan"))
    boots = []
    for _ in range(n_boot):
        samp = rng.choice(values, size=len(values), replace=True)
        boots.append(stat_fn(samp))
    lo = np.quantile(boots, (1-ci)/2)
    hi = np.quantile(boots, 1-(1-ci)/2)
    return float(lo), float(hi)

def bootstrap_curve_ci(
    df: pd.DataFrame,
    group_cols: list[str],
    x_col: str,
    y_col: str,
    unit_col: str,
    stat: str = "mean",
    n_boot: int = 500,
    ci: float = 0.95,
    seed: int = 0,
) -> pd.DataFrame:
    """Compute bootstrap CI for y at each (group,x), resampling over unit_col."""
    rng = np.random.default_rng(seed)
    out_rows = []
    stat_fn = np.mean if stat == "mean" else np.median
    for key, g in df.groupby(group_cols + [x_col], dropna=False):
        units = g[unit_col].dropna().unique()
        if len(units) == 0:
            continue
        # collect per-unit aggregate of y (so unit is iid)
        per_unit = g.groupby(unit_col, dropna=False)[y_col].mean().to_numpy()
        if len(per_unit) == 0:
            continue
        boots = []
        for _ in range(n_boot):
            samp = rng.choice(per_unit, size=len(per_unit), replace=True)
            boots.append(stat_fn(samp))
        lo = float(np.quantile(boots, (1-ci)/2))
        hi = float(np.quantile(boots, 1-(1-ci)/2))
        mid = float(stat_fn(per_unit))
        if not isinstance(key, tuple):
            key = (key,)
        row = {col: val for col, val in zip(group_cols + [x_col], key)}
        row.update({"y": mid, "y_lo": lo, "y_hi": hi, "n_units": int(len(per_unit))})
        out_rows.append(row)
    return pd.DataFrame(out_rows)
