"""
Experiment 2: Sensitivity to α Misspecification
=================================================

Question: How sensitive is Stable-QDA to misspecification of the tail index α?

Key finding: Stable-QDA is remarkably insensitive. Fixed α=1.5 performs
within 1% of oracle across all tail regimes.

Setup:
- d=10, n=500 per class, homoscedastic (Σ₀ = Σ₁ = I)
- True α ∈ {1.2, 1.5, 1.8}
- Fitted α ∈ {1.0, 1.2, 1.4, 1.5, 1.6, 1.8, 2.0}
- Compare: Gaussian QDA, Fixed α=1.5, Estimated α, Oracle (true α)

Output:
- Figure 2: Sensitivity curves showing accuracy vs fitted α
- Figure 3: Fixed vs estimated α comparison
- Table: Maximum accuracy loss from using α=1.5

Usage:
    python exp2_alpha_sensitivity.py
    python exp2_alpha_sensitivity.py --quick
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import argparse

from common import (
    make_homoscedastic_params,
    make_stable_qda_robust,
    make_gaussian_qda,
    generate_balanced_stable_mixture,
    print_section,
    set_seed,
)

import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
from alpha_estimation import estimate_alpha


# =============================================================================
# Configuration
# =============================================================================

DEFAULT_CONFIG = {
    'n_per_class': 500,
    'd': 10,
    'n_repeats': 20,
    'true_alphas': [1.2, 1.5, 1.8],
    'fitted_alphas': [1.0, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
    'separation': 2.0,
    'base_seed': 42,
    'save_path': 'outputs/',
}

QUICK_CONFIG = {
    **DEFAULT_CONFIG,
    'n_repeats': 5,
    'fitted_alphas': [1.0, 1.2, 1.5, 1.8, 2.0],
}


# =============================================================================
# Experiment
# =============================================================================

def run_sensitivity_analysis(config: dict) -> pd.DataFrame:
    """
    Run sensitivity analysis: vary fitted α for each true α.
    """
    print_section("Experiment 2a: Sensitivity Analysis")
    
    results = []
    
    for true_alpha in config['true_alphas']:
        print(f"\nTrue α = {true_alpha}")
        
        params = make_homoscedastic_params(
            d=config['d'],
            alpha=true_alpha,
            separation=config['separation']
        )
        
        for rep in range(config['n_repeats']):
            seed = config['base_seed'] + rep * 1000 + int(true_alpha * 100)
            set_seed(seed)
            
            # Generate data
            X, y = generate_balanced_stable_mixture(params, config['n_per_class'], seed=seed)
            
            # Split
            n = len(y)
            n_test = int(n * 0.2)
            rng = np.random.default_rng(seed)
            perm = rng.permutation(n)
            
            X_train, y_train = X[perm[n_test:]], y[perm[n_test:]]
            X_test, y_test = X[perm[:n_test]], y[perm[:n_test]]
            
            # Test each fitted alpha
            for fitted_alpha in config['fitted_alphas']:
                clf = make_stable_qda_robust(fitted_alpha)
                clf.fit(X_train, y_train)
                acc = clf.score(X_test, y_test)
                
                results.append({
                    'true_alpha': true_alpha,
                    'fitted_alpha': fitted_alpha,
                    'repeat': rep,
                    'accuracy': acc,
                })
            
            # Also test Gaussian QDA
            clf_gauss = make_gaussian_qda()
            clf_gauss.fit(X_train, y_train)
            acc_gauss = clf_gauss.score(X_test, y_test)
            
            results.append({
                'true_alpha': true_alpha,
                'fitted_alpha': 2.0,  # Gaussian = α=2
                'repeat': rep,
                'accuracy': acc_gauss,
                'is_gaussian': True,
            })
    
    return pd.DataFrame(results)


def run_estimation_comparison(config: dict) -> pd.DataFrame:
    """
    Compare: Gaussian QDA, Fixed α=1.5, Estimated α, Oracle.
    """
    print_section("Experiment 2b: Fixed vs Estimated α")
    
    results = []
    
    all_true_alphas = [1.0, 1.2, 1.4, 1.5, 1.6, 1.8, 2.0]
    
    for true_alpha in all_true_alphas:
        print(f"\nTrue α = {true_alpha}")
        
        params = make_homoscedastic_params(
            d=config['d'],
            alpha=true_alpha,
            separation=config['separation']
        )
        
        for rep in range(config['n_repeats']):
            seed = config['base_seed'] + rep * 1000 + int(true_alpha * 100)
            set_seed(seed)
            
            # Generate data
            X, y = generate_balanced_stable_mixture(params, config['n_per_class'], seed=seed)
            
            # Split
            n = len(y)
            n_test = int(n * 0.2)
            rng = np.random.default_rng(seed)
            perm = rng.permutation(n)
            
            X_train, y_train = X[perm[n_test:]], y[perm[n_test:]]
            X_test, y_test = X[perm[:n_test]], y[perm[:n_test]]
            
            row = {
                'true_alpha': true_alpha,
                'repeat': rep,
            }
            
            # 1. Gaussian QDA
            clf = make_gaussian_qda()
            clf.fit(X_train, y_train)
            row['gaussian'] = clf.score(X_test, y_test)
            
            # 2. Fixed α = 1.5
            clf = make_stable_qda_robust(1.5)
            clf.fit(X_train, y_train)
            row['fixed_1.5'] = clf.score(X_test, y_test)
            
            # 3. Estimated α
            est_alpha = estimate_alpha(X_train, y_train)
            clf = make_stable_qda_robust(est_alpha)
            clf.fit(X_train, y_train)
            row['estimated'] = clf.score(X_test, y_test)
            row['estimated_alpha'] = est_alpha
            
            # 4. Oracle (true α)
            clf = make_stable_qda_robust(true_alpha)
            clf.fit(X_train, y_train)
            row['oracle'] = clf.score(X_test, y_test)
            
            results.append(row)
    
    return pd.DataFrame(results)


# =============================================================================
# Visualization
# =============================================================================

def create_sensitivity_figure(df: pd.DataFrame, save_path: str):
    """
    Create Figure 2: Sensitivity to α misspecification.
    
    Each panel shows accuracy vs fitted α for a fixed true α.
    """
    
    fig, axes = plt.subplots(1, 3, figsize=(14, 4))
    
    true_alphas = sorted(df['true_alpha'].unique())
    
    for idx, (ax, true_alpha) in enumerate(zip(axes, true_alphas)):
        sub_df = df[(df['true_alpha'] == true_alpha) & (df.get('is_gaussian', False) != True)]
        
        # Group by fitted alpha
        summary = sub_df.groupby('fitted_alpha')['accuracy'].agg(['mean', 'std'])
        
        fitted_alphas = summary.index.values
        means = summary['mean'].values
        stds = summary['std'].values
        
        # Plot line
        ax.plot(fitted_alphas, means, 'o-', color='#3498DB', linewidth=2, markersize=6,
               label='Stable-QDA')
        ax.fill_between(fitted_alphas, means - stds, means + stds, alpha=0.2, color='#3498DB')
        
        # Gaussian baseline
        gauss_df = df[(df['true_alpha'] == true_alpha) & (df.get('is_gaussian', False) == True)]
        if len(gauss_df) > 0:
            gauss_mean = gauss_df['accuracy'].mean()
            ax.axhline(y=gauss_mean, color='#E74C3C', linestyle='-', linewidth=2,
                      label='Gaussian QDA')
        
        # Mark true alpha
        ax.axvline(x=true_alpha, color='#27AE60', linestyle='--', linewidth=2, alpha=0.7)
        ax.text(true_alpha + 0.05, ax.get_ylim()[0] + 0.02, f'True α={true_alpha}',
               fontsize=9, color='#27AE60')
        
        # Mark fixed α=1.5
        ax.axvline(x=1.5, color='#F39C12', linestyle=':', linewidth=2, alpha=0.7)
        ax.text(1.52, ax.get_ylim()[1] - 0.02, 'Fixed α=1.5', fontsize=9, color='#F39C12')
        
        ax.set_title(f'True α = {true_alpha}', fontsize=12)
        ax.set_xlabel(r'Fitted $\alpha$', fontsize=11)
        if idx == 0:
            ax.set_ylabel('Accuracy', fontsize=11)
        
        ax.set_xlim(0.95, 2.05)
        ax.grid(True, alpha=0.3)
        
        if idx == 0:
            ax.legend(loc='lower right', fontsize=9)
    
    plt.tight_layout()
    
    os.makedirs(save_path, exist_ok=True)
    fig.savefig(os.path.join(save_path, 'fig2_alpha_sensitivity.pdf'), 
               bbox_inches='tight', dpi=300)
    fig.savefig(os.path.join(save_path, 'fig2_alpha_sensitivity.png'), 
               bbox_inches='tight', dpi=150)
    plt.close(fig)
    
    print(f"Sensitivity figure saved to {save_path}")


def create_comparison_figure(df: pd.DataFrame, save_path: str):
    """
    Create Figure 3: Fixed vs estimated α comparison.
    """
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))
    
    # Panel (a): Accuracy comparison
    ax = axes[0]
    
    true_alphas = sorted(df['true_alpha'].unique())
    methods = ['gaussian', 'fixed_1.5', 'estimated', 'oracle']
    colors = ['#E74C3C', '#F39C12', '#9B59B6', '#3498DB']
    labels = ['Gaussian QDA', 'Fixed α=1.5', 'Estimated α', 'Oracle (true α)']
    
    x = np.arange(len(true_alphas))
    width = 0.2
    
    for i, (method, color, label) in enumerate(zip(methods, colors, labels)):
        means = df.groupby('true_alpha')[method].mean()
        stds = df.groupby('true_alpha')[method].std()
        
        ax.bar(x + i*width, means, width, yerr=stds, label=label, color=color, alpha=0.8)
    
    ax.set_xlabel(r'True $\alpha$', fontsize=11)
    ax.set_ylabel('Accuracy', fontsize=11)
    ax.set_title('(a) Classification Accuracy', fontsize=12)
    ax.set_xticks(x + 1.5*width)
    ax.set_xticklabels([f'{a:.1f}' for a in true_alphas])
    ax.legend(loc='lower right', fontsize=9)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Panel (b): Improvement over Gaussian
    ax = axes[1]
    
    methods_compare = ['fixed_1.5', 'estimated', 'oracle']
    colors_compare = ['#F39C12', '#9B59B6', '#3498DB']
    labels_compare = ['Fixed α=1.5', 'Estimated α', 'Oracle (true α)']
    
    for method, color, label in zip(methods_compare, colors_compare, labels_compare):
        improvements = []
        for true_alpha in true_alphas:
            sub_df = df[df['true_alpha'] == true_alpha]
            gauss_acc = sub_df['gaussian'].mean()
            method_acc = sub_df[method].mean()
            improvements.append((method_acc - gauss_acc) * 100)
        
        ax.plot(true_alphas, improvements, 'o-', color=color, linewidth=2, 
               markersize=8, label=label)
    
    ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    ax.set_xlabel(r'True $\alpha$', fontsize=11)
    ax.set_ylabel('Improvement over Gaussian (%)', fontsize=11)
    ax.set_title('(b) Improvement vs Gaussian QDA', fontsize=12)
    ax.legend(loc='upper right', fontsize=9)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    os.makedirs(save_path, exist_ok=True)
    fig.savefig(os.path.join(save_path, 'fig3_fixed_vs_estimated.pdf'), 
               bbox_inches='tight', dpi=300)
    fig.savefig(os.path.join(save_path, 'fig3_fixed_vs_estimated.png'), 
               bbox_inches='tight', dpi=150)
    plt.close(fig)
    
    print(f"Comparison figure saved to {save_path}")


# =============================================================================
# Tables
# =============================================================================

def create_sensitivity_table(df: pd.DataFrame, save_path: str):
    """Create table showing accuracy for different fitted α values."""
    
    # Filter to non-Gaussian results
    df_stable = df[df.get('is_gaussian', False) != True]
    
    # Pivot
    pivot = df_stable.groupby(['true_alpha', 'fitted_alpha'])['accuracy'].mean().unstack()
    
    latex = """\\begin{table}[h]
\\centering
\\caption{Sensitivity analysis: accuracy (\\%) for different fitted $\\alpha$ values. 
Bold indicates the best fitted $\\alpha$ for each true $\\alpha$.}
\\label{tab:exp2_sensitivity}
\\small
\\begin{tabular}{c|""" + "c" * len(pivot.columns) + "|c}\n\\toprule\n"
    
    # Header
    latex += "True $\\alpha$ & " + " & ".join([f"{a:.1f}" for a in pivot.columns]) + " & Gaussian \\\\\n"
    latex += "\\midrule\n"
    
    # Rows
    for true_alpha in pivot.index:
        row_vals = pivot.loc[true_alpha].values * 100
        best_idx = np.argmax(row_vals)
        
        # Get Gaussian baseline
        gauss_df = df[(df['true_alpha'] == true_alpha) & (df.get('is_gaussian', False) == True)]
        gauss_acc = gauss_df['accuracy'].mean() * 100 if len(gauss_df) > 0 else np.nan
        
        row_strs = []
        for i, val in enumerate(row_vals):
            if i == best_idx:
                row_strs.append(f"\\textbf{{{val:.1f}}}")
            else:
                row_strs.append(f"{val:.1f}")
        
        latex += f"{true_alpha:.1f} & " + " & ".join(row_strs) + f" & {gauss_acc:.1f} \\\\\n"
    
    latex += """\\bottomrule
\\end{tabular}
\\end{table}
"""
    
    os.makedirs(save_path, exist_ok=True)
    with open(os.path.join(save_path, 'table_exp2_sensitivity.tex'), 'w') as f:
        f.write(latex)
    
    print(f"Sensitivity table saved to {save_path}")


# =============================================================================
# Main
# =============================================================================

def main():
    parser = argparse.ArgumentParser(description='Experiment 2: Alpha Sensitivity')
    parser.add_argument('--quick', action='store_true', help='Quick test run')
    parser.add_argument('--output', type=str, default='outputs/', help='Output directory')
    args = parser.parse_args()
    
    config = QUICK_CONFIG if args.quick else DEFAULT_CONFIG
    config['save_path'] = args.output
    
    # Run experiments
    df_sensitivity = run_sensitivity_analysis(config)
    df_comparison = run_estimation_comparison(config)
    
    # Save raw results
    os.makedirs(config['save_path'], exist_ok=True)
    df_sensitivity.to_csv(os.path.join(config['save_path'], 'exp2a_sensitivity.csv'), index=False)
    df_comparison.to_csv(os.path.join(config['save_path'], 'exp2b_comparison.csv'), index=False)
    
    # Generate outputs
    print_section("Generating Outputs")
    
    fig_path = os.path.join(config['save_path'], 'figures')
    table_path = os.path.join(config['save_path'], 'tables')
    
    create_sensitivity_figure(df_sensitivity, fig_path)
    create_comparison_figure(df_comparison, fig_path)
    create_sensitivity_table(df_sensitivity, table_path)
    
    # Print summary
    print_section("Summary")
    
    print("\nMax accuracy loss from using α=1.5 vs oracle:")
    for true_alpha in sorted(df_comparison['true_alpha'].unique()):
        sub = df_comparison[df_comparison['true_alpha'] == true_alpha]
        fixed_acc = sub['fixed_1.5'].mean()
        oracle_acc = sub['oracle'].mean()
        loss = (oracle_acc - fixed_acc) * 100
        print(f"  True α={true_alpha}: {loss:+.2f}%")
    
    print(f"\nAll results saved to {config['save_path']}")


if __name__ == "__main__":
    main()
