"""
Parameter estimators for Stable-QDA.

This module provides location and dispersion estimators:
- Standard: Sample mean + Ledoit-Wolf shrinkage covariance
- Robust: Spatial median + Tyler's M-estimator

Key Finding (from paper):
    Standard estimators often outperform robust alternatives when class
    heteroscedasticity is discriminative. Tyler's M-estimator normalizes
    scatter matrices to fixed trace, discarding scale information that
    the QDA likelihood uses through the log-determinant term.
"""

import numpy as np
from sklearn.covariance import LedoitWolf


def spatial_median(X, tol=1e-6, max_iter=100):
    """
    Compute the spatial median (geometric median) via Weiszfeld's algorithm.
    
    The spatial median minimizes the sum of Euclidean distances:
        μ̂ = argmin_μ Σ_i ||x_i - μ||_2
    
    It is a robust location estimator with 50% breakdown point.
    
    Parameters
    ----------
    X : ndarray of shape (n_samples, n_features)
        Input data.
        
    tol : float, default=1e-6
        Convergence tolerance.
        
    max_iter : int, default=100
        Maximum number of iterations.
        
    Returns
    -------
    median : ndarray of shape (n_features,)
        Spatial median.
        
    Notes
    -----
    For α-stable distributions with α > 1, the spatial median achieves
    O(1/n) convergence rate, unlike the sample mean which has infinite
    asymptotic variance for α < 2.
    
    References
    ----------
    [1] Weiszfeld, E. (1937). Sur le point pour lequel la somme des 
        distances de n points donnés est minimum. Tohoku Mathematical
        Journal, 43:355-386.
    """
    n_samples, n_features = X.shape
    
    # Initialize with componentwise median (robust starting point)
    median = np.median(X, axis=0)
    
    for _ in range(max_iter):
        # Compute distances to current estimate
        diff = X - median
        distances = np.linalg.norm(diff, axis=1, keepdims=True)
        
        # Avoid division by zero
        distances = np.maximum(distances, 1e-10)
        
        # Weiszfeld update: weighted mean with weights 1/distance
        weights = 1.0 / distances
        new_median = np.sum(X * weights, axis=0) / np.sum(weights)
        
        # Check convergence
        if np.linalg.norm(new_median - median) < tol:
            break
        
        median = new_median
    
    return median


def tyler_m_estimator(X, location=None, tol=1e-6, max_iter=100):
    """
    Tyler's M-estimator of scatter.
    
    Computes the distribution-free M-estimator that solves:
        Σ = (p/n) Σ_i (x_i - μ)(x_i - μ)^T / ((x_i - μ)^T Σ^{-1} (x_i - μ))
    
    Tyler's estimator has 50% breakdown point and is consistent for any
    elliptical distribution regardless of tail index.
    
    WARNING: Tyler's estimator is defined only up to scale (normalized to
    have trace = p). This discards inter-class scale information that is
    discriminative for QDA. Use with caution when classes have different
    spreads.
    
    Parameters
    ----------
    X : ndarray of shape (n_samples, n_features)
        Input data.
        
    location : ndarray of shape (n_features,), optional
        Pre-computed location estimate. If None, uses spatial median.
        
    tol : float, default=1e-6
        Convergence tolerance (relative change in Frobenius norm).
        
    max_iter : int, default=100
        Maximum number of iterations.
        
    Returns
    -------
    scatter : ndarray of shape (n_features, n_features)
        Tyler's M-estimate of scatter, normalized to trace = n_features.
        
    References
    ----------
    [1] Tyler, D.E. (1987). A distribution-free M-estimator of multivariate
        scatter. The Annals of Statistics, 15(1):234-251.
    """
    n_samples, n_features = X.shape
    
    # Use spatial median if location not provided
    if location is None:
        location = spatial_median(X)
    
    # Center the data
    X_centered = X - location
    
    # Initialize with identity
    scatter = np.eye(n_features)
    
    for _ in range(max_iter):
        # Compute Mahalanobis distances
        scatter_inv = np.linalg.inv(scatter)
        mahal_sq = np.sum((X_centered @ scatter_inv) * X_centered, axis=1)
        
        # Avoid division by zero
        mahal_sq = np.maximum(mahal_sq, 1e-10)
        
        # Tyler update
        weights = 1.0 / mahal_sq
        new_scatter = (n_features / n_samples) * (
            X_centered.T @ (X_centered * weights[:, np.newaxis])
        )
        
        # Normalize to trace = n_features (Tyler's convention)
        new_scatter = new_scatter * (n_features / np.trace(new_scatter))
        
        # Check convergence
        rel_change = np.linalg.norm(new_scatter - scatter, 'fro') / np.linalg.norm(scatter, 'fro')
        if rel_change < tol:
            break
        
        scatter = new_scatter
    
    return scatter


def ledoit_wolf_shrinkage(X):
    """
    Ledoit-Wolf shrinkage covariance estimator.
    
    Computes the optimal linear shrinkage between the sample covariance
    and a scaled identity matrix:
        Σ̂ = (1 - λ) S + λ (tr(S)/p) I
    
    where λ is chosen to minimize expected squared error.
    
    This estimator provides implicit robustness by regularizing extreme
    eigenvalues that might arise from heavy-tailed samples.
    
    Parameters
    ----------
    X : ndarray of shape (n_samples, n_features)
        Input data.
        
    Returns
    -------
    covariance : ndarray of shape (n_features, n_features)
        Shrinkage covariance estimate.
        
    Notes
    -----
    Key finding from paper: Ledoit-Wolf often outperforms Tyler's
    M-estimator for classification because it preserves inter-class
    scale differences that are discriminative.
    
    References
    ----------
    [1] Ledoit, O. and Wolf, M. (2004). A well-conditioned estimator for
        large-dimensional covariance matrices. Journal of Multivariate
        Analysis, 88(2):365-411.
    """
    lw = LedoitWolf()
    lw.fit(X)
    return lw.covariance_


def sample_covariance(X, ddof=1):
    """
    Standard sample covariance matrix.
    
    Parameters
    ----------
    X : ndarray of shape (n_samples, n_features)
        Input data.
        
    ddof : int, default=1
        Degrees of freedom correction.
        
    Returns
    -------
    covariance : ndarray of shape (n_features, n_features)
        Sample covariance matrix.
    """
    return np.cov(X, rowvar=False, ddof=ddof)


def mad_scale(X, consistency_constant=1.4826):
    """
    Median Absolute Deviation (MAD) scale estimator.
    
    Computes robust scale estimates for each feature:
        σ̂_j = c * median(|x_j - median(x_j)|)
    
    where c is a consistency constant (1.4826 for Gaussian, smaller for
    α-stable).
    
    Parameters
    ----------
    X : ndarray of shape (n_samples, n_features)
        Input data.
        
    consistency_constant : float, default=1.4826
        Multiplier for consistency with standard deviation.
        Use 1.4826 for Gaussian, ~0.95 for α-stable with α ≈ 1.
        
    Returns
    -------
    scales : ndarray of shape (n_features,)
        MAD scale estimates for each feature.
    """
    medians = np.median(X, axis=0)
    abs_deviations = np.abs(X - medians)
    mad = np.median(abs_deviations, axis=0)
    return consistency_constant * mad


def mad_spearman_covariance(X, location=None, consistency_constant=1.4826):
    """
    MAD-Spearman covariance estimator.
    
    Combines MAD scales with Spearman rank correlation:
        Σ̂ = D R̂ D
    
    where D = diag(σ̂_1, ..., σ̂_p) and R̂ is the Spearman correlation.
    
    Parameters
    ----------
    X : ndarray of shape (n_samples, n_features)
        Input data.
        
    location : ndarray of shape (n_features,), optional
        Location estimate (unused, included for API consistency).
        
    consistency_constant : float, default=1.4826
        MAD consistency constant.
        
    Returns
    -------
    covariance : ndarray of shape (n_features, n_features)
        MAD-Spearman covariance estimate.
        
    Notes
    -----
    From paper experiments (Appendix C): MAD-Spearman with Gaussian
    constant performs identically to α-corrected version because
    miscalibration affects both classes equally and cancels in
    likelihood ratios.
    """
    from scipy.stats import spearmanr
    
    n_features = X.shape[1]
    
    # MAD scales
    scales = mad_scale(X, consistency_constant)
    
    # Spearman correlation
    if n_features > 1:
        corr, _ = spearmanr(X)
        if n_features == 2:
            # spearmanr returns scalar for 2D
            corr = np.array([[1, corr], [corr, 1]])
    else:
        corr = np.array([[1.0]])
    
    # Combine: Σ = D R D
    D = np.diag(scales)
    covariance = D @ corr @ D
    
    return covariance


def compute_determinant_ratio(X, y):
    """
    Compute the determinant ratio between class covariances.
    
    This diagnostic helps decide between robust and standard estimators:
    - Large ratio (> 100): Standard estimators preferred
    - Small ratio (< 10): Robust estimators may be used
    
    Parameters
    ----------
    X : ndarray of shape (n_samples, n_features)
        Input data.
        
    y : ndarray of shape (n_samples,)
        Class labels.
        
    Returns
    -------
    det_ratio : float
        max(|Σ_k|) / min(|Σ_k|) across classes.
        
    trace_ratio : float
        max(tr(Σ_k)) / min(tr(Σ_k)) as a proxy when determinant
        is numerically unstable.
    """
    classes = np.unique(y)
    
    dets = []
    traces = []
    
    for k in classes:
        X_k = X[y == k]
        cov_k = np.cov(X_k, rowvar=False)
        
        # Ensure 2D for single-feature case
        if cov_k.ndim == 0:
            cov_k = np.array([[cov_k]])
        
        sign, logdet = np.linalg.slogdet(cov_k)
        if sign > 0:
            dets.append(np.exp(logdet))
        traces.append(np.trace(cov_k))
    
    if len(dets) >= 2:
        det_ratio = max(dets) / max(min(dets), 1e-300)
    else:
        det_ratio = np.nan
    
    trace_ratio = max(traces) / max(min(traces), 1e-300)
    
    return det_ratio, trace_ratio
