"""
Create Multi-Dataset Ablation Study Visualizations
=================================================

This script creates individual PNG files for multi-dataset ablation study results.
"""

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
from pathlib import Path

def create_dataset_comparison_plot():
    """Create dataset comparison visualization."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/final_ablation_study/final_ablation_detailed_results.csv")
    
    # Create the plot
    plt.figure(figsize=(14, 10))
    
    # Get unique datasets
    datasets = df['dataset'].unique()
    
    # Create subplot for each dataset
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    fig.suptitle('Multi-Dataset Ablation Study Results\n(Test Accuracy Comparison)', 
                 fontsize=16, fontweight='bold', y=1.02)
    
    for i, dataset in enumerate(datasets):
        ax = axes[i]
        
        # Filter data for this dataset
        dataset_data = df[df['dataset'] == dataset].copy()
        dataset_data = dataset_data.sort_values('best_test_acc', ascending=True)
        
        # Color bars based on model type
        colors = []
        for model_name in dataset_data['model_name']:
            if 'Baseline' in model_name:
                colors.append('#FF6B6B')  # Red for baselines
            elif 'GraGR Core (Full)' in model_name:
                colors.append('#4ECDC4')  # Teal for full GraGR
            else:
                colors.append('#45B7D1')  # Blue for GraGR variants
        
        y_pos = np.arange(len(dataset_data))
        bars = ax.barh(y_pos, dataset_data['best_test_acc'] * 100, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
        
        ax.set_title(f'{dataset.upper()} Dataset', fontweight='bold', fontsize=12)
        ax.set_xlabel('Test Accuracy (%)', fontsize=10)
        ax.set_ylabel('Model', fontsize=10)
        
        # Set y-axis labels
        model_labels = [name.replace('GraGR Core ', 'Core ').replace(' w/o ', ' -')
                       for name in dataset_data['model_name']]
        ax.set_yticks(y_pos)
        ax.set_yticklabels(model_labels, fontsize=8)
        
        # Add value labels
        for j, (bar, acc) in enumerate(zip(bars, dataset_data['best_test_acc'] * 100)):
            ax.text(bar.get_width() + 0.2, bar.get_y() + bar.get_height()/2, 
                   f'{acc:.1f}%', ha='left', va='center', fontsize=8, fontweight='bold')
        
        ax.grid(True, alpha=0.3, axis='x')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='#FF6B6B', label='Baseline Models'),
                      Patch(facecolor='#4ECDC4', label='GraGR Core (Full)'),
                      Patch(facecolor='#45B7D1', label='GraGR Core Variants')]
    fig.legend(handles=legend_elements, loc='center', bbox_to_anchor=(0.5, 0.02), ncol=3)
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/final_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "multi_dataset_comparison.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Multi-dataset comparison plot created: multi_dataset_comparison.png")

def create_component_impact_heatmap():
    """Create component impact heatmap across datasets."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/final_ablation_study/final_ablation_detailed_results.csv")
    
    # Calculate component impact for each dataset
    datasets = df['dataset'].unique()
    components = ['Conflict Detection', 'Gradient Alignment', 'Gradient Attention', 'Meta Modulation']
    
    impact_matrix = np.zeros((len(components), len(datasets)))
    
    for i, comp in enumerate(components):
        for j, dataset in enumerate(datasets):
            dataset_data = df[df['dataset'] == dataset]
            full_performance = dataset_data[dataset_data['model_name'] == 'GraGR Core (Full)']['best_test_acc'].iloc[0]
            
            if comp == 'Conflict Detection':
                w_o_performance = dataset_data[dataset_data['model_name'] == 'GraGR Core w/o Conflict']['best_test_acc'].iloc[0]
            elif comp == 'Gradient Alignment':
                w_o_performance = dataset_data[dataset_data['model_name'] == 'GraGR Core w/o Alignment']['best_test_acc'].iloc[0]
            elif comp == 'Gradient Attention':
                w_o_performance = dataset_data[dataset_data['model_name'] == 'GraGR Core w/o Attention']['best_test_acc'].iloc[0]
            elif comp == 'Meta Modulation':
                w_o_performance = dataset_data[dataset_data['model_name'] == 'GraGR Core w/o Meta']['best_test_acc'].iloc[0]
            
            impact_matrix[i, j] = full_performance - w_o_performance
    
    # Create the plot
    plt.figure(figsize=(10, 8))
    
    sns.heatmap(impact_matrix, 
               xticklabels=[d.upper() for d in datasets], 
               yticklabels=components, 
               annot=True, 
               fmt='.3f', 
               cmap='RdYlGn',
               center=0,
               cbar_kws={'label': 'Performance Impact (Test Accuracy)'})
    
    plt.title('Component Impact Heatmap\n(Full GraGR Core - w/o Component)', 
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Dataset', fontsize=12)
    plt.ylabel('GraGR Core Components', fontsize=12)
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/final_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "component_impact_heatmap.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Component impact heatmap created: component_impact_heatmap.png")

def create_best_models_plot():
    """Create best performing models visualization."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/final_ablation_study/final_ablation_detailed_results.csv")
    
    # Find best model for each dataset
    best_models = []
    for dataset in df['dataset'].unique():
        dataset_data = df[df['dataset'] == dataset]
        best_model = dataset_data.loc[dataset_data['best_test_acc'].idxmax()]
        best_models.append({
            'Dataset': dataset.upper(),
            'Model': best_model['model_name'],
            'Test Accuracy': best_model['best_test_acc'] * 100
        })
    
    best_df = pd.DataFrame(best_models)
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    
    # Color bars based on model type
    colors = []
    for model_name in best_df['Model']:
        if 'Baseline' in model_name:
            colors.append('#FF6B6B')  # Red for baselines
        elif 'GraGR Core (Full)' in model_name:
            colors.append('#4ECDC4')  # Teal for full GraGR
        else:
            colors.append('#45B7D1')  # Blue for GraGR variants
    
    bars = plt.bar(best_df['Dataset'], best_df['Test Accuracy'], 
                   color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
    
    plt.title('Best Performing Models by Dataset\n(Test Accuracy)', 
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Dataset', fontsize=12)
    plt.ylabel('Test Accuracy (%)', fontsize=12)
    plt.xticks(rotation=0)
    plt.grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for bar, acc in zip(bars, best_df['Test Accuracy']):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.2, 
                f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold', fontsize=11)
    
    # Add model names as text below bars
    for i, (bar, model) in enumerate(zip(bars, best_df['Model'])):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height()/2, 
                model.replace('GraGR Core ', 'Core ').replace(' w/o ', ' -'), 
                ha='center', va='center', fontsize=9, rotation=90, color='white', fontweight='bold')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='#FF6B6B', label='Baseline Models'),
                      Patch(facecolor='#4ECDC4', label='GraGR Core (Full)'),
                      Patch(facecolor='#45B7D1', label='GraGR Core Variants')]
    plt.legend(handles=legend_elements, loc='upper right')
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/final_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "best_models_by_dataset.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Best models plot created: best_models_by_dataset.png")

def create_gragr_vs_baseline_plot():
    """Create GraGR vs Baseline comparison plot."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/final_ablation_study/final_ablation_detailed_results.csv")
    
    # Calculate average performance by model type and dataset
    comparison_data = []
    
    for dataset in df['dataset'].unique():
        dataset_data = df[df['dataset'] == dataset]
        
        # Baseline average
        baseline_avg = dataset_data[dataset_data['model_name'].str.contains('Baseline')]['best_test_acc'].mean()
        
        # GraGR Core (Full)
        gragr_full = dataset_data[dataset_data['model_name'] == 'GraGR Core (Full)']['best_test_acc'].iloc[0]
        
        # Best GraGR variant
        gragr_variants = dataset_data[dataset_data['model_name'].str.contains('GraGR Core w/o')]
        best_gragr = gragr_variants['best_test_acc'].max()
        
        comparison_data.append({
            'Dataset': dataset.upper(),
            'Baseline Average': baseline_avg * 100,
            'GraGR Core (Full)': gragr_full * 100,
            'Best GraGR Variant': best_gragr * 100
        })
    
    comp_df = pd.DataFrame(comparison_data)
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    
    x = np.arange(len(comp_df))
    width = 0.25
    
    bars1 = plt.bar(x - width, comp_df['Baseline Average'], width, 
                   label='Baseline Average', alpha=0.8, color='#FF6B6B', edgecolor='black')
    bars2 = plt.bar(x, comp_df['GraGR Core (Full)'], width, 
                   label='GraGR Core (Full)', alpha=0.8, color='#4ECDC4', edgecolor='black')
    bars3 = plt.bar(x + width, comp_df['Best GraGR Variant'], width, 
                   label='Best GraGR Variant', alpha=0.8, color='#45B7D1', edgecolor='black')
    
    plt.title('GraGR vs Baseline Performance Comparison\n(Across All Datasets)', 
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Dataset', fontsize=12)
    plt.ylabel('Test Accuracy (%)', fontsize=12)
    plt.xticks(x, comp_df['Dataset'])
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for bars in [bars1, bars2, bars3]:
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2, height + 0.2, 
                    f'{height:.1f}%', ha='center', va='bottom', fontweight='bold', fontsize=9)
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/final_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "gragr_vs_baseline_comparison.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ GraGR vs Baseline comparison plot created: gragr_vs_baseline_comparison.png")

def create_efficiency_comparison_plot():
    """Create efficiency comparison plot across datasets."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/final_ablation_study/final_ablation_detailed_results.csv")
    
    # Calculate efficiency score (performance per second)
    df['efficiency_score'] = df['best_test_acc'] / df['training_time'] * 100
    
    # Create the plot
    plt.figure(figsize=(14, 10))
    
    # Create subplot for each dataset
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    fig.suptitle('Efficiency Analysis Across Datasets\n(Performance per Second)', 
                 fontsize=16, fontweight='bold', y=1.02)
    
    for i, dataset in enumerate(df['dataset'].unique()):
        ax = axes[i]
        
        # Filter data for this dataset
        dataset_data = df[df['dataset'] == dataset].copy()
        dataset_data = dataset_data.sort_values('efficiency_score', ascending=True)
        
        # Color bars based on model type
        colors = []
        for model_name in dataset_data['model_name']:
            if 'Baseline' in model_name:
                colors.append('#FF6B6B')  # Red for baselines
            elif 'GraGR Core (Full)' in model_name:
                colors.append('#4ECDC4')  # Teal for full GraGR
            else:
                colors.append('#45B7D1')  # Blue for GraGR variants
        
        y_pos = np.arange(len(dataset_data))
        bars = ax.barh(y_pos, dataset_data['efficiency_score'], color=colors, alpha=0.8, edgecolor='black', linewidth=1)
        
        ax.set_title(f'{dataset.upper()} Dataset', fontweight='bold', fontsize=12)
        ax.set_xlabel('Efficiency Score', fontsize=10)
        ax.set_ylabel('Model', fontsize=10)
        
        # Set y-axis labels
        model_labels = [name.replace('GraGR Core ', 'Core ').replace(' w/o ', ' -')
                       for name in dataset_data['model_name']]
        ax.set_yticks(y_pos)
        ax.set_yticklabels(model_labels, fontsize=8)
        
        # Add value labels
        for j, (bar, score) in enumerate(zip(bars, dataset_data['efficiency_score'])):
            ax.text(bar.get_width() + 0.5, bar.get_y() + bar.get_height()/2, 
                   f'{score:.1f}', ha='left', va='center', fontsize=8, fontweight='bold')
        
        ax.grid(True, alpha=0.3, axis='x')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='#FF6B6B', label='Baseline Models'),
                      Patch(facecolor='#4ECDC4', label='GraGR Core (Full)'),
                      Patch(facecolor='#45B7D1', label='GraGR Core Variants')]
    fig.legend(handles=legend_elements, loc='center', bbox_to_anchor=(0.5, 0.02), ncol=3)
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/final_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "efficiency_comparison_across_datasets.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Efficiency comparison plot created: efficiency_comparison_across_datasets.png")

def main():
    """Create all multi-dataset ablation study visualizations."""
    print("🎨 Creating Multi-Dataset Ablation Study Visualizations")
    print("=" * 60)
    
    # Create output directory
    output_dir = Path("GraGR_Research_Results/final_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create all visualizations
    print("\n📊 Creating multi-dataset comparison plot...")
    create_dataset_comparison_plot()
    
    print("\n🔥 Creating component impact heatmap...")
    create_component_impact_heatmap()
    
    print("\n🏆 Creating best models plot...")
    create_best_models_plot()
    
    print("\n⚔️ Creating GraGR vs Baseline comparison plot...")
    create_gragr_vs_baseline_plot()
    
    print("\n⚡ Creating efficiency comparison plot...")
    create_efficiency_comparison_plot()
    
    print(f"\n✅ All multi-dataset ablation study visualizations created!")
    print(f"📁 Plots saved in: {output_dir}/")
    print(f"\n📊 Generated Files:")
    print(f"  - multi_dataset_comparison.png")
    print(f"  - component_impact_heatmap.png")
    print(f"  - best_models_by_dataset.png")
    print(f"  - gragr_vs_baseline_comparison.png")
    print(f"  - efficiency_comparison_across_datasets.png")

if __name__ == "__main__":
    main()
