"""
look at the time plots for ode experiments
"""

import numpy as np
import matplotlib.pyplot as plt

import os
import os.path as osp

import argparse

import seaborn as sns
from matplotlib.pyplot import rc
from plotting.colors_and_styles import method_colors, method_linestyles, method_markers, method_names, method_filenames


parser = argparse.ArgumentParser()
parser.add_argument('--experiment', type=str, choices=['g1d', 'nested_spheres', 'sines10', 'sines50', 'mnist', 'cifar10', 'svhn'], default='g1d')
parser.add_argument('--batchsize', type=int, default=16)
args = parser.parse_args()


sns.set_style('whitegrid')
rc('font', family='serif')



sizes = {'g1d': np.array([5, 20, 100, 500, 1000, 1500, 2000, 2500, 3000]),
         'nested_spheres': np.array([5, 20, 100, 500, 1000, 1500, 2000, 2500, 3000]),
         'sines10': np.array([5, 20, 200, 1000, 2000]),
         'sines50': np.array([5, 20, 200, 1000, 2000]),
         'mnist': np.array([50, 200, 350, 500]),
         'cifar10': np.array([50, 200, 350, 500]),
         'svhn': np.array([50, 200, 350, 500])}

ablations = {'g1d': 'aug3',
             'nested_spheres': 'gtol0.003',
             'sines10': '10_regular',
             'sines50': '50_regular',
             'mnist': str(args.batchsize),
             'cifar10': str(args.batchsize),
             'svhn': str(args.batchsize),
             'ornstein_uhlenbeck': '10'}

nexperiments = {'g1d': 10,
                'nested_spheres': 10,
                'sines10': 5,
                'sines50': 5,
                'mnist': 3,
                'cifar10': 3,
                'svhn': 3,
                'ornstein_uhlenbeck': 5}

if args.experiment == 'mnist' or args.experiment == 'cifar10' or args.experiment == 'svhn':
    methods = ['adjoint', 'seminorm', 'gq']
else:
    methods = ['direct', 'adjoint', 'seminorm', 'gq']


markersize=5
height = 1
width = 2
axis_fontsize = 14
title_fontsize = 18
legend_fontsize = 12
legend_alpha = 0.9
metrics = ['epoch_times', 'epoch_val_metric_history']

titles = {'g1d': ['Training Time', 'MSE'],
          'nested_spheres': ['Training Time', 'Test Accuracy'],
          'sines10': ['Training Time', 'Test MSE'],
          'sines50': ['Training Time', 'Test MSE'],
          'mnist': ['Training Time', 'Test Accuracy'],
          'cifar10': ['Training Time', 'Test Accuracy'],
          'svhn': ['Training Time', 'Test Accuracy']}

xaxis = {'g1d': ['Parameters', 'Hidden Width'],
         'nested_spheres': ['Parameters', 'Hidden Width'],
         'sines10': ['Parameters', 'Hidden Width'],
         'sines50': ['Parameters', 'Hidden Width'],
         'mnist': ['Parameters', 'Hidden Channels'],
         'cifar10': ['Parameters', 'Hidden Channels'],
         'svhn': ['Parameters', 'Hidden Channels']}


def add_to_plot(method, metric, x):

    label = method_names[method]
    line = method_linestyles[method]
    color = method_colors[method]
    marker = method_markers[method]
    widths = sizes[args.experiment]
    means = np.empty(len(widths))
    stds = np.empty(len(widths))
    pltx = np.empty(len(widths))
    for w in range(len(widths)):
        if args.experiment=='sines10' or args.experiment=='sines50':
            experiment = 'sines'
        else:
            experiment = args.experiment
        folder = osp.join('results', experiment, ablations[args.experiment], str(sizes[args.experiment][w]), method_filenames[method])
        to_plot = np.empty(nexperiments[args.experiment])
        for i in range(nexperiments[args.experiment]):
            a = np.load(osp.join(folder, str(i+1), metric+'.npy'))
            if metric == 'epoch_times':
                a = np.cumsum(a)
            to_plot[i] = a[-1]
        means[w] = np.mean(to_plot)
        stds[w] = np.std(to_plot)
        if x == 'Parameters':
            a = np.load(osp.join(folder, str(i+1), 'nparams.npy'))
            pltx[w] = a
        else:
            pltx[w] = widths[w]
    plt.errorbar(x=pltx, y=means, yerr=stds, label=label, color=color, linestyle=line)



def make_one_plot(i):
    ax = plt.subplot(height, width, i+1)
    for m in methods:
        add_to_plot(m, metrics[i], xaxis[args.experiment][i])
    plt.xlabel(xaxis[args.experiment][i], fontsize=axis_fontsize)
    plt.ylabel(titles[args.experiment][i], fontsize=axis_fontsize)
    if i%3==0:
        plt.legend(fontsize=legend_fontsize, framealpha=legend_alpha, loc='upper left')
    plt.title(titles[args.experiment][i], fontsize=title_fontsize)






fig = plt.figure(figsize=[5*width, 4*height])
fig.subplots_adjust(hspace=0.0, wspace=0.3)

for i in range(2):
    make_one_plot(i)


# save figure
if args.experiment == 'mnist' or args.experiment == 'cifar10' or args.experiment == 'svhn':
    name = args.experiment + str(args.batchsize)
else:
    name = args.experiment
plt.savefig(osp.join('plotting', 'plots', name+'_times.pdf'), bbox_inches='tight')

     
        
        
        
        
        
        