import torch 
import torch.nn as nn
import torch.optim as optim 

import numpy as np
import os, sys

PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

import lib.utils as utils
from lib.ode_func import ODEFunc

import random
from random import SystemRandom

import scipy
import scipy.io as sio 
import scipy.linalg as scilin
from scipy.optimize import newton, brentq
from scipy.special import legendre, roots_legendre

import matplotlib.pyplot as plt

from torchdiffeq import odeint as odeint

torch.set_default_dtype(torch.float64)
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

seed = 0 
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

save_path = 'experiments/'
utils.makedirs(save_path)
experimentID = int(SystemRandom().random()*100000)
ckpt_path = os.path.join(save_path, "experiment_" + str(experimentID) + '.ckpt')
fig_save_path = os.path.join(save_path,"experiment_"+str(experimentID))
utils.makedirs(fig_save_path)
ckpt_path_outer = os.path.join(save_path, "experiment_" + str(experimentID) + '_outer.ckpt')
print(ckpt_path)

# 
mat_data = sio.loadmat("../data/TGC/GENERIC_Int_TGC_d001_N1_T30_L1_p2_E4_ad5_RK3.mat")['x']
h_ref = 0.001
Time = 30 
N_steps = int(np.floor(Time/h_ref)) + 1
t = np.expand_dims(np.linspace(0,Time,N_steps,endpoint=True,dtype=np.float64),axis=-1)[::1] 
t = t/(t[-1])*3

data = mat_data[:N_steps:1,:2]

class ODEfunc(nn.Module):
	def __init__(self, output_dim):
		super(ODEfunc, self).__init__()
		self.output_dim = output_dim
		self.net = utils.create_net(output_dim, output_dim, n_layers=3, n_units=10, nonlinear = nn.Tanh).to(device)
		self.NFE = 0

	def forward(self, t, y):
		self.NFE = self.NFE + 1
		return self.net(y) 

train_end = 10000
t_q = torch.tensor(t[:train_end].copy()).squeeze() 
data_qp = torch.tensor(data[:train_end].copy(),requires_grad=True)

def get_batch(data_qps, batch_len=80, batch_size=40):
	s = torch.from_numpy(np.random.choice(np.arange(train_end - batch_len, dtype=np.int64), batch_size, replace=False))
	batch_y0 = data_qps[s]  # (M, D)
	batch_t = t_q[:batch_len]  # (T)
	batch_y = torch.stack([data_qps[s + i] for i in range(batch_len)], dim=0)[:,:,:2]  # (T, M, D)
	return batch_y0.to(device), batch_t.to(device), batch_y.to(device)

odefunc = ODEfunc(2) 

best_loss = 1e30
best_loss_outer = 1e30
params = odefunc.parameters()
optimizer = optim.Adamax(params, lr=1e-2)
x0 = mat_data[:1,:2].copy()
x0 = torch.tensor(x0, requires_grad = True).to(device)

t_val = torch.tensor(t[:train_end*2].copy()).to(device).squeeze() 
data_val = torch.tensor(data[:train_end*2].copy())
frame = 0


for itr in range(1, 50001):
	optimizer.zero_grad()
	batch_y0, batch_t, batch_y = get_batch(data_qp)
	pred_y = odeint(odefunc, batch_y0, batch_t, rtol=1e-5, atol=1e-6, method='dopri5').to(device)
	loss = torch.mean((pred_y[:,:,:2] - batch_y).pow(2).sum(axis=1))
	print(itr, loss.item())
	loss.backward()
	optimizer.step()

	if itr % 500 == 0:
		pred_x = odeint(odefunc, x0, t_val, rtol=1e-5, atol=1e-6, method='dopri5').squeeze()
		loss = torch.mean(torch.abs(pred_x[:,:2] - data_val).pow(2).sum(axis=1))
		print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
		data_s = pred_x[:train_end,2:].clone().detach().requires_grad_(True)
		data_qps = torch.cat((data_qp,data_s),axis=-1)
		
	
		if best_loss > loss.item():
			print('saving...', loss.item())
			torch.save({
				'state_dict': odefunc.state_dict(),
				}, ckpt_path)
			best_loss = loss.item()
		print(loss.item())


		plt.figure()
		plt.tight_layout()
		save_file = os.path.join(fig_save_path,"image_{:03d}.png".format(frame))
		fig = plt.figure(figsize=(8,4))
		axes = []
		for i in range(2):
			axes.append(fig.add_subplot(1,2,i+1))
			if i < 2:
				axes[i].plot(t_val,data_val[:,i],lw=2,color='k')
				axes[i].plot(t_val,pred_x.detach().numpy()[:,i],lw=2,color='c',ls='--')
		plt.savefig(save_file)
		plt.close(fig)
		plt.close('all')
		plt.clf()
		frame += 1
