from collections import Counter
from xml.parsers.expat import model
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import os
import utilities.data_generation as dg
import json
from torch.optim.lr_scheduler import LambdaLR

def plot_training_results(train_losses, val_losses, val_accuracies, noise_stabilities, folder_name, val_perplexities=None, task_type='classification'):
    os.makedirs(folder_name, exist_ok=True)

    # Determine number of plots
    has_perplexity = (task_type == 'language_modeling' and val_perplexities is not None and len(val_perplexities) > 0)
    base_plots = 3 if has_perplexity else 2
    num_plots = base_plots if noise_stabilities == [] else base_plots + 1
    
    fig, axes = plt.subplots(1, num_plots, figsize=(5*num_plots, 5))
    if num_plots == 1:
        axes = [axes]
    
    plot_idx = 0
    
    # Plot main losses
    axes[plot_idx].plot(train_losses, label='Training Loss')
    axes[plot_idx].plot(val_losses, label='Validation Loss')
    axes[plot_idx].set_xlabel('Epochs')
    axes[plot_idx].set_ylabel('Loss')
    axes[plot_idx].set_title('Training and Validation Loss')
    axes[plot_idx].legend()
    plot_idx += 1
    
    # Plot accuracy
    axes[plot_idx].plot(val_accuracies)
    axes[plot_idx].set_xlabel('Epochs')
    axes[plot_idx].set_ylabel('Accuracy (%)')
    axes[plot_idx].set_title('Validation Accuracy')
    plot_idx += 1
    
    # Plot perplexity for language modeling
    if has_perplexity:
        axes[plot_idx].plot(val_perplexities)
        axes[plot_idx].set_xlabel('Epochs')
        axes[plot_idx].set_ylabel('Perplexity')
        axes[plot_idx].set_title('Validation Perplexity')
        axes[plot_idx].set_yscale('log')
        plot_idx += 1

    # Plot noise stability
    if noise_stabilities != [] and len(noise_stabilities) > 0:
        for r_val, stability in noise_stabilities.items():
            axes[plot_idx].plot(stability, label=f'Model r={r_val:.2f}', linestyle='-')
        
        axes[plot_idx].set_xlabel('Epochs')
        axes[plot_idx].set_ylabel('Noise Stability')
        axes[plot_idx].set_title('Model Noise Stability')
        axes[plot_idx].legend()
    
    plt.tight_layout()
    plt.savefig(f'{folder_name}/training_results.png')
    plt.show()

def plot_combined_results(results, folder_name, epoch_period=1, learn_function_stabilities=None):
    """
    Plot results across multiple seeds showing variance
    """
    # Determine task type and if we have perplexity
    task_type = results.get('task_type', 'classification')
    has_perplexity = (task_type == 'language_modeling' and 
                      'val_perplexities' in results and 
                      len(results['val_perplexities']) > 0)
    
    # Create figure with subplots
    base_plots = 3 if has_perplexity else 2
    num_plots = base_plots if not results['noise_stabilities'] else base_plots + 1
    fig, axes = plt.subplots(1, num_plots, figsize=(5*num_plots, 5))
    if num_plots == 1:
        axes = [axes]
    
    # Get number of epochs
    n_epochs = len(results['train_losses'][0])
    epochs = np.arange(1, n_epochs + 1, step=epoch_period)
    
    plot_idx = 0
    
    # Plot train and validation losses
    train_losses = np.array(results['train_losses'])
    train_losses = train_losses[:, ::epoch_period]  # Downsample by epoch_period
    val_losses = np.array(results['val_losses'])
    
    train_mean = np.mean(train_losses, axis=0) # shape=(n_epochs // epoch_period,)
    train_std = np.std(train_losses, axis=0) # shape=(n_epochs // epoch_period,)
    val_mean = np.mean(val_losses, axis=0) # shape=(n_epochs // epoch_period,)
    val_std = np.std(val_losses, axis=0) # shape=(n_epochs // epoch_period, )
    
    axes[plot_idx].plot(epochs, train_mean, 'b-', label='Train Loss')
    axes[plot_idx].fill_between(epochs, train_mean - train_std, train_mean + train_std, alpha=0.3, color='b')
    axes[plot_idx].plot(epochs, val_mean, 'r-', label='Val Loss')
    axes[plot_idx].fill_between(epochs, val_mean - val_std, val_mean + val_std, alpha=0.3, color='r')
    axes[plot_idx].set_xlabel('Epochs')
    axes[plot_idx].set_ylabel('Loss')
    axes[plot_idx].set_title('Training and Validation Loss')
    axes[plot_idx].legend()
    plot_idx += 1
    
    # Plot validation accuracy
    val_accuracies = np.array(results['val_accuracies'])
    acc_mean = np.mean(val_accuracies, axis=0)
    acc_std = np.std(val_accuracies, axis=0)
    
    axes[plot_idx].plot(epochs, acc_mean, 'g-')
    axes[plot_idx].fill_between(epochs, acc_mean - acc_std, acc_mean + acc_std, alpha=0.3, color='g')
    axes[plot_idx].set_xlabel('Epochs')
    axes[plot_idx].set_ylabel('Accuracy (%)')
    axes[plot_idx].set_title('Validation Accuracy')
    plot_idx += 1
    
    # Plot perplexity for language modeling
    if has_perplexity:
        val_perplexities = np.array(results['val_perplexities'])
        perp_mean = np.mean(val_perplexities, axis=0)
        perp_std = np.std(val_perplexities, axis=0)
        
        axes[plot_idx].plot(epochs, perp_mean, 'm-')
        axes[plot_idx].fill_between(epochs, perp_mean - perp_std, perp_mean + perp_std, alpha=0.3, color='m')
        axes[plot_idx].set_xlabel('Epochs')
        axes[plot_idx].set_ylabel('Perplexity')
        axes[plot_idx].set_title('Validation Perplexity')
        axes[plot_idx].set_yscale('log')
        plot_idx += 1
    
    # Plot noise stability if available
    if plot_idx < num_plots and results['noise_stabilities']:
        for r_val, stability_lists in results['noise_stabilities'].items():
            stability_array = np.array(stability_lists)

            # An array of size num_seeds.
            # It contains lists of noise stability values for each epoch.
            stability_epochs = np.arange(1, stability_array.shape[1] + 1)

            # Find mean and std across seeds for each epoch.
            stability_mean = np.mean(stability_array, axis = 0)
            stability_std = np.std(stability_array, axis = 0)
            
            # Plot model noise stability
            axes[plot_idx].plot(stability_epochs, stability_mean, 
                        label=f'Model r={r_val:.2f}', linestyle='-')
            axes[plot_idx].fill_between(stability_epochs, 
                                stability_mean - stability_std, 
                                stability_mean + stability_std, 
                                alpha=0.3)
            
            # Add the learned function stability reference line if available
            if learn_function_stabilities is not None and r_val in learn_function_stabilities:
                learn_function_stability_list = [learn_function_stabilities[r_val]] * len(stability_epochs)
                axes[plot_idx].plot(stability_epochs, 
                             learn_function_stability_list,
                             label=f'Learned r={r_val:.2f}', 
                             linestyle='--', alpha=0.7)

        axes[plot_idx].set_xlabel('Epochs')
        axes[plot_idx].set_ylabel('Noise Stability')
        axes[plot_idx].set_title('Model Noise Stability')
        axes[plot_idx].legend()

    plt.tight_layout()
    plt.savefig(f'{folder_name}/combined_training_results.png')
    
    # Add a final plot showing test accuracy distribution
    test_accuracies = np.array(results['test_accuracies'])
    plt.figure(figsize=(6, 4))
    plt.boxplot(test_accuracies)
    plt.scatter(np.ones_like(test_accuracies), test_accuracies, alpha=0.6, color='red')
    plt.title(f'Test Accuracy Distribution\nMean: {np.mean(test_accuracies):.2f}%, Std: {np.std(test_accuracies):.2f}%')
    plt.ylabel('Test Accuracy (%)')
    plt.grid(axis='y', alpha=0.3)
    plt.savefig(f'{folder_name}/test_accuracy_distribution.png')