import os
from config import *
import sys
import pickle
import argparse
import matplotlib.pyplot as plt
import numpy as np

from config.config import get_experiment_config_mnist, get_experiment_config_fmnist, get_experiment_config_organamnist, get_experiment_config_kmnist
from data.data_loader import load_data
from algorithms.vr_sort import main as vr_sort_main

from algorithms.conftr import main as conftr_main
from evaluation import main as eval_main


os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

############# HELPER FUNCTIONS ###############
def plot_combined_losses(*loss_data, title, save_path, loss_type='Train'):
    """
    Plots multiple sets of loss data with their corresponding standard deviations.
    
    Parameters:
        *loss_data (tuple): Each tuple should contain (losses, stds, label, color).
        title (str): Title of the plot.
        save_path (str): Path to save the plot.
        loss_type (str): Type of loss, e.g., 'Train' or 'Test'.
    """
    plt.figure()
    
    # Iterate over all loss data
    for losses, std, label, color in loss_data:
        plt.plot(losses, label = label, color = color, linewidth = 2)
        plt.fill_between(range(len(losses)),
                            np.array(losses) - np.array(std),
                            np.array(losses) + np.array(std),
                            color = color, alpha = 0.3)

    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(title)
    plt.legend()
    plt.savefig(save_path)
    plt.close()

def plot_set_sizes(*size_data, title, save_path):
    """
    Plots set sizes over epochs for multiple models.
    
    Parameters:
        *size_data (tuple): Each tuple should contain (set_sizes, label, color).
        title (str): Title of the plot.
        save_path (str): Path to save the plot.
    """
    plt.figure()

    # Iterate over all size data
    for set_sizes, std, label, color in size_data:
        plt.plot(set_sizes, label=label, color=color, linewidth=2)
        plt.fill_between(range(len(set_sizes)),
                            np.array(set_sizes) - np.array(std),
                            np.array(set_sizes) + np.array(std),
                            color = color, alpha = 0.3)

    plt.xlabel('Epoch')
    plt.ylabel('Set Size')
    plt.title(title)
    plt.legend()
    plt.savefig(save_path)
    plt.close()
    print(f'Plot saved to {save_path}')

def plot_accuracies(*accuracy_data, title, save_path):
    """
    Plots accuracies over epochs for multiple models.
    
    Parameters:
        *accuracy_data (tuple): Each tuple should contain (accuracies, label, color).
        title (str): Title of the plot.
        save_path (str): Path to save the plot.
    """
    plt.figure()

    # Iterate over all accuracy data
    for accuracies, std,  label, color in accuracy_data:
        plt.plot(accuracies, label=label, color=color, linewidth=2)
        plt.fill_between(range(len(accuracies)),
                            np.array(accuracies) - np.array(std),
                            np.array(accuracies) + np.array(std),
                            color = color, alpha = 0.3)

    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(title)
    plt.legend()
    plt.savefig(save_path)
    plt.close()
    print(f'Plot saved to {save_path}')

def plot_avg_sizes(*model_data, title, save_path):
    """
    Plots average sizes against model names.
    
    Parameters:
        *model_data (tuple): Each tuple should contain (model_name, avg_size).
        title (str): Title of the plot.
        save_path (str): Path to save the plot.
    """
    model_names = [data[0] for data in model_data]
    avg_sizes = [data[1] for data in model_data]
    
    plt.figure()
    
    plt.bar(model_names, avg_sizes, color='skyblue')
    plt.xlabel('Model Name')
    plt.ylabel('Average Size')
    plt.title(title)
    plt.xticks(rotation=45, ha='right')
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    print(f'Plot saved to {save_path}')

def average_results(results_list):
    avg_results = {}
    for key in results_list[0]:
        avg_results[key] = np.mean([result[key] for result in results_list])
    return avg_results

def save_model(params, model_name, results_dir):
    model_save_path = os.path.join(results_dir, model_name)
    os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
    with open(model_save_path, 'wb') as f:
        pickle.dump(params, f)

################ TRAINING ################
def plot_tuning(results_dir, old_results_dir):

    experiment_results_path = os.path.join(old_results_dir, 'experiment_results_sort.pkl')
    with open(experiment_results_path, 'rb') as f:
        sort_results = pickle.load(f)
    results_dir = results_dir
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    num_sort_values = [4, 6, 8, 10, 16, 20]

    color_mapping = {
    2: 'black',
    4: 'red',
    6: 'blue',
    8: 'green',
    10: 'orange',
    16: 'purple',
    20: 'yellow',
    }

    # Plot Training and Test Losses for different `num_splits`
    for loss_type in ['Train', 'Test']:
        plot_combined_losses(
            *[
                (sort_results[num_sort][f'{loss_type.lower()}_losses'], 
                 sort_results[num_sort][f'{loss_type.lower()}_std'], 
                 f'VR-ConfTr {loss_type} loss (m={num_sort})', 
                 color_mapping[num_sort])  # Generate random colors for each num_sort
                for num_sort in num_sort_values
            ],
            title=f'{loss_type} Loss per Epoch for Different m',
            save_path=os.path.join(results_dir, f'combined_{loss_type.lower()}_losses_sort.png'),
            loss_type=loss_type
        )

    # Plot Set Sizes for different `num_splits`
    plot_set_sizes(
        *[
            (sort_results[num_sort]['set_sizes'], 
             sort_results[num_sort]['set_std'], 
             f'VR-ConfTr Model Set Sizes (m={num_sort})', 
             color_mapping[num_sort])
            for num_sort in num_sort_values
        ],
        title='Set Sizes per Epoch for Different m',
        save_path=os.path.join(results_dir, 'set_sizes_per_epoch_sort.png')
    )

    # Plot Test Accuracy per Epoch during training for different `num_splits`
    plot_accuracies(
        *[
            (sort_results[num_sort]['test_accuracies'], 
             sort_results[num_sort]['test_accuracies_std'], 
             f'VR-ConfTr Test Accuracy (m={num_sort})', 
             color_mapping[num_sort])
            for num_sort in num_sort_values
        ],
        title='Test Accuracy per Epoch for Different m',
        save_path=os.path.join(results_dir, 'combined_test_accuracies_sort.png')
    )
######## 


def run_tuning(config_func, results_dir):
    config = get_experiment_config_kmnist() #TODO: figure out to pass the right config format TypeError: 'ConfigDict' object is not callable
    num_trials = 2
    results_dir = results_dir
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    # Storing Results 
    experiment_results = {}

    #List of the grid search for the split tuning
    num_sort_values = [2, 4, 6, 8, 10, 16, 20]

    color_mapping = {
    2: 'black',
    4: 'red',
    6: 'blue',
    8: 'green',
    10: 'orange',
    16: 'purple',
    20: 'yellow',
    }

    #Place Holder for results
    sort_results = {}



    for num_sort in num_sort_values:
        print(f"Running experiment with num_splits={num_sort}")

        sort_trial_params_and_seeds,(sort_train_losses,  sort_std_train_losses),(sort_test_losses, sort_std_test_losses),(sort_test_accuracies, sort_std_test_accuracies), sort_loss_variances, (sort_set_sizes, sort_std_set_sizes) = vr_sort_main(config.vr, num_trials, num_sort = num_sort)
        save_model(sort_trial_params_and_seeds, config.sort_model_path, results_dir)

        sort_results[num_sort] = {
                'train_losses': sort_train_losses,
                'train_std': sort_std_train_losses,
                'test_losses': sort_test_losses,
                'test_std': sort_std_test_losses,
                'test_accuracies': sort_test_accuracies,
                'test_accuracies_std': sort_std_test_accuracies,
                'loss_variances': sort_loss_variances,
                'set_sizes': sort_set_sizes,
                'set_std': sort_std_set_sizes
            }

    # Plot Training and Test Losses for different `num_splits`
    for loss_type in ['Train', 'Test']:
        plot_combined_losses(
            *[
                (sort_results[num_sort][f'{loss_type.lower()}_losses'], 
                 sort_results[num_sort][f'{loss_type.lower()}_std'], 
                 f'VR-CT {loss_type} loss (m={num_sort})', 
                 color_mapping[num_sort]
                )  # Generate random colors for each num_sort
                for num_sort in num_sort_values
            ],
            title=f'{loss_type} Loss per Epoch for Different num_sort',
            save_path=os.path.join(results_dir, f'combined_{loss_type.lower()}_losses_sort.png'),
            loss_type=loss_type
        )

    # Plot Set Sizes for different `num_splits`
    plot_set_sizes(
        *[
            (sort_results[num_sort]['set_sizes'], 
             sort_results[num_sort]['set_std'], 
             f'VR-CT Model Set Sizes (m={num_sort})', 
             color_mapping[num_sort])
            for num_sort in num_sort_values
        ],
        title='Set Sizes per Epoch for Different num_sort',
        save_path=os.path.join(results_dir, 'set_sizes_per_epoch_sort.png')
    )

    # Plot Test Accuracy per Epoch during training for different `num_splits`
    plot_accuracies(
        *[
            (sort_results[num_sort]['test_accuracies'], 
             sort_results[num_sort]['test_accuracies_std'], 
             f'VR-CT Test Accuracy (m={num_sort})', 
             color_mapping[num_sort])
            for num_sort in num_sort_values
        ],
        title='Test Accuracy per Epoch for Different num_sort',
        save_path=os.path.join(results_dir, 'combined_test_accuracies_sort.png')
    )

    # Save all experiment results in the results directory 
    experiment_results_path = os.path.join(results_dir, 'experiment_results_sort.pkl')
    with open(experiment_results_path, 'wb') as f:
        pickle.dump(sort_results, f)

    print(f"Experiment results saved at {experiment_results_path}")

def run_experiment_mnist():
    config = get_experiment_config_mnist()
    num_trials = config.num_trials
    results_dir = config.results_dir
    print("results", results_dir)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    experiment_results = {}


    
    print("Sorting Version")
    sort_trial_params_and_seeds,(sort_train_losses,  sort_std_train_losses),(sort_test_losses, sort_std_test_losses),(sort_test_accuracies, sort_std_test_accuracies), sort_loss_variances, (sort_set_sizes, sort_std_set_sizes) = vr_sort_main(config.vr, num_trials, num_sort = 4)
    save_model(sort_trial_params_and_seeds, config.sort_model_path, results_dir)
    experiment_results['sort'] = {
        'model_params': sort_trial_params_and_seeds,
        'train_losses': sort_train_losses,
        'train_std': sort_std_train_losses,
        'test_losses': sort_test_losses,
        'test_std': sort_std_test_losses,
        'test_accuracies': sort_test_accuracies,
        'test_accuracies_std': sort_std_test_accuracies,
        'loss_variances': sort_loss_variances,
        'set_sizes': sort_set_sizes,
        'set_std' : sort_std_set_sizes
    }

    print("Conftr Version")
    conftr_trial_params_and_seeds,(conftr_train_losses,  conftr_std_train_losses),(conftr_test_losses, conftr_std_test_losses),(conftr_test_accuracies, conftr_std_test_accuracies), conftr_loss_variances, (conftr_set_sizes, conftr_std_set_sizes) = conftr_main(config.vr, num_trials)
    save_model(conftr_trial_params_and_seeds, config.conftr_model_path, results_dir)
    experiment_results['conftr'] = {
        'model_params': conftr_trial_params_and_seeds,
        'train_losses': conftr_train_losses,
        'train_std': conftr_std_train_losses,
        'test_losses': conftr_test_losses,
        'test_std': conftr_std_test_losses,
        'test_accuracies': conftr_test_accuracies,
        'test_accuracies_std': conftr_std_test_accuracies,
        'loss_variances': conftr_loss_variances,
        'set_sizes': conftr_set_sizes,
        'set_std' : conftr_std_set_sizes
    }

    # Evaluate the models on their corresponding test_loader 
    all_results = []
    sort_results = []
    conftr_results = []


    for trial in range(num_trials):
        results = eval_main( models_info=[
                            ("sort_model", config.vr, sort_trial_params_and_seeds[trial]),
                            ("conftr_model", config.vr, conftr_trial_params_and_seeds[trial]),
                        
                ])
        all_results.append(results)
        sort_results.append(results["sort_model"])
        conftr_results.append(results["conftr_model"])
        

    
    avg_sort_results = average_results(sort_results)
    avg_conftr_results = average_results(conftr_results)


    # Add evaluation results to the experiment results dictionary
    experiment_results['evaluation'] = {
        'all_results': all_results,
        'conftr_results': conftr_results,
        'sort_results': sort_results
    }
    experiment_results['avg_results'] = {
        'conftr': avg_conftr_results,
        'sort': avg_sort_results
    }

    # Plot histogram of average sizes
    avg_sizes = { 
                  "Conftr Model" : avg_conftr_results['avg_size'],
                  "Sort Model" : avg_sort_results['avg_size']
    }

    # Plot Train Losses
    plot_combined_losses(
        (sort_train_losses, sort_std_train_losses, 'VR-ConfTr Train loss', 'red'),
        (conftr_train_losses, conftr_std_train_losses, 'ConfTr Train loss', 'blue'),
        title = 'Training Loss per Epoch',
        save_path = os.path.join(results_dir, 'combined_train_losses.png'),
        loss_type = 'Train'
    )

    # Plot Test loss per epoch 
    plot_combined_losses(
        (sort_test_losses, sort_std_test_losses, 'VR-ConfTr Test loss', 'red'),
        (conftr_test_losses, conftr_std_test_losses, 'ConfTr Test loss', 'blue'),
        title = 'Test Loss per Epoch',
        save_path = os.path.join(results_dir, 'combined_test_losses.png'),
        loss_type = 'Test'
    )

    # Plot Set Sizes per epoch
    plot_set_sizes(
        (sort_set_sizes, sort_std_set_sizes, 'VR-ConfTr Model Set Sizes', 'red'),
        (conftr_set_sizes, conftr_std_set_sizes, 'ConfTr Model Set Sizes', 'blue'),
        title='Set Sizes per Epoch',
        save_path=os.path.join(results_dir, 'set_sizes_per_epoch.png')
    )

    # Plot accuracies per epoch
    plot_accuracies(
        (sort_test_accuracies, sort_std_test_accuracies, 'VR-ConfTr Test Accuracy', 'red'),
        (conftr_test_accuracies, conftr_std_test_accuracies, 'ConfTr Test Accuracy', 'blue'),
        title='Test Accuracy per Epoch',
        save_path=os.path.join(results_dir, 'combined_train_accuracies.png')
    )

    plot_avg_sizes(
        *[(name, size) for name, size in avg_sizes.items()],
        title='Average Sizes of Models',
        save_path=os.path.join(results_dir, 'avg_sizes_plot.png')
    )

    # Save all experiment results in the results directory 
    experiment_results_path = os.path.join(results_dir, 'experiment_results.pkl')
    with open(experiment_results_path, 'wb') as f:
        pickle.dump(experiment_results, f)

def run_experiment_fmnist():
    config = get_experiment_config_fmnist()
    num_trials = config.num_trials
    print("num_trials = ", num_trials)
    results_dir = config.results_dir
    print("results", results_dir)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    experiment_results = {}

    
    print("Sorting Version")
    sort_trial_params_and_seeds,(sort_train_losses,  sort_std_train_losses),(sort_test_losses, sort_std_test_losses),(sort_test_accuracies, sort_std_test_accuracies), sort_loss_variances, (sort_set_sizes, sort_std_set_sizes) = vr_sort_main(config.vr, num_trials, num_sort = 4)
    save_model(sort_trial_params_and_seeds, config.sort_model_path, results_dir)
    experiment_results['sort'] = {
        'model_params': sort_trial_params_and_seeds,
        'train_losses': sort_train_losses,
        'train_std': sort_std_train_losses,
        'test_losses': sort_test_losses,
        'test_std': sort_std_test_losses,
        'test_accuracies': sort_test_accuracies,
        'test_accuracies_std': sort_std_test_accuracies,
        'loss_variances': sort_loss_variances,
        'set_sizes': sort_set_sizes,
        'set_std' : sort_std_set_sizes
    }

    print("Conftr 100 Version")
    conftr100_trial_params_and_seeds,(conftr100_train_losses,  conftr100_std_train_losses),(conftr100_test_losses, conftr100_std_test_losses),(conftr100_test_accuracies, conftr100_std_test_accuracies), conftr100_loss_variances, (conftr100_set_sizes, conftr100_std_set_sizes) = conftr_main(config.ct100, num_trials)
    save_model(conftr100_trial_params_and_seeds, config.conftr_model_path, results_dir)
    experiment_results['conftr100'] = {
        'model_params': conftr100_trial_params_and_seeds,
        'train_losses': conftr100_train_losses,
        'train_std': conftr100_std_train_losses,
        'test_losses': conftr100_test_losses,
        'test_std': conftr100_std_test_losses,
        'test_accuracies': conftr100_test_accuracies,
        'test_accuracies_std': conftr100_std_test_accuracies,
        'loss_variances': conftr100_loss_variances,
        'set_sizes': conftr100_set_sizes,
        'set_std' : conftr100_std_set_sizes
    }

    print("Conftr 500 Version")
    conftr500_trial_params_and_seeds,(conftr500_train_losses,  conftr500_std_train_losses),(conftr500_test_losses, conftr500_std_test_losses),(conftr500_test_accuracies, conftr500_std_test_accuracies), conftr500_loss_variances, (conftr500_set_sizes, conftr500_std_set_sizes) = conftr_main(config.vr, num_trials)
    save_model(conftr500_trial_params_and_seeds, config.conftr_model_path, results_dir)
    experiment_results['conftr500'] = {
        'model_params': conftr500_trial_params_and_seeds,
        'train_losses': conftr500_train_losses,
        'train_std': conftr500_std_train_losses,
        'test_losses': conftr500_test_losses,
        'test_std': conftr500_std_test_losses,
        'test_accuracies': conftr500_test_accuracies,
        'test_accuracies_std': conftr500_std_test_accuracies,
        'loss_variances': conftr500_loss_variances,
        'set_sizes': conftr500_set_sizes,
        'set_std' : conftr500_std_set_sizes
    }

    # Evaluate the models on their corresponding test_loader 
    all_results = []
    sort_results = []
    conftr100_results = []
    conftr500_results = []


    for trial in range(num_trials):
        results = eval_main( models_info=[
                            ("sort_model", config.vr, sort_trial_params_and_seeds[trial]),
                            ("conftr500_model", config.vr, conftr500_trial_params_and_seeds[trial]),
                            ("conftr100_model", config.ct100, conftr100_trial_params_and_seeds[trial]),
                ])
        all_results.append(results)
        sort_results.append(results["sort_model"])
        conftr100_results.append(results["conftr100_model"])
        conftr500_results.append(results["conftr500_model"])

    avg_conftr100_results = average_results(conftr100_results)
    avg_sort_results = average_results(sort_results)
    avg_conftr500_results = average_results(conftr500_results)

    # Add evaluation results to the experiment results dictionary
    experiment_results['evaluation'] = {
        'all_results': all_results,
        'conftr100_results': conftr100_results,
        'conftr500_results': conftr500_results,
        'sort_results': sort_results
    }
    experiment_results['avg_results'] = {
        #'split': avg_split_results,
        'conftr100': avg_conftr100_results,
        'conftr500': avg_conftr500_results,
        'sort': avg_sort_results
    }

    # Plot histogram of average sizes
    avg_sizes = {
                  "Conftr 100 Model" : avg_conftr100_results['avg_size'],
                  "Conftr 500 Model" : avg_conftr500_results['avg_size'],
                  "Sort Model" : avg_sort_results['avg_size']
    }


    # Save all experiment results in the results directory 
    experiment_results_path = os.path.join(results_dir, 'experiment_results.pkl')
    with open(experiment_results_path, 'wb') as f:
        pickle.dump(experiment_results, f)

    
   # Plot Train Losses
    plot_combined_losses(
        (sort_train_losses, sort_std_train_losses, 'VR-ConfTr Train loss', 'red'),
        (conftr500_train_losses, conftr500_std_train_losses, 'ConfTr 500 Train loss', 'blue'),
        (conftr100_train_losses, conftr100_std_train_losses, 'ConfTr 100 Train loss', 'purple'),
        title = 'Training Loss per Epoch',
        save_path = os.path.join(results_dir, 'combined_train_losses.png'),
        loss_type = 'Train'
    )

    # Plot Test loss per epoch 
    plot_combined_losses(
        (sort_test_losses, sort_std_test_losses, 'VR-ConfTr Test loss', 'red'),
        (conftr500_test_losses, conftr500_std_test_losses, 'ConfTr 500 Test loss', 'blue'),
        (conftr100_test_losses, conftr100_std_test_losses, 'ConfTr 100 Train loss', 'purple'),
        title = 'Test Loss per Epoch',
        save_path = os.path.join(results_dir, 'combined_test_losses.png'),
        loss_type = 'Test'
    )

    # Plot Set Sizes per epoch
    plot_set_sizes(
        (sort_set_sizes, sort_std_set_sizes, 'VR-ConfTr Model Set Sizes', 'red'),
        (conftr100_set_sizes, conftr100_std_set_sizes, 'ConfTr 100 Model Set Sizes', 'purple'),
        (conftr500_set_sizes, conftr500_std_set_sizes, 'ConfTr 500 Model Set Sizes', 'blue'),
        title='Set Sizes per Epoch',
        save_path=os.path.join(results_dir, 'set_sizes_per_epoch.png')
    )

    # Plot accuracies per epoch
    plot_accuracies(
        (sort_test_accuracies, sort_std_test_accuracies, 'VR-ConfTr Test Accuracy', 'red'),
        (conftr500_test_accuracies, conftr500_std_test_accuracies, 'Conftr 500 Test Accuracy', 'blue'),
        (conftr100_test_accuracies, conftr100_std_test_accuracies, 'Conftr 100 Test Accuracy', 'purple'),
        title='Test Accuracy per Epoch',
        save_path=os.path.join(results_dir, 'combined_train_accuracies.png')
    ) 

    plot_avg_sizes(
        *[(name, size) for name, size in avg_sizes.items()],
        title='Average Sizes of Models',
        save_path=os.path.join(results_dir, 'avg_sizes_plot.png')
    )


###################  OrganA-MNIST ######################

def run_experiment_organamnist():
    config = get_experiment_config_organamnist()
    num_trials = config.num_trials
    results_dir = config.results_dir
    print("results", results_dir)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    experiment_results = {}

    
    
    print("Sorting Version")
    sort_trial_params_and_seeds,(sort_train_losses,  sort_std_train_losses),(sort_test_losses, sort_std_test_losses),(sort_test_accuracies, sort_std_test_accuracies), sort_loss_variances, (sort_set_sizes, sort_std_set_sizes) = vr_sort_main(config.vr, num_trials, num_sort = 5)
    save_model(sort_trial_params_and_seeds, config.sort_model_path, results_dir)
    experiment_results['sort'] = {
        'model_params': sort_trial_params_and_seeds,
        'train_losses': sort_train_losses,
        'train_std': sort_std_train_losses,
        'test_losses': sort_test_losses,
        'test_std': sort_std_test_losses,
        'test_accuracies': sort_test_accuracies,
        'test_accuracies_std': sort_std_test_accuracies,
        'loss_variances': sort_loss_variances,
        'set_sizes': sort_set_sizes,
        'set_std' : sort_std_set_sizes
    }
    print("Conftr Version")
    conftr_trial_params_and_seeds,(conftr_train_losses,  conftr_std_train_losses),(conftr_test_losses, conftr_std_test_losses),(conftr_test_accuracies, conftr_std_test_accuracies), conftr_loss_variances, (conftr_set_sizes, conftr_std_set_sizes) = conftr_main(config.vr, num_trials)
    save_model(conftr_trial_params_and_seeds, config.conftr_model_path, results_dir)
    experiment_results['conftr'] = {
        'model_params': conftr_trial_params_and_seeds,
        'train_losses': conftr_train_losses,
        'train_std': conftr_std_train_losses,
        'test_losses': conftr_test_losses,
        'test_std': conftr_std_test_losses,
        'test_accuracies': conftr_test_accuracies,
        'test_accuracies_std': conftr_std_test_accuracies,
        'loss_variances': conftr_loss_variances,
        'set_sizes': conftr_set_sizes,
        'set_std' : conftr_std_set_sizes
    }
    
    # Evaluate the models on their corresponding test_loader 
    all_results = []
    sort_results = []
    conftr_results = []

    for trial in range(num_trials):
        results = eval_main( models_info=[
                            ("sort_model", config.vr, sort_trial_params_and_seeds[trial]),
                            ("conftr_model", config.vr, conftr_trial_params_and_seeds[trial]),
                ])
        all_results.append(results)
        sort_results.append(results["sort_model"])
        conftr_results.append(results["conftr_model"])

    avg_sort_results = average_results(sort_results)
    avg_conftr_results = average_results(conftr_results)

    # Add evaluation results to the experiment results dictionary
    experiment_results['evaluation'] = {
        'all_results': all_results,
        'conftr_results': conftr_results,
        'sort_results': sort_results
    }
    experiment_results['avg_results'] = {
        'conftr': avg_conftr_results,
        'sort': avg_sort_results
    }

    # Plot histogram of average sizes
    avg_sizes = {
                  "Conftr Model" : avg_conftr_results['avg_size'],
                  "Sort Model" : avg_sort_results['avg_size']
    }

       # Plot Train Losses
    plot_combined_losses(
        (sort_train_losses, sort_std_train_losses, 'VR-ConfTr Train loss', 'red'),
        (conftr_train_losses, conftr_std_train_losses, 'ConfTr Train loss', 'blue'),
        title = 'Training Loss per Epoch',
        save_path = os.path.join(results_dir, 'combined_train_losses.png'),
        loss_type = 'Train'
    )

    # Plot Test loss per epoch 
    plot_combined_losses(
        (sort_test_losses, sort_std_test_losses, 'VR-ConfTr Test loss', 'red'),
        (conftr_test_losses, conftr_std_test_losses, 'ConfTr Test loss', 'blue'),
        title = 'Test Loss per Epoch',
        save_path = os.path.join(results_dir, 'combined_test_losses.png'),
        loss_type = 'Test'
    )

    # Plot Set Sizes per epoch
    plot_set_sizes(
        (sort_set_sizes, sort_std_set_sizes, 'VR-ConfTr Model Set Sizes', 'red'),
        (conftr_set_sizes, conftr_std_set_sizes, 'ConfTr Model Set Sizes', 'blue'),
        title='Set Sizes per Epoch',
        save_path=os.path.join(results_dir, 'set_sizes_per_epoch.png')
    )

    # Plot accuracies per epoch
    plot_accuracies(
        (sort_test_accuracies, sort_std_test_accuracies, 'VR-ConfTr Test Accuracy', 'red'),
        (conftr_test_accuracies, conftr_std_test_accuracies, 'ConfTr Test Accuracy', 'blue'),
        title='Test Accuracy per Epoch',
        save_path=os.path.join(results_dir, 'combined_train_accuracies.png')
    )

    plot_avg_sizes(
        *[(name, size) for name, size in avg_sizes.items()],
        title='Average Sizes of Models',
        save_path=os.path.join(results_dir, 'avg_sizes_plot.png')
    )

    # Save all experiment results in the results directory 
    experiment_results_path = os.path.join(results_dir, 'experiment_results.pkl')
    with open(experiment_results_path, 'wb') as f:
        pickle.dump(experiment_results, f)



##### K MNSIT EXPERIMENT ############
def run_experiment_kmnist():
    config = get_experiment_config_kmnist()
    num_trials = config.num_trials
    results_dir = config.results_dir
    print("results", results_dir)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    experiment_results = {}

    print("Sorting Version")
    sort_trial_params_and_seeds,(sort_train_losses,  sort_std_train_losses),(sort_test_losses, sort_std_test_losses),(sort_test_accuracies, sort_std_test_accuracies), sort_loss_variances, (sort_set_sizes, sort_std_set_sizes) = vr_sort_main(config.vr, num_trials, num_sort = 4)
    save_model(sort_trial_params_and_seeds, config.sort_model_path, results_dir)
    experiment_results['sort'] = {
        'model_params': sort_trial_params_and_seeds,
        'train_losses': sort_train_losses,
        'train_std': sort_std_train_losses,
        'test_losses': sort_test_losses,
        'test_std': sort_std_test_losses,
        'test_accuracies': sort_test_accuracies,
        'test_accuracies_std': sort_std_test_accuracies,
        'loss_variances': sort_loss_variances,
        'set_sizes': sort_set_sizes,
        'set_std' : sort_std_set_sizes
    }
    
    



    print("Conftr Version")
    #Conftr Model Training
    conftr_trial_params_and_seeds,(conftr_train_losses,  conftr_std_train_losses),(conftr_test_losses, conftr_std_test_losses),(conftr_test_accuracies, conftr_std_test_accuracies), conftr_loss_variances, (conftr_set_sizes, conftr_std_set_sizes) = conftr_main(config.vr, num_trials)
    save_model(conftr_trial_params_and_seeds, config.conftr_model_path, results_dir)
    experiment_results['conftr'] = {
        'model_params': conftr_trial_params_and_seeds,
        'train_losses': conftr_train_losses,
        'train_std': conftr_std_train_losses,
        'test_losses': conftr_test_losses,
        'test_std': conftr_std_test_losses,
        'test_accuracies': conftr_test_accuracies,
        'test_accuracies_std': conftr_std_test_accuracies,
        'loss_variances': conftr_loss_variances,
        'set_sizes': conftr_set_sizes,
        'set_std' : conftr_std_set_sizes
    }


    # Evaluate the models on their corresponding test_loader 
    all_results = []
    sort_results = []
    conftr_results = []

    for trial in range(num_trials):
        results = eval_main( models_info=[
                            ("sort_model", config.vr, sort_trial_params_and_seeds[trial]),
                            ("conftr_model", config.vr, conftr_trial_params_and_seeds[trial]),
                ])
        all_results.append(results)
        sort_results.append(results["sort_model"])
        conftr_results.append(results["conftr_model"])

    avg_sort_results = average_results(sort_results)
    avg_conftr_results = average_results(conftr_results)

    # Add evaluation results to the experiment results dictionary
    experiment_results['evaluation'] = {
        'all_results': all_results,
        'conftr_results': conftr_results,
        'sort_results': sort_results
    }
    experiment_results['avg_results'] = {
        'conftr': avg_conftr_results,
        'sort': avg_sort_results
    }

    # Plot histogram of average sizes
    avg_sizes = {
                  "Conftr Model" : avg_conftr_results['avg_size'],
                  "Sort Model" : avg_sort_results['avg_size']
    }

       # Plot Train Losses
    plot_combined_losses(
        (sort_train_losses, sort_std_train_losses, 'VR-ConfTr Train loss', 'red'),
        (conftr_train_losses, conftr_std_train_losses, 'ConfTr Train loss', 'blue'),
        title = 'Training Loss per Epoch',
        save_path = os.path.join(results_dir, 'combined_train_losses.png'),
        loss_type = 'Train'
    )

    # Plot Test loss per epoch 
    plot_combined_losses(
        (sort_test_losses, sort_std_test_losses, 'VR-ConfTr Test loss', 'red'),
        (conftr_test_losses, conftr_std_test_losses, 'ConfTr Test loss', 'blue'),
        title = 'Test Loss per Epoch',
        save_path = os.path.join(results_dir, 'combined_test_losses.png'),
        loss_type = 'Test'
    )

    # Plot Set Sizes per epoch
    plot_set_sizes(
        (sort_set_sizes, sort_std_set_sizes, 'VR-ConfTr Model Set Sizes', 'red'),
        (conftr_set_sizes, conftr_std_set_sizes, 'ConfTr Model Set Sizes', 'blue'),
        title='Set Sizes per Epoch',
        save_path=os.path.join(results_dir, 'set_sizes_per_epoch.png')
    )

    # Plot accuracies per epoch
    plot_accuracies(
        (sort_test_accuracies, sort_std_test_accuracies, 'VR-ConfTr Test Accuracy', 'red'),
        (conftr_test_accuracies, conftr_std_test_accuracies, 'ConfTr Test Accuracy', 'blue'),
        title='Test Accuracy per Epoch',
        save_path=os.path.join(results_dir, 'combined_train_accuracies.png')
    )

    plot_avg_sizes(
        *[(name, size) for name, size in avg_sizes.items()],
        title='Average Sizes of Models',
        save_path=os.path.join(results_dir, 'avg_sizes_plot.png')
    )

    # Save all experiment results in the results directory 
    experiment_results_path = os.path.join(results_dir, 'experiment_results.pkl')
    with open(experiment_results_path, 'wb') as f:
        pickle.dump(experiment_results, f)

def main():
    valid_datasets = ["mnist", "kmnist", "fmnist", "organamnist"]
    
    # Check if a dataset name is provided
    if len(sys.argv) < 2:
        print(f"Please provide a dataset name. Valid options: {', '.join(valid_datasets)}")
        return

    dataset_name = sys.argv[1].lower()

    # Check if the provided dataset name is valid
    if dataset_name not in valid_datasets:
        print(f"Invalid dataset name '{dataset_name}'. Valid options: {', '.join(valid_datasets)}")
        return

    # Run the appropriate experiment based on the dataset name
    if dataset_name == "mnist":
        run_experiment_mnist()
    elif dataset_name == "kmnist":
        run_experiment_kmnist()
    elif dataset_name == "fmnist":
        run_experiment_fmnist()
    elif dataset_name == "organamnist":
        run_experiment_organamnist()

if __name__ == "__main__":
    main()
