"""
Audit Protocol for Distributional Parity vs. Tail Disparity
===========================================================

A reproducible, assumption-light auditing toolkit that tests whether
low-order parity (means/variances/positive-rate) can coexist with
distribution- or tail-level disparities across a binary sensitive group.

Implements:
  - Matching/reweighting via propensity scores (with caliper trimming)
  - Balance checks on low-order statistics
  - Tail error gap with within-group bootstrap CIs
  - Weighted two-sample KS test with multiplier bootstrap p-value
  - Groupwise Expected Calibration Error (ECE) and ECE gap with bootstrap CIs
  - Alternative distances: MMD (Gaussian kernel, median heuristic) and Wasserstein-1
  - Multiple-testing control: Benjamini–Hochberg (BH) and Benjamini–Yekutieli (BY)
  - Robustness helpers to vary tail quantile and tolerances

Dependencies:
  - numpy, pandas, scipy, scikit-learn
"""

from __future__ import annotations
import argparse
import warnings
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, List

import numpy as np
import pandas as pd
from scipy.stats import norm
from scipy.stats import wasserstein_distance
from sklearn.linear_model import LogisticRegression, Ridge, LogisticRegressionCV
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler


# --------------------------- Utilities ---------------------------

def _weighted_mean(x: np.ndarray, w: np.ndarray) -> float:
    w = np.asarray(w, float)
    x = np.asarray(x, float)
    w = np.clip(w, 0.0, np.inf)
    s = w.sum()
    return float(np.sum(w * x) / s) if s > 0 else np.nan


def _weighted_var(x: np.ndarray, w: np.ndarray) -> float:
    m = _weighted_mean(x, w)
    s = w.sum()
    return float(np.sum(w * (x - m) ** 2) / s) if s > 0 else np.nan


def _check_binary(arr: np.ndarray, name: str):
    vals = np.unique(arr[~np.isnan(arr)])
    if not set(vals).issubset({0, 1}):
        raise ValueError(f"{name} must be binary in {{0,1}}.")


def _safe_quantile(x: np.ndarray, w: np.ndarray, q: float) -> float:
    """Weighted quantile in [0,1] using CDF inversion."""
    if not (0.0 < q < 1.0):
        raise ValueError("q must be in (0,1).")
    order = np.argsort(x)
    x_sorted = x[order]
    w_sorted = w[order]
    cdf = np.cumsum(w_sorted) / np.sum(w_sorted)
    idx = np.searchsorted(cdf, q, side="right")
    idx = min(max(idx, 0), len(x_sorted) - 1)
    return float(x_sorted[idx])


def _binarize_from_threshold(scores: np.ndarray, tau: float) -> np.ndarray:
    return (scores >= tau).astype(int)


# -------------------- Matching / Reweighting ---------------------

@dataclass
class MatchConfig:
    method: str = "propensity"  # "propensity" (supported)
    caliper: float = 0.2
    trim_ps: float = 0.02
    max_iter: int = 1  # placeholder for future iterative refinements


def propensity_weights(
    X: np.ndarray, G: np.ndarray, cfg: MatchConfig
) -> np.ndarray:
    """
    Propensity-score weighting with trimming and caliper conservative clipping.

    Returns per-sample weights w >= 0 with sum within each group equal to group size.
    """
    _check_binary(G, "G")

    # Standardize X for stability if present
    if X.size > 0:
        scaler = StandardScaler()
        Xs = scaler.fit_transform(X)
    else:
        # No covariates: weights default to 1 (no reweighting)
        return np.ones_like(G, dtype=float)

    # Logistic regression for p(G=1|X)
    lr = LogisticRegression(max_iter=200, solver="lbfgs")
    lr.fit(Xs, G)
    ps = lr.predict_proba(Xs)[:, 1]
    ps = np.clip(ps, cfg.trim_ps, 1.0 - cfg.trim_ps)

    # Inverse probability weights (stabilized)
    p1 = np.mean(G)
    p0 = 1.0 - p1
    w = np.where(G == 1, p1 / ps, p0 / (1.0 - ps))

    # Simple caliper damping: shrink extreme weights within caliper bands
    z = (ps - p1) / np.sqrt(p1 * (1 - p1) + 1e-8)
    mask_far = np.abs(z) > norm.ppf(1 - cfg.caliper / 2.0)
    if np.any(mask_far):
        w[mask_far] *= 0.5  # conservative shrink; keeps support but reduces leverage

    # Normalize weights to sum to group counts
    w0 = w[G == 0]
    w1 = w[G == 1]
    w[G == 0] = w0 * (len(w0) / (w0.sum() + 1e-12))
    w[G == 1] = w1 * (len(w1) / (w1.sum() + 1e-12))
    return w


# ------------------------ Balance checks -------------------------

@dataclass
class BalanceTargets:
    delta_mu: float
    delta_var: float
    delta_pi: Optional[float] = None  # optional if only scores are available
    tau: Optional[float] = None       # threshold for score-positive rate


def balance_statistics(
    S: np.ndarray,
    G: np.ndarray,
    w: np.ndarray,
    Yhat: Optional[np.ndarray] = None,
    tau: Optional[float] = None,
) -> Dict[str, float]:
    """Compute weighted means, variances, and optional positive rates by group."""
    idx0 = (G == 0)
    idx1 = (G == 1)
    mu0 = _weighted_mean(S[idx0], w[idx0])
    mu1 = _weighted_mean(S[idx1], w[idx1])
    var0 = _weighted_var(S[idx0], w[idx0])
    var1 = _weighted_var(S[idx1], w[idx1])

    out = {
        "mu_diff": abs(mu1 - mu0),
        "var_diff": abs(var1 - var0),
        "mu0": mu0, "mu1": mu1, "var0": var0, "var1": var1,
    }

    # Positive rate from Yhat or threshold on scores
    if Yhat is not None:
        _check_binary(Yhat, "Yhat")
        pr0 = _weighted_mean(Yhat[idx0], w[idx0])
        pr1 = _weighted_mean(Yhat[idx1], w[idx1])
        out["pi_diff"] = abs(pr1 - pr0)
        out["pi0"] = pr0
        out["pi1"] = pr1
    elif tau is not None:
        pr0 = _weighted_mean(_binarize_from_threshold(S[idx0], tau), w[idx0])
        pr1 = _weighted_mean(_binarize_from_threshold(S[idx1], tau), w[idx1])
        out["pi_diff"] = abs(pr1 - pr0)
        out["pi0"] = pr0
        out["pi1"] = pr1

    return out


def check_balance(
    stats: Dict[str, float],
    targets: BalanceTargets,
) -> bool:
    cond_mu = stats["mu_diff"] <= targets.delta_mu
    cond_var = stats["var_diff"] <= targets.delta_var
    if targets.delta_pi is not None and "pi_diff" in stats:
        cond_pi = stats["pi_diff"] <= targets.delta_pi
    else:
        cond_pi = True
    return bool(cond_mu and cond_var and cond_pi)


# ------------------- Tail error gap with bootstrap -------------------

def tail_error_gap(
    S: np.ndarray,
    Y: np.ndarray,
    G: np.ndarray,
    w: np.ndarray,
    q: float = 0.9,
    alpha: float = 0.05,
    n_boot: int = 2000,
    rng: Optional[np.random.RandomState] = None,
) -> Dict[str, float]:
    """
    Tail error gap at quantile q with weighted bootstrap CI, resampling within groups.
    Returns dict with point estimate and (1-alpha) CI.
    """
    _check_binary(G, "G")
    _check_binary(Y, "Y")
    rng = np.random.RandomState(123) if rng is None else rng

    def _tail_error(S, Y, w, idx, q):
        if w[idx].sum() <= 0:
            return np.nan
        t = _safe_quantile(S[idx], w[idx], q)
        sel = idx & (S >= t)
        if sel.sum() == 0 or w[sel].sum() == 0:
            return np.nan
        # Error on tail: 1 - accuracy where label available
        yhat = (S >= t).astype(int)  # decision aligned with tail threshold
        err = _weighted_mean((yhat != Y).astype(float)[sel], w[sel])
        return err

    idx0 = (G == 0)
    idx1 = (G == 1)
    err0 = _tail_error(S, Y, w, idx0, q)
    err1 = _tail_error(S, Y, w, idx1, q)
    delta = err1 - err0

    # Bootstrap (within-group resampling of indices with replacement)
    deltas = []
    i0 = np.where(idx0)[0]
    i1 = np.where(idx1)[0]
    for _ in range(n_boot):
        b0 = rng.choice(i0, size=len(i0), replace=True)
        b1 = rng.choice(i1, size=len(i1), replace=True)
        bb_idx0 = np.zeros_like(G, dtype=bool); bb_idx0[b0] = True
        bb_idx1 = np.zeros_like(G, dtype=bool); bb_idx1[b1] = True
        e0 = _tail_error(S, Y, w, bb_idx0, q)
        e1 = _tail_error(S, Y, w, bb_idx1, q)
        deltas.append(e1 - e0)
    deltas = np.array(deltas)
    lo, hi = np.nanpercentile(deltas, [100 * alpha / 2, 100 * (1 - alpha / 2)])

    return {
        "tail_gap": float(delta),
        "ci_low": float(lo),
        "ci_high": float(hi),
        "reject": not (lo <= 0.0 <= hi),
    }


# ------------------ Weighted KS with multiplier bootstrap ------------------

def weighted_ecdf(x: np.ndarray, w: np.ndarray, grid: np.ndarray) -> np.ndarray:
    order = np.argsort(x)
    x_sorted = x[order]
    w_sorted = w[order]
    cdf = np.cumsum(w_sorted) / np.sum(w_sorted)
    return np.interp(grid, x_sorted, cdf, left=0.0, right=1.0)


def ks_weighted_test(
    S: np.ndarray,
    G: np.ndarray,
    w: np.ndarray,
    alpha: float = 0.05,
    n_boot: int = 2000,
    rng: Optional[np.random.RandomState] = None,
) -> Dict[str, float]:
    """
    Two-sample weighted KS with multiplier bootstrap p-value.

    Multiplier bootstrap: draws independent standard normal multipliers within each
    group and perturbs the weighted ECDF; approximates the null distribution by
    centering on the pooled ECDF. This respects matching/reweighting designs.
    """
    rng = np.random.RandomState(123) if rng is None else rng
    idx0 = (G == 0)
    idx1 = (G == 1)

    # Grid over pooled unique scores
    grid = np.unique(S)
    if grid.size > 2000:  # subsample grid if too dense
        grid = np.linspace(np.min(S), np.max(S), 2000)

    F0 = weighted_ecdf(S[idx0], w[idx0], grid)
    F1 = weighted_ecdf(S[idx1], w[idx1], grid)
    D = np.max(np.abs(F1 - F0))

    # Multiplier bootstrap under pooled reference ECDF
    # We approximate asymptotic covariance via influence-function style perturbations.
    def _bootstrap_one():
        z0 = rng.standard_normal(np.sum(idx0))
        z1 = rng.standard_normal(np.sum(idx1))
        # Normalize multipliers
        z0 = (z0 - z0.mean()) / (z0.std() + 1e-8)
        z1 = (z1 - z1.mean()) / (z1.std() + 1e-8)
        # Perturb weights
        w0b = w[idx0] * (1 + 0.1 * z0)
        w1b = w[idx1] * (1 + 0.1 * z1)
        # Recompute CDFs
        F0b = weighted_ecdf(S[idx0], w0b, grid)
        F1b = weighted_ecdf(S[idx1], w1b, grid)
        return np.max(np.abs(F1b - F0b))

    boots = np.array([_bootstrap_one() for _ in range(n_boot)])
    pval = float(np.mean(boots >= D))
    return {"ks_stat": float(D), "p_value": pval, "reject": pval < alpha}


# ---------------------- Groupwise ECE with bootstrap ----------------------

def groupwise_ece(
    P: np.ndarray,  # predicted probabilities
    Y: np.ndarray,
    G: np.ndarray,
    w: np.ndarray,
    n_bins: int = 10,
) -> Dict[str, float]:
    """Compute groupwise ECE and absolute gap."""
    _check_binary(G, "G")
    _check_binary(Y, "Y")

    def _ece_one(p, y, ww):
        # Equal-frequency bins by weighted quantiles
        edges = [0.0]
        for b in range(1, n_bins):
            edges.append(_safe_quantile(p, ww, b / n_bins))
        edges.append(1.0)
        edges = np.unique(np.clip(edges, 0.0, 1.0))
        # Bin stats
        ece = 0.0
        for i in range(len(edges) - 1):
            lo, hi = edges[i], edges[i + 1] + 1e-12
            mask = (p >= lo) & (p < hi)
            if np.any(mask) and ww[mask].sum() > 0:
                conf = _weighted_mean(p[mask], ww[mask])
                acc = _weighted_mean(y[mask], ww[mask])
                ece += (ww[mask].sum() / ww.sum()) * abs(acc - conf)
        return float(ece)

    idx0 = (G == 0)
    idx1 = (G == 1)
    e0 = _ece_one(P[idx0], Y[idx0], w[idx0])
    e1 = _ece_one(P[idx1], Y[idx1], w[idx1])
    return {"ece0": e0, "ece1": e1, "ece_gap": abs(e1 - e0)}


def ece_gap_bootstrap(
    P: np.ndarray,
    Y: np.ndarray,
    G: np.ndarray,
    w: np.ndarray,
    n_bins: int = 10,
    alpha: float = 0.05,
    n_boot: int = 2000,
    rng: Optional[np.random.RandomState] = None,
) -> Dict[str, float]:
    rng = np.random.RandomState(123) if rng is None else rng
    base = groupwise_ece(P, Y, G, w, n_bins=n_bins)
    idx0 = np.where(G == 0)[0]
    idx1 = np.where(G == 1)[0]
    gaps = []
    for _ in range(n_boot):
        b0 = rng.choice(idx0, size=len(idx0), replace=True)
        b1 = rng.choice(idx1, size=len(idx1), replace=True)
        m = np.zeros_like(G, dtype=bool)
        m[b0] = True
        m[b1] = True
        g = groupwise_ece(P[m], Y[m], G[m], w[m], n_bins=n_bins)
        gaps.append(g["ece_gap"])
    lo, hi = np.nanpercentile(gaps, [100 * alpha / 2, 100 * (1 - alpha / 2)])
    return {
        "ece0": base["ece0"],
        "ece1": base["ece1"],
        "ece_gap": base["ece_gap"],
        "ci_low": float(lo),
        "ci_high": float(hi),
        "reject": not (lo <= 0.0 <= hi),
    }


# ------------------- Alternative distances: MMD & W1 -------------------

def _median_heuristic_sigma(x0: np.ndarray, x1: np.ndarray) -> float:
    x = np.concatenate([x0, x1])
    d = np.abs(x[:, None] - x[None, :])
    med = np.median(d[d > 0])
    return float(med if med > 0 else np.std(x) + 1e-6)


def mmd_gaussian_weighted(
    S: np.ndarray, G: np.ndarray, w: np.ndarray, sigma: Optional[float] = None
) -> float:
    """Weighted MMD^2 with Gaussian kernel and median heuristic bandwidth."""
    idx0 = (G == 0); idx1 = (G == 1)
    x = S[:, None]
    if sigma is None:
        sigma = _median_heuristic_sigma(S[idx0], S[idx1])
    gamma = 1.0 / (2.0 * sigma ** 2 + 1e-12)

    def k(a, b):
        # a: (n,1), b: (m,1)
        d2 = (a - b.T) ** 2
        return np.exp(-gamma * d2)

    w0 = w[idx0]; w1 = w[idx1]
    x0 = x[idx0]; x1 = x[idx1]

    k00 = k(x0, x0); k11 = k(x1, x1); k01 = k(x0, x1)
    # Normalize weights to sum to 1 within group for unbiased scaling
    w0n = w0 / (w0.sum() + 1e-12)
    w1n = w1 / (w1.sum() + 1e-12)

    mmd2 = (
        float(w0n[:, None] * w0n[None, :] * k00).sum()
        + float(w1n[:, None] * w1n[None, :] * k11).sum()
        - 2.0 * float((w0n[:, None] * w1n[None, :]) * k01).sum()
    )
    return mmd2


def wasserstein1_weighted(S: np.ndarray, G: np.ndarray, w: np.ndarray) -> float:
    """Wasserstein-1 distance using SciPy with per-group weights."""
    idx0 = (G == 0); idx1 = (G == 1)
    return float(
        wasserstein_distance(S[idx0], S[idx1], u_weights=w[idx0], v_weights=w[idx1])
    )


# --------------------- Multiple-testing control -----------------------

def fdr_correction(pvals: List[float], alpha: float = 0.05, method: str = "bh") -> Dict:
    """
    Benjamini–Hochberg (bh) or Benjamini–Yekutieli (by) FDR control.
    Returns adjusted q-values and reject decisions.
    """
    p = np.asarray(pvals, float)
    n = p.size
    order = np.argsort(p)
    ranks = np.empty_like(order)
    ranks[order] = np.arange(1, n + 1)

    if method.lower() == "bh":
        denom = n
        factors = ranks / denom
    elif method.lower() == "by":
        denom = np.sum(1.0 / np.arange(1, n + 1))
        factors = ranks / (n * denom)
    else:
        raise ValueError("method must be 'bh' or 'by'")

    q = np.minimum.accumulate((p[order] / factors[::-1])[::-1])
    q = np.clip(q, 0, 1)
    q_full = np.empty_like(q)
    q_full[order] = q

    reject = q_full <= alpha
    return {"q_values": q_full.tolist(), "reject": reject.tolist()}


# ---------------------------- Orchestrator ----------------------------

@dataclass
class AuditResult:
    balance_ok: bool
    balance_stats: Dict[str, float]
    tail_gap: Optional[Dict[str, float]]
    ks: Dict[str, float]
    ece: Optional[Dict[str, float]]
    distances: Dict[str, float]


def run_audit(
    S: np.ndarray,
    G: np.ndarray,
    X: Optional[np.ndarray] = None,
    Y: Optional[np.ndarray] = None,
    Yhat: Optional[np.ndarray] = None,
    tau: Optional[float] = None,
    q: float = 0.9,
    alpha: float = 0.05,
    targets: BalanceTargets = BalanceTargets(delta_mu=0.02, delta_var=0.02, delta_pi=0.02, tau=None),
    match_cfg: MatchConfig = MatchConfig(),
    n_boot: int = 2000,
    rng: Optional[np.random.RandomState] = None,
) -> AuditResult:
    """
    Execute the audit protocol:
      1) Match/reweight via propensity scores on X (if given)
      2) Verify low-order balance (mu, var, and optional pi)
      3) Tail error gap (if Y available) with CI
      4) Weighted KS test with multiplier bootstrap
      5) ECE gap (if probabilistic scores P= S in [0,1] and Y available)
      6) Alternative distances (MMD, W1)
    """
    if rng is None:
        rng = np.random.RandomState(123)

    if X is None:
        X = np.empty((len(S), 0))

    # 1) Matching / reweighting
    w = propensity_weights(X, G, match_cfg)

    # 2) Balance checks
    stats = balance_statistics(
        S=S, G=G, w=w, Yhat=Yhat, tau=tau if Yhat is None else None
    )
    bal_ok = check_balance(stats, targets)

    # 3) Tail error gap (requires Y)
    tail = None
    if Y is not None:
        tail = tail_error_gap(S=S, Y=Y, G=G, w=w, q=q, alpha=alpha, n_boot=n_boot, rng=rng)

    # 4) Weighted KS test
    ks = ks_weighted_test(S=S, G=G, w=w, alpha=alpha, n_boot=n_boot, rng=rng)

    # 5) ECE gap (if probabilities and labels available)
    ece = None
    if Y is not None and np.all((S >= 0) & (S <= 1)):
        ece = ece_gap_bootstrap(P=S, Y=Y, G=G, w=w, n_bins=10, alpha=alpha, n_boot=n_boot, rng=rng)

    # 6) Alternative distances
    distances = {
        "mmd2_gaussian": mmd_gaussian_weighted(S=S, G=G, w=w),
        "wasserstein1": wasserstein1_weighted(S=S, G=G, w=w),
    }

    return AuditResult(
        balance_ok=bal_ok,
        balance_stats=stats,
        tail_gap=tail,
        ks=ks,
        ece=ece,
        distances=distances,
    )


# ------------------------------- CLI ---------------------------------

def _load_csv(
    path: str,
    s_col: str,
    g_col: str,
    y_col: Optional[str],
    x_cols: Optional[List[str]],
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]:
    df = pd.read_csv(path)
    if s_col not in df.columns or g_col not in df.columns:
        raise ValueError("Required columns are missing from the CSV.")
    S = df[s_col].to_numpy(float)
    G = df[g_col].to_numpy(int)
    Y = df[y_col].to_numpy(int) if (y_col and y_col in df.columns) else None
    X = df[x_cols].to_numpy(float) if x_cols else None
    return S, G, Y, X


def main():
    parser = argparse.ArgumentParser(description="Audit protocol for fairness illusion.")
    parser.add_argument("--csv", type=str, required=True, help="Path to CSV file.")
    parser.add_argument("--score_col", type=str, default="S", help="Score column name.")
    parser.add_argument("--group_col", type=str, default="G", help="Sensitive group column name (0/1).")
    parser.add_argument("--label_col", type=str, default=None, help="Binary label column name (optional).")
    parser.add_argument("--x_cols", type=str, nargs="*", default=None, help="List of non-sensitive feature columns for matching.")
    parser.add_argument("--alpha", type=float, default=0.05, help="Significance level.")
    parser.add_argument("--q", type=float, default=0.9, help="Tail quantile for tail error gap.")
    parser.add_argument("--delta_mu", type=float, default=0.02, help="Tolerance for mean difference.")
    parser.add_argument("--delta_var", type=float, default=0.02, help="Tolerance for variance difference.")
    parser.add_argument("--delta_pi", type=float, default=0.02, help="Tolerance for positive-rate difference.")
    parser.add_argument("--tau", type=float, default=None, help="Threshold for positive-rate when labels are absent.")
    parser.add_argument("--n_boot", type=int, default=2000, help="Bootstrap replicates.")
    args = parser.parse_args()

    S, G, Y, X = _load_csv(
        path=args.csv,
        s_col=args.score_col,
        g_col=args.group_col,
        y_col=args.label_col,
        x_cols=args.x_cols,
    )

    targets = BalanceTargets(
        delta_mu=args.delta_mu, delta_var=args.delta_var,
        delta_pi=args.delta_pi if (args.label_col or args.tau is not None) else None,
        tau=args.tau
    )

    res = run_audit(
        S=S, G=G, X=X, Y=Y, Yhat=None,
        tau=args.tau, q=args.q, alpha=args.alpha,
        targets=targets, n_boot=args.n_boot
    )

    # --- Reporting ---
    print("# Balance")
    print(f"balance_ok: {res.balance_ok}")
    print(res.balance_stats)

    if res.tail_gap is not None:
        print("\n# Tail error gap")
        print(res.tail_gap)

    print("\n# Weighted KS")
    print(res.ks)

    if res.ece is not None:
        print("\n# ECE gap")
        print(res.ece)

    print("\n# Distances")
    print(res.distances)

    # Flag fairness illusion if balance holds but any disparity test rejects
    disparity_flags = [
        (res.tail_gap is not None and res.tail_gap["reject"]),
        res.ks["reject"],
        (res.ece is not None and res.ece["reject"]),
    ]
    fairness_illusion = bool(res.balance_ok and any(disparity_flags))
    print(f"\nflag_fairness_illusion: {fairness_illusion}")


if __name__ == "__main__":
    main()
