"""
Plot Eigenworm Classification Results

This script plots the classification accuracy for KSigPDE, KSig RFSF-TRP, and PowerSigJax
across different window sizes (up to 1024) with window size on log scale x-axis and accuracy on y-axis.
"""

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

def plot_eigenworm_results(csv_file='eigenworms_results_old.csv', max_window_size=1024):
    """
    Plot classification results for different kernel methods.
    
    Args:
        csv_file: Path to the CSV file containing results
        max_window_size: Maximum window size to include in the plot (default: 1024)
    """
    # Read the CSV file
    df = pd.read_csv(csv_file)
    
    # Filter for the three kernel methods and window sizes up to max_window_size
    kernels_of_interest = ['KSigPDE', 'KSig RFSF-TRP', 'PowerSigJax']
    df_filtered = df[
        (df['kernel_name'].isin(kernels_of_interest)) & 
        (df['window_size'] <= max_window_size)
    ]
    
    # Remove rows with zero accuracy (OOM errors)
    df_filtered = df_filtered[df_filtered['accuracy'] > 0]
    
    # Create the plot
    plt.figure(figsize=(10, 6))
    
    # Define colors and markers for each kernel
    colors = {
        'KSigPDE': '#1f77b4',          # Blue
        'KSig RFSF-TRP': '#ff7f0e',    # Orange
        'PowerSigJax': '#2ca02c'       # Green
    }
    
    markers = {
        'KSigPDE': 'o',
        'KSig RFSF-TRP': 's',
        'PowerSigJax': '^'
    }
    
    # Plot each kernel
    for kernel in kernels_of_interest:
        kernel_data = df_filtered[df_filtered['kernel_name'] == kernel]
        if len(kernel_data) > 0:
            plt.plot(
                kernel_data['window_size'], 
                kernel_data['accuracy'],
                marker=markers[kernel],
                markersize=8,
                linewidth=2,
                label=kernel if kernel != 'PowerSigJax' else 'PowerSig',
                color=colors[kernel],
                linestyle='-',
                alpha=0.8
            )
    
    # Set x-axis to log scale
    plt.xscale('log')
    
    # Set x-axis ticks to powers of 2
    window_sizes = sorted(df_filtered['window_size'].unique())
    plt.xticks(window_sizes, [str(int(ws)) for ws in window_sizes])
    
    # Add grid
    plt.grid(True, which='both', linestyle='--', alpha=0.3)
    
    # Labels and title
    plt.xlabel('Window Size (log scale)', fontsize=12, fontweight='bold')
    plt.ylabel('Accuracy', fontsize=12, fontweight='bold')
    plt.title('Eigenworm Classification: Accuracy vs Window Size', fontsize=14, fontweight='bold')
    
    # Add legend
    plt.legend(loc='best', fontsize=10, framealpha=0.9)
    
    # Set y-axis limits to better show the range
    plt.ylim(0, 1)
    
    # Tight layout for better spacing
    plt.tight_layout()
    
    # Save the plot in both PNG and SVG formats
    output_file_png = 'eigenworm_classification_results.png'
    output_file_svg = 'eigenworm_classification_results.svg'
    plt.savefig(output_file_png, dpi=300, bbox_inches='tight')
    plt.savefig(output_file_svg, format='svg', bbox_inches='tight')
    print(f"Plot saved to: {output_file_png}")
    print(f"Plot saved to: {output_file_svg}")
    
    # Show the plot
    plt.show()
    
    # Print summary statistics
    print("\n=== Summary Statistics ===")
    for kernel in kernels_of_interest:
        kernel_data = df_filtered[df_filtered['kernel_name'] == kernel]
        if len(kernel_data) > 0:
            print(f"\n{kernel}:")
            print(f"  Number of data points: {len(kernel_data)}")
            print(f"  Window sizes: {sorted(kernel_data['window_size'].unique())}")
            print(f"  Accuracy range: {kernel_data['accuracy'].min():.4f} - {kernel_data['accuracy'].max():.4f}")
            print(f"  Best accuracy: {kernel_data['accuracy'].max():.4f} at window size {kernel_data.loc[kernel_data['accuracy'].idxmax(), 'window_size']:.0f}")

if __name__ == '__main__':
    # Plot the results
    plot_eigenworm_results('eigenworms_results_old.csv', max_window_size=1024)

