import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import lib.utils as utils

#####################################################################################################

class ODEfuncRegGC(nn.Module):
	def __init__(self, dim, nlayer, nunit, aug_dim=0, drop_rate=0.0, device = torch.device("cpu")):
		super(ODEfuncRegGC, self).__init__()
		self.nvar = 3 + aug_dim
		self.gradient_nets = nn.ModuleList([utils.create_net_reg(dim, dim//self.nvar, n_layers=nlayer, n_units=nunit, nonlinear = nn.Tanh).to(device) for i in range(self.nvar)])
		self.NFE = 0

	def forward(self, t, y):
		batch_size = y.shape[0]
		output = torch.cat([self.gradient_nets[i](y).unsqueeze(-1) for i in range(self.nvar)],dim=2).reshape(batch_size, -1)
		return output 

class DDEfuncGC(nn.Module):
	def __init__(self, dim, nlayer, nunit, device = torch.device("cpu")):
		super(DDEfuncGC, self).__init__()
		self.gradient_net = utils.create_net(dim, dim, n_layers=nlayer, n_units=nunit, nonlinear = nn.Tanh).to(device)
		self.NFE = 0
		self.dim = dim

	def forward(self, t, y):
		output = self.gradient_net(y)
		return output 
