"""
Create Enhanced Ablation Study Visualizations
============================================

This script creates individual PNG files for the enhanced ablation study
with all 6 GraGR components (4 Core + 2 GraGR++ components).
"""

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
from pathlib import Path

def create_component_impact_plot():
    """Create component impact visualization for all 6 components."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/enhanced_ablation_study/enhanced_ablation_detailed_results.csv")
    
    # Calculate component impact
    full_performance = df[df['model_name'] == 'GraGR++ (Full)']['best_test_acc'].iloc[0]
    
    component_impacts = {
        'Conflict Detection': full_performance - df[df['model_name'] == 'GraGR++ w/o Conflict']['best_test_acc'].iloc[0],
        'Gradient Alignment': full_performance - df[df['model_name'] == 'GraGR++ w/o Alignment']['best_test_acc'].iloc[0],
        'Gradient Attention': full_performance - df[df['model_name'] == 'GraGR++ w/o Attention']['best_test_acc'].iloc[0],
        'Meta Modulation': full_performance - df[df['model_name'] == 'GraGR++ w/o Meta']['best_test_acc'].iloc[0],
        'Multiple Pathways': full_performance - df[df['model_name'] == 'GraGR++ w/o Multiple Pathways']['best_test_acc'].iloc[0],
        'Adaptive Scheduling': full_performance - df[df['model_name'] == 'GraGR++ w/o Adaptive Scheduling']['best_test_acc'].iloc[0]
    }
    
    # Create the plot
    plt.figure(figsize=(14, 8))
    
    components = list(component_impacts.keys())
    impacts = list(component_impacts.values())
    
    # Color bars based on component type and impact
    colors = []
    for i, (comp, impact) in enumerate(zip(components, impacts)):
        if 'Multiple Pathways' in comp or 'Adaptive Scheduling' in comp:
            # GraGR++ components
            colors.append('green' if impact > 0 else 'red')
        else:
            # Core components
            colors.append('blue' if impact > 0 else 'orange')
    
    bars = plt.bar(components, impacts, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
    
    plt.title('Enhanced GraGR Component Impact Analysis\n(All 6 Components - PubMed Dataset)', 
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('GraGR Components', fontsize=12)
    plt.ylabel('Performance Impact (Test Accuracy)', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.grid(True, alpha=0.3, axis='y')
    
    # Add horizontal line at zero
    plt.axhline(y=0, color='black', linestyle='-', alpha=0.5, linewidth=2)
    
    # Add value labels on bars
    for bar, impact in zip(bars, impacts):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, 
                height + (0.001 if height > 0 else -0.003), 
                f'{impact:.3f}', ha='center', va='bottom' if height > 0 else 'top', 
                fontsize=11, fontweight='bold')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='blue', label='Core Components (Positive)'),
                      Patch(facecolor='orange', label='Core Components (Negative)'),
                      Patch(facecolor='green', label='GraGR++ Components (Positive)'),
                      Patch(facecolor='red', label='GraGR++ Components (Negative)')]
    plt.legend(handles=legend_elements, loc='upper right')
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/enhanced_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "enhanced_component_impact_analysis.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Enhanced component impact plot created: enhanced_component_impact_analysis.png")

def create_performance_ranking_plot():
    """Create performance ranking visualization for all models."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/enhanced_ablation_study/enhanced_ablation_detailed_results.csv")
    
    # Sort by test accuracy
    df_sorted = df.sort_values('best_test_acc', ascending=True)
    
    # Create the plot
    plt.figure(figsize=(16, 12))
    
    y_pos = np.arange(len(df_sorted))
    
    # Color bars based on model type
    colors = []
    for model_name in df_sorted['model_name']:
        if 'Baseline' in model_name:
            colors.append('#FF6B6B')  # Red for baselines
        elif 'GraGR++ (Full)' in model_name:
            colors.append('#4ECDC4')  # Teal for full GraGR++
        elif 'GraGR Core (Full)' in model_name:
            colors.append('#45B7D1')  # Blue for full GraGR Core
        elif 'GraGR++' in model_name:
            colors.append('#90EE90')  # Light green for GraGR++ variants
        else:
            colors.append('#87CEEB')  # Light blue for GraGR Core variants
    
    bars = plt.barh(y_pos, df_sorted['best_test_acc'] * 100, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
    
    plt.title('Enhanced Ablation Study - Model Performance Ranking\n(All 6 Components - PubMed Dataset)', 
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Test Accuracy (%)', fontsize=12)
    plt.ylabel('Model', fontsize=12)
    
    # Set y-axis labels
    model_labels = [name.replace('GraGR Core ', 'Core ').replace('GraGR++ ', '++ ').replace(' w/o ', ' -')
                   for name in df_sorted['model_name']]
    plt.yticks(y_pos, model_labels, fontsize=10)
    
    # Add value labels
    for i, (bar, acc) in enumerate(zip(bars, df_sorted['best_test_acc'] * 100)):
        plt.text(bar.get_width() + 0.2, bar.get_y() + bar.get_height()/2, 
                f'{acc:.1f}%', ha='left', va='center', fontsize=9, fontweight='bold')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='#FF6B6B', label='Baseline Models'),
                      Patch(facecolor='#4ECDC4', label='GraGR++ (Full)'),
                      Patch(facecolor='#45B7D1', label='GraGR Core (Full)'),
                      Patch(facecolor='#90EE90', label='GraGR++ Variants'),
                      Patch(facecolor='#87CEEB', label='GraGR Core Variants')]
    plt.legend(handles=legend_elements, loc='lower right')
    
    plt.grid(True, alpha=0.3, axis='x')
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/enhanced_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "enhanced_performance_ranking.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Enhanced performance ranking plot created: enhanced_performance_ranking.png")

def create_core_vs_plus_components_plot():
    """Create comparison between Core and GraGR++ components."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/enhanced_ablation_study/enhanced_ablation_detailed_results.csv")
    
    # Calculate average performance by component type
    core_components = ['Conflict Detection', 'Gradient Alignment', 'Gradient Attention', 'Meta Modulation']
    plus_components = ['Multiple Pathways', 'Adaptive Scheduling']
    
    core_impacts = []
    plus_impacts = []
    
    full_performance = df[df['model_name'] == 'GraGR++ (Full)']['best_test_acc'].iloc[0]
    
    # Core component impacts
    for comp in core_components:
        if comp == 'Conflict Detection':
            impact = full_performance - df[df['model_name'] == 'GraGR++ w/o Conflict']['best_test_acc'].iloc[0]
        elif comp == 'Gradient Alignment':
            impact = full_performance - df[df['model_name'] == 'GraGR++ w/o Alignment']['best_test_acc'].iloc[0]
        elif comp == 'Gradient Attention':
            impact = full_performance - df[df['model_name'] == 'GraGR++ w/o Attention']['best_test_acc'].iloc[0]
        elif comp == 'Meta Modulation':
            impact = full_performance - df[df['model_name'] == 'GraGR++ w/o Meta']['best_test_acc'].iloc[0]
        core_impacts.append(impact)
    
    # GraGR++ component impacts
    for comp in plus_components:
        if comp == 'Multiple Pathways':
            impact = full_performance - df[df['model_name'] == 'GraGR++ w/o Multiple Pathways']['best_test_acc'].iloc[0]
        elif comp == 'Adaptive Scheduling':
            impact = full_performance - df[df['model_name'] == 'GraGR++ w/o Adaptive Scheduling']['best_test_acc'].iloc[0]
        plus_impacts.append(impact)
    
    # Create the plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    fig.suptitle('Core vs GraGR++ Components Impact Analysis\n(PubMed Dataset)', 
                 fontsize=16, fontweight='bold', y=1.02)
    
    # Core components plot
    colors1 = ['blue' if impact > 0 else 'orange' for impact in core_impacts]
    bars1 = ax1.bar(core_components, core_impacts, color=colors1, alpha=0.7, edgecolor='black', linewidth=1.5)
    ax1.set_title('GraGR Core Components (1-4)', fontweight='bold')
    ax1.set_ylabel('Performance Impact (Test Accuracy)')
    ax1.tick_params(axis='x', rotation=45)
    ax1.grid(True, alpha=0.3, axis='y')
    ax1.axhline(y=0, color='black', linestyle='-', alpha=0.5)
    
    # Add value labels
    for bar, impact in zip(bars1, core_impacts):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2, 
                height + (0.001 if height > 0 else -0.003), 
                f'{impact:.3f}', ha='center', va='bottom' if height > 0 else 'top', 
                fontsize=10, fontweight='bold')
    
    # GraGR++ components plot
    colors2 = ['green' if impact > 0 else 'red' for impact in plus_impacts]
    bars2 = ax2.bar(plus_components, plus_impacts, color=colors2, alpha=0.7, edgecolor='black', linewidth=1.5)
    ax2.set_title('GraGR++ Components (5-6)', fontweight='bold')
    ax2.set_ylabel('Performance Impact (Test Accuracy)')
    ax2.tick_params(axis='x', rotation=45)
    ax2.grid(True, alpha=0.3, axis='y')
    ax2.axhline(y=0, color='black', linestyle='-', alpha=0.5)
    
    # Add value labels
    for bar, impact in zip(bars2, plus_impacts):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2, 
                height + (0.001 if height > 0 else -0.003), 
                f'{impact:.3f}', ha='center', va='bottom' if height > 0 else 'top', 
                fontsize=10, fontweight='bold')
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/enhanced_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "core_vs_plus_components_analysis.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Core vs GraGR++ components plot created: core_vs_plus_components_analysis.png")

def create_efficiency_analysis_plot():
    """Create efficiency analysis visualization."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/enhanced_ablation_study/enhanced_ablation_detailed_results.csv")
    
    # Calculate efficiency score (performance per second)
    df['efficiency_score'] = df['best_test_acc'] / df['training_time'] * 100
    
    # Create the plot
    plt.figure(figsize=(14, 10))
    
    # Color points based on model type
    colors = []
    for model_name in df['model_name']:
        if 'Baseline' in model_name:
            colors.append('#FF6B6B')  # Red for baselines
        elif 'GraGR++ (Full)' in model_name:
            colors.append('#4ECDC4')  # Teal for full GraGR++
        elif 'GraGR Core (Full)' in model_name:
            colors.append('#45B7D1')  # Blue for full GraGR Core
        elif 'GraGR++' in model_name:
            colors.append('#90EE90')  # Light green for GraGR++ variants
        else:
            colors.append('#87CEEB')  # Light blue for GraGR Core variants
    
    # Create scatter plot
    scatter = plt.scatter(df['training_time'], df['best_test_acc'] * 100, 
                         c=colors, s=200, alpha=0.8, edgecolors='black', linewidth=1.5)
    
    # Add model labels
    for _, row in df.iterrows():
        plt.annotate(row['model_name'].replace('GraGR Core ', 'Core ').replace('GraGR++ ', '++ ').replace(' w/o ', ' -'), 
                    (row['training_time'], row['best_test_acc'] * 100),
                    xytext=(5, 5), textcoords='offset points', fontsize=8, alpha=0.8)
    
    plt.title('Enhanced Ablation Study - Training Time vs Performance\n(Efficiency Analysis - PubMed Dataset)', 
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Training Time (seconds)', fontsize=12)
    plt.ylabel('Test Accuracy (%)', fontsize=12)
    plt.grid(True, alpha=0.3)
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='#FF6B6B', label='Baseline Models'),
                      Patch(facecolor='#4ECDC4', label='GraGR++ (Full)'),
                      Patch(facecolor='#45B7D1', label='GraGR Core (Full)'),
                      Patch(facecolor='#90EE90', label='GraGR++ Variants'),
                      Patch(facecolor='#87CEEB', label='GraGR Core Variants')]
    plt.legend(handles=legend_elements, loc='lower right')
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/enhanced_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "enhanced_efficiency_analysis.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Enhanced efficiency analysis plot created: enhanced_efficiency_analysis.png")

def create_component_heatmap():
    """Create component importance heatmap."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/enhanced_ablation_study/enhanced_ablation_detailed_results.csv")
    
    # Calculate component importance matrix
    components = ['Conflict Detection', 'Gradient Alignment', 'Gradient Attention', 'Meta Modulation', 'Multiple Pathways', 'Adaptive Scheduling']
    
    full_performance = df[df['model_name'] == 'GraGR++ (Full)']['best_test_acc'].iloc[0]
    
    importance_matrix = np.zeros((len(components), 1))
    
    for i, comp in enumerate(components):
        if comp == 'Conflict Detection':
            w_o_performance = df[df['model_name'] == 'GraGR++ w/o Conflict']['best_test_acc'].iloc[0]
        elif comp == 'Gradient Alignment':
            w_o_performance = df[df['model_name'] == 'GraGR++ w/o Alignment']['best_test_acc'].iloc[0]
        elif comp == 'Gradient Attention':
            w_o_performance = df[df['model_name'] == 'GraGR++ w/o Attention']['best_test_acc'].iloc[0]
        elif comp == 'Meta Modulation':
            w_o_performance = df[df['model_name'] == 'GraGR++ w/o Meta']['best_test_acc'].iloc[0]
        elif comp == 'Multiple Pathways':
            w_o_performance = df[df['model_name'] == 'GraGR++ w/o Multiple Pathways']['best_test_acc'].iloc[0]
        elif comp == 'Adaptive Scheduling':
            w_o_performance = df[df['model_name'] == 'GraGR++ w/o Adaptive Scheduling']['best_test_acc'].iloc[0]
        
        importance_matrix[i, 0] = full_performance - w_o_performance
    
    # Create the plot
    plt.figure(figsize=(8, 10))
    
    sns.heatmap(importance_matrix, 
               xticklabels=['PubMed'], 
               yticklabels=components, 
               annot=True, 
               fmt='.3f', 
               cmap='RdYlGn',
               center=0,
               cbar_kws={'label': 'Performance Impact (Test Accuracy)'})
    
    plt.title('Component Importance Heatmap\n(GraGR++ Full - w/o Component)', 
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Dataset', fontsize=12)
    plt.ylabel('GraGR Components', fontsize=12)
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/enhanced_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "enhanced_component_heatmap.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Enhanced component heatmap created: enhanced_component_heatmap.png")

def main():
    """Create all enhanced ablation study visualizations."""
    print("🎨 Creating Enhanced Ablation Study Visualizations")
    print("=" * 60)
    
    # Create output directory
    output_dir = Path("GraGR_Research_Results/enhanced_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create all visualizations
    print("\n📊 Creating enhanced component impact plot...")
    create_component_impact_plot()
    
    print("\n📈 Creating enhanced performance ranking plot...")
    create_performance_ranking_plot()
    
    print("\n🔧 Creating core vs GraGR++ components plot...")
    create_core_vs_plus_components_plot()
    
    print("\n⚡ Creating enhanced efficiency analysis plot...")
    create_efficiency_analysis_plot()
    
    print("\n🔥 Creating enhanced component heatmap...")
    create_component_heatmap()
    
    print(f"\n✅ All enhanced ablation study visualizations created!")
    print(f"📁 Plots saved in: {output_dir}/")
    print(f"\n📊 Generated Files:")
    print(f"  - enhanced_component_impact_analysis.png")
    print(f"  - enhanced_performance_ranking.png")
    print(f"  - core_vs_plus_components_analysis.png")
    print(f"  - enhanced_efficiency_analysis.png")
    print(f"  - enhanced_component_heatmap.png")

if __name__ == "__main__":
    main()
