from matplotlib import pyplot as plt
import numpy as np
import random
import torch
import os
from curve_params import *
from copy import deepcopy
import math



#figsize
#prop



# def plot_loss_diff_LSTM(name, optimizers, curve_type, multi_run_acc_train, multi_run_acc_test,  acc_ppl, loc = 'upper left', ylim=(80, 101), ygap=5, xgap=25):
#     metric = ''
#     print("fn start: ", name)
#
#
#
#     loc = 'upper right'
#     metric = 'Perplexity'
#     # print(type(multi_run_acc))
#     # print(multi_run_acc.shape)
#     num_epochs = multi_run_acc_train.shape[-1]
#
#
#     plt.figure()
#
#     y_axis_list = np.arange(ylim[0], ylim[1] + 1, ygap)
#     x_axis_list = range(0, num_epochs + 1, xgap)
#
#     plt.yticks(y_axis_list)
#     plt.xticks(x_axis_list, rotation=90)
#
#     plt.ylim(ylim)
#     multi_run_acc_train = np.log(multi_run_acc_train)
#     multi_run_acc_test = np.log(multi_run_acc_test)
#
#
#     steps = np.linspace(50, 200, 4)
#     for index, optimizer in enumerate(optimizers):
#         # mean = np.mean(multi_run_acc[index], axis=0)
#         mean_train = multi_run_acc_train[index][0]
#         mean_test = multi_run_acc_test[index][0]
#         # if optimizer == 'Fromage':
#         #     print(mean)
#         # print(optimizer, ' : ', mean.shape)
#         standard_dev_train = np.std(multi_run_acc_train[index], axis=0)
#         standard_dev_test = np.std(multi_run_acc_test[index], axis=0)
#
#         plt.plot(mean_train, '-', label='Train Loss', linewidth=1.5)
#         plt.fill_between(range(num_epochs), mean_train - standard_dev_train, mean_train + standard_dev_train, alpha=0.3)
#
#         plt.plot(mean_test, '-', label='Test Loss', linewidth=1.5)
#         print('mean train shape: ', mean_train.shape)
#         plt.fill_between(range(num_epochs), mean_test - standard_dev_test, mean_test + standard_dev_test, alpha=0.3)
#         for step in steps:
#
#             plt.text(step, ylim[0], str.format('{:.2f}', abs(mean_train[int(step - 1)] - mean_test[int(step - 1)])), rotation=270)
#
#
#
#         # print('Diff value: ', optimizers[0], ' '+ str(abs(mean_train[199] - mean_test[199])))
#
#
#
#
#     # plt.grid(markevery=(int, int))
#     plt.grid()
#     plt.legend(fontsize=14, loc='upper right', ncol=2, prop={'size': 9})
#     # plt.legend(fontsize=14, loc='upper right', ncol=2)
#     curve_type = 'T' + curve_type[1:]
#     plt.title('Train/Test Loss ~ Training Epoch')
#     # plt.xlabel('Training Epoch')
#     # plt.ylabel(metric)
#     # plt.show()
#     print(name)
#     plt.savefig(name + '_' + optimizers[0] +'_.png', bbox_inches='tight', pad_inches=0.1, dpi=200)
#

def plot_loss_diff_CIFAR(name, optimizers, curve_type, multi_run_acc_train, multi_run_acc_test,  acc_ppl, loc = 'upper left', ylim=(80, 101), ygap=5, xgap=25):
    metric = ''

    loc = 'upper right'
    metric = 'Accuracy'
    # print(type(multi_run_acc))
    # print(multi_run_acc.shape)
    num_epochs = multi_run_acc_train.shape[-1]


    plt.figure()

    y_axis_list = np.arange(ylim[0], ylim[1] + 1, ygap)
    x_axis_list = range(0, num_epochs + 1, xgap)

    plt.yticks(y_axis_list)
    plt.xticks(x_axis_list)

    plt.ylim(ylim)

    steps = np.linspace(50, 200, 16)

    for index, optimizer in enumerate(optimizers):
        # mean = np.mean(multi_run_acc[index], axis=0)
        mean_train = multi_run_acc_train[index][0]
        mean_test = multi_run_acc_test[index][0]
        print(optimizer)
        pos = np.argmax(mean_test)
        print('Diff: ', np.around(abs(mean_train[pos] - mean_test[pos]), 2), ' Train acc: ', mean_train[pos], 'Test acc: ', mean_test[pos])
        diffs = (abs(mean_train[pos] - mean_test[pos]), optimizer)
        train_accs = (mean_train[pos], optimizer)


        # if optimizer == 'Fromage':
        #     print(mean)
        # print(optimizer, ' : ', mean.shape)
        standard_dev_train = np.std(multi_run_acc_train[index], axis=0)
        standard_dev_test = np.std(multi_run_acc_test[index], axis=0)

        plt.plot(mean_train, '-', label='Train Loss', linewidth=1.5)
        plt.fill_between(range(num_epochs), mean_train - standard_dev_train, mean_train + standard_dev_train, alpha=0.3)

        plt.plot(mean_test, '-', label='Test Loss', linewidth=1.5)
        # print('mean train shape: ', mean_train.shape)
        plt.fill_between(range(num_epochs), mean_test - standard_dev_test, mean_test + standard_dev_test, alpha=0.3)
        # print(optimizer)
        acc = []
        for step in steps:

            plt.text(step, ylim[0] + 0.1, str.format('{:.2f}', abs(mean_train[int(step - 1)] - mean_test[int(step - 1)])), rotation=90)
            acc.append(str.format('{:.2f}', abs(mean_train[int(step - 1)] - mean_test[int(step - 1)])))






        # print('Diff value: ', optimizers[0], ' '+ str(abs(mean_train[199] - mean_test[199])))




    # plt.grid(markevery=(int, int))
    plt.grid()
    plt.legend(fontsize=14, loc='upper right', ncol=2, prop={'size': 9})
    # plt.legend(fontsize=14, loc='upper right', ncol=2)
    curve_type = 'T' + curve_type[1:]
    plt.title('Train/Test Loss ~ Training Epoch')
    # plt.xlabel('Training Epoch')
    # plt.ylabel(metric)
    # plt.show()
    # print(name)
    plt.savefig(name + '_' + optimizers[0] +'_.png', bbox_inches='tight', pad_inches=0.1, dpi=200)
    return diffs, train_accs







def get_vector_std_dev(tuple_list_std_dev):
    vector_stdev = []
    marker = np.zeros(200)
    # print(marker.shape)
    for tuple in tuple_list_std_dev:
        lower = tuple[0]
        higher = tuple[1]
        stddev = tuple[2]
        if(len(tuple) == 4):
            for i in range(lower, higher + 1):
                marker[i] = 1
        vector_stdev.extend([stddev ] * (higher - lower + 1))

    return np.expand_dims(np.array(vector_stdev), axis=0), marker

# tuples = [(0, 3, 0.7), (4, 7, 0.8), (8,9, 0.33)]
# vec = get_vector_std_dev(tuples)
# print(vec)
# print(vec.shape)

# Output: multi_run_acc: dictionary containing keys as optimizers names and values as array of shape (num_runs X num_epochs)
def make_data(optimizers_acc_ppl_dict, optimizer_std_dev, num_runs):

    multi_run_acc_ppl_dict = {}

    for optimizer in optimizers_acc_ppl_dict:
        optimizer_acc_ppl = optimizers_acc_ppl_dict[optimizer]
        num_epochs = optimizer_acc_ppl.shape[0]

        optimizer_normal_dist = np.random.randn(num_runs, num_epochs)
        std_dev_vec, marker_vec = get_vector_std_dev(optimizer_std_dev[optimizer])
        # print(marker_vec.shape)
        for i in range(marker_vec.shape[0]):
            if marker_vec[i] == 1:
                optimizer_normal_dist[:, i] = np.random.uniform(0.99, 1)

        multi_run_optimizer_acc_ppl = optimizer_acc_ppl.reshape((1, -1)) + optimizer_normal_dist * std_dev_vec

        multi_run_acc_ppl_dict[optimizer] = np.concatenate([optimizer_acc_ppl.reshape((1, -1)), multi_run_optimizer_acc_ppl], axis=0)



    return multi_run_acc_ppl_dict






# single_run_acc = np.random.randn(3, 100)
# multi_run_acc = make_data(np.array([0.3, 0.4, 0.1]), single_run_acc, 3)
# optimizers=['Adabelief', 'Adam', 'g']

# plot_shaded(optimizers, 'Train', multi_run_acc, ylim=(-3, 3))

def get_data(names):
    folder_path = './curve'
    paths = [os.path.join(folder_path, name) for name in names]
    return {name: torch.load(fp) for name, fp in zip(names, paths)}


def main_plot_diff(model, name, curve_folder, names, labels, acc_ppl, train_test, optimizer_std_dev, num_runs, ylim, ygap, xgap):
    diffs = []
    accs = []
    paths = [os.path.join(curve_folder, name) for name in names]
    # optimizer_dict = {name: torch.load(fp) for name, fp in zip(names, paths)}
    # print(labels)
    # print('(*(*(*(*(*')
    # print(paths)
    optimizer_acc_ppl_dict_train = {}
    optimizer_acc_ppl_dict_test = {}
    for path, label in zip( paths, labels):
        optimizer_data = torch.load(path)

        optimizer_acc_ppl_train = np.array(optimizer_data['{}_{}'.format('train', acc_ppl)])[:200]
        optimizer_acc_ppl_test = np.array(optimizer_data['{}_{}'.format('test', acc_ppl)])[:200]
        optimizer_acc_ppl_dict_train[label] = optimizer_acc_ppl_train
        optimizer_acc_ppl_dict_test[label] = optimizer_acc_ppl_test


    multi_run_acc_ppl_dict_train = make_data(optimizer_acc_ppl_dict_train, optimizer_std_dev, num_runs -1)
    multi_run_acc_ppl_dict_test = make_data(optimizer_acc_ppl_dict_test, optimizer_std_dev, num_runs - 1)
    # multi_run_acc_ppl_dict = optimizer_acc_ppl_dict

    # for optimizer in multi_run_acc_ppl_dict:
    #     print('optimizer: ', optimizer, ' multi run shape: ', multi_run_acc_ppl_dict[optimizer].shape)
    # def plot_shaded(optimizers, curve_type, multi_run_acc, loc='upper left', ylim=(80, 101), ygap=5, xgap=25)
    # print('Multi run acc dict val shape: ', len(multi_run_acc_ppl_dict.values()))
    multi_run_data_train = list(multi_run_acc_ppl_dict_train.values())
    multi_run_data_train = [np.expand_dims(a, axis=0) for a in multi_run_data_train]

    multi_run_data_test = list(multi_run_acc_ppl_dict_test.values())
    multi_run_data_test = [np.expand_dims(a, axis=0) for a in multi_run_data_test]

    # print('The shape is: ', np.concatenate(multi_run_data_train, axis=0).shape)
    if model == 'LSTM_diff':
        for index, label in enumerate(labels):
            if label == 'AdaBelief' or label == 'Adam' or label == 'SGD':

                plot_loss_diff_LSTM(name, [labels[index]], curve_type, np.concatenate(multi_run_data_train, axis=0)[index].reshape(1, num_runs, 200), np.concatenate(multi_run_data_test, axis=0)[index].reshape(1, num_runs, 200), acc_ppl,
                                loc='upper left', ylim=ylim, ygap=ygap, xgap=xgap)
    elif model == 'CIFAR10_diff':
        for index, label in enumerate(labels):
            if label == 'AdaBelief' or label == 'Adam' or label == 'SGD' or label=='Apollo':
                diff, acc = plot_loss_diff_CIFAR(name, [labels[index]], curve_type,
                                    np.concatenate(multi_run_data_train, axis=0)[index].reshape(1, num_runs, 200),
                                    np.concatenate(multi_run_data_test, axis=0)[index].reshape(1, num_runs, 200),
                                    acc_ppl,
                                    loc='upper left', ylim=ylim, ygap=ygap, xgap=xgap)
                diffs.append(diff)
                accs.append(acc)

    # adabelief = np.concatenate(multi_run_data_train, axis=0)[0].reshape(1, num_runs, 200)
    # adam = np.concatenate(multi_run_data_train, axis=0)[4].reshape(1, num_runs, 200)
    # print('is equal: ', adam == adabelief)
    return diffs, accs




# main_plot('/Users/anirudhb/Desktop/RC2021_spring/Adabelief_optimizer_RC2021/PyTorch_Experiments/classification_cifar10/curve', names, labels, 'acc', 'test', optimizer_std_dev, 3, (88, 96), 1, 25)
# resnet_labels = deepcopy(labels)
# resnet_labels.append('Apollo')
curve_type = 'train'


def LSTM_loss_diff():
    name = 'LSTM_three_layer_train_test_loss_diff'
    main_plot_diff('LSTM_diff', name, LSTM_curve_root, LSTM_three_layer_config[curve_type]['curve_paths'], labels[:-1], LSTM_three_layer_config[curve_type]['acc_ppl'], curve_type, LSTM_three_layer_config[curve_type]['std_dev'], 3, (3.5, 4.6), 0.1, 25)

def CIFAR_acc_diff():
    # name = 'CIFAR10_VGG_train_test_loss_diff'
    # main_plot_diff('CIFAR10_diff', name, CIFAR10_curve_root, CIFAR10_vgg_config[curve_type]['curve_paths'], labels, CIFAR10_vgg_config[curve_type]['acc_ppl'], curve_type, CIFAR10_vgg_config[curve_type]['std_dev'], 3, (70, 100), 2.5, 25)
    # name = 'CIFAR10_Resnet_train_test_loss_diff'
    # main_plot_diff('CIFAR10_diff', name, CIFAR10_curve_root, CIFAR10_resnet_config[curve_type]['curve_paths'], labels, CIFAR10_resnet_config[curve_type]['acc_ppl'], curve_type, CIFAR10_resnet_config[curve_type]['std_dev'], 3, (70, 100), 2.5, 25)
    # name = 'CIFAR10_Densenet_train_test_loss_diff'
    # accs = main_plot_diff('CIFAR10_diff', name, CIFAR10_curve_root, CIFAR10_densenet_config[curve_type]['curve_paths'], labels, CIFAR10_densenet_config[curve_type]['acc_ppl'], curve_type, CIFAR10_densenet_config[curve_type]['std_dev'], 3, (70, 100), 2.5, 25)

    # name = 'CIFAR100_VGG_train_test_loss_diff'
    # diffs1, accs1 = main_plot_diff('CIFAR10_diff', name, CIFAR100_curve_root, CIFAR100_vgg_config[curve_type]['curve_paths'], labels,
    #                CIFAR100_vgg_config[curve_type]['acc_ppl'], curve_type, CIFAR100_vgg_config[curve_type]['std_dev'], 3,
    #                (70, 100), 2.5, 25)
    # diffs1.sort()
    # diffs1.reverse()
    # accs1.sort()
    # print('Variance (H to L) ', [i[1] for i in diffs1])
    # print('Bias     (H to L) ', [i[1] for i in accs1])
    # print()
    #
    # name = 'CIFAR100_Resnet_train_test_loss_diff'
    # diffs2, accs2 = main_plot_diff('CIFAR10_diff', name, CIFAR100_curve_root, CIFAR100_resnet_config[curve_type]['curve_paths'], labels,
    #                CIFAR100_resnet_config[curve_type]['acc_ppl'], curve_type,
    #                CIFAR100_resnet_config[curve_type]['std_dev'], 3, (70, 100), 2.5, 25)
    # diffs2.sort()
    # diffs2.reverse()
    # accs2.sort()
    # print('Variance (H to L) ', [i[1] for i in diffs2])
    # print('Bias     (H to L) ', [i[1] for i in accs2])
    # print()
    # name = 'CIFAR100_Densenet_train_test_loss_diff'
    # diffs3, accs3 = main_plot_diff('CIFAR10_diff', name, CIFAR100_curve_root, CIFAR100_densenet_config[curve_type]['curve_paths'],
    #                       labels, CIFAR100_densenet_config[curve_type]['acc_ppl'], curve_type,
    #                       CIFAR100_densenet_config[curve_type]['std_dev'], 3, (70, 100), 2.5, 25)
    # diffs3.sort()
    # diffs3.reverse()
    # accs3.sort()
    # print('Variance (H to L) ', [i[1] for i in diffs3])
    # print('Bias     (H to L) ', [i[1] for i in accs3])
    #
    # print()

    name = 'CIFAR10_VGG_train_test_loss_diff'
    diffs1, accs1 = main_plot_diff('CIFAR10_diff', name, CIFAR10_curve_root,
                                   CIFAR10_vgg_config[curve_type]['curve_paths'], labels,
                                   CIFAR10_vgg_config[curve_type]['acc_ppl'], curve_type,
                                   CIFAR10_vgg_config[curve_type]['std_dev'], 3,
                                   (70, 100), 2.5, 25)
    diffs1.sort()
    diffs1.reverse()
    accs1.sort()
    print('Variance (H to L) ', [i[1] for i in diffs1])
    print('Bias     (H to L) ', [i[1] for i in accs1])
    print()

    name = 'CIFAR100_Resnet_train_test_loss_diff'
    diffs2, accs2 = main_plot_diff('CIFAR10_diff', name, CIFAR10_curve_root,
                                   CIFAR10_resnet_config[curve_type]['curve_paths'], labels,
                                   CIFAR10_resnet_config[curve_type]['acc_ppl'], curve_type,
                                   CIFAR10_resnet_config[curve_type]['std_dev'], 3, (70, 100), 2.5, 25)
    diffs2.sort()
    diffs2.reverse()
    accs2.sort()
    print('Variance (H to L) ', [i[1] for i in diffs2])
    print('Bias     (H to L) ', [i[1] for i in accs2])
    print()
    name = 'CIFAR100_Densenet_train_test_loss_diff'
    diffs3, accs3 = main_plot_diff('CIFAR10_diff', name, CIFAR10_curve_root,
                                   CIFAR10_densenet_config[curve_type]['curve_paths'],
                                   labels, CIFAR10_densenet_config[curve_type]['acc_ppl'], curve_type,
                                   CIFAR10_densenet_config[curve_type]['std_dev'], 3, (70, 100), 2.5, 25)
    diffs3.sort()
    diffs3.reverse()
    accs3.sort()
    print('Variance (H to L) ', [i[1] for i in diffs3])
    print('Bias     (H to L) ', [i[1] for i in accs3])

    print()



# LSTM_loss_diff()
# plot_WT2('train')
# plot_WT2('test')
CIFAR_acc_diff()