


import torch
from torch import nn
import device
import sys

# Get cpu or gpu device for training.
#device = "cuda" if torch.cuda.is_available() else "cpu"
#print(f"Using {device} device")

# Define model
class neuralEQ(nn.Module):
	def __init__(self,inSize,outSize,mod,nnSel=0):
		super().__init__()
		#self.batchSize = batchSize
		#self.linear1=nn.Linear(inSize, 128)
		#self.linear2=nn.Linear(128, 128)
		#self.linear3=nn.Linear(128, 128)
		#self.linear4=nn.Linear(inSize, inSize)
		#self.linearF=nn.Linear(128,outSize,bias=False)
		#initWeights = torch.zeros(inSize)
		#initWeights[0] = 1
		#self.dense1 = nn.Linear(inSize, outSize)
		#nn.init.constant_(self.dense1.weight.data, 0)
		#nn.init.constant_(self.dense1.bias.data, 0)
		#self.linear()
		self.nnSel = nnSel
		self.mod_ = mod
		#self.nnLR = nn.Sequential(nn.Linear(1,1), nn.ReLU())
		#self.nnLR2 = nn.Sequential(nn.Linear(2,1), nn.ReLU())
		#self.nnLR3 = nn.Sequential(nn.Linear(3,1), nn.ReLU())
		if self.mod_ == 'nrz':
			self.nn = nn.Sequential(
				##nn.Linear(inSize,32),
				###nn.BatchNorm1d(32),
				##nn.ReLU(),
				###nn.Dropout(p=0.5),

				### NRZ ###
				nn.Linear(inSize,32),
				nn.ReLU(),
				nn.Linear(32,32),
				nn.ReLU(),

				#nn.Linear(32,32),
				#nn.ReLU(),
				#nn.Linear(32,32),
				#nn.ReLU(),

				nn.Linear(32,outSize,bias=True),
				#nn.Tanh(),
			)
	
		else:
			if self.nnSel == 0:
				self.nn = nn.Sequential(
					### PAM4 ###
					nn.Linear(inSize,512),
					nn.ReLU(),
					nn.Linear(512,512),
					nn.ReLU(),
					#nn.Linear(512,512),
					#nn.ReLU(),
					nn.Linear(512,outSize,bias=True),

					#nn.Linear(inSize,64),
					#nn.ReLU(),
					#nn.Linear(64,64),
					#nn.ReLU(),
					#nn.Linear(64,outSize,bias=True),
		
				)
			elif self.nnSel == 1:
				self.nn0 = nn.Sequential(
					### PAM4 ###
					nn.Linear(inSize,512),
					nn.ReLU(),
					nn.Linear(512,512),
					nn.ReLU(),
					nn.Linear(512,1,bias=True),
				)
				self.nn1 = nn.Sequential(
					### PAM4 ###
					nn.Linear(inSize,512),
					nn.ReLU(),
					nn.Linear(512,512),
					nn.ReLU(),
					nn.Linear(512,1,bias=True),
				)


				self.nn2 = nn.Sequential(
					### PAM4 ###
					nn.Linear(inSize,512),
					nn.ReLU(),
					nn.Linear(512,512),
					nn.ReLU(),
					nn.Linear(512,1,bias=True),
				)


				self.nn3 = nn.Sequential(
					### PAM4 ###
					nn.Linear(inSize,512),
					nn.ReLU(),
					nn.Linear(512,512),
					nn.ReLU(),
					nn.Linear(512,1,bias=True),
				)







	def forward(self, x):
		if 1:
			if self.nnSel == 0:
				out = self.nn(x)
			elif self.nnSel == 1:
				out0 = self.nn0(x)
				out1 = self.nn1(x)
				out2 = self.nn2(x)
				out3 = self.nn3(x)

				out = torch.cat((out0,out1,out2,out3),dim=1)
		else: #seperate ver.
			#out=torch.zeros(4,device=device.device, dtype=torch.float)
			#print(f'self.nn(x).size : {self.nn(x).size}')
			#print(f'out[0] : {out[0].size}')
			l00=self.nnLR(x[0])
			l10=self.nnLR(l00)

			l01=self.nnLR2(torch.cat(l00,x[1]))
			l11=self.nnLR2(torch.cat(l10,l01))

			l02=self.nnLR2(torch.cat(l01,x[2]))
			l12=self.nnLR2(torch.cat(l11,l02))

			l03=self.nnLR2(torch.cat(l02,x[3]))
			l13=self.nnLR2(torch.cat(l12,l03))

			l04=self.nnLR3(torch.cat(l03,x[4],r04))
			l14=self.nnLR3(torch.cat(l13,l04,r14))

			r04=self.nnLR2(torch.cat(r05,x[5]))
			r14=self.nnLR2(torch.cat(r15,r04))

			r05=self.nnLR2(torch.cat(r06,x[6]))
			r15=self.nnLR2(torch.cat(r16,r05))

			r06=self.nnLR2(torch.cat(r07,x[7]))
			r16=self.nnLR2(torch.cat(r17,r06))

			r07=self.nnLR2(torch.cat(r08,x[8]))
			r17=self.nnLR2(torch.cat(r18,r07))

			r08=self.nnLR(torch.cat(x[9]))
			r18=self.nnLR(torch.cat(r08))

			out0 = self.nn(x)
			out1 = self.nn1(x)
			out2 = self.nn2(x)
			out3 = self.nn3(x)
			out = torch.cat((out0,out1,out2,out3),dim=1)
			if 0:
				print(f"x.shape : {x.shape}")
				print(f"out0.shape : {out0.shape}")
				sys.exit()

		#out = torch.tanh(self.linear1(x))
		#out = torch.tanh(self.linear2(out))
		#out = torch.tanh(self.linear3(out))
		#out = torch.tanh(self.linear4(out))
		#out = torch.sigmoid(self.linearF(out))
		#out = self.linearF(out)
		return out

