"""
Gradient-quantile proxy Lx(q): compute q-quantile of squared input-gradient norms.

Supports sklearn Ridge and LogisticRegression via analytic formulas, and PyTorch nn.Module via autograd.
"""
from __future__ import annotations

from typing import Iterable, Literal, Optional, Tuple

import numpy as np

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
except Exception:  # pragma: no cover
    torch = None  # type: ignore
    nn = None  # type: ignore
    F = None  # type: ignore


def _sqnorm(a: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
    return np.sum(a * a, axis=axis)


def _ridge_sklearn_grad_norms_sq(X: np.ndarray, y: np.ndarray, coef: np.ndarray, intercept: float = 0.0) -> np.ndarray:
    """Squared gradient norms for squared loss: 0.5 (y - w^T x - b)^2."""
    # grad_x = -(y - w^T x - b) * w
    residual = y - (X @ coef + intercept)
    # Clip residuals to stabilize Lx proxy against outliers, per user suggestion
    res_abs_q99 = np.quantile(np.abs(residual), 0.99)
    residual_clipped = np.clip(residual, -res_abs_q99, res_abs_q99)

    coeff_norm2 = float(np.sum(coef * coef))
    norms_sq = (residual_clipped * residual_clipped) * coeff_norm2
    return norms_sq.astype(np.float64)


def _logistic_sklearn_grad_norms_sq(X: np.ndarray, y: np.ndarray, coef: np.ndarray, intercept: float = 0.0) -> np.ndarray:
    """Squared gradient norms for logistic CE loss with label y in {0,1}."""
    logits = X @ coef + intercept
    probs = 1.0 / (1.0 + np.exp(-logits))
    diff = probs - y  # dℓ/dz
    coeff_norm2 = float(np.sum(coef * coef))
    norms_sq = (diff * diff) * coeff_norm2
    return norms_sq.astype(np.float64)


def _torch_autograd_grad_norms_sq(model: "nn.Module", X: np.ndarray, y: np.ndarray, loss: Literal["mse", "ce"], device: str = "cpu", batch_size: int = 1024) -> np.ndarray:
    assert torch is not None and nn is not None and F is not None
    model.eval()
    dev = torch.device(device)
    dtype = torch.float64
    norms: list[float] = []
    with torch.no_grad():
        pass
    for start in range(0, X.shape[0], batch_size):
        end = min(start + batch_size, X.shape[0])
        xb = torch.tensor(X[start:end], dtype=dtype, device=dev, requires_grad=True)
        yb_np = y[start:end]
        if loss == "mse":
            yb = torch.tensor(yb_np.reshape(-1, 1), dtype=dtype, device=dev)
            preds = model(xb)
            lb = 0.5 * torch.mean((preds - yb) ** 2, dim=1)
        else:
            # Assume model returns logits of shape (B, C); y is int labels
            yb = torch.tensor(yb_np.astype(np.int64), dtype=torch.long, device=dev)
            logits = model(xb)
            lb = F.cross_entropy(logits, yb, reduction="none")

        grads = torch.autograd.grad(lb.sum(), xb)[0]
        norms_sq = torch.sum(grads * grads, dim=1).detach().cpu().numpy()
        norms.extend(norms_sq.tolist())
    return np.asarray(norms, dtype=np.float64)


def gradient_quantile_proxy(
    model,
    X: np.ndarray,
    y: np.ndarray,
    model_type: Literal["sklearn_ridge", "sklearn_logistic", "torch"] = "torch",
    loss: Literal["mse", "ce"] = "ce",
    q: float = 0.9,
    device: str = "cpu",
    intercept: float | None = None,
) -> float:
    """Compute Lx(q): q-quantile of squared input-gradient norms.

    For sklearn models, uses analytic formulas. For torch models, uses autograd.
    Returns the q-quantile of ||∇_x ℓ(f(x), y)||^2 over samples.
    """
    assert 0.0 < q < 1.0
    if model_type == "sklearn_ridge":
        coef = getattr(model, "coef_")
        # Use provided intercept if available, otherwise get from model
        intercept_val = intercept if intercept is not None else float(getattr(model, "intercept_", 0.0))
        norms_sq = _ridge_sklearn_grad_norms_sq(X, y, coef, intercept_val)
    elif model_type == "sklearn_logistic":
        # Binary logistic assumed; if multiclass, treat as one-vs-rest average (not used here)
        coef = getattr(model, "coef_")
        intercept_arr = getattr(model, "intercept_", np.array([0.0]))
        coef = coef.reshape(-1)
        intercept = float(intercept_arr.reshape(-1)[0])
        norms_sq = _logistic_sklearn_grad_norms_sq(X, y, coef, intercept)
    else:
        norms_sq = _torch_autograd_grad_norms_sq(model, X, y, loss=loss, device=device)
    return float(np.quantile(norms_sq, q))




