import torch 
import torch.nn as nn
import torch.optim as optim 

import numpy as np
import os, sys

PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

import lib.utils as utils
from lib.ode_func import ODEFunc

import random
from random import SystemRandom

import scipy
import scipy.io as sio 
import scipy.linalg as scilin
from scipy.optimize import newton, brentq
from scipy.special import legendre, roots_legendre

import matplotlib.pyplot as plt

from torchdiffeq import odeint as odeint
import argparse

parser = argparse.ArgumentParser(description='.')
parser.add_argument('--id', type=int, default=0, help='random_seed')
parser.add_argument('--lE', type=int, default=2, help='random_seed')
parser.add_argument('--lS', type=int, default=2, help='random_seed')
parser.add_argument('--nE', type=int, default=5, help='random_seed')
parser.add_argument('--nS', type=int, default=5, help='random_seed')
parser.add_argument('--D1', type=int, default=2, help='random_seed')
parser.add_argument('--D2', type=int, default=2, help='random_seed')

args = parser.parse_args()

torch.set_default_dtype(torch.float64)
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

# 
mat_data = sio.loadmat("../data/TGC/GENERIC_Int_TGC_d001_N1_T30_L1_p2_E4_ad5_RK3.mat")['x']
h_ref = 0.001
Time = 30 
N_steps = int(np.floor(Time/h_ref)) + 1
t = np.expand_dims(np.linspace(0,Time,N_steps,endpoint=True,dtype=np.float64),axis=-1)[::1]
t = t/(t[-1])*3
data = mat_data[:N_steps:1,:2]

class ODEfunc(nn.Module):
	def __init__(self, output_dim):
		super(ODEfunc, self).__init__()
		self.output_dim = output_dim
		self.dimD = args.D1 
		self.dimD2 = args.D2 
		self.friction_D = nn.Parameter(torch.randn((self.dimD, self.dimD2), requires_grad=True))
		self.friction_L = nn.Parameter(torch.randn((self.output_dim, self.output_dim, self.dimD), requires_grad=True)) # [alpha, beta, m] or [mu, nu, n]

		self.poisson_xi = nn.Parameter(torch.randn((self.output_dim, self.output_dim, self.output_dim), requires_grad=True))

		self.energy = utils.create_net(output_dim, 1, n_layers=args.lE, n_units=args.nE, nonlinear = nn.Tanh).to(device)
		self.entropy = utils.create_net(output_dim, 1, n_layers=args.lS, n_units=args.nS, nonlinear = nn.Tanh).to(device)
		self.NFE = 0

	def Poisson_matvec(self,dE,dS):
		# zeta [alpha, beta, gamma]
		xi = (self.poisson_xi - self.poisson_xi.permute(0,2,1) + self.poisson_xi.permute(1,2,0) -
			self.poisson_xi.permute(1,0,2) + self.poisson_xi.permute(2,0,1) - self.poisson_xi.permute(2,1,0))/6.0
		
		# dE and dS [batch, alpha]
		LdE = torch.einsum('abc, zb, zc -> za',xi,dE,dS)
		return LdE 

	def friction_matvec(self,dE,dS): 	
		# D [m,n] L [alpha,beta,m] or [mu,nu,n] 
		D = self.friction_D @ torch.transpose(self.friction_D, 0, 1)
		L = (self.friction_L - torch.transpose(self.friction_L, 0, 1))/2.0
		zeta = torch.einsum('abm,mn,cdn->abcd',L,D,L) # zeta [alpha, beta, mu, nu] 
		MdS = torch.einsum('abmn,zb,zm,zn->za',zeta,dE,dS,dE)
		return MdS 

	def get_dSdt(self, dS, dE):
		xi = (self.poisson_xi - self.poisson_xi.permute(0,2,1) + self.poisson_xi.permute(1,2,0) -
			self.poisson_xi.permute(1,0,2) + self.poisson_xi.permute(2,0,1) - self.poisson_xi.permute(2,1,0))/6.0
		dSLdE = torch.einsum('abc, za, zb, zc -> z',xi,dS,dE,dS)

		D = self.friction_D @ torch.transpose(self.friction_D, 0, 1)
		L = (self.friction_L - torch.transpose(self.friction_L, 0, 1))/2.0
		zeta = torch.einsum('abm,mn,cdn->abcd',L,D,L) # zeta [alpha, beta, mu, nu] 
		dSMdS = torch.einsum('abmn,za,zb,zm,zn->z',zeta,dS,dE,dS,dE)

		return dSLdE + dSMdS

	def get_penalty(self):
		return self.LdS, self.MdE

	def forward(self, t, y):
		E = self.energy(y)
		S = self.entropy(y)

		dE = torch.autograd.grad(E.sum(), y, create_graph=True)[0]
		dS = torch.autograd.grad(S.sum(), y, create_graph=True)[0] 

		LdE = self.Poisson_matvec(dE,dS)
		MdS = self.friction_matvec(dE,dS)
		output = LdE  + MdS
		self.NFE = self.NFE + 1

		# compute penalty
		self.LdS = self.Poisson_matvec(dS,dS)
		self.MdE = self.friction_matvec(dE,dE)
		self.dSdt = self.get_dSdt(dS,dE)
		return output 

save_path = 'experiments/'
experimentID = args.id 

pred_x = np.zeros((30001,4))
dSdt = np.zeros((30001))


fig, ax = plt.subplots(figsize=(3,3), dpi=300)
fig_save_path = os.path.join(save_path,"experiment_"+str(experimentID))
save_npz_file = os.path.join(fig_save_path,'gnode_exp_{}.npz'.format(experimentID))
npzfile = np.load(save_npz_file)
pred_x[:,:] = npzfile['pred_x'][:,:4]

save_path = 'experiments/'
ckpt_path = os.path.join(save_path, "experiment_" + str(experimentID) + '.ckpt')

ckpt = torch.load(ckpt_path)

odefunc = ODEfunc(4) 
odefunc.load_state_dict(ckpt['state_dict'])


odefunc(None,torch.tensor(pred_x,requires_grad=True))
dSdt = odefunc.dSdt.detach().numpy()


plt.plot(t,dSdt,lw=1,ls='-',color='blue')
plt.hlines(0,0,3,lw=.5,ls='--',color='red')
ylim = ax.get_ylim()
plt.vlines(.5,ylim[0],ylim[1],lw=.5,ls='--',color='red')
plt.xlim(0,3)
plt.xticks([0,3], labels=['0','30'])
plt.xlabel('$t$ (second)')
plt.title(r'$\frac{dS}{dt}$ - GNODE')
plt.tight_layout()

save_file = os.path.join(fig_save_path,"dSdt_gnode.png")
plt.savefig(save_file)
plt.close(fig)
