import json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

def plot_training_curves(file_path):
    """
    Plot training curves with mean lines and standard deviation shadows
    Simplified version with axis limits based on actual data ranges
    """
    # Load data
    with open(file_path, 'r') as f:
        data = json.load(f)
    
    # Set up the plot
    fig, axes = plt.subplots(1, 3, figsize=(20, 5.5))
    # fig.suptitle('Training Curves: Mean ± Standard Deviation', 
                #  fontsize=16, fontweight='bold')
    
    # Colors and line styles for each optimizer
    colors = {
        'SGD': '#1f77b4',
        'SGD_Nesterov': '#ff7f0e', 
        'Adam': '#2ca02c',
        'Lion': '#d62728',
        'HomM': '#9467bd'
    }
    
    line_styles = {
        'SGD': '-',
        'SGD_Nesterov': '-', 
        'Adam': '-.',
        'Lion': ':',
        'HomM': '-'
    }
    
    # Plot each metric
    metrics = [
        ('train_losses', 'Training Loss', axes[0]),
        ('train_accuracies', 'Training Accuracy (%)', axes[1]),
        ('test_accuracies', 'Test Accuracy (%)', axes[2])
    ]
    
    # Track min/max values for axis limits
    all_min_vals = []
    all_max_vals = []
    
    for metric_name, title, ax in metrics:
        ax.set_xlabel('Epoch', fontsize=20, fontweight='bold')
        # ax.set_ylabel(title, fontsize=20, fontweight='bold')
        ax.set_title(title, fontsize=20, fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        metric_min_vals = []
        metric_max_vals = []
        
        for optimizer_name, runs in data.items():
            if not runs:  # Skip empty optimizers
                continue
                
            color = colors.get(optimizer_name, '#000000')
            line_style = line_styles.get(optimizer_name, '-')
            
            # Extract metric data from all runs
            all_curves = []
            for run in runs:
                curve = run[metric_name]
                all_curves.append(curve)
            
            # Convert to numpy array for easier manipulation
            min_length = min(len(curve) for curve in all_curves)
            truncated_curves = [curve[:min_length] for curve in all_curves]
            curves_array = np.array(truncated_curves)
            
            # Calculate mean and std
            mean_curve = np.mean(curves_array, axis=0)
            std_curve = np.std(curves_array, axis=0)
            epochs = np.arange(1, len(mean_curve) + 1)
            
            # Track min/max values for this metric (including std bounds)
            lower_bound = mean_curve - std_curve
            upper_bound = mean_curve + std_curve
            metric_min_vals.append(np.min(lower_bound))
            metric_max_vals.append(np.max(upper_bound))
            
            # Plot mean line
            ax.plot(epochs, mean_curve, color=color, linewidth=2.5, 
                   linestyle=line_style, label=f'{optimizer_name}', alpha=0.9)
            
            # Plot standard deviation shadow
            ax.fill_between(epochs, 
                           mean_curve - std_curve,
                           mean_curve + std_curve,
                           color=color, alpha=0.15)
        
        # Set axis limits based on actual data range with small margin
        if metric_min_vals and metric_max_vals:
            data_min = min(metric_min_vals)
            data_max = max(metric_max_vals)
            margin = (data_max - data_min) * 0.05  # 5% margin
            ax.set_ylim(data_min - margin, data_max + margin)
        
        ax.legend(loc='best', fontsize=20)
        ax.tick_params(axis='both', labelsize=20)
    
    plt.tight_layout()
    plt.savefig('traing_process.pdf')
    plt.show()

# ------------------------------------------------------------
# ----  Plot peroformance difference withrespect to HomM ----
# ------------------------------------------------------------

    
def plot_difference_analysis(file_path, reference_optimizer='HomM'):
    """
    Plot difference from a reference optimizer to highlight relative performance
    Enhanced with standard deviation shadows
    """
    with open(file_path, 'r') as f:
        data = json.load(f)
    
    if reference_optimizer not in data or not data[reference_optimizer]:
        print(f"Reference optimizer {reference_optimizer} not found")
        return
    
    fig, axes = plt.subplots(1, 3, figsize=(20, 6))
    
    colors = {'SGD': '#1f77b4', 'SGD_Nesterov': '#ff7f0e', 'Adam': '#2ca02c', 'Lion': '#d62728', 'HomM': '#9467bd'}
    
    # Calculate reference curves with std
    ref_runs = data[reference_optimizer]
    min_length = min(len(run['test_accuracies']) for run in ref_runs)
    
    # Get all reference curves
    ref_train_losses_all = np.array([run['train_losses'][:min_length] for run in ref_runs])
    ref_train_accs_all = np.array([run['train_accuracies'][:min_length] for run in ref_runs])
    ref_test_accs_all = np.array([run['test_accuracies'][:min_length] for run in ref_runs])
    
    # Calculate means
    ref_train_losses = np.mean(ref_train_losses_all, axis=0)
    ref_train_accs = np.mean(ref_train_accs_all, axis=0)
    ref_test_accs = np.mean(ref_test_accs_all, axis=0)
    
    epochs = np.arange(1, len(ref_test_accs) + 1)
    
    metrics = [
        ('train_losses', ref_train_losses, 'Training Loss Difference', axes[0], True),  # True = invert for loss
        ('train_accuracies', ref_train_accs, 'Training Accuracy Difference (%)', axes[1], False),
        ('test_accuracies', ref_test_accs, 'Test Accuracy Difference (%)', axes[2], False)
    ]
    
    for metric_name, ref_curve, title, ax, invert in metrics:
        ax.set_title(title, fontsize=20, fontweight='bold')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Difference')
        ax.grid(True, alpha=0.3)
        ax.axhline(y=0, color='black', linestyle='-', alpha=0.5, linewidth=1.5)
        
        for opt_name, runs in data.items():
            if not runs or opt_name == reference_optimizer:
                continue
                
            # Calculate all curves for this optimizer (for std calculation)
            opt_curves_all = np.array([run[metric_name][:min_length] for run in runs])
            opt_mean = np.mean(opt_curves_all, axis=0)
            
            # Calculate differences for all runs
            if invert:
                diff_curves_all = ref_train_losses_all - opt_curves_all if metric_name == 'train_losses' else ref_curve - opt_curves_all
            else:
                diff_curves_all = opt_curves_all - (ref_train_accs_all if metric_name == 'train_accuracies' else 
                                                   ref_test_accs_all if metric_name == 'test_accuracies' else ref_curve)
            
            # Calculate mean and std of differences
            diff_mean = np.mean(diff_curves_all, axis=0)
            diff_std = np.std(diff_curves_all, axis=0)
            
            # Plot mean line
            ax.plot(epochs, diff_mean, color=colors.get(opt_name, '#000000'), 
                   linewidth=2.5, label=opt_name, alpha=0.9)
            
            # Plot standard deviation shadow
            ax.fill_between(epochs, 
                           diff_mean - diff_std,
                           diff_mean + diff_std,
                           color=colors.get(opt_name, '#000000'), alpha=0.15)
        
        ax.legend(loc='best')
        
        ax.set_xlabel('Epoch', fontsize=20, fontweight='bold')
        ax.set_ylabel('Difference', fontsize=20, fontweight='bold')
        # Tick labels
        ax.tick_params(axis='both', labelsize=20)

        # Legend
        ax.legend(loc='best', fontsize=20, frameon=True)
    
    plt.tight_layout()
    plt.savefig('performance_difference.pdf')
    plt.show()





def print_summary_stats(data):
    """Print summary statistics for final performance"""
    print("\n" + "="*60)
    print("FINAL PERFORMANCE SUMMARY")
    print("="*60)
    
    for optimizer_name, runs in data.items():
        if not runs:
            continue
            
        # Collect final values
        final_train_losses = [run['train_losses'][-1] for run in runs]
        final_train_accs = [run['train_accuracies'][-1] for run in runs]
        final_test_accs = [run['test_accuracies'][-1] for run in runs]
        best_test_accs = [max(run['test_accuracies']) for run in runs]
        
        print(f"\n{optimizer_name}:")
        print(f"  Final Train Loss: {np.mean(final_train_losses):.4f} ± {np.std(final_train_losses):.4f}")
        print(f"  Final Train Acc:  {np.mean(final_train_accs):.2f}% ± {np.std(final_train_accs):.2f}%")
        print(f"  Final Test Acc:   {np.mean(final_test_accs):.2f}% ± {np.std(final_test_accs):.2f}%")
        print(f"  Best Test Acc:    {np.mean(best_test_accs):.2f}% ± {np.std(best_test_accs):.2f}%")

# Enhanced usage with multiple visualization options
if __name__ == "__main__":
    file_path = "results/training_histories.json"

    
    print("Generating enhanced visualizations...")
    print("This will create multiple plots to highlight differences between optimizers.")
    
    # 1. Main plot with full and zoomed views
    plot_training_curves(file_path)
    
    # 2. Difference analysis relative to a reference optimizer
    plot_difference_analysis(file_path, reference_optimizer='HomM')
    

    
    #  Print detailed summary
    with open(file_path, 'r') as f:
        data = json.load(f)
    print_summary_stats(data)
    
    print("\nVisualization complete! Check the plots to see:")
    print("1. Full range vs zoomed comparisons")
    print("2. Final performance with error bars")
    print("3. Relative differences between optimizers")
    print("4. Learning rate schedules")
    print("5. Statistical summaries in console")