import numpy as np
import pickle as pkl
import os
import matplotlib.pyplot as plt
from pathlib import Path


whitenings = ['none', 'whiten', 'normalize']

lambda_values = np.logspace(-5, -1, 9, endpoint=True)

nrc1_color = '#1E90FF'  # Light blue
nrc2_color = '#4682B4'  # Darker blue
nrc3_color = '#FF8C00'  # Orange

green = '#32CD32'
cyan = '#00CED1'
purple = '#800080'


def plot_single_vs_multi(data, env, seeds, lambda_values, split):
    single_task_means = [np.mean(
        [data[lamW]['single-task'][seed] for seed in seeds]) for lamW in lambda_values]
    single_task_stds = [np.std([data[lamW]['single-task'][seed]
                               for seed in seeds]) for lamW in lambda_values]
    multi_task_means = [np.mean([data[lamW]['multi-task'][seed]
                                for seed in seeds]) for lamW in lambda_values]
    multi_task_stds = [np.std([data[lamW]['multi-task'][seed]
                              for seed in seeds]) for lamW in lambda_values]

    plt.figure(figsize=(8, 8))

    index = np.arange(len(lambda_values))
    scientific_labels = [f'{x:.1e}' for x in lambda_values]

    plt.grid(True, linestyle='--', alpha=0.7)
    plt.errorbar(index, single_task_means, yerr=single_task_stds, label='n Single-task', capsize=5, alpha=1,
                 fmt='-o', color=nrc2_color, linewidth=2, markersize=6)

    plt.errorbar(index, multi_task_means, yerr=multi_task_stds, label='Multi-task', capsize=5, alpha=1,
                 fmt='-o', color=nrc3_color, linewidth=2, markersize=6)

    # Set all ticks but make alternate labels empty
    labels = ['' if i % 2 else label for i,
              label in enumerate(scientific_labels)]
    plt.xticks(index, labels, fontsize=22)
    plt.yticks(fontsize=24)
    plt.legend(fontsize=22, loc='upper left')
    plt.tight_layout()
    plt.savefig(f'{env}_single_vs_multi_{split}_mses.png')
    plt.close()


def plot_whitening(data, env, seeds, lambda_values, split):
    plt.figure(figsize=(8, 8))

    index = np.arange(len(lambda_values))
    scientific_labels = [f'{x:.1e}' for x in lambda_values]
    colors = [green, cyan, purple]
    names = names = ['MSE', 'MSE(de-whiten)', 'MSE(de-normalize)']

    for i, whitening in enumerate(whitenings):
        whitening_means = [np.mean(
            [data[lamW][whitening][seed] for seed in seeds]) for lamW in lambda_values]
        whitening_stds = [np.std([data[lamW][whitening][seed]
                                 for seed in seeds]) for lamW in lambda_values]

        plt.errorbar(index, whitening_means, yerr=whitening_stds, label=names[i], capsize=5, alpha=1,
                     fmt='-o', color=colors[i], linewidth=2, markersize=6)

    # Show only alternate x-ticks but keep all grid lines
    labels = ['' if i % 2 else label for i,
              label in enumerate(scientific_labels)]
    plt.xticks(index, labels, fontsize=22)
    plt.yticks(fontsize=24)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(fontsize=22, loc='upper left')
    plt.tight_layout()
    plt.savefig(f'{env}_whitening_{split}_mses.png')
    plt.close()


def read_and_plot_results(env, seeds, dims, lambda_values):
    lambda_values = np.round(lambda_values, 6)
    single_vs_multi_training_mses = {
        lamW: {
            'single-task': [0 for _ in range(len(seeds))],
            'multi-task': [0 for _ in range(len(seeds))]
        }
        for lamW in lambda_values
    }
    single_vs_multi_test_mses = {
        lamW: {
            'single-task': [0 for _ in range(len(seeds))],
            'multi-task': [0 for _ in range(len(seeds))]
        }
        for lamW in lambda_values
    }

    for seed in seeds:
        for dim in dims:
            dir_name = Path(f'./E{env}/whitening_none/S{seed}/dim_{dim}/mses')
            pkl_files = os.listdir(dir_name)

            for file_ in pkl_files:
                lamW = round(float(file_[:-len('.pkl')].rsplit('_')[-1]), 6)

                with open(dir_name / file_, 'rb') as f:
                    data = pkl.load(f)

                    train_mse = np.mean(data['train_mses'])
                    if dim == None:
                        single_vs_multi_training_mses[lamW]['multi-task'][seed] = train_mse
                    else:
                        single_vs_multi_training_mses[lamW]['single-task'][seed] += train_mse

                    test_mse = np.mean(data['val_mses'])
                    if dim == None:
                        single_vs_multi_test_mses[lamW]['multi-task'][seed] = test_mse
                    else:
                        single_vs_multi_test_mses[lamW]['single-task'][seed] += test_mse

    plot_single_vs_multi(data=single_vs_multi_training_mses, env=env,
                         seeds=seeds, lambda_values=lambda_values, split='training')
    plot_single_vs_multi(data=single_vs_multi_test_mses, env=env,
                         seeds=seeds, lambda_values=lambda_values, split='test')

    whitening_training_mses = {
        lamW: {
            whitening: [0 for _ in range(len(seeds))]
            for whitening in whitenings
        } for lamW in lambda_values
    }
    whitening_test_mses = {
        lamW: {
            whitening: [0 for _ in range(len(seeds))]
            for whitening in whitenings
        } for lamW in lambda_values
    }

    for seed in seeds:
        for whitening in whitenings:
            dir_name = Path(
                f'./E{env}/whitening_{whitening}/S{seed}/dim_None/mses')
            pkl_files = os.listdir(dir_name)

            for file_ in pkl_files:
                lamW = round(float(file_[:-len('.pkl')].rsplit('_')[-1]), 6)

                with open(dir_name / file_, 'rb') as f:
                    data = pkl.load(f)

                    train_mse = np.mean(data['train_mses'])
                    whitening_training_mses[lamW][whitening][seed] = train_mse

                    test_mse = np.mean(data['val_mses'])
                    whitening_test_mses[lamW][whitening][seed] = test_mse

    plot_whitening(data=whitening_training_mses, env=env,
                   seeds=seeds, lambda_values=lambda_values, split='training')
    plot_whitening(data=whitening_test_mses, env=env, seeds=seeds,
                   lambda_values=lambda_values, split='test')


if __name__ == '__main__':
    read_and_plot_results(
        env='swimmer',
        seeds=[0, 1, 2],
        dims=[0, 1, None],
        lambda_values=lambda_values
    )
    read_and_plot_results(
        env='reacher',
        seeds=[0, 1, 2],
        dims=[0, 1, None],
        lambda_values=1.5 * lambda_values
    )
    read_and_plot_results(
        env='hopper',
        seeds=[0, 1, 2],
        dims=[0, 1, 2, None],
        lambda_values=lambda_values
    )
    read_and_plot_results(
        env='carla2d',
        seeds=[0, 1],
        dims=[0, 1, None],
        lambda_values=lambda_values
    )
