"""
Testing the smoothness of effective loss surfaces
"""

import argparse

import os
import os.path as osp

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim


# parse the arguments
parser = argparse.ArgumentParser()
parser.add_argument('--rtol', type=float, default=1e-6)
parser.add_argument('--atol', type=float, default=1e-7)
parser.add_argument('--method', type=str, choices=['adjoint', 'direct', 'gq'], default='direct')
parser.add_argument('--z0', type=float, default=1.0)
parser.add_argument('--T', type=float, default=1.0)
args = parser.parse_args()


# import correct torch ode solver
if args.method =='adjoint':
    from torchdiffeq_gq import odeint_adjoint as odeint
elif args.method == 'direct':
    from torchdiffeq_gq import odeint as odeint
elif args.method == 'gq':
    from torchdiffeq_gq import odeint_adjoint_gq as odeint


# turn args into tensors
z0 = torch.tensor([args.z0]).float()
T = torch.tensor(args.T).float()


# true loss functions, functions and gradients
def loss_func(z):
    return z**2

def true_loss(z, a, T):
    return (z**2)*torch.exp(2*a*T)

def true_func(z, a, T):
    return z*torch.exp(a*T)

def dldz(z, a, T):
    return 2*z*torch.exp(2*a*T)

def dlda(z, a, T):
    return 2*T*(z**2)*torch.exp(2*a*T)

def dldT(z, a, T):
    return 2*a*(z**2)*torch.exp(2*a*T)


# nn modules defining the odefunction and the odeblock
class odefunc(nn.Module):

    def __init__(self, a):
        super(odefunc, self).__init__()
        self.a = torch.nn.Parameter(torch.tensor(a).float())
        self.nfe = 0

    def forward(self, t, z):
        self.nfe += 1
        return self.a*z


class odeblock(nn.Module):

    def __init__(self, z0, a, T):
        super(odeblock, self).__init__()
        self.func = odefunc(a)
        self.z0 = torch.nn.Parameter(z0)
        self.times = torch.nn.Parameter(torch.tensor([0.0, T]))

    def forward(self):                                              # if using torchdiffeq methods
        options = {'rtol': args.rtol, 'atol': args.atol}
        out = odeint(self.func, self.z0, self.times, **options)[1]
        return out


# get the errors for one particular T, print and add to a list
def get_error(a, z_list, z_error_list, loss_list, loss_error_list, a_grad_list, a_grad_error_list):
    model = odeblock(z0, a, T)
    
    zT_pred = model()
    zT_true = true_func(z0, a, T).item()
    z_error = zT_true - zT_pred.item()

    loss_pred = loss_func(zT_pred)
    loss_pred.backward()
    loss_true = true_loss(z0, a, T).item()
    loss_error = loss_true - loss_pred.item()

    agrad_pred = model.func.a.grad.item()
    agrad_true = dlda(z0, a, T).item()
    agrad_error = agrad_true - agrad_pred

    print('a: {:.3f}'.format(a))
    print('z: {:.5f}, loss: {:.5f}, a grad: {:.5f}'.format(zT_pred.item(), loss_pred.item(), agrad_pred))
    print('z error: {:.5f}, loss error: {:.5f}, a grad error: {:.5f}'.format(z_error, loss_error, agrad_error))
    print('\n')
    z_list.append(zT_pred.item())
    z_error_list.append(z_error)
    loss_list.append(loss_pred.item())
    loss_error_list.append(loss_error)
    a_grad_list.append(agrad_pred)
    a_grad_error_list.append(agrad_error)



# make linspace of Ts to test and set up save folder and empty lists of results
a_list = np.arange(0.0, 6.02, 0.02)
z_list = []
z_error_list = []
loss_list = []
loss_error_list = []
a_grad_list = []
a_grad_error_list = []

folder = osp.join('results/', 'direct_adjoint_gradients/', args.method)
if not osp.exists(folder):
        os.makedirs(folder)

# run experiment for each a
for a in a_list:
    get_error(a, z_list, z_error_list, loss_list, loss_error_list, a_grad_list, a_grad_error_list)

# save the results
np.save(osp.join(folder, 'as.npy'), np.array(a_list))
np.save(osp.join(folder, 'z.npy'), np.array(z_list))
np.save(osp.join(folder, 'loss.npy'), np.array(loss_list))
np.save(osp.join(folder, 'a_grad.npy'), np.array(a_grad_list))
np.save(osp.join(folder, 'z_errors.npy'), np.array(z_error_list))
np.save(osp.join(folder, 'loss_errors.npy'), np.array(loss_error_list))
np.save(osp.join(folder, 'a_grad_errors.npy'), np.array(a_grad_error_list))
