"""
Common utilities for synthetic experiments.

Provides data generation, experiment running infrastructure, and result formatting.
"""

import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Callable
from scipy import stats
import warnings

import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))

from stable_qda import StableQDA
from estimators import spatial_median, tyler_m_estimator, ledoit_wolf_shrinkage


# =============================================================================
# Data Generation
# =============================================================================

@dataclass
class StableMixtureParams:
    """Parameters for a two-class stable mixture."""
    alpha_0: float
    alpha_1: float
    mu_0: np.ndarray
    mu_1: np.ndarray
    Sigma_0: np.ndarray
    Sigma_1: np.ndarray
    prior_0: float = 0.5


def sample_positive_stable(n: int, alpha_half: float, rng: np.random.Generator) -> np.ndarray:
    """
    Sample from positive α/2-stable subordinator using Chambers-Mallows-Stuck method.
    
    For α/2 >= 1 (i.e., α >= 2), returns ones (Gaussian case).
    """
    if alpha_half >= 1:
        return np.ones(n)
    
    U = rng.uniform(-np.pi/2, np.pi/2, n)
    W = rng.exponential(1, n)
    
    t = alpha_half
    S = (
        np.sin(t * (U + np.pi/2)) / (np.cos(U) ** (1/t))
        * (np.cos(U - t * (U + np.pi/2)) / W) ** ((1 - t) / t)
    )
    
    return np.maximum(S, 1e-10)


def generate_stable_class(
    n: int,
    alpha: float,
    mu: np.ndarray,
    Sigma: np.ndarray,
    rng: np.random.Generator
) -> np.ndarray:
    """
    Generate samples from sub-Gaussian α-stable distribution.
    
    X = μ + A^{1/2} Σ^{1/2} Z
    where A ~ S_{α/2}(1,1,0) and Z ~ N(0,I).
    """
    d = len(mu)
    
    # Sample subordinator
    A = sample_positive_stable(n, alpha / 2, rng)
    
    # Sample Gaussian
    Z = rng.standard_normal((n, d))
    
    # Compute Σ^{1/2}
    eigvals, eigvecs = np.linalg.eigh(Sigma)
    Sigma_sqrt = eigvecs @ np.diag(np.sqrt(np.maximum(eigvals, 0))) @ eigvecs.T
    
    # Generate samples
    X = mu + np.sqrt(A)[:, np.newaxis] * (Z @ Sigma_sqrt.T)
    
    return X


def generate_balanced_stable_mixture(
    params: StableMixtureParams,
    n_per_class: int,
    seed: int = None
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate balanced two-class data from stable mixture.
    
    Returns
    -------
    X : ndarray of shape (2*n_per_class, d)
    y : ndarray of shape (2*n_per_class,)
    """
    rng = np.random.default_rng(seed)
    
    X_0 = generate_stable_class(n_per_class, params.alpha_0, params.mu_0, params.Sigma_0, rng)
    X_1 = generate_stable_class(n_per_class, params.alpha_1, params.mu_1, params.Sigma_1, rng)
    
    X = np.vstack([X_0, X_1])
    y = np.array([0] * n_per_class + [1] * n_per_class)
    
    # Shuffle
    perm = rng.permutation(len(y))
    return X[perm], y[perm]


# =============================================================================
# Covariance Structure Generators
# =============================================================================

def make_homoscedastic_params(d: int, alpha: float, separation: float = 2.0) -> StableMixtureParams:
    """Homoscedastic setup: Σ_0 = Σ_1 = I (det ratio = 1)."""
    mu_0 = np.zeros(d)
    mu_1 = np.zeros(d)
    mu_1[0] = separation
    
    return StableMixtureParams(
        alpha_0=alpha, alpha_1=alpha,
        mu_0=mu_0, mu_1=mu_1,
        Sigma_0=np.eye(d), Sigma_1=np.eye(d),
        prior_0=0.5
    )


def make_heteroscedastic_params(
    d: int,
    alpha: float,
    scale_ratio: float = 2.0,
    separation: float = 2.0
) -> StableMixtureParams:
    """
    Heteroscedastic setup with controllable scale ratio.
    
    det(Σ_1) / det(Σ_0) = scale_ratio^d
    
    Parameters
    ----------
    scale_ratio : float
        Scale multiplier for class 1. 
        - scale_ratio=1.0 → det ratio = 1
        - scale_ratio=2.0 → det ratio = 2^d ≈ 1024 for d=10
        - scale_ratio=3.0 → det ratio = 3^d ≈ 59049 for d=10
    """
    mu_0 = np.zeros(d)
    mu_1 = np.zeros(d)
    mu_1[0] = separation
    
    Sigma_0 = np.eye(d)
    Sigma_1 = scale_ratio * np.eye(d)
    
    return StableMixtureParams(
        alpha_0=alpha, alpha_1=alpha,
        mu_0=mu_0, mu_1=mu_1,
        Sigma_0=Sigma_0, Sigma_1=Sigma_1,
        prior_0=0.5
    )


# =============================================================================
# Classifiers
# =============================================================================

def make_gaussian_qda():
    """Create Gaussian QDA (sklearn-compatible StableQDA with α=2)."""
    return StableQDA(alpha=2.0, estimator='standard')


def make_stable_qda_standard(alpha: float = 1.5):
    """Create Stable-QDA with standard estimators (mean + Ledoit-Wolf)."""
    return StableQDA(alpha=alpha, estimator='standard')


def make_stable_qda_robust(alpha: float = 1.5):
    """Create Stable-QDA with robust estimators (spatial median + Tyler)."""
    return StableQDA(alpha=alpha, estimator='robust')


# =============================================================================
# Experiment Infrastructure
# =============================================================================

def run_single_trial(
    params: StableMixtureParams,
    classifiers: Dict[str, Callable],
    n_per_class: int = 500,
    test_size: float = 0.2,
    seed: int = None
) -> Dict[str, float]:
    """
    Run single train/test trial.
    
    Parameters
    ----------
    params : StableMixtureParams
        Data generation parameters.
    classifiers : dict
        Mapping from classifier name to callable that returns a classifier.
    n_per_class : int
        Samples per class.
    test_size : float
        Fraction for test set.
    seed : int
        Random seed.
        
    Returns
    -------
    results : dict
        Mapping from classifier name to test accuracy.
    """
    rng = np.random.default_rng(seed)
    
    # Generate data
    X, y = generate_balanced_stable_mixture(params, n_per_class, seed=seed)
    
    # Train/test split
    n = len(y)
    n_test = int(n * test_size)
    perm = rng.permutation(n)
    
    test_idx = perm[:n_test]
    train_idx = perm[n_test:]
    
    X_train, y_train = X[train_idx], y[train_idx]
    X_test, y_test = X[test_idx], y[test_idx]
    
    # Evaluate classifiers
    results = {}
    for name, clf_factory in classifiers.items():
        try:
            clf = clf_factory()
            clf.fit(X_train, y_train)
            results[name] = clf.score(X_test, y_test)
        except Exception as e:
            warnings.warn(f"Classifier {name} failed: {e}")
            results[name] = np.nan
    
    return results


def run_experiment(
    param_grid: List[Dict],
    classifiers: Dict[str, Callable],
    n_repeats: int = 20,
    n_per_class: int = 500,
    base_seed: int = 42,
    verbose: bool = True
) -> pd.DataFrame:
    """
    Run experiment over parameter grid with multiple repeats.
    
    Parameters
    ----------
    param_grid : list of dict
        Each dict contains 'params' (StableMixtureParams) and any metadata.
    classifiers : dict
        Classifier factories.
    n_repeats : int
        Number of random repeats per configuration.
    n_per_class : int
        Samples per class.
    base_seed : int
        Base random seed.
    verbose : bool
        Print progress.
        
    Returns
    -------
    results_df : DataFrame
        Long-form results with columns for each parameter and classifier accuracies.
    """
    results = []
    
    total = len(param_grid) * n_repeats
    count = 0
    
    for config in param_grid:
        params = config['params']
        metadata = {k: v for k, v in config.items() if k != 'params'}
        
        for rep in range(n_repeats):
            seed = base_seed + rep * 1000 + count
            
            trial_results = run_single_trial(
                params, classifiers, n_per_class, seed=seed
            )
            
            row = {**metadata, 'repeat': rep, **trial_results}
            results.append(row)
            
            count += 1
            if verbose and count % 10 == 0:
                print(f"Progress: {count}/{total} ({100*count/total:.1f}%)")
    
    return pd.DataFrame(results)


# =============================================================================
# Result Formatting
# =============================================================================

def summarize_results(
    df: pd.DataFrame,
    group_cols: List[str],
    classifier_cols: List[str]
) -> pd.DataFrame:
    """
    Summarize experiment results with mean and std.
    
    Parameters
    ----------
    df : DataFrame
        Raw results from run_experiment.
    group_cols : list
        Columns to group by (e.g., ['alpha', 'scale_ratio']).
    classifier_cols : list
        Classifier name columns to summarize.
        
    Returns
    -------
    summary : DataFrame
        Summary with mean and std for each classifier.
    """
    agg_dict = {col: ['mean', 'std'] for col in classifier_cols}
    summary = df.groupby(group_cols).agg(agg_dict)
    summary.columns = ['_'.join(col) for col in summary.columns]
    return summary.reset_index()


def find_best_method(row: pd.Series, methods: List[str]) -> str:
    """Find the best performing method in a row."""
    best = methods[0]
    best_val = row[best]
    for m in methods[1:]:
        if row[m] > best_val:
            best = m
            best_val = row[m]
    return best


def compute_paired_ttest(
    df: pd.DataFrame,
    method1: str,
    method2: str,
    group_cols: List[str]
) -> pd.DataFrame:
    """
    Compute paired t-test between two methods for each configuration.
    
    Returns DataFrame with p-values and effect sizes.
    """
    results = []
    
    for name, group in df.groupby(group_cols):
        vals1 = group[method1].values
        vals2 = group[method2].values
        
        # Paired t-test
        t_stat, p_value = stats.ttest_rel(vals1, vals2)
        
        # Cohen's d
        diff = vals1 - vals2
        cohens_d = np.mean(diff) / np.std(diff, ddof=1) if np.std(diff) > 0 else 0
        
        if isinstance(name, tuple):
            row = dict(zip(group_cols, name))
        else:
            row = {group_cols[0]: name}
        
        row.update({
            't_statistic': t_stat,
            'p_value': p_value,
            'cohens_d': cohens_d,
            'mean_diff': np.mean(diff),
            'significant': p_value < 0.05
        })
        results.append(row)
    
    return pd.DataFrame(results)


def results_to_latex(
    summary_df: pd.DataFrame,
    classifier_cols: List[str],
    caption: str = "Experimental results",
    label: str = "tab:results"
) -> str:
    """Convert summary DataFrame to LaTeX table."""
    
    # Format columns
    formatted = summary_df.copy()
    for col in classifier_cols:
        mean_col = f"{col}_mean"
        std_col = f"{col}_std"
        if mean_col in formatted.columns and std_col in formatted.columns:
            formatted[col] = formatted.apply(
                lambda r: f"{100*r[mean_col]:.1f} ± {100*r[std_col]:.1f}",
                axis=1
            )
            formatted = formatted.drop(columns=[mean_col, std_col])
    
    latex = formatted.to_latex(index=False, escape=False)
    
    # Add caption and label
    latex = latex.replace(
        "\\begin{tabular}",
        f"\\caption{{{caption}}}\n\\label{{{label}}}\n\\begin{{tabular}}"
    )
    
    return latex


# =============================================================================
# Utility Functions
# =============================================================================

def set_seed(seed: int):
    """Set random seeds for reproducibility."""
    np.random.seed(seed)


def print_section(title: str, width: int = 70):
    """Print section header."""
    print("=" * width)
    print(title.center(width))
    print("=" * width)
