from typing import Union, Sequence, Optional, Tuple
import numpy as np
import torch

@torch.no_grad()
def U_statistic_from_kernel(
    M: torch.Tensor,
    w: torch.Tensor,
    *,
    weighted: bool = True,
    exclude_diag: bool = True,
) -> float:
    """
    Compute U-statistic from kernel matrix.
    - weighted=True  : average under mu_n ⊗ mu_n excluding diagonal
    - weighted=False : equal-weight average over indices i != j
    """
    n = M.shape[0]
    if weighted:
        w = w / w.sum()
        W = w[:, None] * w[None, :]
        if exclude_diag:
            off_mass = 1.0 - (w * w).sum()
            val = (W * M).sum() - (w * w * torch.diag(M)).sum()
            return val / off_mass
        else:
            return (W * M).sum()
    else:
        if exclude_diag:
            return ((M.sum() - torch.diag(M).sum()) / (n * (n - 1))).item()
        else:
            return M.mean().item()
        
def centered_matrix(D2: torch.Tensor) -> torch.Tensor:

    row_mean = D2.mean(dim=1, keepdim=True)
    col_mean = D2.mean(dim=0, keepdim=True)
    overall_mean = D2.mean()
    return D2 - row_mean - col_mean + overall_mean


def _double_center(M: torch.Tensor) -> torch.Tensor:
    # Numerically stable double-centering: M - row_mean - col_mean + grand_mean
    row_mean = M.mean(dim=1, keepdim=True)
    col_mean = M.mean(dim=0, keepdim=True)
    grand_mean = M.mean()
    return M - row_mean - col_mean + grand_mean

def _eigvalsh_robust(
    C: torch.Tensor,
    max_tries: int = 6,
    base_jitter: float = 1e-12,
) -> torch.Tensor:
    """
    Robust symmetric eigenvalues:
      - force float64
      - try on current device first
      - if fail, move to CPU
      - add increasing diagonal jitter
    """
    # Work in float64 for stability
    C64 = C.to(dtype=torch.float64)

    # Force symmetry (important after centering)
    C64 = 0.5 * (C64 + C64.T)

    if not torch.isfinite(C64).all():
        bad = (~torch.isfinite(C64)).sum().item()
        raise ValueError(f"C contains non-finite entries (count={bad}). Check centering / upstream numerics.")

    # Scale jitter by matrix magnitude
    scale = torch.linalg.norm(C64, ord="fro").item()
    scale = max(scale, 1.0)

    last_err = None
    for t in range(max_tries):
        jitter = (base_jitter * (10.0 ** t)) * scale
        C_try = C64 + jitter * torch.eye(C64.shape[0], dtype=C64.dtype, device=C64.device)

        try:
            return torch.linalg.eigvalsh(C_try)
        except RuntimeError as e:
            last_err = e
            # Fallback to CPU after first failure (GPU solvers can be fragile)
            C64 = C64.cpu()

    # If we get here, all retries failed
    raise RuntimeError(f"eigvalsh failed after {max_tries} attempts. Last error: {last_err}")

import numpy as np
import torch
from typing import Optional, Sequence, Union

@torch.no_grad()
def chaos_quantile(
    D2: torch.Tensor,
    n_sims: int = 100_000,
    cutoffs: Union[float, Sequence[float], torch.Tensor] = 0.95,
    t_obs: Optional[Union[float, torch.Tensor]] = None,
    pvalue_tail: str = "upper",  # {"upper","lower","two-sided"}
    symmetrize: bool = True,
    clamp_eigs: bool = False,
    return_lambdas: bool = False,
    return_samples: bool = False,
    sim_chunk_size: int = 4096,                               
    sim_dtype: torch.dtype = torch.float32,                   
):
    """
    CPU-first implementation:
      - D2/C construction and Monte Carlo are done on CPU.
      - Eigenvalues can be computed on GPU if eig_device is CUDA.
    """
    if D2.ndim != 2 or D2.shape[0] != D2.shape[1]:
        raise ValueError(f"D2 must be square (n,n). Got {tuple(D2.shape)}")

    n = D2.shape[0]

    D2 = 0.5 * (D2 + D2.T)
    C = _double_center(D2)

    if symmetrize:
        C = 0.5 * (C + C.T)

    try:
        lambdas = _eigvalsh_robust(C).detach().to("cpu")
    except Exception:
        if np.isscalar(cutoffs):
            quantiles = np.nan
        else:
            cutoff_arr = np.asarray(cutoffs)
            quantiles = np.full(cutoff_arr.shape, np.nan)
        return {"quantiles": quantiles, "p_value": np.nan}

    if clamp_eigs:
        lambdas = torch.clamp(lambdas, min=0.0)

    # ---- Monte Carlo: CPU only, chunked ----
    samples = torch.empty((n_sims,), device="cpu", dtype=torch.float32)
    lam = lambdas.to(dtype=sim_dtype, device="cpu")

    k = 0
    for start in range(0, n_sims, sim_chunk_size):
        B = min(sim_chunk_size, n_sims - start)
        Z = torch.randn((B, n), device="cpu", dtype=sim_dtype)
        s = (Z.square() - 1.0) @ lam
        samples[k:k+B] = s.to(torch.float32)
        k += B
        del Z, s

    # ---- Quantiles ----
    q = torch.as_tensor(cutoffs, device="cpu", dtype=torch.float64).flatten()
    if torch.any((q < 0) | (q > 1)):
        raise ValueError("All cutoffs must lie in [0, 1].")

    quantiles = torch.quantile(samples.to(torch.float64), q=q)

    quant_out = quantiles.numpy()
    if quantiles.numel() == 1:
        quant_out = float(quant_out.item())

    # ---- p-value: CPU only ----
    p_out = None
    if t_obs is not None:
        if isinstance(t_obs, torch.Tensor):
            t_val = float(t_obs.detach().cpu().item())
        else:
            t_val = float(t_obs)

        tail = pvalue_tail.lower()
        if tail not in {"upper", "lower", "two-sided"}:
            raise ValueError("pvalue_tail must be one of {'upper','lower','two-sided'}.")

        if tail == "upper":
            count = torch.sum(samples >= t_val)
        elif tail == "lower":
            count = torch.sum(samples <= t_val)
        else:
            count = torch.sum(samples.abs() >= abs(t_val))

        p_val = (count + 1.0) / (n_sims + 1.0)
        p_out = float(p_val.item())

    # Optional returns
    lambdas_out = lambdas.numpy() if return_lambdas else None
    samples_out = samples.numpy() if return_samples else None

    if return_lambdas or return_samples or (t_obs is not None):
        out = {"quantiles": quant_out}
        if p_out is not None:
            out["p_value"] = p_out
        if return_lambdas:
            out["lambdas"] = lambdas_out
        if return_samples:
            out["samples"] = samples_out
        return out

    return quant_out


def pca_torch(X, k, center=True):
    """
    X: (N, D) torch tensor (CPU or GPU)
    Returns: (N, k) PCA-projected data
    """
    if center:
        mean = X.mean(dim=0, keepdim=True)
        Xc = X - mean
    else:
        Xc = X

    # covariance matrix
    # (D, D) = (D, N) @ (N, D)
    cov = Xc.T @ Xc / (Xc.shape[0] - 1)

    # eigen-decomposition (cov is symmetric)
    eigvals, eigvecs = torch.linalg.eigh(cov)

    # take top-k eigenvectors (largest eigenvalues)
    idx = torch.argsort(eigvals, descending=True)
    eigvecs = eigvecs[:, idx[:k]]

    # project
    X_pca = Xc @ eigvecs
    return X_pca