import torch 
import torch.nn as nn
import torch.optim as optim 

import numpy as np
import os, sys

PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

import lib.utils as utils
from lib.ode_func import ODEFunc

import random
from random import SystemRandom

import scipy
import scipy.io as sio 
import scipy.linalg as scilin
from scipy.optimize import newton, brentq
from scipy.special import legendre, roots_legendre

import matplotlib.pyplot as plt

from torchdiffeq import odeint as odeint
import argparse
parser = argparse.ArgumentParser(description='.')
parser.add_argument('--r', type=int, default=0, help='random_seed')
parser.add_argument('--lE', type=int, default=1, help='random_seed')
parser.add_argument('--lS', type=int, default=1, help='random_seed')
parser.add_argument('--nE', type=int, default=10, help='random_seed')
parser.add_argument('--nS', type=int, default=10, help='random_seed')
parser.add_argument('--pen', type=float, default=1e-4, help='random_seed')

args = parser.parse_args()

torch.set_default_dtype(torch.float64)
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

seed = args.r 
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

save_path = 'experiments/'
utils.makedirs(save_path)
experimentID = int(SystemRandom().random()*100000)
ckpt_path = os.path.join(save_path, "experiment_" + str(experimentID) + '.ckpt')
fig_save_path = os.path.join(save_path,"experiment_"+str(experimentID))
utils.makedirs(fig_save_path)
ckpt_path_outer = os.path.join(save_path, "experiment_" + str(experimentID) + '_outer.ckpt')
print(ckpt_path)

# 
mat_data = sio.loadmat("../data/DHO/GENERIC_Int_Nonlinear_d001_gamma_d01_k1_N1_T180_q0-2_RK3.mat")['x']
h_ref = 0.001
Time = 60 
N_steps = int(np.floor(Time/h_ref)) + 1
t = np.expand_dims(np.linspace(0,Time,N_steps,endpoint=True,dtype=np.float64),axis=-1)[::1] #- Time/4.0
t = t/(t[-1])*3

data = mat_data[:N_steps:1,:2]
s_data = mat_data[:N_steps:1,2]

class ODEfunc(nn.Module):
	def __init__(self, output_dim):
		super(ODEfunc, self).__init__()
		self.output_dim = output_dim
		self.energy = utils.create_net(output_dim, 1, n_layers=args.lE, n_units=args.nE, nonlinear = nn.Tanh).to(device)
		self.entropy = utils.create_net(output_dim, 1, n_layers=args.lS, n_units=args.nS, nonlinear = nn.Tanh).to(device)
		self.dimD = 3 
		self.dimD2 = 3 
		self.friction_vector = nn.Parameter(torch.randn((self.dimD, self.dimD2), requires_grad=True))
		self.Poisson_matrix = nn.Parameter(torch.randn((self.dimD, self.dimD), requires_grad=True))

		self.NFE = 0

	def Poisson_energy_matvec(self,y):
		L = (self.Poisson_matrix - torch.transpose(self.Poisson_matrix, 0, 1))/2.0
		return torch.einsum('ab,cb->ca',L,y)
	
	def friction_entropy_matvec(self,y): 	
		D = self.friction_vector @ torch.transpose(self.friction_vector, 0, 1)
		return torch.einsum('ab,cb->ca',D,y) 

	def get_penalty(self):
		return self.LdS, self.MdE

	def forward(self, t, y):
		E = self.energy(y)
		S = self.entropy(y)

		dE = torch.autograd.grad(E.sum(), y, create_graph=True)[0]
		dS = torch.autograd.grad(S.sum(), y, create_graph=True)[0] 

		LdE = self.Poisson_energy_matvec(dE)
		MdS = self.friction_entropy_matvec(dS)
		output = LdE  + MdS
		self.NFE = self.NFE + 1

		# compute penalty
		self.LdS = self.Poisson_energy_matvec(dS)
		self.MdE = self.friction_entropy_matvec(dE)
		return output 

train_end = 20000
t_q = torch.tensor(t[:train_end].copy()).squeeze() 
data_qp = torch.tensor(data[:train_end].copy(),requires_grad=True)
data_s = torch.unsqueeze(torch.linspace(0.,1.,train_end,requires_grad=True),-1)
 
data_qps = torch.cat((data_qp,data_s),axis=-1)


def get_batch(data_qps, batch_len=120, batch_size=20):
	s = torch.from_numpy(np.random.choice(np.arange(train_end - batch_len, dtype=np.int64), batch_size, replace=False))
	batch_y0 = data_qps[s]  # (M, D)
	batch_t = t_q[:batch_len]  # (T)
	batch_y = torch.stack([data_qps[s + i] for i in range(batch_len)], dim=0)[:,:,:2]  # (T, M, D)
	return batch_y0.to(device), batch_t.to(device), batch_y.to(device)

odefunc = ODEfunc(3) 

best_loss = 1e30
best_loss_outer = 1e30
params = odefunc.parameters()
optimizer = optim.Adamax(params, lr=1e-2)
x0 = mat_data[:1,:3].copy()
x0 = torch.tensor(x0, requires_grad = True).to(device)
t_val = torch.tensor(t[:train_end*2].copy()).to(device).squeeze() 
data_val = torch.tensor(data[:train_end*2].copy())
frame = 0


for itr in range(1, 30001):
	optimizer.zero_grad()
	batch_y0, batch_t, batch_y = get_batch(data_qps)
	pred_y = odeint(odefunc, batch_y0, batch_t, rtol=1e-5, atol=1e-6, method='dopri5').to(device)
	data_loss = torch.mean(torch.abs(pred_y[:,:,:2] - batch_y).sum(axis=1))
	odeeval = odefunc(batch_t, pred_y.reshape((-1,3)))
	p1, p2 = odefunc.get_penalty()
	pen1 = torch.mean(p1.pow(2).sum(axis=-1))
	pen2 = torch.mean(p2.pow(2).sum(axis=-1))
	loss = data_loss + args.pen*(pen1+pen2)
	print(itr, loss.item(), data_loss.item(), pen1.item(), pen2.item())
	loss.backward()
	optimizer.step()

	if itr % 500 == 0:
		pred_x = odeint(odefunc, x0, t_val, rtol=1e-5, atol=1e-6, method='dopri5').squeeze()
		loss = torch.mean(torch.abs(pred_x[:,:2] - data_val).pow(2).sum(axis=1))
		print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
		data_s = pred_x[:train_end,2:].clone().detach().requires_grad_(True)
		data_qps = torch.cat((data_qp,data_s),axis=-1)
	
		if best_loss > loss.item():
			print('saving ode...', loss.item())
			torch.save({
				'state_dict': odefunc.state_dict(),                                           
				}, ckpt_path)
			best_loss = loss.item()
		print(loss.item())	


		plt.figure()
		plt.tight_layout()
		save_file = os.path.join(fig_save_path,"image_{:03d}.png".format(frame))
		fig = plt.figure(figsize=(12,4))
		axes = []
		for i in range(3):
			axes.append(fig.add_subplot(1,3,i+1))
			if i < 2:
				axes[i].plot(t_val,data_val[:,i],lw=2,color='k')
				axes[i].plot(t_val,pred_x.detach().numpy()[:,i],lw=2,color='c',ls='--')
			else:
				axes[i].plot(t_val,pred_x.detach().numpy()[:,i],lw=2,color='c',ls='--')
		plt.savefig(save_file)
		plt.close(fig)
		plt.close('all')
		plt.clf()
		frame += 1
