"""
Utility functions and evaluation metrics for Stable-QDA.

This module provides:
- Tail-conditional recall (TailRec): Evaluates classifier performance
  in the distributional tails where heavy-tailed modeling matters most
- Standard classification metrics wrappers
- Data generation utilities for synthetic experiments
"""

import numpy as np
from scipy.stats import chi2
from sklearn.metrics import (
    accuracy_score,
    precision_recall_curve,
    average_precision_score,
    f1_score,
    roc_auc_score,
)


def tail_conditional_recall(y_true, y_pred, X, class_label=1, epsilon=0.10):
    """
    Compute tail-conditional recall at level epsilon.
    
    Measures classifier recall restricted to observations in the
    (1-ε) tail of the target class distribution, where ε determines
    the tail threshold.
    
    TailRec_{1-ε} = P(Ŷ = 1 | Y = 1, X ∈ T_1(ε))
    
    where T_1(ε) = {x : r_1(x) > q_{1-ε}} and r_1(x) is the
    Mahalanobis distance to the class center.
    
    Parameters
    ----------
    y_true : array-like of shape (n_samples,)
        True labels.
        
    y_pred : array-like of shape (n_samples,)
        Predicted labels.
        
    X : array-like of shape (n_samples, n_features)
        Feature matrix.
        
    class_label : int, default=1
        The class for which to compute tail recall (typically minority/positive).
        
    epsilon : float, default=0.10
        Tail probability. TailRec_{0.90} uses epsilon=0.10.
        
    Returns
    -------
    tail_recall : float
        Recall among observations in the tail region.
        
    n_tail : int
        Number of observations in the tail region.
        
    Notes
    -----
    Standard metrics (accuracy, ROC-AUC, PR-AUC) are dominated by
    high-density regions. TailRec explicitly evaluates the regime
    where heavy-tailed modeling makes a difference.
    
    Examples
    --------
    >>> from stable_qda.utils import tail_conditional_recall
    >>> # Compute recall in the 10% tail
    >>> tail_rec, n = tail_conditional_recall(y_test, y_pred, X_test, epsilon=0.10)
    >>> print(f"Tail recall (90%): {tail_rec:.3f} (n={n})")
    """
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    X = np.asarray(X)
    
    # Get class samples
    class_mask = y_true == class_label
    X_class = X[class_mask]
    n_class = X_class.shape[0]
    
    if n_class == 0:
        return np.nan, 0
    
    # Compute Mahalanobis distances from class center
    mean = np.mean(X_class, axis=0)
    cov = np.cov(X_class, rowvar=False)
    
    # Handle 1D case
    if cov.ndim == 0:
        cov = np.array([[cov]])
    
    # Regularize
    n_features = X.shape[1]
    cov += 1e-6 * np.eye(n_features)
    
    try:
        cov_inv = np.linalg.inv(cov)
    except np.linalg.LinAlgError:
        return np.nan, 0
    
    diff = X_class - mean
    mahal_distances = np.sqrt(np.sum((diff @ cov_inv) * diff, axis=1))
    
    # Determine tail threshold (from class samples only)
    threshold = np.percentile(mahal_distances, 100 * (1 - epsilon))
    
    # Identify tail samples in full dataset
    diff_full = X - mean
    mahal_full = np.sqrt(np.sum((diff_full @ cov_inv) * diff_full, axis=1))
    
    # Tail region: class samples with Mahalanobis distance above threshold
    tail_mask = class_mask & (mahal_full > threshold)
    n_tail = np.sum(tail_mask)
    
    if n_tail == 0:
        return np.nan, 0
    
    # Recall in tail region
    tail_recall = np.mean(y_pred[tail_mask] == class_label)
    
    return tail_recall, n_tail


def recall_at_precision(y_true, y_score, min_precision=0.95):
    """
    Compute recall at a given precision threshold.
    
    Finds the maximum recall achievable while maintaining at least
    the specified precision level.
    
    Parameters
    ----------
    y_true : array-like of shape (n_samples,)
        Binary true labels.
        
    y_score : array-like of shape (n_samples,)
        Target scores (e.g., predicted probabilities for positive class).
        
    min_precision : float, default=0.95
        Minimum precision threshold.
        
    Returns
    -------
    recall : float
        Maximum recall at or above min_precision.
        Returns 0.0 if precision threshold cannot be achieved.
        
    Examples
    --------
    >>> from stable_qda.utils import recall_at_precision
    >>> recall = recall_at_precision(y_test, clf.predict_proba(X_test)[:, 1])
    >>> print(f"Recall@P95: {recall:.3f}")
    """
    precision, recall, _ = precision_recall_curve(y_true, y_score)
    
    # Find recalls where precision >= threshold
    valid_idx = precision >= min_precision
    
    if not np.any(valid_idx):
        return 0.0
    
    return np.max(recall[valid_idx])


def evaluate_classifier(clf, X_test, y_test, positive_label=1):
    """
    Comprehensive evaluation of a fitted classifier.
    
    Parameters
    ----------
    clf : classifier
        Fitted classifier with predict and predict_proba methods.
        
    X_test : array-like of shape (n_samples, n_features)
        Test features.
        
    y_test : array-like of shape (n_samples,)
        Test labels.
        
    positive_label : int, default=1
        Label of the positive/minority class.
        
    Returns
    -------
    metrics : dict
        Dictionary containing:
        - accuracy: Classification accuracy
        - pr_auc: Precision-recall AUC
        - roc_auc: ROC AUC
        - f1: F1 score
        - recall_at_p95: Recall at 95% precision
        - tail_recall_90: Tail recall at 90% level
        - tail_recall_95: Tail recall at 95% level
        - tail_recall_99: Tail recall at 99% level
    """
    y_pred = clf.predict(X_test)
    
    # Get probability scores
    if hasattr(clf, 'predict_proba'):
        y_score = clf.predict_proba(X_test)
        # Handle binary and multiclass
        if y_score.ndim > 1:
            # Find column for positive class
            classes = clf.classes_ if hasattr(clf, 'classes_') else np.unique(y_test)
            pos_idx = np.where(classes == positive_label)[0]
            if len(pos_idx) > 0:
                y_score = y_score[:, pos_idx[0]]
            else:
                y_score = y_score[:, 1] if y_score.shape[1] > 1 else y_score[:, 0]
    else:
        y_score = clf.decision_function(X_test)
        if y_score.ndim > 1:
            y_score = y_score[:, 1]
    
    metrics = {}
    
    # Standard metrics
    metrics['accuracy'] = accuracy_score(y_test, y_pred)
    metrics['f1'] = f1_score(y_test, y_pred, pos_label=positive_label, zero_division=0)
    
    # Only compute AUC metrics for binary classification
    unique_labels = np.unique(y_test)
    if len(unique_labels) == 2:
        try:
            metrics['pr_auc'] = average_precision_score(y_test, y_score, pos_label=positive_label)
            metrics['roc_auc'] = roc_auc_score(y_test, y_score)
            metrics['recall_at_p95'] = recall_at_precision(y_test == positive_label, y_score, 0.95)
        except ValueError:
            metrics['pr_auc'] = np.nan
            metrics['roc_auc'] = np.nan
            metrics['recall_at_p95'] = np.nan
    else:
        metrics['pr_auc'] = np.nan
        metrics['roc_auc'] = np.nan
        metrics['recall_at_p95'] = np.nan
    
    # Tail-conditional recall at multiple levels
    for epsilon, level in [(0.10, 90), (0.05, 95), (0.01, 99)]:
        tail_rec, n_tail = tail_conditional_recall(
            y_test, y_pred, X_test, 
            class_label=positive_label, 
            epsilon=epsilon
        )
        metrics[f'tail_recall_{level}'] = tail_rec
        metrics[f'n_tail_{level}'] = n_tail
    
    return metrics


def generate_stable_data(
    n_samples=500,
    n_features=10,
    alpha=1.5,
    class_separation=0.5,
    scale_ratio=1.0,
    random_state=None,
):
    """
    Generate synthetic data from sub-Gaussian α-stable distributions.
    
    Uses the stochastic representation:
        X = μ + A^{1/2} Σ^{1/2} Z
    where A ~ S_{α/2}(1, 1, 0) is a positive stable r.v. and Z ~ N(0, I).
    
    Parameters
    ----------
    n_samples : int, default=500
        Samples per class.
        
    n_features : int, default=10
        Number of features.
        
    alpha : float, default=1.5
        Stability index (tail parameter).
        
    class_separation : float, default=0.5
        Mean separation between classes (along each coordinate).
        
    scale_ratio : float, default=1.0
        Ratio of class 1 scale to class 0 scale.
        
    random_state : int or None, default=None
        Random seed.
        
    Returns
    -------
    X : ndarray of shape (2*n_samples, n_features)
        Feature matrix.
        
    y : ndarray of shape (2*n_samples,)
        Labels (0 and 1).
        
    Notes
    -----
    For α = 2, this generates Gaussian data.
    For α < 2, data has heavy tails with infinite variance.
    """
    rng = np.random.RandomState(random_state)
    
    # Generate positive stable subordinator A ~ S_{α/2}(1, 1, 0)
    # Using Chambers-Mallows-Stuck method
    def sample_positive_stable(n, alpha_half):
        if alpha_half >= 1:
            # α/2 >= 1 means α >= 2, use Gaussian
            return np.ones(n)
        
        # Stable subordinator via CMS method
        U = rng.uniform(-np.pi/2, np.pi/2, n)
        W = rng.exponential(1, n)
        
        # For totally skewed stable (β = 1)
        t = alpha_half
        S = (
            np.sin(t * (U + np.pi/2)) / (np.cos(U) ** (1/t))
            * (np.cos(U - t * (U + np.pi/2)) / W) ** ((1 - t) / t)
        )
        
        return np.maximum(S, 1e-10)
    
    # Class 0: centered at origin
    A0 = sample_positive_stable(n_samples, alpha / 2)
    Z0 = rng.randn(n_samples, n_features)
    X0 = np.sqrt(A0)[:, np.newaxis] * Z0
    
    # Class 1: shifted and scaled
    A1 = sample_positive_stable(n_samples, alpha / 2)
    Z1 = rng.randn(n_samples, n_features)
    X1 = scale_ratio * np.sqrt(A1)[:, np.newaxis] * Z1 + class_separation
    
    X = np.vstack([X0, X1])
    y = np.array([0] * n_samples + [1] * n_samples)
    
    # Shuffle
    perm = rng.permutation(2 * n_samples)
    X = X[perm]
    y = y[perm]
    
    return X, y


def cross_validate_comparison(classifiers, X, y, cv=5, random_state=None):
    """
    Compare multiple classifiers using stratified cross-validation.
    
    Parameters
    ----------
    classifiers : dict
        Dictionary mapping names to classifier instances.
        
    X : array-like of shape (n_samples, n_features)
        Features.
        
    y : array-like of shape (n_samples,)
        Labels.
        
    cv : int, default=5
        Number of folds.
        
    random_state : int or None, default=None
        Random seed for reproducibility.
        
    Returns
    -------
    results : dict
        Dictionary mapping classifier names to lists of metric dicts
        (one per fold).
    """
    from sklearn.model_selection import StratifiedKFold
    from copy import deepcopy
    
    skf = StratifiedKFold(n_splits=cv, shuffle=True, random_state=random_state)
    
    results = {name: [] for name in classifiers}
    
    for train_idx, test_idx in skf.split(X, y):
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        
        for name, clf_template in classifiers.items():
            clf = deepcopy(clf_template)
            clf.fit(X_train, y_train)
            metrics = evaluate_classifier(clf, X_test, y_test)
            results[name].append(metrics)
    
    return results


def summarize_cv_results(results, metric='accuracy'):
    """
    Summarize cross-validation results.
    
    Parameters
    ----------
    results : dict
        Output from cross_validate_comparison.
        
    metric : str, default='accuracy'
        Metric to summarize.
        
    Returns
    -------
    summary : dict
        Dictionary mapping classifier names to (mean, std) tuples.
    """
    summary = {}
    
    for name, fold_results in results.items():
        values = [r[metric] for r in fold_results if not np.isnan(r.get(metric, np.nan))]
        if values:
            summary[name] = (np.mean(values), np.std(values))
        else:
            summary[name] = (np.nan, np.nan)
    
    return summary
