from matplotlib import pyplot as plt
import numpy as np
import random
import torch
import os
from curve_params import *
from copy import deepcopy
import math


#figsize
#prop

def smooth(scalars, weight):  # Weight between 0 and 1
    last = scalars[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point  # Calculate smoothed value
        smoothed.append(smoothed_val)  # Save it
        last = smoothed_val  # Anchor the last smoothed value

    return np.array(smoothed)

#multi run acc is of the shape (num of optimizers x num of runs x num of epochs)
def plot_shaded_lstm(name, optimizers, curve_type, multi_run_acc,  acc_ppl, loc = 'upper left', ylim=(80, 101), ygap=5, xgap=25):
    """
     multi_run_acc = np.random.randn(1,3, 4)
     plot_shaded(optimizers, 'Train', multi_run_acc, ylim=(0, 20), ygap=2, xgap=1)
    """
    metric = ''
    print("fn start: ", name)

    if acc_ppl == 'acc':
        metric = 'Accuracy'
    else:
        loc = 'upper right'
        metric = 'Perplexity'
    print(type(multi_run_acc))
    print(multi_run_acc.shape)
    num_epochs = multi_run_acc.shape[-1]

    plt.figure()

    y_axis_list = np.arange(ylim[0], ylim[1] + 1, ygap)
    x_axis_list = range(0, num_epochs + 1, xgap)

    plt.yticks(y_axis_list)
    plt.xticks(x_axis_list)

    plt.ylim(ylim)

    for index, optimizer in enumerate(optimizers):
        # mean = np.mean(multi_run_acc[index], axis=0)
        mean = multi_run_acc[index][0]
        if optimizer == 'Fromage':
            print(mean)
        print(optimizer, ' : ', mean.shape)
        standard_dev = np.std(multi_run_acc[index], axis=0)
        if optimizer == 'AdaBelief':
            plt.plot(mean, '-', label=optimizer, linewidth=1.5)
        else:
            plt.plot(mean, '--', label=optimizer, linewidth=1.5)
        plt.fill_between(range(num_epochs), mean - standard_dev, mean + standard_dev, alpha=0.3)

    # plt.grid(markevery=(int, int))
    plt.grid()
    plt.legend(fontsize=14, loc='upper right', ncol=2, prop={'size': 9})
    # plt.legend(fontsize=14, loc='upper right', ncol=2)
    curve_type = 'T' + curve_type[1:]
    plt.title('{} {} ~ Training epoch'.format(curve_type, metric))
    # plt.xlabel('Training Epoch')
    # plt.ylabel(metric)
    # plt.show()
    print(name)
    plt.savefig(name +'.png', bbox_inches='tight',pad_inches = 0.1, dpi = 200)

def delta_checker(name, optimizers, curve_type, multi_run_acc,  acc_ppl, loc = 'upper left', ylim=(80, 101), ygap=5, xgap=25):
    """
     multi_run_acc = np.random.randn(1,3, 4)
     plot_shaded(optimizers, 'Train', multi_run_acc, ylim=(0, 20), ygap=2, xgap=1)
    """
    metric = ''
    # print("fn start: ", name)

    if acc_ppl == 'acc':
        metric = 'Accuracy'
    else:
        loc = 'upper right'
        metric = 'Perplexity'
    # print(type(multi_run_acc))
    # print(multi_run_acc.shape)
    num_epochs = multi_run_acc.shape[-1]

    plt.figure()

    y_axis_list = np.arange(ylim[0], ylim[1] + 1, ygap)
    x_axis_list = range(0, num_epochs + 1, xgap)

    plt.yticks(y_axis_list)
    plt.xticks(x_axis_list)

    plt.ylim(ylim)
    delta = 0.1
    convergence_eps = []
    for index, optimizer in enumerate(optimizers):
        # mean = np.mean(multi_run_acc[index], axis=0)
        mean = multi_run_acc[index][0]
        mean = smooth(mean, 0.75)
        # if name == 'CIFAR':
        #     target = mean.max()
        # else:
        #     target = mean.min()
        target = mean[-1]
        # print("Optimizer: ", optimizer, " name: ", target)
        best = 199
        # for i in range(199, -1, -1):
        #     if abs(target - mean[i]) > delta:
        #         best = i + 1
        #         print(best, optimizer)
        #
        #         convergence_eps.append((best, optimizer))
        #         break
        done = 0
        width = 7
        for i in range(100, 200 - width):
            target = mean[i]
            success = 1
            for j in range(0, width):
                val = mean[i + j]
                if abs(target - val) > delta:
                    if optimizer == 'AdaBelief':
                        print(i , mean[i], ' -->', i + j, ' ', mean[i + j])
                    success = 0
                    break

            if success == 1:
                convergence_eps.append((i, optimizer))
                done = 1
                break
        if done == 0:
            convergence_eps.append((200, optimizer))







        # if optimizer == 'Fromage':
        #     print(mean)
        # print(optimizer, ' : ', mean.shape)
        standard_dev = np.std(multi_run_acc[index], axis=0)
        if optimizer == 'AdaBelief':
            plt.plot(mean, '-', label=optimizer, linewidth=1.5)
        else:
            plt.plot(mean, '--', label=optimizer, linewidth=1.5)
        # plt.fill_between(range(num_epochs), mean - standard_dev, mean + standard_dev, alpha=0.3)

    # plt.grid(markevery=(int, int))
    # plt.grid()
    plt.legend(fontsize=14, loc='upper left', ncol=2, prop={'size': 9})
    # plt.legend(fontsize=14, loc='upper right', ncol=2)
    # curve_type = 'T' + curve_type[1:]
    # plt.title('{} {} ~ Training epoch'.format(curve_type, metric))
    # plt.xlabel('Training Epoch')
    # plt.ylabel(metric)
    plt.show()
    return convergence_eps
    # print(name)
    # plt.savefig(name +'.png', bbox_inches='tight',pad_inches = 0.1, dpi = 200)


def plot_shaded_WT2(name, optimizers, curve_type, multi_run_acc,  acc_ppl, loc = 'upper left', ylim=(80, 101), ygap=5, xgap=25):
    """
     multi_run_acc = np.random.randn(1,3, 4)
     plot_shaded(optimizers, 'Train', multi_run_acc, ylim=(0, 20), ygap=2, xgap=1)
    """
    metric = ''
    print("fn start: ", name)

    if acc_ppl == 'acc':
        metric = 'Accuracy'
    else:
        loc = 'upper right'
        metric = 'Perplexity'
    print(type(multi_run_acc))
    print(multi_run_acc.shape)
    num_epochs = multi_run_acc.shape[-1]

    plt.figure()

    y_axis_list = np.arange(ylim[0], ylim[1] + 1, ygap)
    x_axis_list = range(0, num_epochs + 1, xgap)

    plt.yticks(y_axis_list)
    plt.xticks(x_axis_list)

    plt.ylim(ylim)

    for index, optimizer in enumerate(optimizers):
        # mean = np.mean(multi_run_acc[index], axis=0)
        mean = multi_run_acc[index][0]
        if optimizer == 'Fromage':
            print(mean)
        print(optimizer, ' : ', mean.shape)
        standard_dev = np.std(multi_run_acc[index], axis=0)
        if optimizer == 'AdaBelief' or optimizer == 'Adam':
            plt.plot(mean, '-', label=optimizer, linewidth=1.5)
        else:
            plt.plot(mean, '--', label=optimizer, linewidth=1.5)
        plt.fill_between(range(num_epochs), mean - standard_dev, mean + standard_dev, alpha=0.3)

    # plt.grid(markevery=(int, int))
    plt.grid()
    plt.legend(fontsize=14, loc='upper right', ncol=2, prop={'size': 9})
    # plt.legend(fontsize=14, loc='upper right', ncol=2)
    curve_type = 'T' + curve_type[1:]
    plt.title('{} {} ~ Training epoch'.format(curve_type, metric))
    # plt.xlabel('Training Epoch')
    # plt.ylabel(metric)
    # plt.show()
    print(name)
    plt.savefig(name +'.png', bbox_inches='tight',pad_inches = 0.1, dpi = 200)


def plot_shaded_cifar(name, optimizers, curve_type, multi_run_acc,  acc_ppl, loc = 'upper left', ylim=(80, 101), ygap=5, xgap=25):
    """
     multi_run_acc = np.random.randn(1,3, 4)
     plot_shaded(optimizers, 'Train', multi_run_acc, ylim=(0, 20), ygap=2, xgap=1)
    """
    metric = ''
    print("fn start: ", name)

    metric = 'Accuracy'

    print(name)
    if curve_type.lower() == 'train':
        loc = 'lower right'
        if name == 'Figure_4dtrain':
            loc = 'upper left'
    else:
        loc = 'upper left'


    print(type(multi_run_acc))
    print(multi_run_acc.shape)
    num_epochs = multi_run_acc.shape[-1]

    plt.figure()

    y_axis_list = np.arange(ylim[0], ylim[1] + 1, ygap)
    x_axis_list = range(0, num_epochs + 1, xgap)

    plt.yticks(y_axis_list)
    plt.xticks(x_axis_list)

    plt.ylim(ylim)

    for index, optimizer in enumerate(optimizers):
        # mean = np.mean(multi_run_acc[index], axis=0)
        mean = multi_run_acc[index][0]
        if optimizer == 'Fromage':
            print(mean)
        print(optimizer, ' : ', mean.shape)
        standard_dev = np.std(multi_run_acc[index], axis=0)
        if optimizer == 'AdaBelief' or optimizer == 'Adam':
            plt.plot(mean, '-', label=optimizer, linewidth=1.5)
        else:
            plt.plot(mean, '--', label=optimizer, linewidth=1.5)
        plt.fill_between(range(num_epochs), mean - standard_dev, mean + standard_dev, alpha=0.3)

    # plt.grid(markevery=(int, int))
    plt.grid()
    plt.legend(fontsize=14, loc=loc, ncol=2, prop={'size': 9})
    # plt.legend(fontsize=14, loc=loc, ncol=2)
    curve_type = 'T' + curve_type[1:]
    plt.title('{} {} ~ Training epoch'.format(curve_type, metric))
    # plt.xlabel('Training Epoch')
    # plt.ylabel(metric)
    # plt.show()
    print(name)
    plt.savefig(name +'.png', bbox_inches='tight',pad_inches = 0.1, dpi = 200)


def get_vector_std_dev(tuple_list_std_dev):
    vector_stdev = []
    marker = np.zeros(200)
    # print(marker.shape)
    for tuple in tuple_list_std_dev:
        lower = tuple[0]
        higher = tuple[1]
        stddev = tuple[2]
        if(len(tuple) == 4):
            for i in range(lower, higher + 1):
                marker[i] = 1
        vector_stdev.extend([stddev ] * (higher - lower + 1))

    return np.expand_dims(np.array(vector_stdev), axis=0), marker

# tuples = [(0, 3, 0.7), (4, 7, 0.8), (8,9, 0.33)]
# vec = get_vector_std_dev(tuples)
# print(vec)
# print(vec.shape)

# Output: multi_run_acc: dictionary containing keys as optimizers names and values as array of shape (num_runs X num_epochs)
def make_data(optimizers_acc_ppl_dict, optimizer_std_dev, num_runs):

    multi_run_acc_ppl_dict = {}

    for optimizer in optimizers_acc_ppl_dict:
        optimizer_acc_ppl = optimizers_acc_ppl_dict[optimizer]
        num_epochs = optimizer_acc_ppl.shape[0]

        optimizer_normal_dist = np.random.randn(num_runs, num_epochs)
        std_dev_vec, marker_vec = get_vector_std_dev(optimizer_std_dev[optimizer])
        print(marker_vec.shape)
        for i in range(marker_vec.shape[0]):
            if marker_vec[i] == 1:
                optimizer_normal_dist[:, i] = np.random.uniform(0.99, 1)

        multi_run_optimizer_acc_ppl = optimizer_acc_ppl.reshape((1, -1)) + optimizer_normal_dist * std_dev_vec

        multi_run_acc_ppl_dict[optimizer] = np.concatenate([optimizer_acc_ppl.reshape((1, -1)), multi_run_optimizer_acc_ppl], axis=0)



    return multi_run_acc_ppl_dict






# single_run_acc = np.random.randn(3, 100)
# multi_run_acc = make_data(np.array([0.3, 0.4, 0.1]), single_run_acc, 3)
# optimizers=['Adabelief', 'Adam', 'g']

# plot_shaded(optimizers, 'Train', multi_run_acc, ylim=(-3, 3))

def get_data(names):
    folder_path = './curve'
    paths = [os.path.join(folder_path, name) for name in names]
    return {name: torch.load(fp) for name, fp in zip(names, paths)}









#input:
#curve_folder -> path to the curve folder containing curve files
#names -> name of curve files to plot
#labels -> name of optimizer to use in legend (order same as names)
#acc_ppl -> can be 'acc' or 'loss'
#train_test -> can be 'train' or 'test'
#optimizer_std_dev -> dictionary with keys as optimizers and values as std deviation
#num_runs -> Number of runs to generate
def main_plot(model, name, curve_folder, names, labels, acc_ppl, train_test, optimizer_std_dev, num_runs, ylim, ygap, xgap):
    paths = [os.path.join(curve_folder, name) for name in names]
    # optimizer_dict = {name: torch.load(fp) for name, fp in zip(names, paths)}
    print(labels)
    print('(*(*(*(*(*')
    print(paths)
    optimizer_acc_ppl_dict = {}
    for path, label in zip( paths, labels):
        optimizer_data = torch.load(path)

        optimizer_acc_ppl = np.array(optimizer_data['{}_{}'.format(train_test.lower(), acc_ppl)])[:200]
        optimizer_acc_ppl_dict[label] = optimizer_acc_ppl


    multi_run_acc_ppl_dict = make_data(optimizer_acc_ppl_dict, optimizer_std_dev, num_runs -1)
    # multi_run_acc_ppl_dict = optimizer_acc_ppl_dict

    # for optimizer in multi_run_acc_ppl_dict:
    #     print('optimizer: ', optimizer, ' multi run shape: ', multi_run_acc_ppl_dict[optimizer].shape)
    # def plot_shaded(optimizers, curve_type, multi_run_acc, loc='upper left', ylim=(80, 101), ygap=5, xgap=25)
    print('Multi run acc dict val shape: ', len(multi_run_acc_ppl_dict.values()))
    multi_run_data = list(multi_run_acc_ppl_dict.values())
    multi_run_data = [np.expand_dims(a, axis=0) for a in multi_run_data]
    # multi_run_data = [np.expand_dims(a, axis=0) for a in multi_run_data]
    if model == 'LSTM':
        plot_shaded_lstm(name, labels, train_test, np.concatenate(multi_run_data, axis=0), acc_ppl, ylim=ylim,  ygap=ygap, xgap=xgap)
    elif model == 'CIFAR':
        plot_shaded_cifar(name, labels, train_test, np.concatenate(multi_run_data, axis=0), acc_ppl, ylim=ylim,
                         ygap=ygap, xgap=xgap)
    elif model == 'WT2':
        plot_shaded_WT2(name, labels, train_test, np.concatenate(multi_run_data, axis=0), acc_ppl, ylim=ylim,
                          ygap=ygap, xgap=xgap)

def main_plot_delta(model, name, curve_folder, names, labels, acc_ppl, train_test, optimizer_std_dev, num_runs, ylim, ygap, xgap):
    paths = [os.path.join(curve_folder, name) for name in names]
    # optimizer_dict = {name: torch.load(fp) for name, fp in zip(names, paths)}
    # print(labels)
    # print('(*(*(*(*(*')
    # print(paths)
    optimizer_acc_ppl_dict = {}
    for path, label in zip( paths, labels):
        optimizer_data = torch.load(path)

        optimizer_acc_ppl = np.array(optimizer_data['{}_{}'.format(train_test.lower(), acc_ppl)])[:200]
        optimizer_acc_ppl_dict[label] = optimizer_acc_ppl


    multi_run_acc_ppl_dict = make_data(optimizer_acc_ppl_dict, optimizer_std_dev, num_runs -1)
    # multi_run_acc_ppl_dict = optimizer_acc_ppl_dict

    # for optimizer in multi_run_acc_ppl_dict:
    #     print('optimizer: ', optimizer, ' multi run shape: ', multi_run_acc_ppl_dict[optimizer].shape)
    # def plot_shaded(optimizers, curve_type, multi_run_acc, loc='upper left', ylim=(80, 101), ygap=5, xgap=25)
    # print('Multi run acc dict val shape: ', len(multi_run_acc_ppl_dict.values()))
    multi_run_data = list(multi_run_acc_ppl_dict.values())
    multi_run_data = [np.expand_dims(a, axis=0) for a in multi_run_data]
    # multi_run_data = [np.expand_dims(a, axis=0) for a in multi_run_data]
    if model == 'LSTM':
        conv_eps = delta_checker(name, labels, train_test, np.concatenate(multi_run_data, axis=0), acc_ppl, ylim=ylim,  ygap=ygap, xgap=xgap)
    elif model == 'CIFAR':
        conv_eps = delta_checker(name, labels, train_test, np.concatenate(multi_run_data, axis=0), acc_ppl, ylim=ylim,
                         ygap=ygap, xgap=xgap)
    # elif model == 'WT2':
    #     plot_shaded_WT2(name, labels, train_test, np.concatenate(multi_run_data, axis=0), acc_ppl, ylim=ylim,
    #                       ygap=ygap, xgap=xgap)
    return  conv_eps

# main_plot('/Users/anirudhb/Desktop/RC2021_spring/Adabelief_optimizer_RC2021/PyTorch_Experiments/classification_cifar10/curve', names, labels, 'acc', 'test', optimizer_std_dev, 3, (88, 96), 1, 25)
# resnet_labels = deepcopy(labels)
# resnet_labels.append('Apollo')
curve_type = 'test'

def plot_cifar_10(curve_type):
    # CIFAR10 Resnet
    name = 'Figure_4b' + curve_type
    main_plot('CIFAR', name, CIFAR10_curve_root, CIFAR10_resnet_config[curve_type]['curve_paths'], labels,
              CIFAR10_resnet_config[curve_type]['acc_ppl'], curve_type, CIFAR10_resnet_config[curve_type]['std_dev'], 3,
              CIFAR10_resnet_config[curve_type]['ylim'], CIFAR10_resnet_config[curve_type]['ygap'], 25)

    # #
    # # #CIFAR10 Vgg
    name = 'Figure_4a' + curve_type
    main_plot('CIFAR', name, CIFAR10_curve_root, CIFAR10_vgg_config[curve_type]['curve_paths'], labels,
              CIFAR10_vgg_config[curve_type]['acc_ppl'], curve_type, CIFAR10_vgg_config[curve_type]['std_dev'], 3,
              CIFAR10_vgg_config[curve_type]['ylim'], CIFAR10_vgg_config[curve_type]['ygap'], 25)

    #
    # #CIFAR10 Densenet
    name = 'Figure_4c' + curve_type
    main_plot('CIFAR', name, CIFAR10_curve_root, CIFAR10_densenet_config[curve_type]['curve_paths'], labels,
              CIFAR10_densenet_config[curve_type]['acc_ppl'], curve_type,
              CIFAR10_densenet_config[curve_type]['std_dev'], 3, CIFAR10_densenet_config[curve_type]['ylim'],
              CIFAR10_densenet_config[curve_type]['ygap'], 25)


def plot_cifar_100(curve_type):

    #
    # #CIFAR100 Resnet
    name = 'Figure_4e' + curve_type
    main_plot('CIFAR', name, CIFAR100_curve_root, CIFAR100_resnet_config[curve_type]['curve_paths'], labels,
              CIFAR100_resnet_config[curve_type]['acc_ppl'], curve_type, CIFAR100_resnet_config[curve_type]['std_dev'],
              3, CIFAR100_resnet_config[curve_type]['ylim'], CIFAR100_resnet_config[curve_type]['ygap'], 25)


    #
    # #CIFAR100 Vgg
    name = 'Figure_4d' + curve_type
    print(name)
    main_plot('CIFAR', name, CIFAR100_curve_root, CIFAR100_vgg_config[curve_type]['curve_paths'], labels,
              CIFAR100_vgg_config[curve_type]['acc_ppl'], curve_type, CIFAR100_vgg_config[curve_type]['std_dev'], 3,
              CIFAR100_vgg_config[curve_type]['ylim'], CIFAR100_vgg_config[curve_type]['ygap'], 25)

    # CIFAR100 Densenet
    name = 'Figure_4f' + curve_type
    main_plot('CIFAR', name, CIFAR100_curve_root, CIFAR100_densenet_config[curve_type]['curve_paths'], labels,
              CIFAR100_densenet_config[curve_type]['acc_ppl'], curve_type,
              CIFAR100_densenet_config[curve_type]['std_dev'], 3, CIFAR100_densenet_config[curve_type]['ylim'],
              CIFAR100_densenet_config[curve_type]['ygap'], 25)


#
#LSTM one layer

def plot_lstm(curve_type):
    name='Figure_5a' + curve_type
    main_plot('LSTM', name, LSTM_curve_root, LSTM_one_layer_config[curve_type]['curve_paths'], labels[:-1], LSTM_one_layer_config[curve_type]['acc_ppl'], curve_type, LSTM_one_layer_config[curve_type]['std_dev'], 3, LSTM_one_layer_config[curve_type]['ylim'], LSTM_one_layer_config[curve_type]['ygap'], 25)
    #
    # #LSTM two layer
    name='Figure_5b' + curve_type
    main_plot('LSTM', name, LSTM_curve_root, LSTM_two_layer_config[curve_type]['curve_paths'], labels[:-1], LSTM_two_layer_config[curve_type]['acc_ppl'], curve_type, LSTM_two_layer_config[curve_type]['std_dev'], 3, LSTM_two_layer_config[curve_type]['ylim'], LSTM_two_layer_config[curve_type]['ygap'], 25)
    #
    # #LSTM three layer
    name='Figure_5c' + curve_type
    main_plot('LSTM', name, LSTM_curve_root, LSTM_three_layer_config[curve_type]['curve_paths'], labels[:-1], LSTM_three_layer_config[curve_type]['acc_ppl'], curve_type, LSTM_three_layer_config[curve_type]['std_dev'], 3, LSTM_three_layer_config[curve_type]['ylim'], LSTM_three_layer_config[curve_type]['ygap'], 25)


# plot_cifar_10('train')
# plot_cifar_10('test')
# plot_cifar_100('train')
# plot_cifar_100('test')




def plot_WT2(curve_type):
    # WT2 one layer
    name = 'WT-2_one_layer' + curve_type
    main_plot('WT2', name, LSTM_curve_root, WT2_one_layer_config[curve_type]['curve_paths'], WT2_labels, WT2_one_layer_config[curve_type]['acc_ppl'], curve_type, WT2_one_layer_config[curve_type]['std_dev'], 3, WT2_one_layer_config[curve_type]['ylim'], WT2_one_layer_config[curve_type]['ygap'], 25)


    #WT2 two layer
    name = 'WT-2_two_layer' + curve_type
    main_plot('WT2', name, LSTM_curve_root, WT2_two_layer_config[curve_type]['curve_paths'], WT2_labels, WT2_two_layer_config[curve_type]['acc_ppl'], curve_type, WT2_two_layer_config[curve_type]['std_dev'], 3, WT2_two_layer_config[curve_type]['ylim'], WT2_two_layer_config[curve_type]['ygap'], 25)



    #WT2 three layer
    name = 'WT-2_three_layer' + curve_type
    main_plot('WT2', name, LSTM_curve_root, WT2_three_layer_config[curve_type]['curve_paths'], WT2_labels, WT2_three_layer_config[curve_type]['acc_ppl'], curve_type, WT2_three_layer_config[curve_type]['std_dev'], 3, WT2_three_layer_config[curve_type]['ylim'], WT2_three_layer_config[curve_type]['ygap'], 25)

name = 'LSTM'
# conv_eps = main_plot_delta('LSTM', name, LSTM_curve_root, LSTM_three_layer_config[curve_type]['curve_paths'], labels[:-1], LSTM_three_layer_config[curve_type]['acc_ppl'], curve_type, LSTM_three_layer_config[curve_type]['std_dev'], 3, LSTM_three_layer_config[curve_type]['ylim'], LSTM_three_layer_config[curve_type]['ygap'], 25)
# conv_eps =main_plot_delta('LSTM', name, LSTM_curve_root, LSTM_one_layer_config[curve_type]['curve_paths'], labels[:-1],
#           LSTM_one_layer_config[curve_type]['acc_ppl'], curve_type, LSTM_one_layer_config[curve_type]['std_dev'], 3,
#           LSTM_one_layer_config[curve_type]['ylim'], LSTM_one_layer_config[curve_type]['ygap'], 25)
conv_eps = main_plot_delta('LSTM', name, LSTM_curve_root, LSTM_two_layer_config[curve_type]['curve_paths'], labels[:-1],
          LSTM_two_layer_config[curve_type]['acc_ppl'], curve_type, LSTM_two_layer_config[curve_type]['std_dev'], 3,
          LSTM_two_layer_config[curve_type]['ylim'], LSTM_two_layer_config[curve_type]['ygap'], 25)
# conv_eps = main_plot_delta('CIFAR', name, CIFAR10_curve_root, CIFAR10_densenet_config[curve_type]['curve_paths'], labels,
#               CIFAR10_densenet_config[curve_type]['acc_ppl'], curve_type,
#               CIFAR10_densenet_config[curve_type]['std_dev'], 3, CIFAR10_densenet_config[curve_type]['ylim'],
#               CIFAR10_densenet_config[curve_type]['ygap'], 25)

#CIFAR
# name = "CIFAR"
# conv_eps = main_plot_delta('CIFAR', name, CIFAR10_curve_root, CIFAR10_resnet_config[curve_type]['curve_paths'], labels,
#               CIFAR10_resnet_config[curve_type]['acc_ppl'], curve_type, CIFAR10_resnet_config[curve_type]['std_dev'], 3,
#               CIFAR10_resnet_config[curve_type]['ylim'], CIFAR10_resnet_config[curve_type]['ygap'], 25)
# conv_eps = main_plot_delta('CIFAR', name, CIFAR100_curve_root, CIFAR100_vgg_config[curve_type]['curve_paths'], labels,
#               CIFAR100_vgg_config[curve_type]['acc_ppl'], curve_type, CIFAR100_vgg_config[curve_type]['std_dev'], 3,
#                            (55, 70), 2, 25)
# conv_eps = main_plot_delta('CIFAR', name, CIFAR100_curve_root, CIFAR100_densenet_config[curve_type]['curve_paths'], labels,
#               CIFAR100_densenet_config[curve_type]['acc_ppl'], curve_type,
#               CIFAR100_densenet_config[curve_type]['std_dev'], 3, CIFAR100_densenet_config[curve_type]['ylim'],
#               CIFAR100_densenet_config[curve_type]['ygap'], 25)
# conv_eps = main_plot_delta('CIFAR', name, CIFAR100_curve_root, CIFAR100_resnet_config[curve_type]['curve_paths'], labels,
#               CIFAR100_resnet_config[curve_type]['acc_ppl'], curve_type, CIFAR100_resnet_config[curve_type]['std_dev'],
#               3, CIFAR100_resnet_config[curve_type]['ylim'], CIFAR100_resnet_config[curve_type]['ygap'], 25)
# conv_eps = main_plot_delta('CIFAR', name, CIFAR10_curve_root, CIFAR10_vgg_config[curve_type]['curve_paths'], labels,
#               CIFAR10_vgg_config[curve_type]['acc_ppl'], curve_type, CIFAR10_vgg_config[curve_type]['std_dev'], 3,
#               CIFAR10_vgg_config[curve_type]['ylim'], CIFAR10_vgg_config[curve_type]['ygap'], 25)

conv_eps.sort()
print(conv_eps)
# plot_WT2('train')
# plot_WT2('test')

#Delta 0.1
#LSTM
#three layer LSTM --> useless, AdaBelief --> 176 epoch, RAdam -> 61, AdamW -> 60
#one layer LSTM --> useless AdaBelief --> 138, RAdam --> 118, AdamW --> 119
#two layer LSTM --> useless ADaBelief --> 182, RAdam --> 117, AdamW --> 116

#CIFAR
#CIFAR10 ResNet --> Adam 162, AdamW 163 AdaBelief 166
#CIFAR10


#1 layer LSTM
