#!/usr/bin/env python3
"""
Generate visualization plots that match the actual results from the CSV data.
This script creates publication-ready plots for the research paper.
"""

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pathlib import Path

# Set style for publication-quality plots
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 14,
    'axes.labelsize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.titlesize': 16
})

def load_and_clean_data():
    """Load and clean the results data."""
    # Load the CSV data
    data = pd.read_csv('../leaderboards/all_results.csv')
    
    # Clean model names for better presentation
    model_name_mapping = {
        'GRAPHSAGE': 'GraphSAGE',
        'MLP': 'MLP',
        'XGBOOST': 'XGBoost',
        'GATV2': 'GATv2',
        'RANDOMFOREST': 'RandomForest',
        'LOGISTIC': 'Logistic',
        'GIN': 'GIN'
    }
    
    data['Model'] = data['Model'].map(model_name_mapping)
    
    # Clean dataset names
    dataset_mapping = {
        'CISCO': 'Cisco',
        'NSL_KDD': 'NSL-KDD'
    }
    
    data['Dataset'] = data['Dataset'].map(dataset_mapping)
    
    return data

def create_nsl_kdd_ranking():
    """Create NSL-KDD performance ranking plot."""
    data = load_and_clean_data()
    nsl_data = data[data['Dataset'] == 'NSL-KDD'].copy()
    nsl_data = nsl_data.sort_values('F1', ascending=True)
    
    plt.figure(figsize=(10, 6))
    bars = plt.barh(nsl_data['Model'], nsl_data['F1'], color='skyblue', alpha=0.8)
    
    # Add value labels on bars
    for i, (bar, f1) in enumerate(zip(bars, nsl_data['F1'])):
        plt.text(f1 + 0.01, bar.get_y() + bar.get_height()/2, 
                f'{f1:.3f}', va='center', ha='left', fontweight='bold')
    
    plt.xlabel('F1 Score')
    plt.ylabel('Model')
    plt.title('NSL-KDD Dataset: Model Performance Ranking by F1 Score')
    plt.xlim(0, max(nsl_data['F1']) * 1.15)
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    
    plt.savefig('images/nsl_kdd_performance_ranking.png', dpi=300, bbox_inches='tight')
    plt.close()

def create_cisco_ranking():
    """Create Cisco performance ranking plot."""
    data = load_and_clean_data()
    cisco_data = data[data['Dataset'] == 'Cisco'].copy()
    cisco_data = cisco_data.sort_values('F1', ascending=True)
    
    plt.figure(figsize=(10, 6))
    bars = plt.barh(cisco_data['Model'], cisco_data['F1'], color='lightcoral', alpha=0.8)
    
    # Add value labels on bars
    for i, (bar, f1) in enumerate(zip(bars, cisco_data['F1'])):
        plt.text(f1 + 0.02, bar.get_y() + bar.get_height()/2, 
                f'{f1:.3f}', va='center', ha='left', fontweight='bold')
    
    plt.xlabel('F1 Score')
    plt.ylabel('Model')
    plt.title('Cisco Dataset: Model Performance Ranking by F1 Score')
    plt.xlim(0, max(cisco_data['F1']) * 1.15)
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    
    plt.savefig('images/cisco_performance_ranking.png', dpi=300, bbox_inches='tight')
    plt.close()

def create_summary_comparison():
    """Create cross-dataset summary comparison."""
    data = load_and_clean_data()
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
    
    # F1 Score comparison
    nsl_data = data[data['Dataset'] == 'NSL-KDD'].sort_values('F1', ascending=False)
    cisco_data = data[data['Dataset'] == 'Cisco'].sort_values('F1', ascending=False)
    
    ax1.bar(range(len(nsl_data)), nsl_data['F1'], alpha=0.7, color='skyblue', label='NSL-KDD')
    ax1.set_xticks(range(len(nsl_data)))
    ax1.set_xticklabels(nsl_data['Model'], rotation=45, ha='right')
    ax1.set_ylabel('F1 Score')
    ax1.set_title('NSL-KDD F1 Performance')
    ax1.grid(axis='y', alpha=0.3)
    
    ax2.bar(range(len(cisco_data)), cisco_data['F1'], alpha=0.7, color='lightcoral', label='Cisco')
    ax2.set_xticks(range(len(cisco_data)))
    ax2.set_xticklabels(cisco_data['Model'], rotation=45, ha='right')
    ax2.set_ylabel('F1 Score')
    ax2.set_title('Cisco F1 Performance')
    ax2.grid(axis='y', alpha=0.3)
    
    # Accuracy comparison
    ax3.bar(range(len(nsl_data)), nsl_data['Accuracy'], alpha=0.7, color='lightgreen')
    ax3.set_xticks(range(len(nsl_data)))
    ax3.set_xticklabels(nsl_data['Model'], rotation=45, ha='right')
    ax3.set_ylabel('Accuracy')
    ax3.set_title('NSL-KDD Accuracy')
    ax3.grid(axis='y', alpha=0.3)
    
    ax4.bar(range(len(cisco_data)), cisco_data['Accuracy'], alpha=0.7, color='gold')
    ax4.set_xticks(range(len(cisco_data)))
    ax4.set_xticklabels(cisco_data['Model'], rotation=45, ha='right')
    ax4.set_ylabel('Accuracy')
    ax4.set_title('Cisco Accuracy')
    ax4.grid(axis='y', alpha=0.3)
    
    plt.suptitle('Cross-Dataset Performance Comparison', fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    plt.savefig('images/summary_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()

def create_performance_heatmap():
    """Create a performance heatmap across all models and metrics."""
    data = load_and_clean_data()
    
    # Prepare data for heatmap
    metrics = ['F1', 'Accuracy', 'Precision', 'Recall']
    
    # Create separate heatmaps for each dataset
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # NSL-KDD heatmap
    nsl_data = data[data['Dataset'] == 'NSL-KDD'].copy()
    nsl_matrix = nsl_data[metrics].T
    nsl_matrix.columns = nsl_data['Model']
    
    sns.heatmap(nsl_matrix, annot=True, fmt='.3f', cmap='Blues', 
                ax=ax1, cbar_kws={'label': 'Performance Score'})
    ax1.set_title('NSL-KDD Performance Heatmap')
    ax1.set_xlabel('Model')
    ax1.set_ylabel('Metric')
    
    # Cisco heatmap
    cisco_data = data[data['Dataset'] == 'Cisco'].copy()
    cisco_matrix = cisco_data[metrics].T
    cisco_matrix.columns = cisco_data['Model']
    
    sns.heatmap(cisco_matrix, annot=True, fmt='.3f', cmap='Reds', 
                ax=ax2, cbar_kws={'label': 'Performance Score'})
    ax2.set_title('Cisco Performance Heatmap')
    ax2.set_xlabel('Model')
    ax2.set_ylabel('Metric')
    
    plt.tight_layout()
    plt.savefig('images/nsl_kdd_class_f1_heatmap.png', dpi=300, bbox_inches='tight')
    plt.close()

def main():
    """Generate all visualization plots."""
    # Create images directory if it doesn't exist
    Path('images').mkdir(exist_ok=True)
    
    print("Generating NSL-KDD performance ranking...")
    create_nsl_kdd_ranking()
    
    print("Generating Cisco performance ranking...")
    create_cisco_ranking()
    
    print("Generating summary comparison...")
    create_summary_comparison()
    
    print("Generating performance heatmap...")
    create_performance_heatmap()
    
    print("All visualizations generated successfully!")
    print("Files saved in images/ directory:")
    print("- nsl_kdd_performance_ranking.png")
    print("- cisco_performance_ranking.png") 
    print("- summary_comparison.png")
    print("- nsl_kdd_class_f1_heatmap.png")

if __name__ == "__main__":
    main()
