"""
RBF MMD implementation with:
- U-statistic for small sample sizes (n, m <= 10k)
- Blocked linear-time estimator beyond the threshold

Includes median heuristic bandwidth selection.
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal, Tuple

import numpy as np


ArrayLike = np.ndarray


def _pairwise_squared_dists(X: ArrayLike, Y: ArrayLike) -> ArrayLike:
    """Compute pairwise squared Euclidean distances between rows of X and Y.

    Uses a numerically stable formulation.
    """
    X2 = np.sum(X * X, axis=1, keepdims=True)
    Y2 = np.sum(Y * Y, axis=1, keepdims=True).T
    # (x - y)^2 = x^2 + y^2 - 2 x y
    D = X2 + Y2 - 2.0 * X @ Y.T
    # Clamp tiny negatives from numerical errors
    np.maximum(D, 0.0, out=D)
    return D


def rbf_kernel(X: ArrayLike, Y: ArrayLike, bandwidth: float) -> ArrayLike:
    """RBF kernel matrix K(x,y) = exp(-||x - y||^2 / (2 sigma^2)).

    Parameters
    ----------
    X, Y : np.ndarray with shape (n, d), (m, d)
    bandwidth : float
        Sigma in the RBF kernel. Must be > 0.
    """
    assert bandwidth > 0.0, "bandwidth must be positive"
    assert X.ndim == 2 and Y.ndim == 2, "Inputs must be (n,d) and (m,d) feature matrices, not distances."
    D2 = _pairwise_squared_dists(X, Y)
    K = np.exp(-D2 / (2.0 * bandwidth * bandwidth))
    return K


def median_heuristic_bandwidth(X: ArrayLike, Y: ArrayLike, max_samples: int = 2000, rng: np.random.Generator | None = None) -> float:
    """Median heuristic for RBF bandwidth based on pairwise distances.

    Samples up to `max_samples` combined points to estimate the median distance.
    Returns sigma (not sigma^2).
    """
    if rng is None:
        rng = np.random.default_rng(0)
    assert X.ndim == 2 and Y.ndim == 2, "Inputs must be features, not pairwise distances."
    Z = np.concatenate([X, Y], axis=0)
    num = Z.shape[0]
    if num > max_samples:
        idx = rng.choice(num, size=max_samples, replace=False)
        Z = Z[idx]
    D2 = _pairwise_squared_dists(Z, Z)
    # Exclude diagonal
    i_upper = np.triu_indices(D2.shape[0], k=1)
    dists = np.sqrt(D2[i_upper])
    med = np.median(dists)
    # Avoid zero bandwidth; fall back to small positive
    return float(max(med, 1e-12))


@dataclass
class MMDConfig:
    mode: Literal["auto", "u", "linear"] = "auto"
    threshold: int = 10_000  # if max(n, m) > threshold -> linear-time
    block_size: int = 2_048


def _mmd2_u_statistic(X: ArrayLike, Y: ArrayLike, bandwidth: float) -> float:
    n, m = X.shape[0], Y.shape[0]
    if n < 2 or m < 2:
        return 0.0
    Kxx = rbf_kernel(X, X, bandwidth)
    Kyy = rbf_kernel(Y, Y, bandwidth)
    Kxy = rbf_kernel(X, Y, bandwidth)

    # Remove diagonals for unbiased estimate
    np.fill_diagonal(Kxx, 0.0)
    np.fill_diagonal(Kyy, 0.0)

    term_xx = Kxx.sum() / (n * (n - 1))
    term_yy = Kyy.sum() / (m * (m - 1))
    term_xy = Kxy.mean()
    mmd2 = float(term_xx + term_yy - 2.0 * term_xy)
    # Numerical guard: MMD^2 >= 0
    return max(mmd2, 0.0)


def _mmd2_linear_time_pairing(X: ArrayLike, Y: ArrayLike, bandwidth: float) -> float:
    """Linear-time MMD^2 estimator via paired samples (Gretton et al.).

    Pairs adjacent samples; if odd, drops the last one.
    """
    n = X.shape[0]
    m = Y.shape[0]
    n2 = n // 2
    m2 = m // 2
    if n2 == 0 or m2 == 0:
        return 0.0

    def pairwise_sum(Z: ArrayLike) -> float:
        Z1, Z2 = Z[0::2][: min(n2, Z.shape[0] // 2)], Z[1::2][: min(n2, Z.shape[0] // 2)]
        D2 = np.sum((Z1 - Z2) ** 2, axis=1)
        return float(np.mean(np.exp(-D2 / (2.0 * bandwidth * bandwidth))))

    # Cross terms: mix opposite indices
    X1, X2 = X[0::2][:n2], X[1::2][:n2]
    Y1, Y2 = Y[0::2][:m2], Y[1::2][:m2]
    # Align lengths
    L = min(n2, m2)
    X1, X2, Y1, Y2 = X1[:L], X2[:L], Y1[:L], Y2[:L]
    D2_x1y2 = np.sum((X1 - Y2) ** 2, axis=1)
    D2_x2y1 = np.sum((X2 - Y1) ** 2, axis=1)
    k_x1y2 = np.exp(-D2_x1y2 / (2.0 * bandwidth * bandwidth))
    k_x2y1 = np.exp(-D2_x2y1 / (2.0 * bandwidth * bandwidth))

    mmd2 = 2.0 * (pairwise_sum(X) + pairwise_sum(Y) - np.mean(k_x1y2) - np.mean(k_x2y1))
    return max(float(mmd2), 0.0)


def mmd2_rbf(
    X: ArrayLike,
    Y: ArrayLike,
    bandwidth: float | None = None,
    config: MMDConfig | None = None,
    rng: np.random.Generator | None = None,
) -> Tuple[float, float]:
    """Compute RBF MMD^2 between two samples.

    Parameters
    ----------
    X, Y : arrays with shapes (n, d), (m, d)
    bandwidth : float | None
        RBF sigma. If None, use median heuristic on X∪Y.
    config : MMDConfig | None
        Controls whether to use U-statistic or linear-time estimator.
    rng : np.random.Generator | None
        RNG used for bandwidth heuristic subsampling.

    Returns
    -------
    mmd2 : float
        Estimated MMD^2.
    bandwidth : float
        The bandwidth used.
    """
    assert X.ndim == 2 and Y.ndim == 2, "X and Y must be 2D (n,d)/(m,d) feature arrays."
    assert X.shape[1] == Y.shape[1], "dimension mismatch"

    if config is None:
        config = MMDConfig()

    if bandwidth is None:
        bandwidth = median_heuristic_bandwidth(X, Y, rng=rng)

    n, m = X.shape[0], Y.shape[0]
    mode: Literal["u", "linear"]
    if config.mode == "auto":
        mode = "u" if (n <= config.threshold and m <= config.threshold) else "linear"
    else:
        mode = config.mode

    if mode == "u":
        mmd2 = _mmd2_u_statistic(X, Y, bandwidth)
    else:
        # Linear-time: use pairing estimator; config.block_size kept for API parity
        mmd2 = _mmd2_linear_time_pairing(X, Y, bandwidth)
    return float(mmd2), float(bandwidth)


