import torch 
import torch.nn as nn
import torch.optim as optim 

import numpy as np
import os, sys

PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

import lib.utils as utils
from lib.ode_func import ODEFunc

import random
from random import SystemRandom

import scipy
import scipy.io as sio 
import scipy.linalg as scilin
from scipy.optimize import newton, brentq
from scipy.special import legendre, roots_legendre

import matplotlib.pyplot as plt

from torchdiffeq import odeint as odeint
import argparse

parser = argparse.ArgumentParser(description='.')
parser.add_argument('--id', type=int, default=0, help='random_seed')
parser.add_argument('--lE', type=int, default=2, help='random_seed')
parser.add_argument('--lS', type=int, default=2, help='random_seed')
parser.add_argument('--nE', type=int, default=5, help='random_seed')
parser.add_argument('--nS', type=int, default=5, help='random_seed')
parser.add_argument('--D1', type=int, default=2, help='random_seed')
parser.add_argument('--D2', type=int, default=2, help='random_seed')

args = parser.parse_args()

torch.set_default_dtype(torch.float64)
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

save_path = 'experiments/'
experimentID = args.id 
ckpt_path = os.path.join(save_path, "experiment_" + str(experimentID) + '.ckpt')
fig_save_path = os.path.join(save_path,"experiment_"+str(experimentID))
utils.makedirs(fig_save_path)
print(ckpt_path)

# 
mat_data = sio.loadmat("../data/TGC/GENERIC_Int_TGC_d001_N1_T30_L1_p2_E4_ad5_RK3.mat")['x']
h_ref = 0.001
Time = 30 
N_steps = int(np.floor(Time/h_ref)) + 1
t = np.expand_dims(np.linspace(0,Time,N_steps,endpoint=True,dtype=np.float64),axis=-1)[::1] 
t = t/(t[-1])*3

S_ref = mat_data[:N_steps:1,2]
E_ref = mat_data[:N_steps:1,-1]
data = mat_data[:N_steps:1,:2]

class ODEfunc(nn.Module):
	def __init__(self, output_dim):
		super(ODEfunc, self).__init__()
		self.output_dim = output_dim
		self.dimD = args.D1 
		self.dimD2 = args.D2 
		self.friction_D = nn.Parameter(torch.randn((self.dimD, self.dimD2), requires_grad=True))
		self.friction_L = nn.Parameter(torch.randn((self.output_dim, self.output_dim, self.dimD), requires_grad=True)) # [alpha, beta, m] or [mu, nu, n]

		self.poisson_xi = nn.Parameter(torch.randn((self.output_dim, self.output_dim, self.output_dim), requires_grad=True))

		self.energy = utils.create_net(output_dim, 1, n_layers=args.lE, n_units=args.nE, nonlinear = nn.Tanh).to(device)
		self.entropy = utils.create_net(output_dim, 1, n_layers=args.lS, n_units=args.nS, nonlinear = nn.Tanh).to(device)
		self.NFE = 0

	def Poisson_matvec(self,dE,dS):
		# zeta [alpha, beta, gamma]
		xi = (self.poisson_xi - self.poisson_xi.permute(0,2,1) + self.poisson_xi.permute(1,2,0) -
			self.poisson_xi.permute(1,0,2) + self.poisson_xi.permute(2,0,1) - self.poisson_xi.permute(2,1,0))/6.0
		
		# dE and dS [batch, alpha]
		LdE = torch.einsum('abc, zb, zc -> za',xi,dE,dS)
		return LdE 

	def friction_matvec(self,dE,dS): 	
		# D [m,n] L [alpha,beta,m] or [mu,nu,n] 
		D = self.friction_D @ torch.transpose(self.friction_D, 0, 1)
		L = (self.friction_L - torch.transpose(self.friction_L, 0, 1))/2.0
		zeta = torch.einsum('abm,mn,cdn->abcd',L,D,L) # zeta [alpha, beta, mu, nu] 
		MdS = torch.einsum('abmn,zb,zm,zn->za',zeta,dE,dS,dE)
		return MdS 
	
	def get_penalty(self):
		return self.LdS, self.MdE

	def forward(self, t, y):
		E = self.energy(y)
		S = self.entropy(y)

		dE = torch.autograd.grad(E.sum(), y, create_graph=True)[0]
		dS = torch.autograd.grad(S.sum(), y, create_graph=True)[0] 

		LdE = self.Poisson_matvec(dE,dS)
		MdS = self.friction_matvec(dE,dS)
		output = LdE  + MdS
		self.NFE = self.NFE + 1

		# compute penalty
		self.LdS = self.Poisson_matvec(dS,dS)
		self.MdE = self.friction_matvec(dE,dE)
		return output 


odefunc = ODEfunc(4) 

x0 = mat_data[:1,:4].copy()
x0[0,2], x0[0,3] = 0, 0 
x0 = torch.tensor(x0, requires_grad = True).to(device)

ckpt = torch.load(ckpt_path)
odefunc.load_state_dict(ckpt['state_dict'])

odefunc.NFE = 0

t = torch.tensor(t).to(device).squeeze()
pred_x = odeint(odefunc, x0, t, rtol=1e-5, atol=1e-6, method='dopri5').squeeze()

fig = plt.figure(figsize=(12,4))
axes = []
for i in range(2):
	axes.append(fig.add_subplot(1,2,i+1))
	axes[i].plot(t,data[:,i],lw=3,color='k')
	axes[i].plot(t,pred_x.detach().numpy()[:,i],lw=2,color='c',ls='--')

save_file = os.path.join(fig_save_path,"image_best_post.png") 
plt.savefig(save_file)

# E = p^2/2m + E1 + E2    m = K = T = 1
pred_x = pred_x.detach().numpy()
KPE_post = .5*pred_x[:,1]**2 
KPE = .5*data[:,1]**2
fig = plt.figure(figsize=(4,4))
plt.plot(t,KPE,'k')
plt.plot(t,KPE_post,'r:')
save_file = os.path.join(fig_save_path,"KE+PE.png")
plt.savefig(save_file)
plt.close(fig)

save_npz_file = os.path.join(fig_save_path,'gnode_exp_{}.npz'.format(experimentID))        
np.savez(save_npz_file,pred_x=pred_x)
