"""
Functions for processing the experiment results.
"""
import os
import pickle
import numpy as np
from scipy import interpolate
from matplotlib import pyplot as plt

colors = ['blue', 'green', 'red', 'cyan', 'magenta', 'orange', 'lightblue', 'lightgreen', 'lightyellow', 'gray', 'purple', 'pink', 'orange', 'darkgreen', 'navy', 'salmon', 'gold', 'darkred']

def plot_mean_var(datasets, color, label, ax, timestamp):
    """Plot mean and var for one set of data"""
    # Define the time points at which we want to interpolate, we are taking the maximum end time here
    time_points = np.arange(0, np.min([np.max(data[0]) for data in datasets]))

    # Initialize list to hold interpolated accuracy values for all datasets
    interpolated_accuracies = []

    # Interpolate each dataset and evaluate at the desired time points
    for data in datasets:
        time = data[0]
        accuracy = data[1]
        interp_func = interpolate.interp1d(time, accuracy, kind='linear', fill_value="extrapolate")
        interpolated_accuracies.append(interp_func(time_points))

    # Convert to a 2D numpy array for easier manipulation
    interpolated_accuracies = np.array(interpolated_accuracies)

    # Calculate the mean and standard deviation of the accuracy at each time point
    mean_accuracies = np.mean(interpolated_accuracies, axis=0)
    std_accuracies = np.std(interpolated_accuracies, axis=0)

    time_points = time_points[:timestamp]
    mean_accuracies = mean_accuracies[:timestamp]
    std_accuracies = std_accuracies[:timestamp]

    # Plot mean accuracy with standard deviation as shaded region
    ax.plot(time_points, mean_accuracies, label=f'{label}', color=color)
    ax.fill_between(time_points, mean_accuracies - std_accuracies, mean_accuracies + std_accuracies, color=color, alpha=0.2)
    # ax.fill_between(time_points, mean_accuracies - std_accuracies, mean_accuracies + std_accuracies, color=color, alpha=0.2, label=f'{label} std')

def plot_mean_var_among_datasets(datasets_list, colors, labels, caption, ax):
    """Plot mean and var for different sets of data"""
    # plt.figure(figsize=(10,6))
    sync_timestamp = float('inf')
    async_timestamp = float('inf')
    for i in range(0, 2):
        time_points = np.arange(0, np.min([np.max(data[0]) for data in datasets_list[i]]))
        sync_timestamp = min(sync_timestamp, len(time_points))
    for i in range(0, 2):
        plot_mean_var(datasets_list[i], colors[i], labels[i], ax, sync_timestamp)
    for i in range(2, len(datasets_list)):
        time_points = np.arange(0, np.min([np.max(data[0]) for data in datasets_list[i]]))
        async_timestamp = min(async_timestamp, len(time_points))
    for i in range(2, len(datasets_list)):
        plot_mean_var(datasets_list[i], colors[i], labels[i], ax, async_timestamp)
    # for i in range(len(datasets_list)):
    #     plot_mean_var(datasets_list[i], colors[i], labels[i], ax)
    # ax.set_xlabel(f'Time\n\n{caption}', fontsize=15)
    ax.set_xlabel('Time (sec)', fontsize=12)
    ax.text(0.5, -0.2, caption, fontsize=15, transform=ax.transAxes, ha='center')
    ax.set_ylabel('Accuracy', fontsize=12)
    ax.legend()

def plot_dirs(dirs, colors, labels, caption, ax):
    # Initialize an empty list to hold the data
    data = []

    # Walk through all files in the directory, including subdirectories
    for directory in dirs:
        # Initialize an empty list to hold the data for this directory path
        dirpath_data = []
            
        # Walk through all files in the directory
        for dirpath, _, filenames in os.walk(directory):
            # For each file in the current directory
            for filename in filenames:
                # If the file starts with 'metric' and ends with '.pkl'
                if filename.startswith('metric') and filename.endswith('.pkl'):
                    # Construct the full file path
                    filepath = os.path.join(dirpath, filename)
                    # Load the file and append the data to the list
                    with open(filepath, 'rb') as f:
                        dirpath_data.append(pickle.load(f))

            # Append the list of data for this directory path to the main dictionary
            if dirpath_data != []: 
                data.append(dirpath_data)
    plot_mean_var_among_datasets(data, colors, labels, caption, ax)

def plot_multiple_dirs(multiple_dirs, multiple_labels, captions, output_dir, output_filename, colors):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    save_path = os.path.join(output_dir, f'{output_filename}.pdf')
    num_subplots = len(multiple_dirs)
    plt.figure(figsize=(6*num_subplots, 4))
    for i, (dirs, labels) in enumerate(zip(multiple_dirs, multiple_labels)):
        # Plot one subplot
        ax = plt.subplot(1, num_subplots, i+1)
        plot_dirs(dirs, colors, labels, captions[i], ax)
    plt.tight_layout()
    plt.savefig(save_path)

def obtain_max_accuracy(name, dir):
    """Obtain the average maximum accuracy among multiple experiment runs in a certain result directory."""
    max_accuracy = []
    for dirpath, dirnames, filenames in os.walk(dir):
        for filename in filenames:
            if filename.startswith('metric') and filename.endswith('.pkl'):
                filepath = os.path.join(dirpath, filename)
                # Load the file and append the data to the list
                with open(filepath, 'rb') as f:
                    data = pickle.load(f)
                acc = data[1]
                max_accuracy.append(max(acc))
    max_accuracy = np.array(max_accuracy)
    print(f'{name} - Total:{len(max_accuracy)} Average: {np.mean(max_accuracy)} Std: {np.std(max_accuracy)}')

def obtain_time_result(names, dirs, target_acc, norm_idx=-2):
    """Obtain the relative time for different algorithms to reach a certain target accuracy."""
    min_length = float('inf')
    time_for_acc = []
    for dir in dirs:
        alg_time = []
        for dirpath, dirnames, filenames in os.walk(dir):
            for filename in filenames:
                if filename.startswith('metric') and filename.endswith('.pkl'):
                    filepath = os.path.join(dirpath, filename)
                    # Load the file and append the data to the list
                    with open(filepath, 'rb') as f:
                        data = pickle.load(f)
                    time, acc = data[0], data[1]
                    for t, a in zip(time, acc):
                        if a >= target_acc:
                            alg_time.append(t)
                            break
        alg_time.sort()
        time_for_acc.append(alg_time)
        min_length = min(min_length, len(alg_time))

    result = []
    for alg_time, name in zip(time_for_acc, names):
        alg_time = np.array(alg_time[:min_length])
        result.append(np.mean(alg_time))
    result /= result[norm_idx]
    for name, res in zip(names, result):
        print(f'{name} - Time: {res} out of {min_length} results')