#!/usr/bin/env python3
"""
Plotting utilities for regression accuracy analysis.
"""

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def plot_csv_results(csv_filename="recurrence_reg.csv"):
    """
    Plot the results from the CSV file with separate figures for each seed.
    
    Args:
        csv_filename: Path to the CSV file containing results
    """
    # Read the CSV file
    try:
        # Try to read with headers first
        df = pd.read_csv(csv_filename)
        
        # Check if the expected columns exist, if not, assume no headers
        expected_columns = ['kernel_name', 'num_timesteps', 'seed', 'mse', 'mape', 'condition_number']
        if not all(col in df.columns for col in expected_columns):
            logger.info("No headers detected, using default column names")
            df = pd.read_csv(csv_filename, header=None, names=expected_columns)
            
    except FileNotFoundError:
        logger.error(f"CSV file {csv_filename} not found!")
        return
    
    # df = df[np.isfinite(df['mse']) & np.isfinite(df['mape']) & np.isfinite(df['condition_number'])]
    # Also filter out cases where values are inf (OOM cases)
    df = df[(df['mse'] != np.inf) & (df['mape'] != np.inf) & (df['condition_number'] != np.inf)]
    
    # Filter out timestep lengths over 3000
    df = df[df['num_timesteps'] <= 3000]
    
    # Get unique kernels and seeds
    kernels = df['kernel_name'].unique()
    seeds = sorted(df['seed'].unique())
    
    # Set up colors for different kernels
    colors = plt.cm.Set1(np.linspace(0, 1, len(kernels)))
    color_map = dict(zip(kernels, colors))
    
    # Create separate figures for each seed
    for seed in seeds:
        # Filter data for this seed
        seed_data = df[df['seed'] == seed]
        
        if len(seed_data) == 0:
            continue  # Skip if no data for this seed
        
        # Create figure with subplots for this seed
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # Plot MSE vs timesteps for this seed
        for kernel in kernels:
            kernel_data = seed_data[seed_data['kernel_name'] == kernel].sort_values('num_timesteps')
            if len(kernel_data) > 0:
                axes[0].plot(kernel_data['num_timesteps'], kernel_data['mse'], 
                            marker='o', linewidth=2, markersize=6, label=kernel if kernel != 'PowerSigJax' else 'PowerSig', 
                            color=color_map[kernel], alpha=0.8)
        
        axes[0].set_xlabel('Number of Timesteps', fontsize=12)
        axes[0].set_ylabel('Mean Squared Error (MSE)', fontsize=12)
        axes[0].set_title('MSE vs Number of Timesteps', fontsize=14, fontweight='bold')
        axes[0].set_yscale('log')
        axes[0].legend(fontsize=10)
        axes[0].grid(True, alpha=0.3)
        
        # Plot MAPE vs timesteps for this seed
        for kernel in kernels:
            kernel_data = seed_data[seed_data['kernel_name'] == kernel].sort_values('num_timesteps')
            if len(kernel_data) > 0:
                axes[1].plot(kernel_data['num_timesteps'], kernel_data['mape'], 
                            marker='s', linewidth=2, markersize=6, label=kernel if kernel != 'PowerSigJax' else 'PowerSig', 
                            color=color_map[kernel], alpha=0.8)
        
        axes[1].set_xlabel('Number of Timesteps', fontsize=12)
        axes[1].set_ylabel('Mean Absolute Percentage Error (MAPE) %', fontsize=12)
        axes[1].set_title('MAPE vs Number of Timesteps', fontsize=14, fontweight='bold')
        axes[1].set_yscale('log')
        axes[1].legend(fontsize=10)
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'recurrence_accuracy_seed_{seed}.png', dpi=300, bbox_inches='tight')
        plt.savefig(f'recurrence_accuracy_seed_{seed}.svg', bbox_inches='tight')
        plt.close()  # Close the figure to free memory
        logger.info(f"Saved accuracy plot for seed {seed}: recurrence_accuracy_seed_{seed}.png and .svg")
    
    # Also create a combined figure showing all seeds
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot MSE vs timesteps (all seeds)
    for kernel in kernels:
        kernel_data = df[df['kernel_name'] == kernel].sort_values('num_timesteps')
        if len(kernel_data) > 0:
            # Group by timesteps and take mean across seeds
            grouped = kernel_data.groupby('num_timesteps')['mse'].mean()
            axes[0].plot(grouped.index, grouped.values, marker='o', linewidth=2, 
                        markersize=8, label=kernel if kernel != 'PowerSigJax' else 'PowerSig', color=color_map[kernel], alpha=0.8)
    
    axes[0].set_xlabel('Number of Timesteps', fontsize=12)
    axes[0].set_ylabel('Mean Squared Error (MSE)', fontsize=12)
    axes[0].set_title('MSE vs Number of Timesteps (All Seeds)', fontsize=14, fontweight='bold')
    axes[0].set_yscale('log')
    axes[0].legend(fontsize=11)
    axes[0].grid(True, alpha=0.3)
    
    # Plot MAPE vs timesteps (all seeds)
    for kernel in kernels:
        kernel_data = df[df['kernel_name'] == kernel].sort_values('num_timesteps')
        if len(kernel_data) > 0:
            # Group by timesteps and take mean across seeds
            grouped = kernel_data.groupby('num_timesteps')['mape'].mean()
            axes[1].plot(grouped.index, grouped.values, marker='s', linewidth=2, 
                        markersize=8, label=kernel if kernel != 'PowerSigJax' else 'PowerSig', color=color_map[kernel], alpha=0.8)
    
    axes[1].set_xlabel('Number of Timesteps', fontsize=12)
    axes[1].set_ylabel('Mean Absolute Percentage Error (MAPE) %', fontsize=12)
    axes[1].set_title('MAPE vs Number of Timesteps (All Seeds)', fontsize=14, fontweight='bold')
    axes[1].set_yscale('log')
    axes[1].legend(fontsize=11)
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('recurrence_accuracy.png', dpi=300, bbox_inches='tight')
    plt.savefig('recurrence_accuracy.svg', bbox_inches='tight')
    plt.close()
    logger.info("Saved combined accuracy plot: recurrence_accuracy.png and .svg")
    
    # Create a 2x2 plot showing mean and variance across seeds
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # Top left: MSE mean across seeds
    for kernel in kernels:
        kernel_data = df[df['kernel_name'] == kernel].sort_values('num_timesteps')
        if len(kernel_data) > 0:
            # Group by timesteps and take mean across seeds
            grouped_mean = kernel_data.groupby('num_timesteps')['mse'].mean()
            axes[0, 0].plot(grouped_mean.index, grouped_mean.values, marker='o', linewidth=2, 
                           markersize=8, label=kernel if kernel != 'PowerSigJax' else 'PowerSig', color=color_map[kernel], alpha=0.8)
    
    axes[0, 0].set_xlabel('Number of Timesteps', fontsize=12)
    axes[0, 0].set_ylabel('Mean Squared Error (MSE)', fontsize=12)
    axes[0, 0].set_title('MSE Mean Across Seeds', fontsize=14, fontweight='bold')
    axes[0, 0].set_yscale('log')
    axes[0, 0].legend(fontsize=10)
    axes[0, 0].grid(True, alpha=0.3)
    
    # Top right: MSE variance across seeds
    for kernel in kernels:
        kernel_data = df[df['kernel_name'] == kernel].sort_values('num_timesteps')
        if len(kernel_data) > 0:
            # Group by timesteps and take variance across seeds
            grouped_var = kernel_data.groupby('num_timesteps')['mse'].var()
            axes[0, 1].plot(grouped_var.index, grouped_var.values, marker='s', linewidth=2, 
                           markersize=8, label=kernel if kernel != 'PowerSigJax' else 'PowerSig', color=color_map[kernel], alpha=0.8)
    
    axes[0, 1].set_xlabel('Number of Timesteps', fontsize=12)
    axes[0, 1].set_ylabel('MSE Variance Across Seeds', fontsize=12)
    axes[0, 1].set_title('MSE Variance Across Seeds', fontsize=14, fontweight='bold')
    axes[0, 1].set_yscale('log')
    axes[0, 1].legend(fontsize=10)
    axes[0, 1].grid(True, alpha=0.3)
    
    # Bottom left: MAPE mean across seeds
    for kernel in kernels:
        kernel_data = df[df['kernel_name'] == kernel].sort_values('num_timesteps')
        if len(kernel_data) > 0:
            # Group by timesteps and take mean across seeds
            grouped_mean = kernel_data.groupby('num_timesteps')['mape'].mean()
            axes[1, 0].plot(grouped_mean.index, grouped_mean.values, marker='o', linewidth=2, 
                           markersize=8, label=kernel if kernel != 'PowerSigJax' else 'PowerSig', color=color_map[kernel], alpha=0.8)
    
    axes[1, 0].set_xlabel('Number of Timesteps', fontsize=12)
    axes[1, 0].set_ylabel('Mean Absolute Percentage Error (MAPE) %', fontsize=12)
    axes[1, 0].set_title('MAPE Mean Across Seeds', fontsize=14, fontweight='bold')
    axes[1, 0].set_yscale('log')
    axes[1, 0].legend(fontsize=10)
    axes[1, 0].grid(True, alpha=0.3)
    
    # Bottom right: MAPE variance across seeds
    for kernel in kernels:
        kernel_data = df[df['kernel_name'] == kernel].sort_values('num_timesteps')
        if len(kernel_data) > 0:
            # Group by timesteps and take variance across seeds
            grouped_var = kernel_data.groupby('num_timesteps')['mape'].var()
            axes[1, 1].plot(grouped_var.index, grouped_var.values, marker='s', linewidth=2, 
                           markersize=8, label=kernel if kernel != 'PowerSigJax' else 'PowerSig', color=color_map[kernel], alpha=0.8)
    
    axes[1, 1].set_xlabel('Number of Timesteps', fontsize=12)
    axes[1, 1].set_ylabel('MAPE Variance Across Seeds', fontsize=12)
    axes[1, 1].set_title('MAPE Variance Across Seeds', fontsize=14, fontweight='bold')
    axes[1, 1].set_yscale('log')
    axes[1, 1].legend(fontsize=10)
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('recurrence_accuracy_mean_var.png', dpi=300, bbox_inches='tight')
    plt.savefig('recurrence_accuracy_mean_var.svg', bbox_inches='tight')
    plt.close()
    logger.info("Saved mean/variance analysis plot: recurrence_accuracy_mean_var.png and .svg")
    
    # Print summary statistics
    logger.info(f"\nSummary Statistics:")
    logger.info(f"Total data points: {len(df)}")
    logger.info(f"Kernels analyzed: {list(kernels)}")
    logger.info(f"Seeds analyzed: {seeds}")
    logger.info(f"Timestep range: {df['num_timesteps'].min()} - {df['num_timesteps'].max()}")
    
    for kernel in kernels:
        kernel_data = df[df['kernel_name'] == kernel]
        best_mse_idx = kernel_data['mse'].idxmin()
        best_mape_idx = kernel_data['mape'].idxmin()
        
        logger.info(f"\n{kernel}:")
        logger.info(f"  Best MSE: {kernel_data.loc[best_mse_idx, 'mse']:.6f} at {kernel_data.loc[best_mse_idx, 'num_timesteps']} timesteps")
        logger.info(f"  Best MAPE: {kernel_data.loc[best_mape_idx, 'mape']:.6f}% at {kernel_data.loc[best_mape_idx, 'num_timesteps']} timesteps")
        logger.info(f"  Data points: {len(kernel_data)}")


if __name__ == "__main__":
    # Allow command line usage
    import sys
    
    if len(sys.argv) > 1:
        csv_file = sys.argv[1]
    else:
        csv_file = "recurrence_reg_backup2.csv"
    
    plot_csv_results(csv_file) 