"""
Experiment 1: When Does the Stable Likelihood Help?
====================================================

This is the KEY experiment for the paper (generates Figure 1).

Question: Does the stable likelihood improve classification over Gaussian QDA?
          How does this interact with estimator choice and class heteroscedasticity?

Setup:
- d=10 dimensions, n=500 samples per class
- Vary α ∈ [1.0, 2.0] (tail heaviness)
- Vary scale ratio ∈ {1.0, 2.0, 3.0} (class heteroscedasticity)
- Compare: Gaussian QDA, Stable (mean+LW), Stable (smed+Tyler)

Key findings:
1. Robust estimators (smed+Tyler) dominate at heavy tails (α < 1.5)
2. Standard estimators (mean+LW) win at moderate tails with heteroscedasticity
3. Gaussian QDA wins at light tails (α > 1.8) with heteroscedasticity

Output:
- Figure 1: 3-panel accuracy vs α plot with background shading
- Tables D.1-D.3: Complete accuracy results for appendix
- Table 3: Tyler threshold summary for main paper

Usage:
    python exp1_likelihood_benefit.py
    python exp1_likelihood_benefit.py --quick  # Fast test run
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection
import os
import argparse

from common import (
    StableMixtureParams,
    make_heteroscedastic_params,
    make_gaussian_qda,
    make_stable_qda_standard,
    make_stable_qda_robust,
    run_experiment,
    summarize_results,
    compute_paired_ttest,
    print_section,
    set_seed,
)


# =============================================================================
# Configuration
# =============================================================================

DEFAULT_CONFIG = {
    'n_per_class': 500,
    'd': 10,
    'n_repeats': 20,
    'alphas': [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
    'scale_ratios': [1.0, 2.0, 3.0],
    'separation': 2.0,
    'base_seed': 42,
    'save_path': 'outputs/',
}

QUICK_CONFIG = {
    **DEFAULT_CONFIG,
    'n_repeats': 5,
    'alphas': [1.0, 1.3, 1.5, 1.7, 2.0],
}


# =============================================================================
# Experiment
# =============================================================================

def build_param_grid(config: dict) -> list:
    """Build parameter grid for experiment."""
    param_grid = []
    
    for scale_ratio in config['scale_ratios']:
        for alpha in config['alphas']:
            params = make_heteroscedastic_params(
                d=config['d'],
                alpha=alpha,
                scale_ratio=scale_ratio,
                separation=config['separation']
            )
            
            # Compute det ratio for reference
            det_ratio = scale_ratio ** config['d']
            
            param_grid.append({
                'params': params,
                'alpha': alpha,
                'scale_ratio': scale_ratio,
                'det_ratio': det_ratio,
            })
    
    return param_grid


def get_classifiers(alpha: float) -> dict:
    """Get classifier factories for a given alpha."""
    return {
        'gaussian': make_gaussian_qda,
        'stable_standard': lambda: make_stable_qda_standard(alpha),
        'stable_robust': lambda: make_stable_qda_robust(alpha),
    }


def run_exp1(config: dict) -> pd.DataFrame:
    """Run Experiment 1."""
    print_section("Experiment 1: When Does the Stable Likelihood Help?")
    
    param_grid = build_param_grid(config)
    
    all_results = []
    
    for i, pg in enumerate(param_grid):
        alpha = pg['alpha']
        scale_ratio = pg['scale_ratio']
        
        print(f"\nConfiguration {i+1}/{len(param_grid)}: α={alpha}, scale_ratio={scale_ratio}")
        
        classifiers = get_classifiers(alpha)
        
        # Run trials
        for rep in range(config['n_repeats']):
            seed = config['base_seed'] + rep * 1000 + i
            set_seed(seed)
            
            # Generate data
            from common import generate_balanced_stable_mixture
            X, y = generate_balanced_stable_mixture(pg['params'], config['n_per_class'], seed=seed)
            
            # Train/test split
            n = len(y)
            n_test = int(n * 0.2)
            rng = np.random.default_rng(seed)
            perm = rng.permutation(n)
            
            X_train, y_train = X[perm[n_test:]], y[perm[n_test:]]
            X_test, y_test = X[perm[:n_test]], y[perm[:n_test]]
            
            row = {
                'alpha': alpha,
                'scale_ratio': scale_ratio,
                'det_ratio': pg['det_ratio'],
                'repeat': rep,
            }
            
            for name, clf_factory in classifiers.items():
                try:
                    clf = clf_factory()
                    clf.fit(X_train, y_train)
                    row[name] = clf.score(X_test, y_test)
                except Exception as e:
                    print(f"  Warning: {name} failed: {e}")
                    row[name] = np.nan
            
            all_results.append(row)
    
    return pd.DataFrame(all_results)


# =============================================================================
# Analysis
# =============================================================================

def analyze_results(df: pd.DataFrame) -> dict:
    """Analyze results to find crossover points and best methods."""
    
    analysis = {}
    
    classifier_cols = ['gaussian', 'stable_standard', 'stable_robust']
    
    for scale_ratio in df['scale_ratio'].unique():
        sub_df = df[df['scale_ratio'] == scale_ratio]
        
        # Summarize by alpha
        summary = sub_df.groupby('alpha')[classifier_cols].mean()
        
        # Find winner at each alpha
        winners = summary.idxmax(axis=1)
        
        # Find crossover points
        crossovers = []
        prev_winner = None
        for alpha, winner in winners.items():
            if prev_winner is not None and winner != prev_winner:
                crossovers.append((alpha, prev_winner, winner))
            prev_winner = winner
        
        analysis[scale_ratio] = {
            'summary': summary,
            'winners': winners,
            'crossovers': crossovers,
        }
    
    return analysis


def compute_tyler_thresholds(analysis: dict) -> pd.DataFrame:
    """Extract Tyler threshold table from analysis."""
    
    thresholds = []
    
    for scale_ratio, data in analysis.items():
        # Tyler threshold = max α where robust wins
        winners = data['winners']
        
        robust_wins = winners[winners == 'stable_robust']
        if len(robust_wins) > 0:
            tyler_threshold = robust_wins.index.max()
        else:
            tyler_threshold = 0.0
        
        det_ratio = scale_ratio ** 10  # d=10
        
        thresholds.append({
            'scale_ratio': scale_ratio,
            'det_ratio': det_ratio,
            'tyler_threshold': tyler_threshold,
        })
    
    return pd.DataFrame(thresholds)


# =============================================================================
# Visualization
# =============================================================================

def create_figure1(df: pd.DataFrame, save_path: str):
    """
    Create Figure 1 for the paper: 3-panel accuracy vs α plot.
    
    Each panel shows a different scale ratio with:
    - Lines for each classifier
    - Background shading indicating best method
    """
    
    fig, axes = plt.subplots(1, 3, figsize=(14, 4.5), sharey=True)
    
    scale_ratios = sorted(df['scale_ratio'].unique())
    classifier_cols = ['gaussian', 'stable_standard', 'stable_robust']
    
    colors = {
        'gaussian': '#E74C3C',      # Red
        'stable_standard': '#27AE60',  # Green
        'stable_robust': '#3498DB',    # Blue
    }
    
    labels = {
        'gaussian': 'Gaussian QDA',
        'stable_standard': 'Stable (mean+LW)',
        'stable_robust': 'Stable (smed+Tyler)',
    }
    
    shade_colors = {
        'gaussian': '#FADBD8',
        'stable_standard': '#D5F5E3',
        'stable_robust': '#D6EAF8',
    }
    
    for idx, (ax, scale_ratio) in enumerate(zip(axes, scale_ratios)):
        sub_df = df[df['scale_ratio'] == scale_ratio]
        
        # Compute summary stats
        summary = sub_df.groupby('alpha')[classifier_cols].agg(['mean', 'std'])
        alphas = summary.index.values
        
        # Find winner at each alpha for shading
        means = sub_df.groupby('alpha')[classifier_cols].mean()
        winners = means.idxmax(axis=1)
        
        # Add background shading
        prev_alpha = alphas[0]
        prev_winner = winners.iloc[0]
        
        for i in range(1, len(alphas)):
            alpha = alphas[i]
            winner = winners.iloc[i]
            
            if winner != prev_winner or i == len(alphas) - 1:
                # Draw shading for previous region
                if i == len(alphas) - 1:
                    alpha_end = alpha
                else:
                    alpha_end = (prev_alpha + alpha) / 2
                
                ax.axvspan(
                    prev_alpha - 0.05 if prev_alpha == alphas[0] else (alphas[i-2] + prev_alpha) / 2,
                    alpha_end,
                    alpha=0.3,
                    color=shade_colors[prev_winner],
                    zorder=0
                )
                
                prev_alpha = alpha
                prev_winner = winner
        
        # Plot classifier lines
        for clf in classifier_cols:
            mean = summary[(clf, 'mean')].values
            std = summary[(clf, 'std')].values
            
            ax.plot(alphas, mean, 'o-', color=colors[clf], label=labels[clf], 
                   linewidth=2, markersize=5)
            ax.fill_between(alphas, mean - std, mean + std, 
                           alpha=0.15, color=colors[clf])
        
        # Formatting
        det_ratio = int(scale_ratio ** 10)
        ax.set_title(f'Scale Ratio = {scale_ratio} (Det Ratio = {det_ratio:,})', fontsize=12)
        ax.set_xlabel(r'$\alpha$ (tail index)', fontsize=11)
        if idx == 0:
            ax.set_ylabel('Accuracy', fontsize=11)
        
        ax.set_xlim(0.95, 2.05)
        ax.set_ylim(0.45, 0.95)
        
        # Add vertical line at α=1.5
        ax.axvline(x=1.5, color='gray', linestyle='--', alpha=0.5, linewidth=1)
        ax.text(1.52, 0.47, r'$\alpha$=1.5', fontsize=9, color='gray')
        
        ax.grid(True, alpha=0.3)
    
    # Legend
    handles, labels_list = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels_list, loc='upper center', ncol=3, 
              bbox_to_anchor=(0.5, 1.02), fontsize=10)
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.88)
    
    # Save
    os.makedirs(save_path, exist_ok=True)
    fig.savefig(os.path.join(save_path, 'fig1_likelihood_benefit.pdf'), 
               bbox_inches='tight', dpi=300)
    fig.savefig(os.path.join(save_path, 'fig1_likelihood_benefit.png'), 
               bbox_inches='tight', dpi=150)
    plt.close(fig)
    
    print(f"Figure saved to {save_path}")


# =============================================================================
# Tables
# =============================================================================

def create_appendix_tables(df: pd.DataFrame, save_path: str):
    """Create detailed tables for appendix (Tables D.1-D.3)."""
    
    os.makedirs(save_path, exist_ok=True)
    
    classifier_cols = ['gaussian', 'stable_standard', 'stable_robust']
    
    for scale_ratio in sorted(df['scale_ratio'].unique()):
        sub_df = df[df['scale_ratio'] == scale_ratio]
        
        # Summarize
        summary = sub_df.groupby('alpha')[classifier_cols].agg(['mean', 'std'])
        
        # Format table
        rows = []
        for alpha in summary.index:
            row = {'alpha': alpha}
            
            # Find winner
            means = [summary.loc[alpha, (clf, 'mean')] for clf in classifier_cols]
            winner_idx = np.argmax(means)
            
            for i, clf in enumerate(classifier_cols):
                mean = summary.loc[alpha, (clf, 'mean')]
                std = summary.loc[alpha, (clf, 'std')]
                
                val_str = f"{100*mean:.1f}"
                if i == winner_idx:
                    val_str = f"\\textbf{{{val_str}}}"
                
                row[clf] = val_str
            
            row['winner'] = classifier_cols[winner_idx].replace('_', ' ')
            rows.append(row)
        
        table_df = pd.DataFrame(rows)
        
        # Generate LaTeX
        det_ratio = int(scale_ratio ** 10)
        
        latex = f"""\\begin{{table}}[h]
\\centering
\\caption{{Accuracy (\\%) for Scale Ratio = {scale_ratio} (Det Ratio = {det_ratio:,}).}}
\\label{{tab:exp1_scale{int(scale_ratio)}}}
\\small
\\begin{{tabular}}{{c|ccc|l}}
\\toprule
$\\alpha$ & Gaussian & Stable (mean+LW) & Stable (smed+Tyler) & Winner \\\\
\\midrule
"""
        
        for _, row in table_df.iterrows():
            latex += f"{row['alpha']:.2f} & {row['gaussian']} & {row['stable_standard']} & {row['stable_robust']} & {row['winner']} \\\\\n"
        
        latex += """\\bottomrule
\\end{tabular}
\\end{table}
"""
        
        with open(os.path.join(save_path, f'table_d{int(scale_ratio)}.tex'), 'w') as f:
            f.write(latex)
    
    print(f"Tables saved to {save_path}")


def create_tyler_threshold_table(analysis: dict, save_path: str):
    """Create Table 3: Tyler threshold summary for main paper."""
    
    thresholds = compute_tyler_thresholds(analysis)
    
    latex = """\\begin{table}[t]
\\centering
\\small
\\caption{Estimator selection: use robust (spatial median + Tyler) when
$\\alpha$ falls below this threshold; otherwise use standard (mean + Ledoit--Wolf).}
\\label{tab:tyler_threshold}
\\begin{tabular}{lc}
\\toprule
\\textbf{Determinant Ratio} & \\textbf{Use Robust if $\\alpha <$} \\\\
\\midrule
"""
    
    for _, row in thresholds.iterrows():
        det_ratio = row['det_ratio']
        threshold = row['tyler_threshold']
        
        if det_ratio < 10:
            det_str = "< 10"
        elif det_ratio < 100:
            det_str = "10--100"
        elif det_ratio < 1000:
            det_str = "100--1000"
        else:
            det_str = "> 1000"
        
        if threshold >= 1.9:
            thresh_str = "2.0 (always)"
        else:
            thresh_str = f"{threshold:.1f}"
        
        latex += f"{det_str} & {thresh_str} \\\\\n"
    
    latex += """\\bottomrule
\\end{tabular}
\\end{table}
"""
    
    os.makedirs(save_path, exist_ok=True)
    with open(os.path.join(save_path, 'table3_tyler_threshold.tex'), 'w') as f:
        f.write(latex)
    
    print(f"Tyler threshold table saved to {save_path}")


# =============================================================================
# Main
# =============================================================================

def main():
    parser = argparse.ArgumentParser(description='Experiment 1: Likelihood Benefit')
    parser.add_argument('--quick', action='store_true', help='Quick test run')
    parser.add_argument('--output', type=str, default='outputs/', help='Output directory')
    args = parser.parse_args()
    
    config = QUICK_CONFIG if args.quick else DEFAULT_CONFIG
    config['save_path'] = args.output
    
    # Run experiment
    df = run_exp1(config)
    
    # Save raw results
    os.makedirs(config['save_path'], exist_ok=True)
    df.to_csv(os.path.join(config['save_path'], 'exp1_results.csv'), index=False)
    
    # Analyze
    analysis = analyze_results(df)
    
    # Generate outputs
    print_section("Generating Outputs")
    
    create_figure1(df, os.path.join(config['save_path'], 'figures'))
    create_appendix_tables(df, os.path.join(config['save_path'], 'tables'))
    create_tyler_threshold_table(analysis, os.path.join(config['save_path'], 'tables'))
    
    # Print summary
    print_section("Summary")
    
    for scale_ratio, data in analysis.items():
        print(f"\nScale Ratio = {scale_ratio}:")
        print(f"  Crossovers: {data['crossovers']}")
    
    thresholds = compute_tyler_thresholds(analysis)
    print("\nTyler Thresholds:")
    print(thresholds.to_string(index=False))
    
    print(f"\nAll results saved to {config['save_path']}")


if __name__ == "__main__":
    main()
