import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import lib.utils as utils
#####################################################################################################

class ODEfunc_SPNN(nn.Module):
	def __init__(self, dim, lE, lS, nE, nS, device = torch.device("cpu")):
		super(ODEfunc_SPNN, self).__init__()
		self.dim = dim
		self.energy = utils.create_net(dim, 1, n_layers=lE, n_units=nE, nonlinear = nn.Tanh).to(device)
		self.entropy = utils.create_net(dim, 1, n_layers=lS, n_units=nS, nonlinear = nn.Tanh).to(device)
		self.D1 = dim 
		self.D2 = dim 
		self.friction_vector = nn.Parameter(torch.randn((self.D1, self.D2), requires_grad=True))
		self.Poisson_matrix = nn.Parameter(torch.randn((self.D1, self.D1), requires_grad=True))

		self.NFE = 0

	def Poisson_energy_matvec(self,y):
		L = (self.Poisson_matrix - torch.transpose(self.Poisson_matrix, 0, 1))/2.0
		return torch.einsum('ab,cb->ca',L,y)
	
	def friction_entropy_matvec(self,y): 	
		D = self.friction_vector @ torch.transpose(self.friction_vector, 0, 1)
		return torch.einsum('ab,cb->ca',D,y) 

	def get_penalty(self):
		return self.LdS, self.MdE

	def forward(self, t, y):
		E = self.energy(y)
		S = self.entropy(y)

		dE = torch.autograd.grad(E.sum(), y, create_graph=True)[0]
		dS = torch.autograd.grad(S.sum(), y, create_graph=True)[0] 

		LdE = self.Poisson_energy_matvec(dE)
		MdS = self.friction_entropy_matvec(dS)
		output = LdE  + MdS
		self.NFE = self.NFE + 1

		# compute penalty
		self.LdS = self.Poisson_energy_matvec(dS)
		self.MdE = self.friction_entropy_matvec(dE)
		return output 

class ODEfunc_GNODEv1(nn.Module):
	def __init__(self, dim, lE, lS, nE, nS, D1, D2, device = torch.device("cpu")):
		super(ODEfunc_GNODEv1, self).__init__()
		self.dim = dim
		self.D1 = D1 
		self.D2 = D2 
		self.friction_D = nn.Parameter(torch.randn((self.D1, self.D2), requires_grad=True))
		self.friction_L = nn.Parameter(torch.randn((self.dim, self.dim, self.D1), requires_grad=True)) # [alpha, beta, m] or [mu, nu, n]

		self.poisson_xi = nn.Parameter(torch.randn((self.dim, self.dim, self.dim), requires_grad=True))

		self.energy = utils.create_net(dim, 1, n_layers=lE, n_units=nE, nonlinear = nn.Tanh).to(device)
		self.entropy = utils.create_net(dim, 1, n_layers=lS, n_units=nS, nonlinear = nn.Tanh).to(device)
		self.NFE = 0

	def Poisson_matvec(self,dE,dS):
		# zeta [alpha, beta, gamma]
		xi = (self.poisson_xi - self.poisson_xi.permute(0,2,1) + self.poisson_xi.permute(1,2,0) -
			self.poisson_xi.permute(1,0,2) + self.poisson_xi.permute(2,0,1) - self.poisson_xi.permute(2,1,0))/6.0
		
		# dE and dS [batch, alpha]
		LdE = torch.einsum('abc, zb, zc -> za',xi,dE,dS)
		return LdE 

	def friction_matvec(self,dE,dS): 	
		# D [m,n] L [alpha,beta,m] or [mu,nu,n] 
		D = self.friction_D @ torch.transpose(self.friction_D, 0, 1)
		L = (self.friction_L - torch.transpose(self.friction_L, 0, 1))/2.0
		zeta = torch.einsum('abm,mn,cdn->abcd',L,D,L) # zeta [alpha, beta, mu, nu] 
		MdS = torch.einsum('abmn,zb,zm,zn->za',zeta,dE,dS,dE)
		return MdS 
	
	def get_penalty(self):
		return self.LdS, self.MdE

	def forward(self, t, y):

		E = self.energy(y)
		S = self.entropy(y)

		dE = torch.autograd.grad(E.sum(), y, create_graph=True)[0]
		dS = torch.autograd.grad(S.sum(), y, create_graph=True)[0] 

		LdE = self.Poisson_matvec(dE,dS)
		MdS = self.friction_matvec(dE,dS)
		output = LdE  + MdS
		#self.NFE = self.NFE + 1

		# compute penalty
		#self.LdS = self.Poisson_matvec(dS,dS)
		#self.MdE = self.friction_matvec(dE,dE)
		return output 

class ODEfunc_GFINN(nn.Module):
	def __init__(self, dim, lE, lS, nE, nS, K, device = torch.device("cpu")):
		super(ODEfunc_GFINN, self).__init__()
		self.dim = dim
		self.K = K

		self.energy = utils.create_net(dim, 1, n_layers=lE, n_units=nE, nonlinear = nn.Tanh).to(device)
		self.entropy = utils.create_net(dim, 1, n_layers=lS, n_units=nS, nonlinear = nn.Tanh).to(device)

		self.sigCompL = utils.create_net(dim, self.K**2, n_layers=lS, n_units=nS, nonlinear = nn.Tanh).to(device)
		self.sigCompM = utils.create_net(dim, self.K**2, n_layers=lE, n_units=nE, nonlinear = nn.Tanh).to(device)

		self.xiL = torch.nn.Parameter((torch.randn([self.K, self.dim, self.dim])*0.1).requires_grad_(True)) 
		self.xiM = torch.nn.Parameter((torch.randn([self.K, self.dim, self.dim])*0.1).requires_grad_(True)) 


	def compute_dS_L(self, y):
		sigComp = self.sigCompL(y).reshape(-1, self.K, self.K)
		sigma = sigComp - torch.transpose(sigComp, -1, -2)
		
		S = self.entropy(y)
		dS = torch.autograd.grad(S.sum(), y, create_graph=True)[0]
		ddS = dS.unsqueeze(-2)
		B = []
		for i in range(self.K):
			xi = torch.triu(self.xiL[i], diagonal = 1)
			xi = xi - torch.transpose(xi, -1,-2)
			B.append(ddS@xi)
		B = torch.cat(B, dim = -2)
		L = torch.transpose(B,-1,-2) @ sigma @ B
		return dS, L

	def compute_dE_M(self, y):
		sigComp = self.sigCompM(y).reshape(-1, self.K, self.K)
		sigma = sigComp @ torch.transpose(sigComp, -1, -2)

		E = self.energy(y)
		dE = torch.autograd.grad(E.sum(), y, create_graph=True)[0]
		ddE = dE.unsqueeze(-2)
		B = []
		for i in range(self.K):
			xi = torch.triu(self.xiM[i], diagonal = 1)
			xi = xi - torch.transpose(xi, -1,-2)
			B.append(ddE@xi)
		B = torch.cat(B, dim = -2)
		M = torch.transpose(B,-1,-2) @ sigma @ B
		#print(dE.shape, M.shape)
		#exit()
		return dE, M

	def get_penalty(self):
		return self.LdS, self.MdE

	def forward(self, t, y):
		dE, M = self.compute_dE_M(y)
		dS, L = self.compute_dS_L(y)

		dE = dE.unsqueeze(1)
		dS = dS.unsqueeze(1)

		output = -(dE @ L).squeeze() + (dS @ M).squeeze() 

		self.MdE = dE @ M
		self.LdS = dS @ L
		return output
		
class ODEfunc_GNODEv2(nn.Module):
	def __init__(self, dim, lE, lS, lA, lB, lD, nE, nS, nA, nB, nD, D, C2, device = torch.device("cpu")):
		super(ODEfunc_GNODEv2, self).__init__()
		self.dim = dim
		self.D = D
		self.C2 = C2

		# Build strictly lower triangular mat-valued func for reversible part
		self.poisson_A = utils.create_net(
			dim, dim * (dim-1) // 2,
			n_layers=lA, n_units=nA,
			nonlinear = nn.Tanh).to(device)
		self.poisson_A_idx = torch.tril_indices(
			row=dim, col=dim, offset=-1)

		# Build mat-valued funcs for irreversible part	
		self.friction_C = utils.create_net(
			dim, self.D * self.C2,
			n_layers=lD, n_units=nD, 
			nonlinear = nn.Tanh).to(device)
		self.friction_B = utils.create_net(
			dim, self.D * dim,
			n_layers=lB, n_units=nB,
			nonlinear = nn.Tanh).to(device)
		
		# Build energy and entropy networks
		self.energy = utils.create_net(dim, 1, n_layers=lE,
				n_units=nE, nonlinear = nn.Tanh).to(device)
		self.entropy = utils.create_net(dim, 1, n_layers=lS,
				n_units=nS, nonlinear = nn.Tanh).to(device)
		self.NFE = 0

	def poisson_product(self, y, dE, dS):
		# A is [i,j], dE and dS are [z,i] where z is batch dim
		bdim = y.shape[0]
		Amat = torch.zeros(bdim, self.dim, self.dim)
		Amat[:,self.poisson_A_idx[0], self.poisson_A_idx[1]] = self.poisson_A(y)
		A      = Amat - Amat.transpose(1,2)
		AdE    = torch.bmm(dE.unsqueeze(1), torch.transpose(A, 1,2)).squeeze(1)  # [z,i]
		AdS    = torch.bmm(dS.unsqueeze(1), torch.transpose(A, 1,2)).squeeze(1)  # [z,i]
		correc = (dE * AdS).sum(axis=-1).view(-1,1) * dS \
				  - (dE * dS).sum(axis=-1).view(-1,1) * AdS
		ndS2   = (dS**2).sum(axis=-1).view(-1,1)
		LdE    = AdE + correc / ndS2
		return LdE

	def friction_product(self, y, dE, dS): 	
		# C is [m,n], B is [m,i]  
		bdim = y.shape[0]
		# C is [m,n], B is [m,i]  
		C      = self.friction_C(y).view(bdim, self.D, self.C2)
		B      = self.friction_B(y).view(bdim, self.D, self.dim)
		D      = torch.bmm(C, torch.transpose(C, 1,2))
		BdotdE = torch.bmm(dE.unsqueeze(1), torch.transpose(B, 1,2)).squeeze(1)  # [z,m]
		BdotdS = torch.bmm(dS.unsqueeze(1), torch.transpose(B, 1,2)).squeeze(1)  # [z,m]
		ndE2   = (dE**2).sum(axis=-1)
		BdE    = B - BdotdE.unsqueeze(-1) * dE.unsqueeze(1)/ndE2.view(-1,1,1)
		
		# BdE is [z,m,i]
		dEdS   = (dE * dS).sum(axis=-1)
		BdEdS  = BdotdS - BdotdE * dEdS.view(-1,1) / ndE2.view(-1,1) 
		MdS    = (torch.matmul(BdE.permute(0,2,1), D) * BdEdS.unsqueeze(1)).sum(axis=-1)
		return MdS 

	def get_penalty(self, y, dE, dS):
		# compute penalty
		LdS = self.poisson_product(y, dS, dS)
		MdE = self.friction_product(y, dE, dE)
		return LdS, MdE

	def forward(self, t, y):
		#dE, dS are [z,i]
		E   = self.energy(y)
		#S   = y[:,-1]#self.entropy(y)
		S   = self.entropy(y)
		dE  = torch.autograd.grad(E.sum(), y, create_graph=True)[0]
		dS  = torch.autograd.grad(S.sum(), y, create_graph=True)[0] 
		LdE = self.poisson_product(y, dE, dS)
		MdS = self.friction_product(y, dE, dS)
		self.NFE += 1
		self.LdS, self.MdE = self.get_penalty(y, dE, dS)
		#print(self.LdS.detach().numpy(), self.MdE.detach().numpy())
		return LdE + MdS

