"""
Create Individual PNG Visualizations for Cora Dataset Ablation Study
===================================================================

This script creates separate PNG files for each visualization,
focusing on the Cora dataset ablation study results.
"""

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
from pathlib import Path

def create_component_impact_plot():
    """Create component impact visualization for Cora dataset."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/final_ablation_study/final_ablation_detailed_results.csv")
    
    # Calculate component impact
    full_performance = df[df['model_name'] == 'GraGR Core (Full)']['best_test_acc'].iloc[0]
    
    component_impacts = {
        'Conflict Detection': full_performance - df[df['model_name'] == 'GraGR Core w/o Conflict']['best_test_acc'].iloc[0],
        'Gradient Alignment': full_performance - df[df['model_name'] == 'GraGR Core w/o Alignment']['best_test_acc'].iloc[0],
        'Gradient Attention': full_performance - df[df['model_name'] == 'GraGR Core w/o Attention']['best_test_acc'].iloc[0],
        'Meta Modulation': full_performance - df[df['model_name'] == 'GraGR Core w/o Meta']['best_test_acc'].iloc[0]
    }
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    
    components = list(component_impacts.keys())
    impacts = list(component_impacts.values())
    
    # Color bars based on impact (green for positive, red for negative)
    colors = ['green' if impact > 0 else 'red' for impact in impacts]
    
    bars = plt.bar(components, impacts, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
    
    plt.title('GraGR Core Component Impact Analysis\n(Cora Dataset)', 
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('GraGR Core Components', fontsize=12)
    plt.ylabel('Performance Impact (Test Accuracy)', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.grid(True, alpha=0.3, axis='y')
    
    # Add horizontal line at zero
    plt.axhline(y=0, color='black', linestyle='-', alpha=0.5, linewidth=2)
    
    # Add value labels on bars
    for bar, impact in zip(bars, impacts):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, 
                height + (0.001 if height > 0 else -0.003), 
                f'{impact:.3f}', ha='center', va='bottom' if height > 0 else 'top', 
                fontsize=11, fontweight='bold')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='green', label='Positive Impact'),
                      Patch(facecolor='red', label='Negative Impact')]
    plt.legend(handles=legend_elements, loc='upper right')
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/final_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "cora_component_impact_analysis.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Component impact plot created: cora_component_impact_analysis.png")

def create_performance_ranking_plot():
    """Create performance ranking visualization for Cora dataset."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/final_ablation_study/final_ablation_detailed_results.csv")
    
    # Sort by test accuracy
    df_sorted = df.sort_values('best_test_acc', ascending=True)
    
    # Create the plot
    plt.figure(figsize=(14, 10))
    
    y_pos = np.arange(len(df_sorted))
    
    # Color bars based on model type
    colors = []
    for model_name in df_sorted['model_name']:
        if 'Baseline' in model_name:
            colors.append('#FF6B6B')  # Red for baselines
        elif 'GraGR Core (Full)' in model_name:
            colors.append('#4ECDC4')  # Teal for full GraGR
        else:
            colors.append('#45B7D1')  # Blue for GraGR variants
    
    bars = plt.barh(y_pos, df_sorted['best_test_acc'] * 100, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
    
    plt.title('Cora Dataset - Model Performance Ranking\n(Test Accuracy)', 
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Test Accuracy (%)', fontsize=12)
    plt.ylabel('Model', fontsize=12)
    
    # Set y-axis labels
    model_labels = [name.replace('GraGR Core ', 'Core ').replace(' w/o ', ' -')
                   for name in df_sorted['model_name']]
    plt.yticks(y_pos, model_labels, fontsize=10)
    
    # Add value labels
    for i, (bar, acc) in enumerate(zip(bars, df_sorted['best_test_acc'] * 100)):
        plt.text(bar.get_width() + 0.2, bar.get_y() + bar.get_height()/2, 
                f'{acc:.1f}%', ha='left', va='center', fontsize=10, fontweight='bold')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='#FF6B6B', label='Baseline Models'),
                      Patch(facecolor='#4ECDC4', label='GraGR Core (Full)'),
                      Patch(facecolor='#45B7D1', label='GraGR Core Variants')]
    plt.legend(handles=legend_elements, loc='lower right')
    
    plt.grid(True, alpha=0.3, axis='x')
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/final_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "cora_performance_ranking.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Performance ranking plot created: cora_performance_ranking.png")

def create_validation_vs_test_plot():
    """Create validation vs test accuracy plot for Cora dataset."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/final_ablation_study/final_ablation_detailed_results.csv")
    
    # Create the plot
    plt.figure(figsize=(12, 10))
    
    # Separate data by model type
    baseline_data = df[df['model_name'].str.contains('Baseline')]
    gragr_data = df[df['model_name'].str.contains('GraGR')]
    
    # Plot baselines
    plt.scatter(baseline_data['best_val_acc'] * 100, baseline_data['best_test_acc'] * 100, 
               c='#FF6B6B', s=150, alpha=0.8, label='Baseline Models', edgecolors='black', linewidth=1.5)
    
    # Plot GraGR models
    plt.scatter(gragr_data['best_val_acc'] * 100, gragr_data['best_test_acc'] * 100, 
               c='#4ECDC4', s=150, alpha=0.8, label='GraGR Models', edgecolors='black', linewidth=1.5)
    
    # Add diagonal line (perfect correlation)
    min_acc = min(df['best_val_acc'].min(), df['best_test_acc'].min()) * 100
    max_acc = max(df['best_val_acc'].max(), df['best_test_acc'].max()) * 100
    plt.plot([min_acc, max_acc], [min_acc, max_acc], 'k--', alpha=0.5, linewidth=2, label='Perfect Correlation')
    
    # Add model labels
    for _, row in df.iterrows():
        plt.annotate(row['model_name'].replace('GraGR Core ', 'Core ').replace(' w/o ', ' -'), 
                    (row['best_val_acc'] * 100, row['best_test_acc'] * 100),
                    xytext=(5, 5), textcoords='offset points', fontsize=9, alpha=0.8)
    
    plt.title('Cora Dataset - Validation vs Test Accuracy\n(All Models)', 
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Validation Accuracy (%)', fontsize=12)
    plt.ylabel('Test Accuracy (%)', fontsize=12)
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3)
    
    # Set equal aspect ratio
    plt.axis('equal')
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/final_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "cora_validation_vs_test_accuracy.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Validation vs test accuracy plot created: cora_validation_vs_test_accuracy.png")

def create_component_contribution_plot():
    """Create component contribution visualization for Cora dataset."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/final_ablation_study/final_ablation_detailed_results.csv")
    
    # Calculate individual component contributions
    full_performance = df[df['model_name'] == 'GraGR Core (Full)']['best_test_acc'].iloc[0]
    
    # Calculate contribution of each component
    contributions = {
        'Conflict Detection': df[df['model_name'] == 'GraGR Core w/o Conflict']['best_test_acc'].iloc[0] - full_performance,
        'Gradient Alignment': df[df['model_name'] == 'GraGR Core w/o Alignment']['best_test_acc'].iloc[0] - full_performance,
        'Gradient Attention': df[df['model_name'] == 'GraGR Core w/o Attention']['best_test_acc'].iloc[0] - full_performance,
        'Meta Modulation': df[df['model_name'] == 'GraGR Core w/o Meta']['best_test_acc'].iloc[0] - full_performance
    }
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    
    components = list(contributions.keys())
    contribs = list(contributions.values())
    
    # Color bars based on contribution (green for positive, red for negative)
    colors = ['green' if contrib > 0 else 'red' for contrib in contribs]
    
    bars = plt.bar(components, contribs, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
    
    plt.title('Individual Component Contribution Analysis\n(Cora Dataset)', 
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('GraGR Core Components', fontsize=12)
    plt.ylabel('Contribution to Performance (Test Accuracy)', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.grid(True, alpha=0.3, axis='y')
    
    # Add horizontal line at zero
    plt.axhline(y=0, color='black', linestyle='-', alpha=0.5, linewidth=2)
    
    # Add value labels on bars
    for bar, contrib in zip(bars, contribs):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, 
                height + (0.001 if height > 0 else -0.003), 
                f'{contrib:.3f}', ha='center', va='bottom' if height > 0 else 'top', 
                fontsize=11, fontweight='bold')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='green', label='Positive Contribution'),
                      Patch(facecolor='red', label='Negative Contribution')]
    plt.legend(handles=legend_elements, loc='upper right')
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/final_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "cora_component_contribution_analysis.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Component contribution plot created: cora_component_contribution_analysis.png")

def create_efficiency_analysis_plot():
    """Create efficiency analysis visualization for Cora dataset."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/final_ablation_study/final_ablation_detailed_results.csv")
    
    # Calculate efficiency score (performance per second)
    df['efficiency_score'] = df['best_test_acc'] / df['training_time'] * 100
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    
    # Color points based on model type
    colors = []
    for model_name in df['model_name']:
        if 'Baseline' in model_name:
            colors.append('#FF6B6B')  # Red for baselines
        elif 'GraGR Core (Full)' in model_name:
            colors.append('#4ECDC4')  # Teal for full GraGR
        else:
            colors.append('#45B7D1')  # Blue for GraGR variants
    
    # Create scatter plot
    scatter = plt.scatter(df['training_time'], df['best_test_acc'] * 100, 
                         c=colors, s=200, alpha=0.8, edgecolors='black', linewidth=1.5)
    
    # Add model labels
    for _, row in df.iterrows():
        plt.annotate(row['model_name'].replace('GraGR Core ', 'Core ').replace(' w/o ', ' -'), 
                    (row['training_time'], row['best_test_acc'] * 100),
                    xytext=(5, 5), textcoords='offset points', fontsize=9, alpha=0.8)
    
    plt.title('Cora Dataset - Training Time vs Performance\n(Efficiency Analysis)', 
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Training Time (seconds)', fontsize=12)
    plt.ylabel('Test Accuracy (%)', fontsize=12)
    plt.grid(True, alpha=0.3)
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='#FF6B6B', label='Baseline Models'),
                      Patch(facecolor='#4ECDC4', label='GraGR Core (Full)'),
                      Patch(facecolor='#45B7D1', label='GraGR Core Variants')]
    plt.legend(handles=legend_elements, loc='lower right')
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/final_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "cora_efficiency_analysis.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Efficiency analysis plot created: cora_efficiency_analysis.png")

def create_ablation_summary_table():
    """Create ablation summary table visualization for Cora dataset."""
    
    # Load results
    df = pd.read_csv("GraGR_Research_Results/final_ablation_study/final_ablation_detailed_results.csv")
    
    # Sort by test accuracy
    df_sorted = df.sort_values('best_test_acc', ascending=False)
    
    # Create the plot
    fig, ax = plt.subplots(figsize=(16, 10))
    ax.axis('tight')
    ax.axis('off')
    
    # Prepare table data
    table_data = []
    for _, row in df_sorted.iterrows():
        table_data.append([
            row['model_name'].replace('GraGR Core ', 'Core ').replace(' w/o ', ' -'),
            f"{row['best_val_acc']:.4f}",
            f"{row['best_test_acc']:.4f}",
            f"{row['training_time']:.2f}s"
        ])
    
    # Create table
    table = ax.table(cellText=table_data,
                    colLabels=['Model', 'Val Accuracy', 'Test Accuracy', 'Training Time'],
                    cellLoc='center',
                    loc='center',
                    bbox=[0, 0, 1, 1])
    
    # Style the table
    table.auto_set_font_size(False)
    table.set_fontsize(12)
    table.scale(1, 2)
    
    # Color header
    for i in range(4):
        table[(0, i)].set_facecolor('#40466e')
        table[(0, i)].set_text_props(weight='bold', color='white')
    
    # Color rows based on model type
    for i, row_data in enumerate(table_data):
        for j in range(4):
            if 'Baseline' in row_data[0]:
                table[(i+1, j)].set_facecolor('#FFE6E6')  # Light red
            elif 'Core (Full)' in row_data[0]:
                table[(i+1, j)].set_facecolor('#E6F7FF')  # Light teal
            else:
                table[(i+1, j)].set_facecolor('#E6F3FF')  # Light blue
    
    plt.title('Cora Dataset - Ablation Study Results Summary', 
              fontsize=18, fontweight='bold', pad=20)
    
    plt.tight_layout()
    
    # Save plot
    output_dir = Path("GraGR_Research_Results/final_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "cora_ablation_summary_table.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Ablation summary table created: cora_ablation_summary_table.png")

def main():
    """Create all Cora dataset ablation study visualizations."""
    print("🎨 Creating Cora Dataset Ablation Study Visualizations")
    print("=" * 60)
    
    # Create output directory
    output_dir = Path("GraGR_Research_Results/final_ablation_study/visualizations")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create all visualizations
    print("\n📊 Creating component impact plot...")
    create_component_impact_plot()
    
    print("\n📈 Creating performance ranking plot...")
    create_performance_ranking_plot()
    
    print("\n🔄 Creating validation vs test accuracy plot...")
    create_validation_vs_test_plot()
    
    print("\n🔧 Creating component contribution plot...")
    create_component_contribution_plot()
    
    print("\n⚡ Creating efficiency analysis plot...")
    create_efficiency_analysis_plot()
    
    print("\n📋 Creating ablation summary table...")
    create_ablation_summary_table()
    
    print(f"\n✅ All Cora dataset ablation study visualizations created!")
    print(f"📁 Plots saved in: {output_dir}/")
    print(f"\n📊 Generated Files:")
    print(f"  - cora_component_impact_analysis.png")
    print(f"  - cora_performance_ranking.png")
    print(f"  - cora_validation_vs_test_accuracy.png")
    print(f"  - cora_component_contribution_analysis.png")
    print(f"  - cora_efficiency_analysis.png")
    print(f"  - cora_ablation_summary_table.png")

if __name__ == "__main__":
    main()
