"""
plot ou figure, and print ablation times
"""

import numpy as np
import matplotlib.pyplot as plt

import os
import os.path as osp

import argparse

import seaborn as sns
from matplotlib.pyplot import rc
from plotting.colors_and_styles import method_colors, method_linestyles, method_markers, method_names, method_filenames



sns.set_style('whitegrid')
rc('font', family='serif')



sizes = {'gq': np.array([5, 20, 50, 100]),
         'sde_adjoint': np.array([5, 20, 50, 100]),
         'sde_direct': np.array([5, 20, 50])}


methods = ['sde_direct', 'sde_adjoint', 'gq']



markersize=5
height = 1
width = 3
axis_fontsize = 14
title_fontsize = 18
legend_fontsize = 12
legend_alpha = 0.9

metrics = ['epoch_times', 'epoch_loss_history']
titles = ['Training Time', 'KL Divergence']
xaxis = ['Parameters', 'Hidden Width']


def add_to_plot(method, metric, x):

    label = method_names[method]
    line = method_linestyles[method]
    color = method_colors[method]
    marker = method_markers[method]
    widths = sizes[method]
    means = np.empty(len(widths))
    stds = np.empty(len(widths))
    pltx = np.empty(len(widths))
    for w in range(len(widths)):
        folder = osp.join('results', 'ornstein_uhlenbeck', '10', str(sizes[method][w]), method_filenames[method])
        to_plot = np.empty(5)
        for i in range(5):
            a = np.load(osp.join(folder, str(i+1), metric+'.npy'))
            if metric == 'epoch_times':
                a = np.cumsum(a)
            to_plot[i] = a[-1]
        means[w] = np.mean(to_plot)
        stds[w] = np.std(to_plot)
        if x == 'Parameters':
            a = np.load(osp.join(folder, str(i+1), 'nparams.npy'))
            pltx[w] = a
        else:
            pltx[w] = widths[w]
    plt.errorbar(x=pltx, y=means, yerr=stds, label=label, color=color, linestyle=line)



def make_one_plot(i):
    ax = plt.subplot(height, width, i+1)
    for m in methods:
        add_to_plot(m, metrics[i], xaxis[i])
    plt.xlabel(xaxis[i], fontsize=axis_fontsize)
    plt.ylabel(titles[i], fontsize=axis_fontsize)
    if i%3==0:
        plt.legend(fontsize=legend_fontsize, framealpha=legend_alpha, loc='upper left')
    plt.title(titles[i], fontsize=title_fontsize)







fig = plt.figure(figsize=[5*width, 4*height])
fig.subplots_adjust(hspace=0.0, wspace=0.3)

for i in range(2):
    make_one_plot(i)


# save figure
plt.savefig(osp.join('plotting', 'plots', 'ou_times.pdf'), bbox_inches='tight')




# now print the times taken during ablation and final kl

def get_values(ncosines):
    folder = 'results/ornstein_uhlenbeck/'+str(ncosines)+'/20/adjoint_gq'
    total_times = []
    total_mse = []
    for ex in range(1, 5+1):
        times = np.load(osp.join(folder, str(ex), 'epoch_times.npy'))
        val_metric = np.load(osp.join(folder, str(ex), 'epoch_loss_history.npy'))
        times =  np.cumsum(times)
        total_times.append(times[-1])
        total_mse.append(val_metric[-1])
    total_times = np.array(total_times)
    total_mse = np.array(total_mse)
    mean_time = np.mean(total_times)
    std_time = np.std(total_times)
    mean_mse = np.mean(total_mse)
    std_mse = np.std(total_mse)
    print(ncosines)
    print('Time: {} +- {}'.format(mean_time, std_time))
    print('KL: {} +- {}'.format(mean_mse, std_mse))
    print('\n')




get_values(0)
get_values(1)
get_values(5)
get_values(10)
get_values(25)
get_values(50)