"""
Stable-QDA: Quadratic Discriminant Analysis with α-stable likelihoods.

This module implements the main StableQDA classifier, which replaces the
Gaussian likelihood in classical QDA with a symmetric α-stable likelihood
that decays polynomially rather than exponentially in Mahalanobis distance.
"""

import numpy as np
from scipy.special import softmax
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from sklearn.preprocessing import LabelEncoder

from .estimators import (
    spatial_median,
    tyler_m_estimator,
    ledoit_wolf_shrinkage,
)
from .alpha_estimation import estimate_alpha


class StableQDA(BaseEstimator, ClassifierMixin):
    """
    Quadratic Discriminant Analysis with α-stable likelihoods.
    
    Stable-QDA replaces the Gaussian likelihood in classical QDA with a
    symmetric α-stable likelihood. The key difference is in how the
    Mahalanobis distance enters the likelihood:
    
    - Gaussian: log f(x) ∝ -D(x)/2  (exponential decay)
    - Stable:   log f(x) ∝ -((α+p)/2) * log(1 + D(x))  (polynomial decay)
    
    This prevents over-penalization of tail observations that causes
    Gaussian QDA to misclassify heavy-tailed data.
    
    Parameters
    ----------
    alpha : float or 'auto', default=1.5
        Stability index controlling tail heaviness. Values in (0, 2]:
        - alpha=2: Gaussian (exponential tails)
        - alpha=1: Cauchy (very heavy tails)
        - alpha='auto': Estimate from data using McCulloch's method
        Recommended: Use fixed alpha=1.5 (robust across settings).
        
    estimator : {'standard', 'robust'}, default='standard'
        Parameter estimation strategy:
        - 'standard': Sample mean + Ledoit-Wolf covariance (recommended)
        - 'robust': Spatial median + Tyler's M-estimator
        
    reg_param : float, default=1e-6
        Regularization added to covariance diagonal for numerical stability.
        
    store_covariance : bool, default=False
        If True, store the class covariance matrices in `covariance_`.
        
    Attributes
    ----------
    classes_ : ndarray of shape (n_classes,)
        Unique class labels.
        
    priors_ : ndarray of shape (n_classes,)
        Class prior probabilities.
        
    means_ : ndarray of shape (n_classes, n_features)
        Class means (or spatial medians if estimator='robust').
        
    covariance_ : list of ndarray, each of shape (n_features, n_features)
        Class covariance matrices. Only present if store_covariance=True.
        
    covariance_inv_ : list of ndarray
        Inverse covariance matrices for each class.
        
    log_det_ : ndarray of shape (n_classes,)
        Log-determinants of class covariance matrices.
        
    alpha_ : float
        The stability index used (estimated if alpha='auto').
        
    n_features_in_ : int
        Number of features seen during fit.
        
    Examples
    --------
    >>> import numpy as np
    >>> from stable_qda import StableQDA
    >>> X = np.random.randn(100, 5)
    >>> y = np.array([0]*50 + [1]*50)
    >>> clf = StableQDA(alpha=1.5)
    >>> clf.fit(X, y)
    StableQDA(alpha=1.5)
    >>> clf.predict(X[:5])
    array([0, 0, 0, 0, 0])
    
    References
    ----------
    [1] Samorodnitsky, G. and Taqqu, M.S. (1994). Stable Non-Gaussian
        Random Processes. Chapman and Hall.
    [2] Ledoit, O. and Wolf, M. (2004). A well-conditioned estimator for
        large-dimensional covariance matrices. Journal of Multivariate
        Analysis, 88(2):365-411.
    """
    
    def __init__(
        self,
        alpha=1.5,
        estimator='standard',
        reg_param=1e-6,
        store_covariance=False,
    ):
        self.alpha = alpha
        self.estimator = estimator
        self.reg_param = reg_param
        self.store_covariance = store_covariance
    
    def fit(self, X, y):
        """
        Fit the Stable-QDA model.
        
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Training data.
            
        y : array-like of shape (n_samples,)
            Target values.
            
        Returns
        -------
        self : object
            Fitted estimator.
        """
        X, y = check_X_y(X, y)
        
        self.n_features_in_ = X.shape[1]
        self._label_encoder = LabelEncoder()
        y_encoded = self._label_encoder.fit_transform(y)
        self.classes_ = self._label_encoder.classes_
        n_classes = len(self.classes_)
        n_features = X.shape[1]
        
        # Estimate alpha if 'auto'
        if self.alpha == 'auto':
            self.alpha_ = estimate_alpha(X, y_encoded)
        else:
            self.alpha_ = self.alpha
        
        # Initialize storage
        self.means_ = np.zeros((n_classes, n_features))
        self.covariance_inv_ = []
        self.log_det_ = np.zeros(n_classes)
        self.priors_ = np.zeros(n_classes)
        
        if self.store_covariance:
            self.covariance_ = []
        
        # Fit class-specific parameters
        for k in range(n_classes):
            X_k = X[y_encoded == k]
            n_k = X_k.shape[0]
            self.priors_[k] = n_k / X.shape[0]
            
            # Location estimation
            if self.estimator == 'robust':
                self.means_[k] = spatial_median(X_k)
            else:
                self.means_[k] = np.mean(X_k, axis=0)
            
            # Dispersion estimation
            if self.estimator == 'robust':
                cov_k = tyler_m_estimator(X_k, self.means_[k])
            else:
                cov_k = ledoit_wolf_shrinkage(X_k)
            
            # Regularization for numerical stability
            cov_k += self.reg_param * np.eye(n_features)
            
            if self.store_covariance:
                self.covariance_.append(cov_k)
            
            # Store inverse and log-determinant
            self.covariance_inv_.append(np.linalg.inv(cov_k))
            sign, logdet = np.linalg.slogdet(cov_k)
            self.log_det_[k] = logdet
        
        return self
    
    def _compute_discriminant(self, X):
        """
        Compute discriminant scores for each class.
        
        The stable discriminant is:
            δ_k(x) = log(π_k) - ((α+p)/2) * log(1 + D_k(x)) - (1/2) * log|Σ_k|
        
        Parameters
        ----------
        X : ndarray of shape (n_samples, n_features)
            Input samples.
            
        Returns
        -------
        scores : ndarray of shape (n_samples, n_classes)
            Discriminant scores for each sample and class.
        """
        n_samples = X.shape[0]
        n_classes = len(self.classes_)
        n_features = X.shape[1]
        
        scores = np.zeros((n_samples, n_classes))
        
        for k in range(n_classes):
            # Compute squared Mahalanobis distance
            diff = X - self.means_[k]
            mahal_sq = np.sum(diff @ self.covariance_inv_[k] * diff, axis=1)
            
            # Stable log-likelihood (up to constants)
            # log f_k(x) ∝ -((α+p)/2) * log(1 + D_k) - (1/2) * log|Σ_k|
            log_likelihood = (
                -((self.alpha_ + n_features) / 2) * np.log1p(mahal_sq)
                - 0.5 * self.log_det_[k]
            )
            
            # Add log prior
            scores[:, k] = np.log(self.priors_[k]) + log_likelihood
        
        return scores
    
    def predict(self, X):
        """
        Predict class labels for samples in X.
        
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Samples to predict.
            
        Returns
        -------
        y_pred : ndarray of shape (n_samples,)
            Predicted class labels.
        """
        check_is_fitted(self)
        X = check_array(X)
        
        scores = self._compute_discriminant(X)
        indices = np.argmax(scores, axis=1)
        return self.classes_[indices]
    
    def predict_proba(self, X):
        """
        Predict class probabilities for samples in X.
        
        Probabilities are obtained via softmax normalization of
        discriminant scores.
        
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Samples to predict.
            
        Returns
        -------
        proba : ndarray of shape (n_samples, n_classes)
            Class probabilities for each sample.
        """
        check_is_fitted(self)
        X = check_array(X)
        
        scores = self._compute_discriminant(X)
        return softmax(scores, axis=1)
    
    def predict_log_proba(self, X):
        """
        Predict log class probabilities for samples in X.
        
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Samples to predict.
            
        Returns
        -------
        log_proba : ndarray of shape (n_samples, n_classes)
            Log class probabilities for each sample.
        """
        return np.log(self.predict_proba(X) + 1e-300)
    
    def decision_function(self, X):
        """
        Return discriminant scores for samples in X.
        
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Samples.
            
        Returns
        -------
        scores : ndarray of shape (n_samples, n_classes)
            Discriminant scores.
        """
        check_is_fitted(self)
        X = check_array(X)
        return self._compute_discriminant(X)
    
    def score(self, X, y):
        """
        Return the mean accuracy on the given test data and labels.
        
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Test samples.
            
        y : array-like of shape (n_samples,)
            True labels.
            
        Returns
        -------
        score : float
            Mean accuracy.
        """
        return np.mean(self.predict(X) == y)


class GaussianQDA(StableQDA):
    """
    Classical Gaussian QDA (for comparison).
    
    This is equivalent to StableQDA with alpha=2 and uses the standard
    Gaussian likelihood: log f(x) ∝ -D(x)/2.
    
    Parameters
    ----------
    estimator : {'standard', 'robust'}, default='standard'
        Parameter estimation strategy.
        
    reg_param : float, default=1e-6
        Regularization for numerical stability.
    """
    
    def __init__(self, estimator='standard', reg_param=1e-6, store_covariance=False):
        super().__init__(
            alpha=2.0,
            estimator=estimator,
            reg_param=reg_param,
            store_covariance=store_covariance,
        )
    
    def _compute_discriminant(self, X):
        """
        Compute Gaussian discriminant (linear in Mahalanobis distance).
        """
        n_samples = X.shape[0]
        n_classes = len(self.classes_)
        
        scores = np.zeros((n_samples, n_classes))
        
        for k in range(n_classes):
            diff = X - self.means_[k]
            mahal_sq = np.sum(diff @ self.covariance_inv_[k] * diff, axis=1)
            
            # Gaussian log-likelihood: -D/2 - (1/2)*log|Σ|
            log_likelihood = -0.5 * mahal_sq - 0.5 * self.log_det_[k]
            scores[:, k] = np.log(self.priors_[k]) + log_likelihood
        
        return scores
