"""
plot g1d solutions
"""



import numpy as np
import torch

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.pyplot import rc
from plotting.colors_and_styles import colors_dict, linestyles_dict

from models.neural_de import neural_de
from models.modules import mlp, mlp_t, zero_aug, remove_aug 
from datasets.datasets import g1d

from torch.utils.data import DataLoader

import os
import os.path as osp

from torchdiffeq_gq import odeint



aug0folder = 'results/g1d/aug0/20/adjoint_gq/1'
aug3folder = 'results/g1d/aug3/20/adjoint_gq/10'


data = g1d()
loader = torch.utils.data.DataLoader(data, batch_size=len(data))

# aug0 model
data_dim = 1
dim = data_dim + 0
model0 = neural_de(mlp(dim, 20), None, zero_aug(0), remove_aug(data_dim, True))
model0.load_state_dict(torch.load(osp.join(aug0folder, 'trained_model.pth')))

# aug0 model
dim = data_dim + 3
model3 = neural_de(mlp(dim, 20), None, zero_aug(3), remove_aug(data_dim, True))
model3.load_state_dict(torch.load(osp.join(aug3folder, 'trained_model.pth')))




# get solution 0
in_data = torch.tensor([[1.0], [-1.0]])
times = torch.linspace(0, 1, 40)
solution0 = odeint(model0.defunc, model0.encoder(in_data), times, rtol=1e-3, atol=1e-3)

# get solution 3
solution3 = odeint(model3.defunc, model3.encoder(in_data), times, rtol=1e-3, atol=1e-3)

# turn into numpys
x0 = solution0[:, 0, 0].detach().numpy()
y0 = solution0[:, 1, 0].detach().numpy()
x3 = solution3[:, 0, 0].detach().numpy()
y3 = solution3[:, 1, 0].detach().numpy()
times = times.numpy()



# plotting part

sns.set_style('white')
rc('font', family='serif')


height = 1
width = 2
axis_fontsize = 14
title_fontsize = 18
legend_fontsize = 12
legend_alpha = 0.9




fig = plt.figure(figsize=[6*width, 5*height])
fig.subplots_adjust(hspace=0.0, wspace=0.2)

ax = plt.subplot(height, width, 1)
plt.plot(times, x0, color=colors_dict['green'], linestyle=linestyles_dict['solid'])
plt.plot(times, y0, color=colors_dict['pink'], linestyle=linestyles_dict['dashed'])
plt.xlabel('t', fontsize=axis_fontsize)
plt.ylabel('x', fontsize=axis_fontsize)
plt.title('0 Augmented Dimensions', fontsize=title_fontsize)


ax = plt.subplot(height, width, 2)
plt.plot(times, x3, color=colors_dict['green'], linestyle=linestyles_dict['solid'])
plt.plot(times, y3, color=colors_dict['pink'], linestyle=linestyles_dict['dashed'])
plt.xlabel('t', fontsize=axis_fontsize)
plt.ylabel('x', fontsize=axis_fontsize)
plt.title('3 Augmented Dimensions', fontsize=title_fontsize)

plt.savefig(osp.join('plotting', 'plots', 'g1d_solutions.pdf'), bbox_inches='tight')






# now get the ablation values
def get_values(i):
    folder = 'results/g1d/aug'+str(i)+'/1500/adjoint_gq'
    total_times = []
    total_mse = []
    for ex in range(1, 5+1):
        times = np.load(osp.join(folder, str(ex), 'epoch_fnfe_history.npy'))
        val_metric = np.load(osp.join(folder, str(ex), 'epoch_bnfe_history.npy'))
        #times =  np.cumsum(times)
        total_times.append(times[-1])
        total_mse.append(val_metric[-1])
    total_times = np.array(total_times)
    total_mse = np.array(total_mse)
    mean_time = np.mean(total_times)
    std_time = np.std(total_times)
    mean_mse = np.mean(total_mse)
    std_mse = np.std(total_mse)
    print(i)
    print('Time: {} +- {}'.format(mean_time, std_time))
    print('MSE: {} +- {}'.format(mean_mse, std_mse))
    print('\n')
    

for i in range(4):
    get_values(i)