"""
plot sines solutions and print results
"""



import numpy as np
import torch

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.pyplot import rc
from plotting.colors_and_styles import colors_dict, linestyles_dict

from models.neural_de import neural_de
from models.modules import mlp, mlp_t, zero_aug, remove_aug, identity, mlp_sonode, remove_aug
from datasets.datasets import sinedata

from torch.utils.data import DataLoader

import os
import os.path as osp

from torchdiffeq_gq import odeint



sines10regfolder = 'results/sines/10_regular/1000/adjoint_gq'
sines10irregfolder = 'results/sines/10_irregular/1000/adjoint_gq'
sines50regfolder = 'results/sines/50_regular/1000/adjoint_gq'
sines50irregfolder = 'results/sines/50_irregular/1000/adjoint_gq'



data = sinedata(False, 10, True)

# model
data_dim = 1
model = neural_de(mlp_sonode(data_dim, 20), None, identity(), remove_aug(data_dim))
model.load_state_dict(torch.load(osp.join('results/sines/10_regular/20/adjoint_gq/5', 'trained_model.pth'), map_location='cpu'))






# plotting part

sns.set_style('white')
rc('font', family='serif')


height = 1
width = 1
axis_fontsize = 14
title_fontsize = 18
legend_fontsize = 12
legend_alpha = 0.9


def plot_one_solution(i):
    times = torch.linspace(0, 10, 200).unsqueeze(-1)
    initial = data[i][0].unsqueeze(0)
    solution = model.evaluate(initial, times)
    solution = solution.squeeze().detach().numpy()
    times = times.squeeze().detach().numpy()
    plt.plot(times, solution, color=colors_dict['blue'], alpha=0.5)
    #true_times = data[i][1].squeeze().detach().numpy()
    #true_sol = data[i][2].squeeze().detach().numpy()
    #plt.plot(true_times, true_sol, color=colors_dict['red'], alpha=0.5)
    plt.xlabel('t', fontsize=axis_fontsize)
    plt.ylabel('x', fontsize=axis_fontsize)
    plt.title('Extrapolated sine Solutions', fontsize=title_fontsize)



fig = plt.figure(figsize=[6*width, 5*height])
fig.subplots_adjust(hspace=0.0, wspace=0.0)

for i in range(5):
    plot_one_solution(i)

plt.axvline(x=2*np.pi, linestyle='--', c='k')
plt.savefig(osp.join('plotting', 'plots', 'sines_solutions.pdf'), bbox_inches='tight')






# now get the ablation values
def get_values(folder):
    total_times = []
    total_mse = []
    for ex in range(1, 5+1):
        times = np.load(osp.join(folder, str(ex), 'epoch_times.npy'))
        val_metric = np.load(osp.join(folder, str(ex), 'epoch_val_metric_history.npy'))
        times =  np.cumsum(times)
        total_times.append(times[-1])
        total_mse.append(val_metric[-1])
    total_times = np.array(total_times)
    total_mse = np.array(total_mse)
    mean_time = np.mean(total_times)
    std_time = np.std(total_times)
    mean_mse = np.mean(total_mse)
    std_mse = np.std(total_mse)
    print(folder)
    print('Time: {} +- {}'.format(mean_time, std_time))
    print('MSE: {} +- {}'.format(mean_mse, std_mse))
    print('\n')
    

get_values(sines10regfolder)
get_values(sines10irregfolder)
get_values(sines50regfolder)
get_values(sines50irregfolder)