"""
Dataset diagnostics for Stable-QDA estimator selection.

This module implements the data-driven diagnostic procedure described in
Section 4.5 and Appendix B of the paper. Given a dataset, it recommends
whether to use:
    - Gaussian QDA (light tails)
    - Stable-QDA with standard estimators (moderate tails + heteroscedasticity)
    - Stable-QDA with robust estimators (heavy tails + homoscedasticity)

Key Insight (from paper):
    Tyler's M-estimator normalizes scatter matrices to fixed trace, discarding
    scale information. When classes have different spreads (high determinant
    ratio), this hurts classification. The diagnostic procedure quantifies
    this trade-off.
"""

import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any

from .alpha_estimation import estimate_alpha, estimate_alpha_per_class
from .estimators import compute_determinant_ratio


# Tyler threshold table from paper (Table 3 / Appendix D)
# Determinant ratio -> α threshold below which Tyler is safe
TYLER_THRESHOLDS = [
    (10, 2.0),       # det ratio < 10: always use robust
    (50, 1.9),
    (100, 1.8),
    (1000, 1.7),
    (float('inf'), 1.6),  # det ratio > 1000: use robust only if α < 1.6
]

# Trace ratio thresholds (proxy when determinant unstable)
TYLER_THRESHOLDS_TRACE = [
    (1.25, 2.0),
    (1.5, 1.8),
    (2.0, 1.7),
    (float('inf'), 1.6),
]


@dataclass
class DiagnosticResult:
    """
    Results from dataset diagnostics.
    
    Attributes
    ----------
    recommendation : str
        One of 'gaussian', 'stable_standard', 'stable_robust'.
        
    alpha_estimate : float
        Estimated tail index.
        
    alpha_per_class : dict
        Per-class α estimates.
        
    determinant_ratio : float
        max|Σ_k| / min|Σ_k| across classes.
        
    trace_ratio : float
        Trace ratio as proxy for determinant ratio.
        
    tyler_threshold : float
        The α threshold below which Tyler is recommended.
        
    heavy_tail_signals : dict
        Dictionary of heavy-tail indicators.
        
    likely_heavy_tailed : bool
        True if ≥2 of 3 heavy-tail signals present.
        
    reasoning : str
        Human-readable explanation of recommendation.
    """
    recommendation: str
    alpha_estimate: float
    alpha_per_class: Dict[Any, float]
    determinant_ratio: float
    trace_ratio: float
    tyler_threshold: float
    heavy_tail_signals: Dict[str, bool]
    likely_heavy_tailed: bool
    reasoning: str


class DatasetDiagnostics:
    """
    Diagnostic tool for selecting Stable-QDA estimator configuration.
    
    Analyzes a dataset to determine:
    1. Tail heaviness (via α estimation)
    2. Class heteroscedasticity (via determinant/trace ratios)
    3. Presence of heavy-tail signals (outlier rate, mean-median shift)
    
    Then recommends the optimal configuration based on empirically
    derived thresholds from the paper's synthetic experiments.
    
    Parameters
    ----------
    outlier_threshold : float, default=0.05
        Expected outlier rate under Gaussianity. Actual rate > 1.4×
        this value is a heavy-tail signal.
        
    mean_median_threshold : float, default=0.2
        Relative mean-median shift threshold for heavy-tail signal.
        
    Examples
    --------
    >>> from stable_qda.diagnostics import DatasetDiagnostics
    >>> import numpy as np
    >>> X = np.random.standard_t(df=3, size=(1000, 10))  # Heavy-tailed
    >>> y = np.array([0]*500 + [1]*500)
    >>> diag = DatasetDiagnostics()
    >>> result = diag.fit(X, y)
    >>> print(result.recommendation)
    'stable_robust'
    >>> print(result.reasoning)
    """
    
    def __init__(self, outlier_threshold=0.05, mean_median_threshold=0.2):
        self.outlier_threshold = outlier_threshold
        self.mean_median_threshold = mean_median_threshold
    
    def fit(self, X, y) -> DiagnosticResult:
        """
        Analyze dataset and produce recommendation.
        
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Input data.
            
        y : array-like of shape (n_samples,)
            Class labels.
            
        Returns
        -------
        result : DiagnosticResult
            Diagnostic results with recommendation.
        """
        X = np.asarray(X)
        y = np.asarray(y)
        
        n_samples, n_features = X.shape
        classes = np.unique(y)
        
        # 1. Estimate tail index
        alpha_estimate = estimate_alpha(X, y)
        alpha_per_class = estimate_alpha_per_class(X, y)
        
        # 2. Compute scale ratios
        det_ratio, trace_ratio = compute_determinant_ratio(X, y)
        
        # 3. Check heavy-tail signals
        heavy_tail_signals = self._check_heavy_tail_signals(X, y)
        n_signals = sum(heavy_tail_signals.values())
        likely_heavy_tailed = n_signals >= 2
        
        # 4. Determine Tyler threshold
        tyler_threshold = self._get_tyler_threshold(det_ratio, trace_ratio)
        
        # 5. Make recommendation
        recommendation, reasoning = self._make_recommendation(
            alpha_estimate,
            det_ratio,
            trace_ratio,
            tyler_threshold,
            likely_heavy_tailed,
            heavy_tail_signals,
        )
        
        return DiagnosticResult(
            recommendation=recommendation,
            alpha_estimate=alpha_estimate,
            alpha_per_class=alpha_per_class,
            determinant_ratio=det_ratio,
            trace_ratio=trace_ratio,
            tyler_threshold=tyler_threshold,
            heavy_tail_signals=heavy_tail_signals,
            likely_heavy_tailed=likely_heavy_tailed,
            reasoning=reasoning,
        )
    
    def _check_heavy_tail_signals(self, X, y) -> Dict[str, bool]:
        """
        Check three indicators of heavy-tailed behavior.
        
        1. α < 1.8
        2. Outlier rate > 7% (expected 5% under Gaussianity)
        3. Mean-median shift > 0.2
        """
        signals = {}
        
        # Signal 1: Low α estimate
        alpha = estimate_alpha(X, y)
        signals['low_alpha'] = alpha < 1.8
        
        # Signal 2: High outlier rate
        outlier_rate = self._compute_outlier_rate(X, y)
        signals['high_outlier_rate'] = outlier_rate > 1.4 * self.outlier_threshold
        
        # Signal 3: Mean-median shift
        mean_median_shift = self._compute_mean_median_shift(X)
        signals['mean_median_shift'] = mean_median_shift > self.mean_median_threshold
        
        return signals
    
    def _compute_outlier_rate(self, X, y) -> float:
        """
        Compute fraction of points with Mahalanobis distance exceeding
        the 95th percentile of χ²_p distribution.
        """
        from scipy.stats import chi2
        
        n_features = X.shape[1]
        threshold = chi2.ppf(0.95, df=n_features)
        
        total_outliers = 0
        total_samples = 0
        
        for k in np.unique(y):
            X_k = X[y == k]
            n_k = X_k.shape[0]
            
            if n_k < n_features + 1:
                continue
            
            # Compute Mahalanobis distances
            mean_k = np.mean(X_k, axis=0)
            cov_k = np.cov(X_k, rowvar=False)
            
            # Regularize if needed
            cov_k += 1e-6 * np.eye(n_features)
            
            try:
                cov_inv = np.linalg.inv(cov_k)
                diff = X_k - mean_k
                mahal_sq = np.sum((diff @ cov_inv) * diff, axis=1)
                
                total_outliers += np.sum(mahal_sq > threshold)
                total_samples += n_k
            except np.linalg.LinAlgError:
                continue
        
        if total_samples == 0:
            return 0.05  # Default to expected rate
        
        return total_outliers / total_samples
    
    def _compute_mean_median_shift(self, X) -> float:
        """
        Compute relative norm difference between mean and median.
        """
        mean = np.mean(X, axis=0)
        median = np.median(X, axis=0)
        
        denom = np.linalg.norm(mean) + np.linalg.norm(median) + 1e-10
        return np.linalg.norm(mean - median) / denom
    
    def _get_tyler_threshold(self, det_ratio, trace_ratio) -> float:
        """
        Look up Tyler threshold from empirical tables.
        """
        # Try determinant ratio first
        if not np.isnan(det_ratio) and det_ratio > 0:
            for threshold, alpha_thresh in TYLER_THRESHOLDS:
                if det_ratio < threshold:
                    return alpha_thresh
        
        # Fall back to trace ratio
        for threshold, alpha_thresh in TYLER_THRESHOLDS_TRACE:
            if trace_ratio < threshold:
                return alpha_thresh
        
        return 1.6  # Default conservative threshold
    
    def _make_recommendation(
        self,
        alpha: float,
        det_ratio: float,
        trace_ratio: float,
        tyler_threshold: float,
        likely_heavy_tailed: bool,
        signals: Dict[str, bool],
    ) -> Tuple[str, str]:
        """
        Apply decision rules from Appendix B.
        """
        reasons = []
        
        # Rule 1: Light tails + no heavy-tail signals -> Gaussian
        if alpha > 1.8 and not likely_heavy_tailed:
            reasons.append(f"α={alpha:.2f} > 1.8 indicates light tails")
            reasons.append("No strong heavy-tail signals detected")
            reasons.append("Stable likelihood provides minimal benefit")
            return 'gaussian', ' | '.join(reasons)
        
        # Rule 2: Very heavy tails -> Robust
        if alpha < tyler_threshold:
            reasons.append(f"α={alpha:.2f} < {tyler_threshold:.1f} (Tyler threshold)")
            reasons.append("Heavy tails warrant robust estimators")
            return 'stable_robust', ' | '.join(reasons)
        
        # Rule 3: Heavy tails but high heteroscedasticity -> Standard
        if alpha < 1.5:
            reasons.append(f"α={alpha:.2f} < 1.5 indicates heavy tails")
            reasons.append("But robust estimators recommended due to very heavy tails")
            return 'stable_robust', ' | '.join(reasons)
        
        # Rule 4: Moderate tails + heavy-tail signals -> Standard (preserve scale)
        if likely_heavy_tailed:
            reasons.append(f"α={alpha:.2f} indicates moderate tails")
            if not np.isnan(det_ratio):
                reasons.append(f"Determinant ratio={det_ratio:.1f} suggests heteroscedasticity")
            else:
                reasons.append(f"Trace ratio={trace_ratio:.2f}")
            reasons.append("Standard estimators preserve discriminative scale information")
            return 'stable_standard', ' | '.join(reasons)
        
        # Rule 5: Default to Gaussian
        reasons.append(f"α={alpha:.2f} near Gaussian")
        reasons.append("No strong evidence for heavy tails")
        return 'gaussian', ' | '.join(reasons)


def diagnose_dataset(X, y, verbose=True) -> DiagnosticResult:
    """
    Convenience function to run diagnostics on a dataset.
    
    Parameters
    ----------
    X : array-like of shape (n_samples, n_features)
        Input data.
        
    y : array-like of shape (n_samples,)
        Class labels.
        
    verbose : bool, default=True
        If True, print diagnostic summary.
        
    Returns
    -------
    result : DiagnosticResult
        Diagnostic results with recommendation.
        
    Examples
    --------
    >>> from stable_qda.diagnostics import diagnose_dataset
    >>> from sklearn.datasets import load_iris
    >>> X, y = load_iris(return_X_y=True)
    >>> result = diagnose_dataset(X, y)
    
    === Stable-QDA Dataset Diagnostics ===
    
    Tail Index Estimation:
      Overall α: 1.92
      Per-class: {0: 1.89, 1: 1.94, 2: 1.93}
    
    Scale Analysis:
      Determinant ratio: 15.3
      Trace ratio: 1.8
      Tyler threshold: α < 1.9
    
    Heavy-Tail Signals:
      Low α (< 1.8): False
      High outlier rate: False
      Mean-median shift: False
      Likely heavy-tailed: False
    
    RECOMMENDATION: gaussian
    Reasoning: α=1.92 > 1.8 indicates light tails | No strong heavy-tail signals
    """
    diag = DatasetDiagnostics()
    result = diag.fit(X, y)
    
    if verbose:
        print("\n=== Stable-QDA Dataset Diagnostics ===\n")
        
        print("Tail Index Estimation:")
        print(f"  Overall α: {result.alpha_estimate:.2f}")
        print(f"  Per-class: {{{', '.join(f'{k}: {v:.2f}' for k, v in result.alpha_per_class.items())}}}")
        
        print("\nScale Analysis:")
        if not np.isnan(result.determinant_ratio):
            print(f"  Determinant ratio: {result.determinant_ratio:.1f}")
        else:
            print("  Determinant ratio: N/A (numerically unstable)")
        print(f"  Trace ratio: {result.trace_ratio:.2f}")
        print(f"  Tyler threshold: α < {result.tyler_threshold:.1f}")
        
        print("\nHeavy-Tail Signals:")
        print(f"  Low α (< 1.8): {result.heavy_tail_signals['low_alpha']}")
        print(f"  High outlier rate: {result.heavy_tail_signals['high_outlier_rate']}")
        print(f"  Mean-median shift: {result.heavy_tail_signals['mean_median_shift']}")
        print(f"  Likely heavy-tailed: {result.likely_heavy_tailed}")
        
        print(f"\nRECOMMENDATION: {result.recommendation}")
        print(f"Reasoning: {result.reasoning}")
        print()
    
    return result


def get_recommended_config(result: DiagnosticResult) -> dict:
    """
    Convert diagnostic result to StableQDA constructor arguments.
    
    Parameters
    ----------
    result : DiagnosticResult
        Output from DatasetDiagnostics.fit() or diagnose_dataset().
        
    Returns
    -------
    config : dict
        Keyword arguments for StableQDA constructor.
        
    Examples
    --------
    >>> from stable_qda import StableQDA
    >>> from stable_qda.diagnostics import diagnose_dataset, get_recommended_config
    >>> result = diagnose_dataset(X, y, verbose=False)
    >>> config = get_recommended_config(result)
    >>> clf = StableQDA(**config)
    """
    if result.recommendation == 'gaussian':
        return {'alpha': 2.0, 'estimator': 'standard'}
    elif result.recommendation == 'stable_standard':
        return {'alpha': 1.5, 'estimator': 'standard'}
    elif result.recommendation == 'stable_robust':
        return {'alpha': 1.5, 'estimator': 'robust'}
    else:
        raise ValueError(f"Unknown recommendation: {result.recommendation}")
