"""
Experiment 3: Robustness to Outlier Contamination
==================================================

Question: How does Stable-QDA perform under training data contamination?

Key finding: The stable likelihood's logarithmic form naturally limits
outlier influence. Stable-QDA with standard estimators shows zero
degradation at 20% contamination on Gaussian data.

Setup:
- d=10, n=500 per class
- Contamination rates: 0%, 5%, 10%, 15%, 20%
- Three outlier types: shift+scale, uniform, adversarial
- Two base distributions: Gaussian (α=2), Stable (α=1.8)

Output:
- Figure 4: Degradation curves under contamination
- Tables: Accuracy at each contamination level

Usage:
    python exp3_contamination.py
    python exp3_contamination.py --quick
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import argparse

from common import (
    make_homoscedastic_params,
    make_stable_qda_standard,
    make_stable_qda_robust,
    make_gaussian_qda,
    generate_balanced_stable_mixture,
    print_section,
    set_seed,
)


# =============================================================================
# Configuration
# =============================================================================

DEFAULT_CONFIG = {
    'n_per_class': 500,
    'd': 10,
    'n_repeats': 20,
    'contamination_rates': [0.0, 0.05, 0.10, 0.15, 0.20],
    'base_alphas': [2.0, 1.8],  # Gaussian and moderate stable
    'separation': 3.0,
    'base_seed': 42,
    'save_path': 'outputs/',
}

QUICK_CONFIG = {
    **DEFAULT_CONFIG,
    'n_repeats': 5,
    'contamination_rates': [0.0, 0.10, 0.20],
}


# =============================================================================
# Contamination Functions
# =============================================================================

def add_shift_scale_outliers(
    X: np.ndarray,
    y: np.ndarray,
    contamination_rate: float,
    shift_std: float = 5.0,
    scale_factor: float = 3.0,
    rng: np.random.Generator = None
) -> tuple:
    """
    Replace a fraction of training points with shift+scale outliers.
    
    Outliers are shifted 5 std from class center with 3x inflated variance.
    """
    if rng is None:
        rng = np.random.default_rng()
    
    n = len(y)
    n_contaminate = int(n * contamination_rate)
    
    if n_contaminate == 0:
        return X.copy(), y.copy()
    
    X_out = X.copy()
    
    # Select points to contaminate
    contaminate_idx = rng.choice(n, n_contaminate, replace=False)
    
    for idx in contaminate_idx:
        # Get class mean
        class_mask = y == y[idx]
        class_mean = X[class_mask].mean(axis=0)
        class_std = X[class_mask].std(axis=0)
        
        # Generate outlier: shift + inflate variance
        shift_direction = rng.standard_normal(X.shape[1])
        shift_direction = shift_direction / np.linalg.norm(shift_direction)
        
        outlier = class_mean + shift_std * class_std.mean() * shift_direction
        outlier += scale_factor * rng.standard_normal(X.shape[1]) * class_std
        
        X_out[idx] = outlier
    
    return X_out, y


def add_uniform_outliers(
    X: np.ndarray,
    y: np.ndarray,
    contamination_rate: float,
    extension: float = 3.0,
    rng: np.random.Generator = None
) -> tuple:
    """
    Replace a fraction with uniform outliers over extended range.
    """
    if rng is None:
        rng = np.random.default_rng()
    
    n = len(y)
    n_contaminate = int(n * contamination_rate)
    
    if n_contaminate == 0:
        return X.copy(), y.copy()
    
    X_out = X.copy()
    
    # Compute data bounds
    X_min = X.min(axis=0)
    X_max = X.max(axis=0)
    X_range = X_max - X_min
    
    # Extend bounds
    extended_min = X_min - extension * X_range
    extended_max = X_max + extension * X_range
    
    # Select points to contaminate
    contaminate_idx = rng.choice(n, n_contaminate, replace=False)
    
    for idx in contaminate_idx:
        outlier = rng.uniform(extended_min, extended_max)
        X_out[idx] = outlier
    
    return X_out, y


def add_adversarial_outliers(
    X: np.ndarray,
    y: np.ndarray,
    contamination_rate: float,
    rng: np.random.Generator = None
) -> tuple:
    """
    Replace a fraction with adversarial outliers near opposite class.
    """
    if rng is None:
        rng = np.random.default_rng()
    
    n = len(y)
    n_contaminate = int(n * contamination_rate)
    
    if n_contaminate == 0:
        return X.copy(), y.copy()
    
    X_out = X.copy()
    classes = np.unique(y)
    
    # Compute class means
    class_means = {c: X[y == c].mean(axis=0) for c in classes}
    class_stds = {c: X[y == c].std(axis=0) for c in classes}
    
    # Select points to contaminate
    contaminate_idx = rng.choice(n, n_contaminate, replace=False)
    
    for idx in contaminate_idx:
        current_class = y[idx]
        other_class = [c for c in classes if c != current_class][0]
        
        # Generate point near opposite class center
        outlier = class_means[other_class] + 0.5 * rng.standard_normal(X.shape[1]) * class_stds[other_class]
        X_out[idx] = outlier
    
    return X_out, y


OUTLIER_FUNCTIONS = {
    'shift_scale': add_shift_scale_outliers,
    'uniform': add_uniform_outliers,
    'adversarial': add_adversarial_outliers,
}


# =============================================================================
# Experiment
# =============================================================================

def run_contamination_experiment(config: dict) -> pd.DataFrame:
    """Run contamination robustness experiment."""
    
    print_section("Experiment 3: Contamination Robustness")
    
    results = []
    
    for base_alpha in config['base_alphas']:
        base_name = 'Gaussian' if base_alpha == 2.0 else f'Stable (α={base_alpha})'
        print(f"\nBase distribution: {base_name}")
        
        params = make_homoscedastic_params(
            d=config['d'],
            alpha=base_alpha,
            separation=config['separation']
        )
        
        for outlier_type in OUTLIER_FUNCTIONS.keys():
            print(f"  Outlier type: {outlier_type}")
            
            for contamination_rate in config['contamination_rates']:
                for rep in range(config['n_repeats']):
                    seed = config['base_seed'] + rep * 1000 + int(base_alpha * 100)
                    rng = np.random.default_rng(seed)
                    
                    # Generate clean data
                    X, y = generate_balanced_stable_mixture(params, config['n_per_class'], seed=seed)
                    
                    # Split
                    n = len(y)
                    n_test = int(n * 0.2)
                    perm = rng.permutation(n)
                    
                    X_train_clean, y_train = X[perm[n_test:]], y[perm[n_test:]]
                    X_test, y_test = X[perm[:n_test]], y[perm[:n_test]]  # Test set stays clean
                    
                    # Add contamination to training data
                    outlier_func = OUTLIER_FUNCTIONS[outlier_type]
                    X_train, _ = outlier_func(X_train_clean, y_train, contamination_rate, rng=rng)
                    
                    row = {
                        'base_alpha': base_alpha,
                        'base_name': base_name,
                        'outlier_type': outlier_type,
                        'contamination_rate': contamination_rate,
                        'repeat': rep,
                    }
                    
                    # Evaluate classifiers
                    classifiers = {
                        'gaussian': make_gaussian_qda(),
                        'stable_standard': make_stable_qda_standard(1.5),
                        'stable_robust': make_stable_qda_robust(1.5),
                    }
                    
                    for name, clf in classifiers.items():
                        try:
                            clf.fit(X_train, y_train)
                            row[name] = clf.score(X_test, y_test)
                        except Exception as e:
                            row[name] = np.nan
                    
                    results.append(row)
    
    return pd.DataFrame(results)


# =============================================================================
# Analysis
# =============================================================================

def compute_degradation(df: pd.DataFrame) -> pd.DataFrame:
    """Compute accuracy drop from clean baseline."""
    
    degradation_results = []
    
    classifiers = ['gaussian', 'stable_standard', 'stable_robust']
    
    for base_alpha in df['base_alpha'].unique():
        for outlier_type in df['outlier_type'].unique():
            sub_df = df[(df['base_alpha'] == base_alpha) & (df['outlier_type'] == outlier_type)]
            
            # Get clean baseline
            clean_df = sub_df[sub_df['contamination_rate'] == 0.0]
            clean_accs = {clf: clean_df[clf].mean() for clf in classifiers}
            
            for contamination_rate in sub_df['contamination_rate'].unique():
                cont_df = sub_df[sub_df['contamination_rate'] == contamination_rate]
                
                row = {
                    'base_alpha': base_alpha,
                    'outlier_type': outlier_type,
                    'contamination_rate': contamination_rate,
                }
                
                for clf in classifiers:
                    current_acc = cont_df[clf].mean()
                    drop = (clean_accs[clf] - current_acc) * 100
                    row[f'{clf}_drop'] = drop
                    row[f'{clf}_acc'] = current_acc * 100
                
                degradation_results.append(row)
    
    return pd.DataFrame(degradation_results)


# =============================================================================
# Visualization
# =============================================================================

def create_contamination_figure(df: pd.DataFrame, save_path: str):
    """
    Create Figure 4: Contamination robustness.
    
    2x2 grid: rows = base distribution, cols = accuracy and degradation
    """
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 9))
    
    classifiers = ['gaussian', 'stable_standard', 'stable_robust']
    colors = {
        'gaussian': '#E74C3C',
        'stable_standard': '#F39C12',
        'stable_robust': '#27AE60',
    }
    labels = {
        'gaussian': 'Gaussian QDA',
        'stable_standard': 'Stable (mean+LW)',
        'stable_robust': 'Stable (smed+Tyler)',
    }
    
    # Filter to shift_scale outliers for main figure
    df_ss = df[df['outlier_type'] == 'shift_scale']
    
    for row_idx, base_alpha in enumerate([2.0, 1.8]):
        base_name = 'Gaussian' if base_alpha == 2.0 else 'Stable (α=1.8)'
        sub_df = df_ss[df_ss['base_alpha'] == base_alpha]
        
        # Panel (a): Accuracy
        ax = axes[row_idx, 0]
        
        summary = sub_df.groupby('contamination_rate')[classifiers].agg(['mean', 'std'])
        contamination_rates = summary.index.values * 100
        
        for clf in classifiers:
            means = summary[(clf, 'mean')].values * 100
            stds = summary[(clf, 'std')].values * 100
            
            ax.errorbar(contamination_rates, means, yerr=stds, fmt='o-',
                       color=colors[clf], label=labels[clf], linewidth=2,
                       markersize=6, capsize=3)
        
        ax.set_xlabel('Contamination Rate (%)', fontsize=11)
        ax.set_ylabel('Accuracy (%)', fontsize=11)
        ax.set_title(f'({chr(97 + row_idx*2)}) {base_name} Base: Accuracy', fontsize=12)
        ax.legend(loc='lower left', fontsize=9)
        ax.grid(True, alpha=0.3)
        ax.set_xlim(-1, 21)
        
        # Panel (b): Degradation
        ax = axes[row_idx, 1]
        
        degradation = compute_degradation(sub_df)
        deg_df = degradation[(degradation['base_alpha'] == base_alpha) & 
                            (degradation['outlier_type'] == 'shift_scale')]
        
        for clf in classifiers:
            drops = deg_df.groupby('contamination_rate')[f'{clf}_drop'].mean()
            
            ax.plot(drops.index.values * 100, drops.values, 'o-',
                   color=colors[clf], label=labels[clf], linewidth=2, markersize=6)
        
        ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
        ax.set_xlabel('Contamination Rate (%)', fontsize=11)
        ax.set_ylabel('Accuracy Drop (%)', fontsize=11)
        ax.set_title(f'({chr(98 + row_idx*2)}) {base_name} Base: Degradation', fontsize=12)
        ax.legend(loc='upper left', fontsize=9)
        ax.grid(True, alpha=0.3)
        ax.set_xlim(-1, 21)
    
    plt.tight_layout()
    
    os.makedirs(save_path, exist_ok=True)
    fig.savefig(os.path.join(save_path, 'fig4_contamination.pdf'), 
               bbox_inches='tight', dpi=300)
    fig.savefig(os.path.join(save_path, 'fig4_contamination.png'), 
               bbox_inches='tight', dpi=150)
    plt.close(fig)
    
    print(f"Contamination figure saved to {save_path}")


def create_outlier_comparison_figure(df: pd.DataFrame, save_path: str):
    """Create supplementary figure comparing outlier types."""
    
    fig, axes = plt.subplots(1, 3, figsize=(14, 4))
    
    classifiers = ['gaussian', 'stable_standard', 'stable_robust']
    colors = {
        'gaussian': '#E74C3C',
        'stable_standard': '#F39C12',
        'stable_robust': '#27AE60',
    }
    labels = {
        'gaussian': 'Gaussian QDA',
        'stable_standard': 'Stable (mean+LW)',
        'stable_robust': 'Stable (smed+Tyler)',
    }
    
    # Use Gaussian base data
    df_gauss = df[df['base_alpha'] == 2.0]
    
    for idx, outlier_type in enumerate(['shift_scale', 'uniform', 'adversarial']):
        ax = axes[idx]
        sub_df = df_gauss[df_gauss['outlier_type'] == outlier_type]
        
        summary = sub_df.groupby('contamination_rate')[classifiers].agg(['mean', 'std'])
        contamination_rates = summary.index.values * 100
        
        for clf in classifiers:
            means = summary[(clf, 'mean')].values * 100
            stds = summary[(clf, 'std')].values * 100
            
            ax.errorbar(contamination_rates, means, yerr=stds, fmt='o-',
                       color=colors[clf], label=labels[clf], linewidth=2,
                       markersize=6, capsize=3)
        
        title_map = {
            'shift_scale': '(a) Shift+Scale',
            'uniform': '(b) Uniform',
            'adversarial': '(c) Adversarial',
        }
        
        ax.set_xlabel('Contamination Rate (%)', fontsize=11)
        if idx == 0:
            ax.set_ylabel('Accuracy (%)', fontsize=11)
        ax.set_title(title_map[outlier_type], fontsize=12)
        ax.legend(loc='lower left', fontsize=9)
        ax.grid(True, alpha=0.3)
        ax.set_xlim(-1, 21)
    
    plt.tight_layout()
    
    os.makedirs(save_path, exist_ok=True)
    fig.savefig(os.path.join(save_path, 'fig_outlier_types.pdf'), 
               bbox_inches='tight', dpi=300)
    plt.close(fig)
    
    print(f"Outlier comparison figure saved to {save_path}")


# =============================================================================
# Tables
# =============================================================================

def create_contamination_table(df: pd.DataFrame, save_path: str):
    """Create tables showing accuracy under contamination."""
    
    os.makedirs(save_path, exist_ok=True)
    
    for base_alpha in df['base_alpha'].unique():
        base_name = 'gaussian' if base_alpha == 2.0 else 'stable'
        sub_df = df[(df['base_alpha'] == base_alpha) & (df['outlier_type'] == 'shift_scale')]
        
        latex = f"""\\begin{{table}}[h]
\\centering
\\caption{{Accuracy (\\%) on {'Gaussian' if base_alpha == 2.0 else 'stable (α=1.8)'} base data with shift+scale outliers.}}
\\label{{tab:exp3_{base_name}}}
\\small
\\begin{{tabular}}{{c|ccc|cc}}
\\toprule
Contam. & Gaussian & Stable & Stable & Drop from Clean \\\\
Rate & QDA & (mean+LW) & (smed+Tyler) & Gaussian & mean+LW \\\\
\\midrule
"""
        
        # Get clean baseline
        clean_df = sub_df[sub_df['contamination_rate'] == 0.0]
        clean_gauss = clean_df['gaussian'].mean() * 100
        clean_std = clean_df['stable_standard'].mean() * 100
        
        for cont_rate in sorted(sub_df['contamination_rate'].unique()):
            cont_df = sub_df[sub_df['contamination_rate'] == cont_rate]
            
            g_acc = cont_df['gaussian'].mean() * 100
            s_acc = cont_df['stable_standard'].mean() * 100
            r_acc = cont_df['stable_robust'].mean() * 100
            
            g_drop = clean_gauss - g_acc
            s_drop = clean_std - s_acc
            
            latex += f"{int(cont_rate*100)}\\% & {g_acc:.1f} & {s_acc:.1f} & {r_acc:.1f} & "
            latex += f"$-${g_drop:.1f} & "
            latex += f"{'$+$' if s_drop < 0 else '$-$'}{abs(s_drop):.1f} \\\\\n"
        
        latex += """\\bottomrule
\\end{tabular}
\\end{table}
"""
        
        with open(os.path.join(save_path, f'table_exp3_{base_name}.tex'), 'w') as f:
            f.write(latex)
    
    print(f"Contamination tables saved to {save_path}")


# =============================================================================
# Main
# =============================================================================

def main():
    parser = argparse.ArgumentParser(description='Experiment 3: Contamination Robustness')
    parser.add_argument('--quick', action='store_true', help='Quick test run')
    parser.add_argument('--output', type=str, default='outputs/', help='Output directory')
    args = parser.parse_args()
    
    config = QUICK_CONFIG if args.quick else DEFAULT_CONFIG
    config['save_path'] = args.output
    
    # Run experiment
    df = run_contamination_experiment(config)
    
    # Save raw results
    os.makedirs(config['save_path'], exist_ok=True)
    df.to_csv(os.path.join(config['save_path'], 'exp3_contamination.csv'), index=False)
    
    # Generate outputs
    print_section("Generating Outputs")
    
    fig_path = os.path.join(config['save_path'], 'figures')
    table_path = os.path.join(config['save_path'], 'tables')
    
    create_contamination_figure(df, fig_path)
    create_outlier_comparison_figure(df, fig_path)
    create_contamination_table(df, table_path)
    
    # Print summary
    print_section("Summary")
    
    degradation = compute_degradation(df)
    
    print("\nAccuracy drop at 20% contamination (shift+scale outliers):")
    for base_alpha in df['base_alpha'].unique():
        base_name = 'Gaussian' if base_alpha == 2.0 else 'Stable (α=1.8)'
        print(f"\n{base_name} base:")
        
        deg_df = degradation[(degradation['base_alpha'] == base_alpha) & 
                            (degradation['outlier_type'] == 'shift_scale') &
                            (degradation['contamination_rate'] == 0.20)]
        
        for clf in ['gaussian', 'stable_standard', 'stable_robust']:
            drop = deg_df[f'{clf}_drop'].values[0]
            print(f"  {clf}: {drop:+.1f}%")
    
    print(f"\nAll results saved to {config['save_path']}")


if __name__ == "__main__":
    main()
