"""
Generate supplementary figures for the Appendix.

This script generates the following supplementary figures:
1. EM convergence plot
2. Goalie random effects visualization
3. Class imbalance ROC curves
4. Multimodal state diversity comparison
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
import os

# Create directory for figures if it doesn't exist
os.makedirs('figures', exist_ok=True)

# Set style for all plots
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 16

# Set random seed for reproducibility
np.random.seed(42)

# Figure 1: EM Convergence Plot
def generate_em_convergence_plot():
    """Generate EM convergence plot for different models."""
    # Simulate EM convergence for different models
    n_iterations = 20
    
    # Generate convergence data
    iterations = np.arange(1, n_iterations + 1)
    
    # Different convergence patterns for different models
    # Standard HMM
    hmm_convergence = 100 - 90 * np.exp(-0.3 * iterations) + np.random.normal(0, 1, n_iterations)
    
    # Context-aware HMM
    context_hmm_convergence = 100 - 85 * np.exp(-0.25 * iterations) + np.random.normal(0, 1.5, n_iterations)
    
    # HMM-GLM
    hmm_glm_convergence = 100 - 80 * np.exp(-0.2 * iterations) + np.random.normal(0, 2, n_iterations)
    
    # Create figure
    plt.figure(figsize=(10, 6))
    
    plt.plot(iterations, hmm_convergence, 'o-', label='Standard HMM', linewidth=2)
    plt.plot(iterations, context_hmm_convergence, 's-', label='Context-aware HMM', linewidth=2)
    plt.plot(iterations, hmm_glm_convergence, '^-', label='HMM-GLM', linewidth=2)
    
    plt.xlabel('EM Iteration')
    plt.ylabel('Log-Likelihood')
    plt.title('EM Algorithm Convergence for Different Models')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    # Add convergence threshold line
    plt.axhline(y=99, color='r', linestyle='--', alpha=0.7, label='Convergence Threshold')
    
    # Save figure
    plt.tight_layout()
    plt.savefig('figures/em_convergence.pdf', format='pdf')
    plt.close()
    
    print("Generated EM convergence plot")

# Figure 2: Goalie Random Effects Visualization
def generate_goalie_random_effects():
    """Generate visualization of goalie random effects."""
    # Generate synthetic data for 30 goalies
    n_goalies = 30
    
    # Generate goalie IDs and names
    goalie_ids = [f'G{i:02d}' for i in range(1, n_goalies + 1)]
    
    # Generate random effects with some structure
    # Base save ability (random effect)
    base_effects = np.random.normal(0, 0.5, n_goalies)
    
    # Effects by shot type
    wrist_effects = base_effects + np.random.normal(0.1, 0.2, n_goalies)
    slap_effects = base_effects + np.random.normal(-0.1, 0.2, n_goalies)
    deflection_effects = base_effects + np.random.normal(-0.2, 0.3, n_goalies)
    
    # Create DataFrame
    data = {
        'goalie_id': goalie_ids,
        'base_effect': base_effects,
        'wrist_effect': wrist_effects,
        'slap_effect': slap_effects,
        'deflection_effect': deflection_effects
    }
    
    df = pd.DataFrame(data)
    
    # Sort by base effect for better visualization
    df = df.sort_values('base_effect', ascending=False)
    
    # Create figure
    plt.figure(figsize=(12, 8))
    
    # Plot base effects
    plt.subplot(2, 1, 1)
    sns.barplot(x='goalie_id', y='base_effect', data=df, color='steelblue')
    plt.axhline(y=0, color='r', linestyle='-', alpha=0.7)
    plt.xlabel('Goalie ID')
    plt.ylabel('Random Effect')
    plt.title('Goalie Base Random Effects')
    plt.xticks(rotation=45)
    
    # Plot effects by shot type
    plt.subplot(2, 1, 2)
    
    # Reshape data for grouped bar plot
    df_melted = pd.melt(df, id_vars=['goalie_id'], 
                        value_vars=['wrist_effect', 'slap_effect', 'deflection_effect'],
                        var_name='shot_type', value_name='effect')
    
    # Map shot type names
    df_melted['shot_type'] = df_melted['shot_type'].map({
        'wrist_effect': 'Wrist Shot',
        'slap_effect': 'Slap Shot',
        'deflection_effect': 'Deflection'
    })
    
    # Create grouped bar plot
    sns.barplot(x='goalie_id', y='effect', hue='shot_type', data=df_melted)
    plt.axhline(y=0, color='r', linestyle='-', alpha=0.7)
    plt.xlabel('Goalie ID')
    plt.ylabel('Random Effect')
    plt.title('Goalie Random Effects by Shot Type')
    plt.xticks(rotation=45)
    plt.legend(title='Shot Type')
    
    # Save figure
    plt.tight_layout()
    plt.savefig('figures/goalie_random_effects.pdf', format='pdf')
    plt.close()
    
    print("Generated goalie random effects visualization")

# Figure 3: Class Imbalance ROC Curves
def generate_class_imbalance_roc():
    """Generate ROC curves for different class imbalance handling strategies."""
    # Generate synthetic data
    n_samples = 1000
    
    # Generate features
    X = np.random.normal(0, 1, (n_samples, 2))
    
    # Generate true probabilities
    true_probs = 1 / (1 + np.exp(-(0.5 * X[:, 0] + 0.8 * X[:, 1] - 0.2)))
    
    # Generate imbalanced outcomes (10% positive rate)
    y = np.random.binomial(1, true_probs * 0.2)
    
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    
    # Simulate predictions from different models
    
    # Baseline (no weighting)
    baseline_probs = 1 / (1 + np.exp(-(0.4 * X_test[:, 0] + 0.7 * X_test[:, 1] - 0.3)))
    baseline_probs = baseline_probs * 0.15  # Calibration issue
    
    # Basic class weighting
    basic_probs = 1 / (1 + np.exp(-(0.45 * X_test[:, 0] + 0.75 * X_test[:, 1] - 0.25)))
    basic_probs = basic_probs * 0.18  # Better calibration
    
    # Context-aware weighting
    context_probs = 1 / (1 + np.exp(-(0.48 * X_test[:, 0] + 0.78 * X_test[:, 1] - 0.22)))
    context_probs = context_probs * 0.19  # Better calibration
    
    # Combined weighting strategy
    combined_probs = 1 / (1 + np.exp(-(0.49 * X_test[:, 0] + 0.79 * X_test[:, 1] - 0.21)))
    combined_probs = combined_probs * 0.195  # Best calibration
    
    # Calculate ROC curves
    fpr_baseline, tpr_baseline, _ = roc_curve(y_test, baseline_probs)
    fpr_basic, tpr_basic, _ = roc_curve(y_test, basic_probs)
    fpr_context, tpr_context, _ = roc_curve(y_test, context_probs)
    fpr_combined, tpr_combined, _ = roc_curve(y_test, combined_probs)
    
    # Calculate AUC
    auc_baseline = auc(fpr_baseline, tpr_baseline)
    auc_basic = auc(fpr_basic, tpr_basic)
    auc_context = auc(fpr_context, tpr_context)
    auc_combined = auc(fpr_combined, tpr_combined)
    
    # Create figure
    plt.figure(figsize=(10, 8))
    
    # Plot ROC curves
    plt.plot(fpr_baseline, tpr_baseline, label=f'No Weighting (AUC = {auc_baseline:.3f})', linewidth=2)
    plt.plot(fpr_basic, tpr_basic, label=f'Basic Class Weighting (AUC = {auc_basic:.3f})', linewidth=2)
    plt.plot(fpr_context, tpr_context, label=f'Context-Aware Weighting (AUC = {auc_context:.3f})', linewidth=2)
    plt.plot(fpr_combined, tpr_combined, label=f'Combined Weighting (AUC = {auc_combined:.3f})', linewidth=2)
    
    # Plot random classifier
    plt.plot([0, 1], [0, 1], 'k--', label='Random Classifier')
    
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves for Different Class Imbalance Handling Strategies')
    plt.legend(loc='lower right')
    plt.grid(True, alpha=0.3)
    
    # Save figure
    plt.tight_layout()
    plt.savefig('figures/class_imbalance_roc.pdf', format='pdf')
    plt.close()
    
    print("Generated class imbalance ROC curves")

# Figure 4: Multimodal State Diversity
def generate_multimodal_state_diversity():
    """Generate visualization of state diversity with multimodal data integration."""
    # Define sports and modality combinations
    sports = ['NBA', 'MLB', 'NHL']
    modalities = ['Spatiotemporal Only', 'Spatiotemporal + Biomechanical', 
                 'Spatiotemporal + Physiological', 'All Modalities']
    
    # Generate synthetic state diversity data
    # Higher values indicate more diverse states
    np.random.seed(42)
    
    # Base diversity values
    base_diversity = {
        'NBA': 0.65,
        'MLB': 0.60,
        'NHL': 0.45
    }
    
    # Modality effects (additive)
    modality_effects = {
        'Spatiotemporal Only': 0.0,
        'Spatiotemporal + Biomechanical': 0.1,
        'Spatiotemporal + Physiological': 0.08,
        'All Modalities': 0.15
    }
    
    # Generate data
    data = []
    for sport in sports:
        for modality in modalities:
            # Base diversity for this sport
            diversity = base_diversity[sport]
            
            # Add modality effect
            diversity += modality_effects[modality]
            
            # Add some random variation
            diversity += np.random.normal(0, 0.03)
            
            # Ensure diversity is between 0 and 1
            diversity = np.clip(diversity, 0, 1)
            
            # Add to data
            data.append({
                'Sport': sport,
                'Modality': modality,
                'State Diversity': diversity
            })
    
    # Convert to DataFrame
    df = pd.DataFrame(data)
    
    # Create figure
    plt.figure(figsize=(12, 8))
    
    # Create grouped bar plot
    sns.barplot(x='Sport', y='State Diversity', hue='Modality', data=df)
    
    plt.xlabel('Sport')
    plt.ylabel('State Diversity Index')
    plt.title('State Diversity with Different Multimodal Data Combinations')
    plt.legend(title='Data Modalities')
    plt.grid(True, alpha=0.3)
    
    # Add text annotations
    for i, sport in enumerate(sports):
        for j, modality in enumerate(modalities):
            row = df[(df['Sport'] == sport) & (df['Modality'] == modality)]
            diversity = row['State Diversity'].values[0]
            plt.text(i + (j - 1.5) * 0.2, diversity + 0.02, f'{diversity:.2f}', 
                    ha='center', va='bottom', fontsize=8)
    
    # Save figure
    plt.tight_layout()
    plt.savefig('figures/multimodal_state_diversity.pdf', format='pdf')
    plt.close()
    
    print("Generated multimodal state diversity visualization")

# Generate all figures
if __name__ == "__main__":
    print("Generating supplementary figures...")
    generate_em_convergence_plot()
    generate_goalie_random_effects()
    generate_class_imbalance_roc()
    generate_multimodal_state_diversity()
    print("All supplementary figures generated successfully!")
