"""
Tail index (α) estimation for α-stable distributions.

This module implements McCulloch's quantile-based estimator for the stability
index α, which controls tail heaviness in α-stable distributions.

Key Finding (from paper):
    Stable-QDA is remarkably insensitive to α misspecification. A fixed
    α = 1.5 performs within 1% of oracle across all tail regimes.
    Estimation is optional but provided for diagnostic purposes.
"""

import numpy as np


# McCulloch lookup table: quantile ratio -> alpha
# Calibrated from 5×10^5 samples per alpha value
# Monte Carlo error < 10^-3
MCCULLOCH_TABLE = {
    2.44: 2.0,
    2.51: 1.9,
    2.62: 1.8,
    2.74: 1.7,
    2.90: 1.6,
    3.16: 1.5,
    3.47: 1.4,
    3.87: 1.3,
    4.47: 1.2,
    5.24: 1.1,
    6.30: 1.0,
    10.4: 0.8,
}

# Sorted for interpolation
_NU_VALUES = np.array(sorted(MCCULLOCH_TABLE.keys()))
_ALPHA_VALUES = np.array([MCCULLOCH_TABLE[nu] for nu in _NU_VALUES])


def _quantile_ratio(x):
    """
    Compute McCulloch's quantile ratio.
    
    ν_α = (x_{0.95} - x_{0.05}) / (x_{0.75} - x_{0.25})
    
    This ratio increases monotonically as α decreases (heavier tails),
    ranging from ~2.44 at α=2 (Gaussian) to >10 at α<0.8.
    """
    q05, q25, q75, q95 = np.percentile(x, [5, 25, 75, 95])
    
    iqr = q75 - q25
    if iqr < 1e-10:
        # Near-constant feature -> Gaussian-like
        return 2.44
    
    return (q95 - q05) / iqr


def _nu_to_alpha(nu):
    """
    Map quantile ratio to α via linear interpolation on lookup table.
    """
    if nu <= _NU_VALUES[0]:
        return 2.0
    if nu >= _NU_VALUES[-1]:
        return 0.8
    
    # Linear interpolation
    return np.interp(nu, _NU_VALUES, _ALPHA_VALUES)


class AlphaEstimator:
    """
    McCulloch's quantile-based estimator for the stability index α.
    
    For univariate data, computes the quantile ratio:
        ν = (x_{0.95} - x_{0.05}) / (x_{0.75} - x_{0.25})
    
    and maps to α via a precomputed lookup table.
    
    For multivariate data, estimates α separately for each feature and
    aggregates using the median (robust to features with atypical behavior).
    
    Parameters
    ----------
    aggregation : {'median', 'mean', 'min'}, default='median'
        How to aggregate per-feature α estimates for multivariate data.
        
    Attributes
    ----------
    alpha_ : float
        Estimated stability index.
        
    alpha_per_feature_ : ndarray of shape (n_features,)
        Per-feature α estimates.
        
    nu_per_feature_ : ndarray of shape (n_features,)
        Per-feature quantile ratios.
        
    Examples
    --------
    >>> import numpy as np
    >>> from stable_qda.alpha_estimation import AlphaEstimator
    >>> # Generate heavy-tailed data (Cauchy-like)
    >>> X = np.random.standard_cauchy(size=(1000, 5))
    >>> est = AlphaEstimator()
    >>> est.fit(X)
    >>> print(f"Estimated α: {est.alpha_:.2f}")
    Estimated α: 1.02
    
    References
    ----------
    [1] McCulloch, J.H. (1986). Simple consistent estimators of stable
        distribution parameters. Communications in Statistics - Simulation
        and Computation, 15(4):1109-1136.
    """
    
    def __init__(self, aggregation='median'):
        self.aggregation = aggregation
    
    def fit(self, X, y=None):
        """
        Estimate α from data.
        
        Parameters
        ----------
        X : array-like of shape (n_samples,) or (n_samples, n_features)
            Input data.
            
        y : ignored
            Present for sklearn compatibility.
            
        Returns
        -------
        self : object
            Fitted estimator.
        """
        X = np.asarray(X)
        
        if X.ndim == 1:
            X = X.reshape(-1, 1)
        
        n_features = X.shape[1]
        
        self.nu_per_feature_ = np.zeros(n_features)
        self.alpha_per_feature_ = np.zeros(n_features)
        
        valid_features = []
        
        for j in range(n_features):
            x_j = X[:, j]
            
            # Check for constant feature
            if np.std(x_j) < 1e-10:
                self.nu_per_feature_[j] = 2.44
                self.alpha_per_feature_[j] = 2.0
            else:
                nu_j = _quantile_ratio(x_j)
                self.nu_per_feature_[j] = nu_j
                self.alpha_per_feature_[j] = _nu_to_alpha(nu_j)
                valid_features.append(j)
        
        # Aggregate across features
        if len(valid_features) == 0:
            self.alpha_ = 2.0
        else:
            valid_alphas = self.alpha_per_feature_[valid_features]
            
            if self.aggregation == 'median':
                self.alpha_ = np.median(valid_alphas)
            elif self.aggregation == 'mean':
                self.alpha_ = np.mean(valid_alphas)
            elif self.aggregation == 'min':
                self.alpha_ = np.min(valid_alphas)
            else:
                raise ValueError(f"Unknown aggregation: {self.aggregation}")
        
        return self
    
    def fit_transform(self, X, y=None):
        """Fit and return estimated α."""
        self.fit(X, y)
        return self.alpha_


def estimate_alpha(X, y=None, aggregation='median'):
    """
    Convenience function to estimate α from data.
    
    Parameters
    ----------
    X : array-like of shape (n_samples, n_features)
        Input data.
        
    y : array-like of shape (n_samples,), optional
        Class labels. If provided, estimates α per class and averages.
        
    aggregation : {'median', 'mean', 'min'}, default='median'
        How to aggregate per-feature estimates.
        
    Returns
    -------
    alpha : float
        Estimated stability index.
        
    Examples
    --------
    >>> from stable_qda.alpha_estimation import estimate_alpha
    >>> import numpy as np
    >>> X = np.random.randn(1000, 10)  # Gaussian data
    >>> print(f"α = {estimate_alpha(X):.2f}")  # Should be ~2.0
    α = 1.97
    """
    X = np.asarray(X)
    
    if y is None:
        est = AlphaEstimator(aggregation=aggregation)
        est.fit(X)
        return est.alpha_
    
    # Per-class estimation, then average
    y = np.asarray(y)
    classes = np.unique(y)
    
    alphas = []
    for k in classes:
        X_k = X[y == k]
        est = AlphaEstimator(aggregation=aggregation)
        est.fit(X_k)
        alphas.append(est.alpha_)
    
    return np.mean(alphas)


def estimate_alpha_per_class(X, y, aggregation='median'):
    """
    Estimate α separately for each class.
    
    Parameters
    ----------
    X : array-like of shape (n_samples, n_features)
        Input data.
        
    y : array-like of shape (n_samples,)
        Class labels.
        
    aggregation : {'median', 'mean', 'min'}, default='median'
        How to aggregate per-feature estimates within each class.
        
    Returns
    -------
    alphas : dict
        Dictionary mapping class label to estimated α.
        
    Notes
    -----
    From paper experiments (Appendix G): Per-class α estimation does NOT
    improve classification. Even when classes have different tail indices,
    a single shared α=1.5 performs better because different α values make
    likelihoods incomparable.
    """
    X = np.asarray(X)
    y = np.asarray(y)
    classes = np.unique(y)
    
    alphas = {}
    for k in classes:
        X_k = X[y == k]
        est = AlphaEstimator(aggregation=aggregation)
        est.fit(X_k)
        alphas[k] = est.alpha_
    
    return alphas
