import torch 
import torch.nn as nn
import torch.optim as optim 

import numpy as np
import os, sys

PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

import lib.utils as utils
from lib.ode_func import ODEFunc

import random
from random import SystemRandom

import scipy
import scipy.io as sio 
import scipy.linalg as scilin
from scipy.optimize import newton, brentq
from scipy.special import legendre, roots_legendre

import matplotlib.pyplot as plt

from torchdiffeq import odeint as odeint
import argparse

parser = argparse.ArgumentParser(description='.')
parser.add_argument('--r', type=int, default=0, help='random_seed')
parser.add_argument('--lE', type=int, default=2, help='random_seed')
parser.add_argument('--lS', type=int, default=2, help='random_seed')
parser.add_argument('--nE', type=int, default=5, help='random_seed')
parser.add_argument('--nS', type=int, default=5, help='random_seed')
parser.add_argument('--D1', type=int, default=2, help='random_seed')
parser.add_argument('--D2', type=int, default=2, help='random_seed')

args = parser.parse_args()

torch.set_default_dtype(torch.float64)
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

seed = args.r 
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

save_path = 'experiments/'
utils.makedirs(save_path)
experimentID = int(SystemRandom().random()*100000)
ckpt_path = os.path.join(save_path, "experiment_" + str(experimentID) + '.ckpt')
fig_save_path = os.path.join(save_path,"experiment_"+str(experimentID))
utils.makedirs(fig_save_path)
ckpt_path_outer = os.path.join(save_path, "experiment_" + str(experimentID) + '_outer.ckpt')
print(ckpt_path)

# 
mat_data = sio.loadmat("../data/DHO/GENERIC_Int_Nonlinear_d001_gamma_d01_k1_N1_T180_q0-2_RK3.mat")['x']
h_ref = 0.001
Time = 60 
N_steps = int(np.floor(Time/h_ref)) + 1
t = np.expand_dims(np.linspace(0,Time,N_steps,endpoint=True,dtype=np.float64),axis=-1)[::1] 
t = t/(t[-1])*3

data = mat_data[:N_steps:1,:2]
s_data = mat_data[:N_steps:1,2]

class ODEfunc(nn.Module):
	def __init__(self, output_dim):
		super(ODEfunc, self).__init__()
		self.output_dim = output_dim
		self.dimD = args.D1 
		self.dimD2 = args.D2 
		self.friction_D = nn.Parameter(torch.randn((self.dimD, self.dimD2), requires_grad=True))
		self.friction_L = nn.Parameter(torch.randn((self.output_dim, self.output_dim, self.dimD), requires_grad=True)) # [alpha, beta, m] or [mu, nu, n]

		self.poisson_xi = nn.Parameter(torch.randn((self.output_dim, self.output_dim, self.output_dim), requires_grad=True))

		self.energy = utils.create_net(output_dim, 1, n_layers=args.lE, n_units=args.nE, nonlinear = nn.Tanh).to(device)
		self.entropy = utils.create_net(output_dim, 1, n_layers=args.lS, n_units=args.nS, nonlinear = nn.Tanh).to(device)
		self.NFE = 0

	def Poisson_matvec(self,dE,dS):
		# zeta [alpha, beta, gamma]
		xi = (self.poisson_xi - self.poisson_xi.permute(0,2,1) + self.poisson_xi.permute(1,2,0) -
			self.poisson_xi.permute(1,0,2) + self.poisson_xi.permute(2,0,1) - self.poisson_xi.permute(2,1,0))/6.0
		
		# dE and dS [batch, alpha]
		LdE = torch.einsum('abc, zb, zc -> za',xi,dE,dS)
		return LdE 

	def friction_matvec(self,dE,dS): 	
		# D [m,n] L [alpha,beta,m] or [mu,nu,n] 
		D = self.friction_D @ torch.transpose(self.friction_D, 0, 1)
		L = (self.friction_L - torch.transpose(self.friction_L, 0, 1))/2.0
		zeta = torch.einsum('abm,mn,cdn->abcd',L,D,L) # zeta [alpha, beta, mu, nu] 
		MdS = torch.einsum('abmn,zb,zm,zn->za',zeta,dE,dS,dE)
		return MdS 
	
	def get_penalty(self):
		return self.LdS, self.MdE

	def forward(self, t, y):
		E = self.energy(y)
		S = self.entropy(y)

		dE = torch.autograd.grad(E.sum(), y, create_graph=True)[0]
		dS = torch.autograd.grad(S.sum(), y, create_graph=True)[0] 

		LdE = self.Poisson_matvec(dE,dS)
		MdS = self.friction_matvec(dE,dS)
		output = LdE  + MdS
		self.NFE = self.NFE + 1

		# compute penalty
		self.LdS = self.Poisson_matvec(dS,dS)
		self.MdE = self.friction_matvec(dE,dE)
		return output 


train_end = 20000
t_q = torch.tensor(t[:train_end].copy()).squeeze()
data_qp = torch.tensor(data[:train_end].copy(),requires_grad=True)
data_s = torch.unsqueeze(torch.linspace(0.,1.,train_end,requires_grad=True),-1)
 
data_qps = torch.cat((data_qp,data_s),axis=-1)


def get_batch(data_qps, batch_len=120, batch_size=20):
	s = torch.from_numpy(np.random.choice(np.arange(train_end - batch_len, dtype=np.int64), batch_size, replace=False))
	batch_y0 = data_qps[s]  # (M, D)
	batch_t = t_q[:batch_len]  # (T)
	batch_y = torch.stack([data_qps[s + i] for i in range(batch_len)], dim=0)[:,:,:2]  # (T, M, D)
	return batch_y0.to(device), batch_t.to(device), batch_y.to(device)

odefunc = ODEfunc(3) 

best_loss = 1e30
best_loss_outer = 1e30
params = odefunc.parameters()
optimizer = optim.Adamax(params, lr=1e-2)
x0 = mat_data[:1,:3].copy()
x0 = torch.tensor(x0, requires_grad = True).to(device)
t_val = torch.tensor(t[:train_end*2].copy()).to(device).squeeze()
data_val = torch.tensor(data[:train_end*2].copy())
frame = 0


for itr in range(1, 30001):
	optimizer.zero_grad()
	batch_y0, batch_t, batch_y = get_batch(data_qps)
	pred_y = odeint(odefunc, batch_y0, batch_t, rtol=1e-5, atol=1e-6, method='dopri5').to(device)
	loss = torch.mean(torch.abs(pred_y[:,:,:2] - batch_y).sum(axis=1))
	print(itr, loss.item())
	
	loss.backward()
	optimizer.step()

	if itr % 500 == 0:
		pred_x = odeint(odefunc, x0, t_val, rtol=1e-5, atol=1e-6, method='dopri5').squeeze()
		loss = torch.mean(torch.abs(pred_x[:,:2] - data_val).pow(2).sum(axis=1))
		print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
		data_s = pred_x[:train_end,2:].clone().detach().requires_grad_(True)
		data_qps = torch.cat((data_qp,data_s),axis=-1)
	
		if best_loss > loss.item():
			print('saving ode...', loss.item())
			torch.save({
				'state_dict': odefunc.state_dict(),                                           
				}, ckpt_path)
			best_loss = loss.item()
		print(loss.item())

		plt.figure()
		plt.tight_layout()
		save_file = os.path.join(fig_save_path,"image_{:03d}.png".format(frame))
		fig = plt.figure(figsize=(12,4))
		axes = []
		for i in range(3):
			axes.append(fig.add_subplot(1,3,i+1))
			if i < 2:
				axes[i].plot(t_val,data_val[:,i],lw=2,color='k')
				axes[i].plot(t_val,pred_x.detach().numpy()[:,i],lw=2,color='c',ls='--')
			else:
				axes[i].plot(t_val,pred_x.detach().numpy()[:,i],lw=2,color='c',ls='--')
		plt.savefig(save_file)
		plt.close(fig)
		plt.close('all')
		plt.clf()
		frame += 1
