"""
plotting the results of the toy gradient problem
"""

import numpy as np
import matplotlib.pyplot as plt

import os
import os.path as osp

import argparse

import seaborn as sns
from matplotlib.pyplot import rc
from plotting.colors_and_styles import method_colors, method_linestyles, method_markers, method_names


parser = argparse.ArgumentParser()
parser.add_argument('--show_gq', type=int, choices=[0, 1], default=0)
args = parser.parse_args()


sns.set_style('whitegrid')
rc('font', family='serif')



if args.show_gq:
    methods = ['adjoint', 'gq']
    save_name = 'with_gq'
else:
    methods = ['direct', 'adjoint']
    save_name = 'without_gq'


height = 3
width = 2
axis_fontsize = 14
title_fontsize = 18
legend_fontsize = 12
legend_alpha = 0.9
metrics = ['loss', 'agrad', 'loss_error', 'agrad_error', 'loss_rel_error', 'a_grad_rel_error']
titles = ['Loss', 'dL/da', 'Loss Error', 'dL/da Error', 'Relative Loss Error', 'Relative dL/da Error']


def true_loss(z, a, T):
    return (z**2)*np.exp(2*a*T)

def dlda(z, a, T):
    return 2*T*(z**2)*np.exp(2*a*T)




def add_to_plot(method, metric):
    
    label = method_names[method]
    line = method_linestyles[method]
    color = method_colors[method]
    folder = osp.join('results/', 'direct_adjoint_gradients/', method)
    a_s = np.load(osp.join(folder, 'as.npy'))
    agrad = np.load(osp.join(folder, 'a_grad.npy'))
    loss = np.load(osp.join(folder, 'loss.npy'))
    agrad_error = np.load(osp.join(folder, 'a_grad_errors.npy'))
    loss_error = np.load(osp.join(folder, 'loss_errors.npy'))
    loss_true = true_loss(1, a_s, 1)
    agrad_true = dlda(1, a_s, 1)
    
    if metric == 'loss':
        plt.plot(a_s, loss, label=label, linestyle=line, c=color)
    elif metric == 'agrad':
        plt.plot(a_s, agrad, label=label, linestyle=line, c=color)
    elif metric == 'loss_error':
        plt.plot(a_s, loss_error, label=label, linestyle=line, c=color)
    elif metric == 'agrad_error':
        plt.plot(a_s, agrad_error, label=label, linestyle=line, c=color)
    elif metric == 'loss_rel_error':
        plt.plot(a_s, loss_error/loss_true, label=label, linestyle=line, c=color)
    elif metric == 'a_grad_rel_error':
        plt.plot(a_s, agrad_error/agrad_true, label=label, linestyle=line, c=color)

        

def make_one_plot(i):
    ax = plt.subplot(height, width, i+1)
    for m in methods:
        add_to_plot(m, metrics[i])
    plt.xlabel('a', fontsize=axis_fontsize)
    plt.ylabel(titles[i], fontsize=axis_fontsize)
    plt.legend(fontsize=legend_fontsize, framealpha=legend_alpha, loc='upper left')
    plt.title(titles[i], fontsize=title_fontsize)


        
        
fig = plt.figure(figsize=[6*width, 5*height])
fig.subplots_adjust(hspace=0.3, wspace=0.3)


for i in range(6):
    make_one_plot(i)


# save figure
plt.savefig(osp.join('plotting', 'plots','direct_adjoint_grads_'+save_name+'.pdf'),bbox_inches='tight')







