import torch
import numpy as np
import torch.nn as nn
import math
import torch.nn.functional as F
import sys

def choose_nonlinearity(name):
    nl = None
    if name == 'tanh':
        nl = torch.tanh
    elif name == 'relu':
        nl = torch.relu
    elif name == 'sigmoid':
        nl = torch.sigmoid
    elif name == 'softplus':
        nl = torch.nn.functional.softplus
    elif name == 'selu':
        nl = torch.nn.functional.selu
    elif name == 'elu':
        nl = torch.nn.functional.elu
    elif name == 'swish':
        nl = lambda x: x * torch.sigmoid(x)
    else:
        raise ValueError("nonlinearity not recognized")
    return nl


class Leap_Net_TB(torch.nn.Module):

    def __init__(self, MLP_Spec_m1, MLP_Spec_m2, \
                       mul_output,\
                       plu_output, \
                       cutoff_index, \
                       device,\
                       input_dim,\
                       mean_ap=None, std_ap=None, mean_av=None, std_av=None):

        super(Leap_Net_TB, self).__init__()

        self.MLP_Spec_m1 = MLP_Spec_m1.to(device)
        self.MLP_Spec_m2 = MLP_Spec_m2.to(device)

        # Get the parameter of the system
        self.mu = 4
        self.L = 4
        self.G = 10

        # Define the scale to help predictions
        self.mul_output = torch.tensor(float(mul_output)).to(device)
        self.plu_output = torch.tensor(float(plu_output)).to(device)

        # Define cutoff_index
        self.cutoff_index = cutoff_index

        # Define the device
        self.device = device

        # Define the input input_dim
        self.input_dim = input_dim

        # Define the mean and std
        self.mean_ap = mean_ap
        self.std_ap = std_ap

        self.mean_av = mean_av
        self.std_av = std_av

        # Define the switch that control whether to make a prediction on system parameters or not
        self.prediction = 1

        self.m1 = None
        self.m2 = None

    def get_accelerations(self, state, epsilon=0):
        # shape of state is [bodies x properties]
        net_accs = [] # [nbodies x 2]
        for i in range(state.shape[0]): # number of bodies
            other_bodies = torch.cat((state[:i, :], state[i+1:, :]), axis=0)
            displacements = other_bodies[:, 1:3] - state[i, 1:3] # indexes 1:3 -> pxs, pys
            distances = (displacements**2).sum(1, keepdims=True)**0.5
            masses = other_bodies[:, 0:1] # index 0 -> mass
            pointwise_accs = masses * displacements / (distances**3 + epsilon) # G=1
            net_acc = pointwise_accs.sum(0, keepdims=True)
            net_accs.append(net_acc)
        net_accs = torch.stack(net_accs, axis=0)
        net_accs.retain_grad()

        return net_accs

    def forward(self, t, x):

        with torch.enable_grad():

            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

            # reshape to 2D
            x = x.view(-1,5)
            # x is the state here
            deriv = torch.zeros_like(x)
            deriv[:,1:3] = x[:,3:5] # dx, dy = vx, vy
            deriv[:,3:5] = self.get_accelerations(x).squeeze()
            dxdt = deriv.flatten()

            return dxdt

    # Modify the excitation
    def reset_prediction(self, prediction_enable, r=None):
        self.prediction = prediction_enable
        if prediction_enable == 0:
            self.r = r
  
    def obtain_predict_mass(self):
        return self.r

    def obtain_FourierFeat(self):
        return self.ss_list

    def predict_mass(self,x):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        state = torch.unsqueeze(x, 1).to(self.device)
        state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
        self.ss_list = []
        for ff in [0,1,2,3]:
            ss = state[ff::self.input_dim,:]
            if torch.isnan(ss).any():
                print('In predict_mass(), ss@' + str(ff) + ' is NaN')
            # Use FTT
            ss_fft = torch.fft(ss,1,normalized=False).to(self.device)
            if torch.isnan(ss_fft).any():
                print('In predict_mass(), ss_fft@' + str(ff) + ' is NaN')
            # Get mag
            ss_fft_mag = ss_fft[:,0] ** 2 + ss_fft[:,1] ** 2
            if torch.isnan(ss_fft_mag).any():
                print('In predict_mass(), ss_fft_mag_1@' + str(ff) + ' is NaN')
            ss_fft_mag = (torch.abs(ss_fft_mag[0:self.cutoff_index])+1e-5)**0.5
            if torch.isnan(ss_fft_mag).any():
                print('In predict_mass(), ss_fft_mag_2@' + str(ff) + ' is NaN')
            ss_fft_mag =  torch.log(torch.abs(ss_fft_mag)+1e-5)
            if torch.isnan(ss_fft_mag).any():
                print('In predict_mass(), ss_fft_mag_3@' + str(ff) + ' is NaN')
            self.ss_list.append(ss_fft_mag)

        ss_list_grad = torch.stack(self.ss_list, axis=0)
        ss_list_grad.retain_grad()

        # Only Freq
        if self.prediction:
            self.r = self.MLP_Spec_m1(ss_list_grad[0],ss_list_grad[1],\
                                       ss_list_grad[2],ss_list_grad[3]) + self.plu_output
            if torch.isnan(self.r).any():
                print('In predict_mass(), self.r is NaN')
                self.r = torch.ones_like(self.r).float().to(self.device)


class Leap_Net(torch.nn.Module):

    def __init__(self, MLP_Spec_mu, MLP_Spec_L, \
                       mul_output,\
                       plu_output, \
                       cutoff_index, \
                       device,\
                       input_dim,\
                       mean_ap=None, std_ap=None, mean_av=None, std_av=None):

        super(Leap_Net, self).__init__()

        self.MLP_Spec_mu = MLP_Spec_mu.to(device)
        self.MLP_Spec_L = MLP_Spec_L.to(device)

        # Get the parameter of the system
        self.mu = 4
        self.L = 4
        self.G = 10

        # Define the scale to help predictions
        self.mul_output = torch.tensor(float(mul_output)).to(device)
        self.plu_output = torch.tensor(float(plu_output)).to(device)

        # Define cutoff_index
        self.cutoff_index = cutoff_index

        # Define the device
        self.device = device

        # Define the input input_dim
        self.input_dim = input_dim

        # Define the mean and std
        self.mean_ap = mean_ap
        self.std_ap = std_ap

        self.mean_av = mean_av
        self.std_av = std_av

        # Define the switch that control whether to make a prediction on system parameters or not
        self.prediction = 1

    def forward(self, t, x):

        with torch.enable_grad():

            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

            # Set the index
            i0 = torch.LongTensor([0])
            i1 = torch.LongTensor([1])

            # Get the state
            x1 = x[i0]
            x2 = x[i1]

            # The state is [cos,sin,theta_dot,u]

            # For cos\theta
            # Add i part
            state = torch.unsqueeze(x[2:], 1).to(self.device)
            state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
            cos_theta = state[0::self.input_dim,:]
            # Use FTT
            cos_theta_fft = torch.fft(cos_theta,1,normalized=False).to(self.device)
            # Get mag
            cos_theta_mag = cos_theta_fft[:,0] ** 2 + cos_theta_fft[:,1] ** 2
            cos_theta_mag = cos_theta_mag[0:self.cutoff_index]**0.5
            cos_theta_mag = torch.log(cos_theta_mag)

            # For sin\theta
            # Add i part
            state = torch.unsqueeze(x[2:], 1).to(self.device)
            state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
            sin_theta = state[1::self.input_dim,:]
            # Use FTT
            sin_theta_fft = torch.fft(sin_theta,1,normalized=False).to(self.device)
            # Get mag
            sin_theta_mag = sin_theta_fft[:,0] ** 2 + sin_theta_fft[:,1] ** 2
            sin_theta_mag = sin_theta_mag[0:self.cutoff_index]**0.5
            sin_theta_mag = torch.log(sin_theta_mag)

            # For theta_dot
            # Add i part
            state = torch.unsqueeze(x[2:], 1).to(self.device)
            state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
            theta_dot = state[2::self.input_dim,:] / 7.0 # do normalization
            # Use FTT
            theta_dot_fft = torch.fft(theta_dot,1,normalized=False).to(self.device)
            # Get mag
            theta_dot_mag = theta_dot_fft[:,0] ** 2 + theta_dot_fft[:,1] ** 2
            theta_dot_mag = theta_dot_mag[0:self.cutoff_index]**0.5
            theta_dot_mag = torch.log(theta_dot_mag)

            # Only Freq
            if self.prediction:
                self.mu = self.MLP_Spec_mu(cos_theta_mag,sin_theta_mag,theta_dot_mag) + self.plu_output
                self.L = 1.0 #torch.exp(self.MLP_Spec_L(mag_ap_fft,mag_av_fft) + self.plu_output)

            dx1 = x2
            dx2 = -self.mu*x2 - (self.G/self.L)*torch.sin(x1)

            # Cat
            dxdt = torch.cat((dx1,dx2),axis=0)
            zero_vec = torch.zeros(x.shape[0]-2).to(self.device)
            dxdt_zero = torch.cat((dxdt, zero_vec), dim=0)

            return dxdt_zero

    # Modify the excitation
    def reset_prediction(self, prediction_enable, mu=None, L=None):
        self.prediction = prediction_enable
        if prediction_enable == 0:
            self.mu = mu
            self.L = L

class Leap_Net_MSD(torch.nn.Module):

    def __init__(self, MLP_Spec_cPS,\
                       mul_output,\
                       plu_output, \
                       cutoff_index, \
                       device,\
                       input_dim,\
                       parameter,\
                       TrackExci,\
                       mean_ap=None, std_ap=None, mean_av=None, std_av=None):

        super(Leap_Net_MSD, self).__init__()

        self.MLP_Spec_cPS = MLP_Spec_cPS.to(device)

        # Get the parameter of the system
        self.mBG = torch.tensor(parameter[0])
        self.mCB = torch.tensor(parameter[1])
        self.cPS = torch.tensor(parameter[2])
        self.dPS = torch.tensor(parameter[3])
        self.cSS = torch.tensor(parameter[4])
        self.dSS = torch.tensor(parameter[5])

        # Get the track excitation
        self.T  = torch.tensor(TrackExci[0])
        self.U  = torch.from_numpy(np.array(TrackExci[1]))
        self.Up = torch.from_numpy(np.array(TrackExci[2]))

        # Define the scale to help predictions
        self.mul_output = torch.tensor(float(mul_output)).to(device)
        self.plu_output = torch.tensor(float(plu_output)).to(device)

        # Define cutoff_index
        self.cutoff_index = cutoff_index

        # Define the device
        self.device = device

        # Define the input input_dim
        self.input_dim = 4

        # Define the list for storing the Fourier features
        self.ss_list = []

    def forward(self, t, x):

        with torch.enable_grad():

            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

            # Set the index
            i0 = torch.LongTensor([0])
            i1 = torch.LongTensor([1])
            i2 = torch.LongTensor([2])
            i3 = torch.LongTensor([3])

            # Get the state
            x1 = x[i0]
            x2 = x[i1]
            x3 = x[i2]
            x4 = x[i3]

            # Index the input signal
            index = torch.argmin(torch.abs(self.T-t.to(self.device)))
            # Select the current input signal
            u = self.U[index]
            up = self.Up[index]

            dx1 = x2
            dx2 = ( self.dSS*(x4-x2) + self.cSS*(x3-x1) - self.dPS*(x2-up) - self.cPS*(x1-u) ) / self.mBG
            dx3 = x4
            dx4 = (-self.dSS*(x4-x2) - self.cSS*(x3-x1)) / self.mCB

            # Cat
            dxdt = torch.cat((dx1,dx2,dx3,dx4),axis=0)

            return dxdt

    # Modify the excitation
    def reset_U(self, TrackExci):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # Get the track excitation
        self.T  = torch.tensor(TrackExci[0]).to(self.device)
        self.U  = TrackExci[1].to(self.device)
        self.Up = TrackExci[2].to(self.device)

    # Modify the excitation
    def reset_prediction(self, prediction_enable, para=None):
        self.prediction = prediction_enable
        if prediction_enable == 0:
            self.cPS = para

    def obtain_predict_para(self):
        return self.cPS

    def obtain_FourierFeat(self):
        return self.ss_list

    def predict_para(self,x):
        state = torch.unsqueeze(x, 1).to(self.device)
        state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
        self.ss_list = []
        for ff in [0,1,2,3]:
            ss = state[ff::self.input_dim,:]
            # Use FTT
            ss_fft = torch.fft(ss,1,normalized=False).to(self.device)
            # Get mag
            ss_fft_mag = ss_fft[:,0] ** 2 + ss_fft[:,1] ** 2
            ss_fft_mag = (ss_fft_mag[0:self.cutoff_index]+1e-5)**0.5
            ss_fft_mag = torch.log(ss_fft_mag)
            self.ss_list.append(ss_fft_mag)

        # Compute the accerlation for the primary
        ss_0 = state[1::self.input_dim,:] 
        ss_1  = state[1::self.input_dim,:]
        ss = ss_1[1:] - ss_0[0:-1]
        # Use FTT
        ss_fft = torch.fft(ss,1,normalized=False).to(self.device)
        # Get mag
        ss_fft_mag = ss_fft[:,0] ** 2 + ss_fft[:,1] ** 2
        ss_fft_mag = (ss_fft_mag[0:self.cutoff_index]+1e-5)**0.5
        ss_fft_mag = torch.log(ss_fft_mag)
        self.ss_list.append(ss_fft_mag)

        # Compute the accerlation for the secondary
        ss_0 = state[3::self.input_dim,:] 
        ss_1  = state[3::self.input_dim,:]
        ss = ss_1[1:] - ss_0[0:-1]
        # Use FTT
        ss_fft = torch.fft(ss,1,normalized=False).to(self.device)
        # Get mag
        ss_fft_mag = ss_fft[:,0] ** 2 + ss_fft[:,1] ** 2
        ss_fft_mag = (ss_fft_mag[0:self.cutoff_index]+1e-5)**0.5
        ss_fft_mag = torch.log(ss_fft_mag)
        self.ss_list.append(ss_fft_mag)

        self.ss_list_grad = torch.stack(self.ss_list, axis=0)
        self.ss_list_grad.retain_grad()



        # Only Freq
        if self.prediction:
            self.cPS = torch.exp(self.MLP_Spec_cPS(self.ss_list_grad[0],self.ss_list_grad[1],\
                                                   self.ss_list_grad[2],self.ss_list_grad[3],\
                                                   self.ss_list_grad[4],self.ss_list_grad[5]) + self.plu_output)


class Leap_Net_MSD_Class(torch.nn.Module):

    def __init__(self, MLP_Spec_cPS,\
                       mul_output,\
                       plu_output, \
                       cutoff_index, \
                       device,\
                       input_dim,\
                       parameter,\
                       TrackExci,\
                       mean_ap=None, std_ap=None, mean_av=None, std_av=None):

        super(Leap_Net_MSD_Class, self).__init__()

        self.MLP_Spec_cPS = MLP_Spec_cPS.to(device)

        # Get the parameter of the system
        self.mBG = torch.tensor(parameter[0])
        self.mCB = torch.tensor(parameter[1])
        self.cPS = torch.tensor(parameter[2])
        self.dPS = torch.tensor(parameter[3])
        self.cSS = torch.tensor(parameter[4])
        self.dSS = torch.tensor(parameter[5])

        # Get the track excitation
        self.T  = torch.tensor(TrackExci[0])
        self.U  = torch.from_numpy(np.array(TrackExci[1]))
        self.Up = torch.from_numpy(np.array(TrackExci[2]))

        # Define the scale to help predictions
        self.mul_output = torch.tensor(float(mul_output)).to(device)
        self.plu_output = torch.tensor(float(plu_output)).to(device)

        # Define cutoff_index
        self.cutoff_index = cutoff_index

        # Define the device
        self.device = device

        # Define the input input_dim
        self.input_dim = 4

        # Define the list for storing the Fourier features
        self.ss_list = []

    def forward(self, t, x):

        with torch.enable_grad():

            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

            # Set the index
            i0 = torch.LongTensor([0])
            i1 = torch.LongTensor([1])
            i2 = torch.LongTensor([2])
            i3 = torch.LongTensor([3])

            # Get the state
            x1 = x[i0]
            x2 = x[i1]
            x3 = x[i2]
            x4 = x[i3]

            # Index the input signal
            index = torch.argmin(torch.abs(self.T-t.to(self.device)))
            # Select the current input signal
            u = self.U[index]
            up = self.Up[index]

            dx1 = x2
            dx2 = ( self.dSS*(x4-x2) + self.cSS*(x3-x1) - self.dPS*(x2-up) - self.cPS*(x1-u) ) / self.mBG
            dx3 = x4
            dx4 = (-self.dSS*(x4-x2) - self.cSS*(x3-x1)) / self.mCB

            # Cat
            dxdt = torch.cat((dx1,dx2,dx3,dx4),axis=0)

            return dxdt

    # Modify the excitation
    def reset_U(self, TrackExci):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # Get the track excitation
        self.T  = torch.tensor(TrackExci[0]).to(self.device)
        self.U  = TrackExci[1].to(self.device)
        self.Up = TrackExci[2].to(self.device)

    # Modify the excitation
    def reset_prediction(self, prediction_enable, para=None):
        self.prediction = prediction_enable
        if prediction_enable == 0:
            self.cPS = para

    def obtain_predict_para(self):
        return self.cPS

    def obtain_FourierFeat(self):
        return self.ss_list

    def predict_para(self,x):
        state = torch.unsqueeze(x, 1).to(self.device)
        state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
        self.ss_list = []
        for ff in [0,1,2,3]:
            ss = state[ff::self.input_dim,:]
            # Use FTT
            ss_fft = torch.fft(ss,1,normalized=False).to(self.device)
            # Get mag
            ss_fft_mag = ss_fft[:,0] ** 2 + ss_fft[:,1] ** 2
            ss_fft_mag = (ss_fft_mag[0:self.cutoff_index]+1e-5)**0.5
            ss_fft_mag = torch.log(ss_fft_mag)
            self.ss_list.append(ss_fft_mag)

        # Compute the accerlation for the primary
        ss_0 = state[1::self.input_dim,:] 
        ss_1  = state[1::self.input_dim,:]
        ss = ss_1[1:] - ss_0[0:-1]
        # Use FTT
        ss_fft = torch.fft(ss,1,normalized=False).to(self.device)
        # Get mag
        ss_fft_mag = ss_fft[:,0] ** 2 + ss_fft[:,1] ** 2
        ss_fft_mag = (ss_fft_mag[0:self.cutoff_index]+1e-5)**0.5
        ss_fft_mag = torch.log(ss_fft_mag)
        self.ss_list.append(ss_fft_mag)

        # Compute the accerlation for the secondary
        ss_0 = state[3::self.input_dim,:] 
        ss_1  = state[3::self.input_dim,:]
        ss = ss_1[1:] - ss_0[0:-1]
        # Use FTT
        ss_fft = torch.fft(ss,1,normalized=False).to(self.device)
        # Get mag
        ss_fft_mag = ss_fft[:,0] ** 2 + ss_fft[:,1] ** 2
        ss_fft_mag = (ss_fft_mag[0:self.cutoff_index]+1e-5)**0.5
        ss_fft_mag = torch.log(ss_fft_mag)
        self.ss_list.append(ss_fft_mag)

        self.ss_list_grad = torch.stack(self.ss_list, axis=0)
        self.ss_list_grad.retain_grad()

        # Only Freq
        if self.prediction:
            logits = self.MLP_Spec_cPS(self.ss_list_grad[0],self.ss_list_grad[1],\
                                       self.ss_list_grad[2],self.ss_list_grad[3],\
                                       self.ss_list_grad[4],self.ss_list_grad[5])
            one_hot  = torch.nn.functional.gumbel_softmax(logits, tau=0.1, hard=True)
            self.cPS = torch.matmul(one_hot,torch.tensor([4e3,4e6]).float().to(self.device)).unsqueeze(0)

# Define Neural Pendulum for neural ODE solver
class P_Neural_TIME_MultipleParameter(nn.Module):
    def __init__(self, MLP_Spec_mu, MLP_Spec_L, \
                       mul_output,\
                       plu_output, \
                       cutoff_index, \
                       device,\
                       input_dim,\
                       mean=None, std=None):
        super(P_Neural_TIME_MultipleParameter, self).__init__()

        self.MLP_Spec_mu = MLP_Spec_mu.to(device)
        self.MLP_Spec_L = MLP_Spec_L.to(device)

        # Get the parameter of the system
        self.mu = 4
        self.L = 4
        self.G = 10

        # Define the scale to help predictions
        self.mul_output = torch.tensor(float(mul_output)).to(device)
        self.plu_output = torch.tensor(float(plu_output)).to(device)

        # Define cutoff_index
        self.cutoff_index = cutoff_index

        # Define the device
        self.device = device

        # Define the input input_dim
        self.input_dim = input_dim

        # Define the mean and std
        self.mean = mean
        self.std = std

        # Define the switch that control whether to make a prediction on system parameters or not
        self.prediction = 1

    def forward(self, t, x):
        with torch.enable_grad():

            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

            # Set the index
            i0 = torch.LongTensor([0])
            i1 = torch.LongTensor([1])

            # Get the state
            x1 = x[i0]
            x2 = x[i1]

            if self.mean != None:
                # Compute the x_input of time series data
                x_input = (x[2:] * self.state_scale - self.mean) / self.std
            else:
                x_input = x[2:]

            # Remove u information in x_input
            cos_theta = x_input[0::self.input_dim]
            sin_theta = x_input[1::self.input_dim]
            theta_dot = x_input[2::self.input_dim]
            x_input = torch.cat((cos_theta,sin_theta,theta_dot),axis=0)

            # Only Freq
            if self.prediction:
                # Get the prediction
                self.mu = self.MLP_Spec_mu(x_input) + self.plu_output
                self.L = 1.0#torch.exp(self.MLP_Spec_L(x_input) + self.plu_output)

            # Define the model
            dx1 = x2
            dx2 = -self.mu*x2 - (self.G/self.L)*torch.sin(x1)

            # Cat
            dxdt = torch.cat((dx1,dx2),axis=0)
            zero_vec = torch.zeros(x.shape[0]-2).to(self.device)

            dxdt_zero = torch.cat((dxdt, zero_vec), dim=0)

            return dxdt_zero
    
    # Modify the excitation
    def reset_prediction(self, prediction_enable, mu=None, L=None):
        self.prediction = prediction_enable
        if prediction_enable == 0:
            self.mu = mu
            self.L = L



# Define Neural Pendulum for neural ODE solver
class P_Neural_TIME_MultipleParameterMSD(nn.Module):

    def __init__(self, MLP_Spec_cPS,\
                       mul_output,\
                       plu_output, \
                       cutoff_index, \
                       device,\
                       input_dim,\
                       parameter,\
                       TrackExci,\
                       mean_ap=None, std_ap=None, mean_av=None, std_av=None):

        super(P_Neural_TIME_MultipleParameterMSD, self).__init__()

        self.MLP_Spec_cPS = MLP_Spec_cPS.to(device)

        # Get the parameter of the system
        self.mBG = torch.tensor(parameter[0])
        self.mCB = torch.tensor(parameter[1])
        self.cPS = torch.tensor(parameter[2])
        self.dPS = torch.tensor(parameter[3])
        self.cSS = torch.tensor(parameter[4])
        self.dSS = torch.tensor(parameter[5])

        # Get the track excitation
        self.T  = torch.tensor(TrackExci[0])
        self.U  = torch.from_numpy(np.array(TrackExci[1]))
        self.Up = torch.from_numpy(np.array(TrackExci[2]))

        # Define the scale to help predictions
        self.mul_output = torch.tensor(float(mul_output)).to(device)
        self.plu_output = torch.tensor(float(plu_output)).to(device)

        # Define cutoff_index
        self.cutoff_index = cutoff_index

        # Define the device
        self.device = device

        # Define the input input_dim
        self.input_dim = 4

    def forward(self, t, x):

        with torch.enable_grad():

            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

            # Set the index
            i0 = torch.LongTensor([0])
            i1 = torch.LongTensor([1])
            i2 = torch.LongTensor([2])
            i3 = torch.LongTensor([3])

            # Get the state
            x1 = x[i0]
            x2 = x[i1]
            x3 = x[i2]
            x4 = x[i3]

            # Index the input signal
            index = torch.argmin(torch.abs(self.T-t))
            # Select the current input signal
            u = self.U[index]
            up = self.Up[index]

            dx1 = x2
            dx2 = ( self.dSS*(x4-x2) + self.cSS*(x3-x1) - self.dPS*(x2-up) - self.cPS*(x1-u) ) / self.mBG
            dx3 = x4
            dx4 = (-self.dSS*(x4-x2) - self.cSS*(x3-x1)) / self.mCB

            # Cat
            dxdt = torch.cat((dx1,dx2,dx3,dx4),axis=0)

            return dxdt

    # Modify the excitation
    def reset_U(self, TrackExci):
        # Get the track excitation
        self.T  = torch.tensor(TrackExci[0]).to(self.device)
        self.U  = TrackExci[1].to(self.device)
        self.Up = TrackExci[2].to(self.device)

    # Modify the excitation
    def reset_prediction(self, prediction_enable, para=None):
        self.prediction = prediction_enable
        if prediction_enable == 0:
            self.cPS = para

    def obtain_predict_para(self):
        return self.cPS

    def obtain_FourierFeat(self):
        return self.ss_list

    def predict_para(self,x):
        state = torch.unsqueeze(x, 1).to(self.device)
        state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
        self.ss_list = []
        for ff in [0,1,2,3]:
            ss = state[ff::self.input_dim,:]
            # Use FTT
            ss_fft = torch.fft(ss,1,normalized=False).to(self.device)
            # Get mag
            ss_fft_mag = ss_fft[:,0] ** 2 + ss_fft[:,1] ** 2
            ss_fft_mag = (ss_fft_mag[0:self.cutoff_index]+1e-5)**0.5
            ss_fft_mag = torch.log(ss_fft_mag)
            self.ss_list.append(ss_fft_mag)

        # Compute the accerlation for the primary
        ss_0 = state[1::self.input_dim,:] 
        ss_1  = state[1::self.input_dim,:]
        ss = ss_1[1:] - ss_0[0:-1]
        # Use FTT
        ss_fft = torch.fft(ss,1,normalized=False).to(self.device)
        # Get mag
        ss_fft_mag = ss_fft[:,0] ** 2 + ss_fft[:,1] ** 2
        ss_fft_mag = (ss_fft_mag[0:self.cutoff_index]+1e-5)**0.5
        ss_fft_mag = torch.log(ss_fft_mag)
        self.ss_list.append(ss_fft_mag)

        # Compute the accerlation for the secondary
        ss_0 = state[3::self.input_dim,:] 
        ss_1  = state[3::self.input_dim,:]
        ss = ss_1[1:] - ss_0[0:-1]
        # Use FTT
        ss_fft = torch.fft(ss,1,normalized=False).to(self.device)
        # Get mag
        ss_fft_mag = ss_fft[:,0] ** 2 + ss_fft[:,1] ** 2
        ss_fft_mag = (ss_fft_mag[0:self.cutoff_index]+1e-5)**0.5
        ss_fft_mag = torch.log(ss_fft_mag)
        self.ss_list.append(ss_fft_mag)
        
        # Only Freq
        if self.prediction:
            self.cPS = torch.exp(self.MLP_Spec_cPS(x) + self.plu_output)

# Define Neural Pendulum for neural ODE solver
class P_Neural_TIME_MultipleParameterTB(nn.Module):
    def __init__(self, MLP_Spec_m1, MLP_Spec_m2, \
                       mul_output,\
                       plu_output, \
                       cutoff_index, \
                       device,\
                       input_dim,\
                       mean=None, std=None):
        super(P_Neural_TIME_MultipleParameterTB, self).__init__()

        self.MLP_Spec_m1 = MLP_Spec_m1.to(device)
        self.MLP_Spec_m2 = MLP_Spec_m2.to(device)

        # Get the parameter of the system
        self.mu = 4
        self.L = 4
        self.G = 10

        # Define the scale to help predictions
        self.mul_output = torch.tensor(float(mul_output)).to(device)
        self.plu_output = torch.tensor(float(plu_output)).to(device)

        # Define cutoff_index
        self.cutoff_index = cutoff_index

        # Define the device
        self.device = device

        # Define the input input_dim
        self.input_dim = input_dim

        # Define the mean and std
        self.mean = mean
        self.std = std

        # Define the switch that control whether to make a prediction on system parameters or not
        self.prediction = 1

    def get_accelerations(self, state, epsilon=0):
        # shape of state is [bodies x properties]
        net_accs = [] # [nbodies x 2]
        for i in range(state.shape[0]): # number of bodies
            other_bodies = torch.cat((state[:i, :], state[i+1:, :]), axis=0)
            displacements = other_bodies[:, 1:3] - state[i, 1:3] # indexes 1:3 -> pxs, pys
            distances = (displacements**2).sum(1, keepdims=True)**0.5
            masses = other_bodies[:, 0:1] # index 0 -> mass
            pointwise_accs = masses * displacements / (distances**3 + epsilon) # G=1
            net_acc = pointwise_accs.sum(0, keepdims=True)
            net_accs.append(net_acc)
        net_accs = torch.stack(net_accs, axis=0)
        net_accs.retain_grad()
        return net_accs

    def forward(self, t, x):

        with torch.enable_grad():

            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

            # reshape to 2D
            x = x.view(-1,5)
            # x is the state here
            deriv = torch.zeros_like(x)
            deriv[:,1:3] = x[:,3:5] # dx, dy = vx, vy
            deriv[:,3:5] = self.get_accelerations(x).squeeze()
            dxdt = deriv.flatten()

            return dxdt

    # Modify the excitation
    def reset_prediction(self, prediction_enable, r=None):
        self.prediction = prediction_enable
        if prediction_enable == 0:
            self.r = r

    def obtain_predict_mass(self):
        return self.r

    def obtain_FourierFeat(self):
        return None

    def predict_mass(self,x):
        # Only Freq
        if self.prediction:
            self.r = self.MLP_Spec_m1(x) + self.plu_output


# Define a simple MLP
class TF_Block_EXP_Residual_TV_2(torch.nn.Module):
    """docstring for ClassName"""
    def __init__(self, input_size, input_layer=[100,50], output_size=1, nonlinearity=0, mag=4.0):
        super(TF_Block_EXP_Residual_TV_2, self).__init__()
        self.input_size = input_size
        self.output_size = output_size

        self.fc1_f  = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_f  = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc1_h1  = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_h1  = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc3  = nn.Linear(self.input_size[0]*2, self.output_size)

        self.fc1_f_xt  = nn.Linear(self.input_size[1], self.input_size[1])
        self.fc2_f_xt  = nn.Linear(self.input_size[1], self.input_size[1])

        self.fc1_h1_xt  = nn.Linear(self.input_size[1], self.input_size[1])
        self.fc2_h1_xt  = nn.Linear(self.input_size[1], self.input_size[1])

        self.mag = mag


        if nonlinearity == 0:
            self.nonlinearity = torch.tanh
        elif nonlinearity == 1:
            self.nonlinearity = torch.nn.functional.softplus

    def forward(self, f, dx):
        delta_f_1  = self.nonlinearity(self.fc1_f(f))
        delta_f_2  = self.fc2_f(delta_f_1)
        h1 = self.nonlinearity(delta_f_2 + f)

        delta_h1_1  = self.nonlinearity(self.fc1_h1(h1))
        delta_h1_2  = self.fc2_h1(delta_h1_1)
        h2 = self.nonlinearity(delta_h1_2 + h1)

        delta_f_1_xt  = self.nonlinearity(self.fc1_f_xt(dx))
        delta_f_2_xt  = self.fc2_f_xt(delta_f_1_xt)
        h1_xt = self.nonlinearity(delta_f_2_xt + dx)

        delta_h1_1_xt  = self.nonlinearity(self.fc1_h1_xt(h1_xt))
        delta_h1_2_xt  = self.fc2_h1_xt(delta_h1_1_xt)
        h2_xt = self.nonlinearity(delta_h1_2_xt + h1_xt)

        h = torch.cat((h2, h2_xt),axis=-1)
        #h = torch.cat((h2_xt, h2_xt),axis=-1)

        return torch.tanh(self.fc3(h)) * torch.tensor(self.mag)

# Define a simple MLP
class EstimatorNetworkMSD_Class(torch.nn.Module):
    """docstring for ClassName"""
    def __init__(self, input_size, input_layer=[100,50], output_size=1, nonlinearity=0, mag=4.0):
        super(EstimatorNetworkMSD_Class, self).__init__()
        self.input_size = input_size
        self.hidden_size = input_layer
        self.output_size = output_size

        self.fc1_1 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc1_2 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc1_3 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc1_4 = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc2_1 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_2 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_3 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_4 = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc3_1   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc3_2   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc3_3   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc3_4   = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc4_1   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc4_2   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc4_3   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc4_4   = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc5_1   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc5_2   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc5_3   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc5_4   = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc6_1   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc6_2   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc6_3   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc6_4   = nn.Linear(self.input_size[0], self.input_size[0])

        #self.fc9 = nn.Linear(self.input_size[0] * 4, self.output_size)
        self.fc9 = nn.Linear(self.input_size[0] * 6, self.output_size)

        self.mag = mag

        if nonlinearity == 0:
            self.nonlinearity = torch.tanh
        elif nonlinearity == 1:
            self.nonlinearity = torch.nn.functional.softplus
        elif nonlinearity == 2:
            self.nonlinearity = torch.nn.functional.elu

    def forward(self, feat_1, feat_2, feat_3, feat_4, feat_5, feat_6):

        h1_mid   = self.nonlinearity(self.fc1_1(feat_1))
        h1_final = self.nonlinearity(self.fc1_2(h1_mid) + feat_1)
        h1_mid   = self.nonlinearity(self.fc1_3(h1_final))
        h1_final = self.nonlinearity(self.fc1_4(h1_mid) + h1_final)

        h2_mid   = self.nonlinearity(self.fc2_1(feat_2))
        h2_final = self.nonlinearity(self.fc2_2(h2_mid) + feat_2)
        h2_mid   = self.nonlinearity(self.fc2_3(h2_final))
        h2_final = self.nonlinearity(self.fc2_4(h2_mid) + h2_final)

        h3_mid   = self.nonlinearity(self.fc3_1(feat_3))
        h3_final = self.nonlinearity(self.fc3_2(h3_mid) + feat_3)
        h3_mid   = self.nonlinearity(self.fc3_3(h3_final))
        h3_final = self.nonlinearity(self.fc3_4(h3_mid) + h3_final)

        h4_mid   = self.nonlinearity(self.fc4_1(feat_4))
        h4_final = self.nonlinearity(self.fc4_2(h4_mid) + feat_4)
        h4_mid   = self.nonlinearity(self.fc4_3(h4_final))
        h4_final = self.nonlinearity(self.fc4_4(h4_mid) + h4_final)
        
        h5_mid   = self.nonlinearity(self.fc5_1(feat_5))
        h5_final = self.nonlinearity(self.fc5_2(h5_mid) + feat_5)
        h5_mid   = self.nonlinearity(self.fc5_3(h5_final))
        h5_final = self.nonlinearity(self.fc5_4(h5_mid) + h5_final)

        h6_mid   = self.nonlinearity(self.fc6_1(feat_6))
        h6_final = self.nonlinearity(self.fc6_2(h6_mid) + feat_6)
        h6_mid   = self.nonlinearity(self.fc6_3(h6_final))
        h6_final = self.nonlinearity(self.fc6_4(h6_mid) + h6_final)

        h_input  = torch.cat((h1_final, h2_final, h3_final, h4_final, h5_final, h6_final),axis=-1)

        #h_input  = torch.cat((h1_final, h2_final),axis=-1)

        h_final = self.fc9(h_input)

        return h_final


# Define a simple MLP
class EstimatorNetworkMSD(torch.nn.Module):
    """docstring for ClassName"""
    def __init__(self, input_size, input_layer=[100,50], output_size=1, nonlinearity=0, mag=4.0):
        super(EstimatorNetworkMSD, self).__init__()
        self.input_size = input_size
        self.hidden_size = input_layer
        self.output_size = output_size

        self.fc1_1 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc1_2 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc1_3 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc1_4 = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc2_1 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_2 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_3 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_4 = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc3_1   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc3_2   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc3_3   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc3_4   = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc4_1   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc4_2   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc4_3   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc4_4   = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc5_1   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc5_2   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc5_3   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc5_4   = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc6_1   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc6_2   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc6_3   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc6_4   = nn.Linear(self.input_size[0], self.input_size[0])

        #self.fc9 = nn.Linear(self.input_size[0] * 4, self.output_size)
        self.fc9 = nn.Linear(self.input_size[0] * 6, self.output_size)

        self.mag = mag

        if nonlinearity == 0:
            self.nonlinearity = torch.tanh
        elif nonlinearity == 1:
            self.nonlinearity = torch.nn.functional.softplus
        elif nonlinearity == 2:
            self.nonlinearity = torch.nn.functional.elu

    def forward(self, feat_1, feat_2, feat_3, feat_4, feat_5, feat_6):

        h1_mid   = self.nonlinearity(self.fc1_1(feat_1))
        h1_final = self.nonlinearity(self.fc1_2(h1_mid) + feat_1)
        h1_mid   = self.nonlinearity(self.fc1_3(h1_final))
        h1_final = self.nonlinearity(self.fc1_4(h1_mid) + h1_final)

        h2_mid   = self.nonlinearity(self.fc2_1(feat_2))
        h2_final = self.nonlinearity(self.fc2_2(h2_mid) + feat_2)
        h2_mid   = self.nonlinearity(self.fc2_3(h2_final))
        h2_final = self.nonlinearity(self.fc2_4(h2_mid) + h2_final)

        h3_mid   = self.nonlinearity(self.fc3_1(feat_3))
        h3_final = self.nonlinearity(self.fc3_2(h3_mid) + feat_3)
        h3_mid   = self.nonlinearity(self.fc3_3(h3_final))
        h3_final = self.nonlinearity(self.fc3_4(h3_mid) + h3_final)

        h4_mid   = self.nonlinearity(self.fc4_1(feat_4))
        h4_final = self.nonlinearity(self.fc4_2(h4_mid) + feat_4)
        h4_mid   = self.nonlinearity(self.fc4_3(h4_final))
        h4_final = self.nonlinearity(self.fc4_4(h4_mid) + h4_final)
        
        h5_mid   = self.nonlinearity(self.fc5_1(feat_5))
        h5_final = self.nonlinearity(self.fc5_2(h5_mid) + feat_5)
        h5_mid   = self.nonlinearity(self.fc5_3(h5_final))
        h5_final = self.nonlinearity(self.fc5_4(h5_mid) + h5_final)

        h6_mid   = self.nonlinearity(self.fc6_1(feat_6))
        h6_final = self.nonlinearity(self.fc6_2(h6_mid) + feat_6)
        h6_mid   = self.nonlinearity(self.fc6_3(h6_final))
        h6_final = self.nonlinearity(self.fc6_4(h6_mid) + h6_final)

        h_input  = torch.cat((h1_final, h2_final, h3_final, h4_final, h5_final, h6_final),axis=-1)

        #h_input  = torch.cat((h1_final, h2_final),axis=-1)

        h_final = self.fc9(h_input)

        return torch.tanh(h_final) * torch.tensor(self.mag)


# Define a simple MLP
class EstimatorNetworkTB_v2(torch.nn.Module):
    """docstring for ClassName"""
    def __init__(self, input_size, input_layer=[100,50], output_size=1, nonlinearity=0, mag=4.0):
        super(EstimatorNetworkTB_v2, self).__init__()
        self.input_size = input_size
        self.hidden_size = input_layer
        self.output_size = output_size

        self.fc1 = nn.Linear(self.input_size[0]*4, self.hidden_size[0])
        self.fc2 = nn.Linear(self.hidden_size[0],  self.hidden_size[1])
        self.fc3 = nn.Linear(self.hidden_size[1],  self.output_size)

        self.mag = mag

        if nonlinearity == 0:
            self.nonlinearity = torch.tanh
        elif nonlinearity == 1:
            self.nonlinearity = torch.nn.functional.softplus
        elif nonlinearity == 2:
            self.nonlinearity = torch.nn.functional.elu

    def forward(self, feat_1, feat_2, feat_3, feat_4):

        h_input  = torch.cat((feat_1, feat_2, feat_3, feat_4),axis=-1)
        h1 = self.nonlinearity(self.fc1(h_input))
        h2 = self.nonlinearity(self.fc2(h1))
        h3 = self.nonlinearity(self.fc3(h2))

        return torch.tanh(h3) * torch.tensor(self.mag)

# Define a simple MLP
class EstimatorNetworkTB(torch.nn.Module):
    """docstring for ClassName"""
    def __init__(self, input_size, input_layer=[100,50], output_size=1, nonlinearity=0, mag=4.0):
        super(EstimatorNetworkTB, self).__init__()
        self.input_size = input_size
        self.hidden_size = input_layer
        self.output_size = output_size

        self.fc1_1 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc1_2 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc1_3 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc1_4 = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc2_1 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_2 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_3 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_4 = nn.Linear(self.input_size[0], self.input_size[0])

        
        self.fc3_1   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc3_2   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc3_3   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc3_4   = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc4_1   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc4_2   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc4_3   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc4_4   = nn.Linear(self.input_size[0], self.input_size[0])
        

        self.fc9 = nn.Linear(self.input_size[0] * 4, self.output_size)
        #self.fc9 = nn.Linear(self.input_size[0] * 2, self.output_size)

        self.mag = mag

        if nonlinearity == 0:
            self.nonlinearity = torch.tanh
        elif nonlinearity == 1:
            self.nonlinearity = torch.nn.functional.softplus
        elif nonlinearity == 2:
            self.nonlinearity = torch.nn.functional.elu

    def forward(self, feat_1, feat_2, feat_3, feat_4):

        h1_mid   = self.nonlinearity(self.fc1_1(feat_1))
        h1_final = self.nonlinearity(self.fc1_2(h1_mid) + feat_1)
        h1_mid   = self.nonlinearity(self.fc1_3(h1_final))
        h1_final = self.nonlinearity(self.fc1_4(h1_mid) + h1_final)

        h2_mid   = self.nonlinearity(self.fc2_1(feat_2))
        h2_final = self.nonlinearity(self.fc2_2(h2_mid) + feat_2)
        h2_mid   = self.nonlinearity(self.fc2_3(h2_final))
        h2_final = self.nonlinearity(self.fc2_4(h2_mid) + h2_final)

        h3_mid   = self.nonlinearity(self.fc3_1(feat_3))
        h3_final = self.nonlinearity(self.fc3_2(h3_mid) + feat_3)
        h3_mid   = self.nonlinearity(self.fc3_3(h3_final))
        h3_final = self.nonlinearity(self.fc3_4(h3_mid) + h3_final)

        h4_mid   = self.nonlinearity(self.fc4_1(feat_4))
        h4_final = self.nonlinearity(self.fc4_2(h4_mid) + feat_4)
        h4_mid   = self.nonlinearity(self.fc4_3(h4_final))
        h4_final = self.nonlinearity(self.fc4_4(h4_mid) + h4_final)
        

        h_input  = torch.cat((h1_final, h2_final, h3_final, h4_final),axis=-1)

        #h_input  = torch.cat((h1_final, h2_final),axis=-1)

        h_final = self.fc9(h_input)

        return torch.tanh(h_final) * torch.tensor(self.mag)


# Define a simple MLP
class EstimatorNetwork(torch.nn.Module):
    """docstring for ClassName"""
    def __init__(self, input_size, input_layer=[100,50], output_size=1, nonlinearity=0, mag=4.0):
        super(EstimatorNetwork, self).__init__()
        self.input_size = input_size
        self.hidden_size = input_layer
        self.output_size = output_size

        self.fc1_1 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc1_2 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc1_3 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc1_4 = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc2_1 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_2 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_3 = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc2_4 = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc3_1   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc3_2   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc3_3   = nn.Linear(self.input_size[0], self.input_size[0])
        self.fc3_4   = nn.Linear(self.input_size[0], self.input_size[0])

        self.fc3     = nn.Linear(self.input_size[0] * 3, self.output_size)

        #self.norm_1 = nn.LayerNorm(self.input_size[0]*2)
        #self.norm_2 = nn.LayerNorm(self.input_size[0]*2)

        self.mag = mag

        if nonlinearity == 0:
            self.nonlinearity = torch.tanh
        elif nonlinearity == 1:
            self.nonlinearity = torch.nn.functional.softplus
        elif nonlinearity == 2:
            self.nonlinearity = torch.nn.functional.elu

    def forward(self, feat_1, feat_2, feat_3):

        h1_mid   = self.nonlinearity(self.fc1_1(feat_1))
        h1_final = self.nonlinearity(self.fc1_2(h1_mid) + feat_1)
        h1_mid   = self.nonlinearity(self.fc1_3(h1_final))
        h1_final = self.nonlinearity(self.fc1_4(h1_mid) + h1_final)

        h2_mid   = self.nonlinearity(self.fc2_1(feat_2))
        h2_final = self.nonlinearity(self.fc2_2(h2_mid) + feat_2)
        h2_mid   = self.nonlinearity(self.fc2_3(h2_final))
        h2_final = self.nonlinearity(self.fc2_4(h2_mid) + h2_final)

        h3_mid   = self.nonlinearity(self.fc3_1(feat_3))
        h3_final = self.nonlinearity(self.fc3_2(h3_mid) + feat_3)
        h3_mid   = self.nonlinearity(self.fc3_3(h3_final))
        h3_final = self.nonlinearity(self.fc3_4(h3_mid) + h3_final)

        h_input  = torch.cat((h1_final, h2_final, h3_final),axis=-1)

        h_final = self.fc3(h_input)

        return torch.tanh(h_final) * torch.tensor(self.mag)

# Define a simple MLP
class EstimatorNetwork_v2(torch.nn.Module):
    """docstring for ClassName"""
    def __init__(self, input_size, input_layer=[100,50], output_size=1, nonlinearity=0, mag=4.0):
        super(EstimatorNetwork_v2, self).__init__()
        self.input_size = input_size
        self.hidden_size = input_layer
        self.output_size = output_size

        self.fc1_1 = nn.Linear(self.input_size[0]*3, self.input_size[0]*3)
        self.fc1_2 = nn.Linear(self.input_size[0]*3, self.input_size[0]*3)

        self.fc2_1 = nn.Linear(self.input_size[0]*3, self.hidden_size[0])
        self.fc2_2 = nn.Linear(self.hidden_size[0],  self.hidden_size[0])
        self.fc2_3 = nn.Linear(self.input_size[0]*3, self.hidden_size[0])

        self.fc3_1 = nn.Linear(self.hidden_size[0], self.output_size)
        self.fc3_2 = nn.Linear(self.output_size,    self.output_size)
        self.fc3_3 = nn.Linear(self.hidden_size[0], self.output_size)

        #self.norm_1 = nn.LayerNorm(self.input_size[0]*2)
        #self.norm_2 = nn.LayerNorm(self.input_size[0]*2)

        self.mag = mag

        if nonlinearity == 0:
            self.nonlinearity = torch.tanh
        elif nonlinearity == 1:
            self.nonlinearity = torch.nn.functional.softplus
        elif nonlinearity == 2:
            self.nonlinearity = torch.nn.functional.elu

    def forward(self, feat_1, feat_2, feat_3):

        h_input  = torch.cat((feat_1, feat_2, feat_3),axis=-1)
        h1_mid   = self.nonlinearity(self.fc1_1(h_input))
        h1_final = self.nonlinearity(self.fc1_2(h1_mid) + h_input)

        h2_mid   = self.nonlinearity(self.fc2_1(h1_final))
        h2_final = self.nonlinearity(self.fc2_2(h2_mid) + self.fc2_3(h1_final))

        h3_mid   = self.nonlinearity(self.fc3_1(h2_final))
        h_final  =                   self.fc3_2(h3_mid) + self.fc3_3(h2_final)

        return torch.tanh(h_final) * torch.tensor(self.mag)

# Define a simple MLP
class EstimatorNetwork_NoScale(torch.nn.Module):
    """docstring for ClassName"""
    def __init__(self, input_size, input_layer=[100,50], output_size=1, nonlinearity=0):
        super(EstimatorNetwork_NoScale, self).__init__()
        self.input_size = input_size
        self.hidden_size = input_layer
        self.output_size = output_size

        self.fc1  = nn.Linear(self.input_size[0]*3, self.hidden_size[0])
        self.fc2  = nn.Linear(self.hidden_size[0],self.hidden_size[1])
        self.fc3  = nn.Linear(self.hidden_size[1],self.output_size)

        #self.norm_1 = nn.LayerNorm(self.input_size[0]*2)
        #self.norm_2 = nn.LayerNorm(self.input_size[0]*2)

        if nonlinearity == 0:
            self.nonlinearity = torch.tanh
        elif nonlinearity == 1:
            self.nonlinearity = torch.nn.functional.softplus
        elif nonlinearity == 2:
            self.nonlinearity = torch.nn.functional.elu

    def forward(self, feat_1, feat_2, feat_3):

        h  = torch.cat((feat_1, feat_2, feat_3),axis=-1)
        h1 = self.nonlinearity(self.fc1(h))
        h2 = self.nonlinearity(self.fc2(h1))
        h3 = self.fc3(h2)
        return h3

# Define a simple MLP
class MLP_Mag(nn.Module):
    """docstring for ClassName"""
    def __init__(self, input_size, input_layer=[100,50], output_size=1, nonlinearity=0, mag = 4.0):
        super(MLP_Mag, self).__init__()
        self.input_size = input_size
        self.output_size = output_size

        self.mag = mag

        self.fc1 = nn.Linear(self.input_size,input_layer[0])
        self.fc2 = nn.Linear(input_layer[0],input_layer[1])
        self.fc3 = nn.Linear(input_layer[1],self.output_size)
        #self.fc4 = nn.Linear(input_layer[2],self.output_size)

        if nonlinearity == 0:
            self.nonlinearity = torch.tanh
        elif nonlinearity == 1:
            self.nonlinearity = torch.nn.functional.softplus

    def forward(self, x):
        
        h1 = self.nonlinearity(self.fc1(x))
        h2 = self.nonlinearity(self.fc2(h1))
        
        return torch.tanh(self.fc3(h2)) * torch.tensor(self.mag)


# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayer(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu",pooling='max'):
        super(TransformerEncoderLayer, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=4)
        #self.conv3 = nn.Conv2d(in_channels=4, out_channels=5, kernel_size=5)
        if pooling == 'max':
            self.pool  = nn.MaxPool2d(2, 2)
        else:
            self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        #size_after_conv3_maxpool = int((size_after_conv2_maxpool - 5 + 1) / 2)
        self.size_fc1 = (size_after_conv2_maxpool ** 2) * 4
        self.fc1 = nn.Linear(self.size_fc1, d_model).to(self.device)
        

        #self.fc0_input = nn.Linear(64*64, d_model).to(self.device)
        #self.fc1_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc2_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc3_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc4_input = nn.Linear(d_model, d_model).to(self.device)


        # Define the posional encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout=0.0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, d_model)
        self.linear2 = nn.Linear(d_model, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)
        self.activation_1 = choose_nonlinearity(activation_1)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],1,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(self.activation(self.conv1(src)))
        # Now after the previous code, torch.Size([101, 3, 31, 31]) # since, (64-kernel_size+1) / (max_pool_size)
        src = self.pool(self.activation(self.conv2(src)))
        # Now after the previous code, torch.Size([101, 4, 14, 14]) # since, (31-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*5*5])
        src = self.fc1(src)
        # Now after the previous code, torch.Size([101, d_model])
        
        '''
        src = self.activation(self.fc0_input(src))

        h1_src = self.activation(self.fc1_input(src))
        h2_src = self.activation(self.fc2_input(h1_src) + src)

        h3_src = self.activation(self.fc3_input(h2_src))
        src    = self.activation(self.fc4_input(h3_src) + h2_src)
        '''
        

        #Now after the previous code, torch.Size([101, 1, d_model])
        # Now after the previous code, torch.Size([101, d_model])

        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)
        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        src2 = self.linear2(self.activation_1(self.linear1(src)))

        return src2, src2_attn_weight

# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayer_v2_Categoricals(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu",pooling='max'):
        super(TransformerEncoderLayer_v2_Categoricals, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=4)
        #self.conv3 = nn.Conv2d(in_channels=4, out_channels=5, kernel_size=5)
        if pooling == 'max':
            self.pool  = nn.MaxPool2d(2, 2)
        else:
            self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        #size_after_conv3_maxpool = int((size_after_conv2_maxpool - 5 + 1) / 2)
        self.size_fc1 = (size_after_conv2_maxpool ** 2) * 4

        self.fc1 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc2 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc3 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc4 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc5 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc6 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc7 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc8 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc9 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc10 = nn.Linear(self.size_fc1, 10).to(self.device)
        d_model = 10 * 10

        #self.fc0_input = nn.Linear(64*64, d_model).to(self.device)
        #self.fc1_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc2_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc3_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc4_input = nn.Linear(d_model, d_model).to(self.device)


        # Define the posional encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout=0.0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, d_model)
        self.linear2 = nn.Linear(d_model, d_model)
        self.linear3 = nn.Linear(d_model, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)
        self.activation_1 = choose_nonlinearity(activation_1)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],1,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(self.activation(self.conv1(src)))
        # Now after the previous code, torch.Size([101, 3, 31, 31]) # since, (64-kernel_size+1) / (max_pool_size)
        src = self.pool(self.activation(self.conv2(src)))
        # Now after the previous code, torch.Size([101, 4, 14, 14]) # since, (31-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*5*5])        
        one_hot_1  = torch.nn.functional.gumbel_softmax(self.fc1(src), tau=0.5, hard=True)
        one_hot_2  = torch.nn.functional.gumbel_softmax(self.fc2(src), tau=0.5, hard=True)
        one_hot_3  = torch.nn.functional.gumbel_softmax(self.fc3(src), tau=0.5, hard=True)
        one_hot_4  = torch.nn.functional.gumbel_softmax(self.fc4(src), tau=0.5, hard=True)
        one_hot_5  = torch.nn.functional.gumbel_softmax(self.fc5(src), tau=0.5, hard=True)
        one_hot_6  = torch.nn.functional.gumbel_softmax(self.fc6(src), tau=0.5, hard=True)
        one_hot_7  = torch.nn.functional.gumbel_softmax(self.fc7(src), tau=0.5, hard=True)
        one_hot_8  = torch.nn.functional.gumbel_softmax(self.fc8(src), tau=0.5, hard=True)
        one_hot_9  = torch.nn.functional.gumbel_softmax(self.fc9(src), tau=0.5, hard=True)
        one_hot_10 = torch.nn.functional.gumbel_softmax(self.fc10(src),tau=0.5, hard=True)
        # Now after the previous code, each one_hot is torch.Size([101, 8])
        src = torch.cat((one_hot_1,one_hot_2,one_hot_3,one_hot_4,one_hot_5,\
                         one_hot_6,one_hot_7,one_hot_8,one_hot_9,one_hot_10),axis=-1)
        '''
        src = self.activation(self.fc0_input(src))

        h1_src = self.activation(self.fc1_input(src))
        h2_src = self.activation(self.fc2_input(h1_src) + src)

        h3_src = self.activation(self.fc3_input(h2_src))
        src    = self.activation(self.fc4_input(h3_src) + h2_src)
        '''
        

        #Now after the previous code, torch.Size([101, 1, d_model])
        # Now after the previous code, torch.Size([101, d_model])

        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)
        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # This is the one that I want to paly
        src2 = self.linear2(self.dropout(self.activation_1(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        src_final = self.linear3(src)

        return src_final, src2_attn_weight

# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayer_v2_2C(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu",pooling='max'):
        super(TransformerEncoderLayer_v2_2C, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=4)
        #self.conv3 = nn.Conv2d(in_channels=4, out_channels=5, kernel_size=5)
        if pooling == 'max':
            self.pool  = nn.MaxPool2d(2, 2)
        else:
            self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        #size_after_conv3_maxpool = int((size_after_conv2_maxpool - 5 + 1) / 2)
        self.size_fc1 = (size_after_conv2_maxpool ** 2) * 4
        self.fc1 = nn.Linear(self.size_fc1, d_model).to(self.device)
        
        #self.fc0_input = nn.Linear(64*64, d_model).to(self.device)
        #self.fc1_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc2_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc3_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc4_input = nn.Linear(d_model, d_model).to(self.device)


        # Define the posional encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout=0.0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, d_model)
        self.linear2 = nn.Linear(d_model, d_model)
        self.linear3 = nn.Linear(d_model, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)
        self.activation_1 = choose_nonlinearity(activation_1)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],2,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(self.activation(self.conv1(src)))
        # Now after the previous code, torch.Size([101, 3, 31, 31]) # since, (64-kernel_size+1) / (max_pool_size)
        src = self.pool(self.activation(self.conv2(src)))
        # Now after the previous code, torch.Size([101, 4, 14, 14]) # since, (31-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*5*5])
        src = self.fc1(src)
        # Now after the previous code, torch.Size([101, d_model])
        
        '''
        src = self.activation(self.fc0_input(src))

        h1_src = self.activation(self.fc1_input(src))
        h2_src = self.activation(self.fc2_input(h1_src) + src)

        h3_src = self.activation(self.fc3_input(h2_src))
        src    = self.activation(self.fc4_input(h3_src) + h2_src)
        '''
        

        #Now after the previous code, torch.Size([101, 1, d_model])
        # Now after the previous code, torch.Size([101, d_model])

        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)
        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # This is the one that I want to paly
        src2 = self.linear2(self.dropout(self.activation_1(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        src_final = self.linear3(src)

        return src_final, src2_attn_weight

# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayer_v2_3C(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu",pooling='max'):
        super(TransformerEncoderLayer_v2_3C, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=4)
        #self.conv3 = nn.Conv2d(in_channels=4, out_channels=5, kernel_size=5)
        if pooling == 'max':
            self.pool  = nn.MaxPool2d(2, 2)
        else:
            self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        #size_after_conv3_maxpool = int((size_after_conv2_maxpool - 5 + 1) / 2)
        self.size_fc1 = (size_after_conv2_maxpool ** 2) * 4
        self.fc1 = nn.Linear(self.size_fc1, d_model).to(self.device)
        
        #self.fc0_input = nn.Linear(64*64, d_model).to(self.device)
        #self.fc1_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc2_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc3_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc4_input = nn.Linear(d_model, d_model).to(self.device)


        # Define the posional encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout=0.0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, d_model)
        self.linear2 = nn.Linear(d_model, d_model)
        self.linear3 = nn.Linear(d_model, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)
        self.activation_1 = choose_nonlinearity(activation_1)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],3,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(self.activation(self.conv1(src)))
        # Now after the previous code, torch.Size([101, 3, 31, 31]) # since, (64-kernel_size+1) / (max_pool_size)
        src = self.pool(self.activation(self.conv2(src)))
        # Now after the previous code, torch.Size([101, 4, 14, 14]) # since, (31-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*5*5])
        src = self.fc1(src)
        # Now after the previous code, torch.Size([101, d_model])
        
        '''
        src = self.activation(self.fc0_input(src))

        h1_src = self.activation(self.fc1_input(src))
        h2_src = self.activation(self.fc2_input(h1_src) + src)

        h3_src = self.activation(self.fc3_input(h2_src))
        src    = self.activation(self.fc4_input(h3_src) + h2_src)
        '''
        

        #Now after the previous code, torch.Size([101, 1, d_model])
        # Now after the previous code, torch.Size([101, d_model])

        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)
        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # This is the one that I want to paly
        src2 = self.linear2(self.dropout(self.activation_1(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        src_final = self.linear3(src)

        return src_final, src2_attn_weight

# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayer_v2(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu",pooling='max'):
        super(TransformerEncoderLayer_v2, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=4)
        #self.conv3 = nn.Conv2d(in_channels=4, out_channels=5, kernel_size=5)
        if pooling == 'max':
            self.pool  = nn.MaxPool2d(2, 2)
        else:
            self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        #size_after_conv3_maxpool = int((size_after_conv2_maxpool - 5 + 1) / 2)
        self.size_fc1 = (size_after_conv2_maxpool ** 2) * 4
        self.fc1 = nn.Linear(self.size_fc1, d_model).to(self.device)
        
        #self.fc0_input = nn.Linear(64*64, d_model).to(self.device)
        #self.fc1_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc2_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc3_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc4_input = nn.Linear(d_model, d_model).to(self.device)


        # Define the posional encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout=0.0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, d_model)
        self.linear2 = nn.Linear(d_model, d_model)
        self.linear3 = nn.Linear(d_model, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)
        self.activation_1 = choose_nonlinearity(activation_1)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],1,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(self.activation(self.conv1(src)))
        # Now after the previous code, torch.Size([101, 3, 31, 31]) # since, (64-kernel_size+1) / (max_pool_size)
        src = self.pool(self.activation(self.conv2(src)))
        # Now after the previous code, torch.Size([101, 4, 14, 14]) # since, (31-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*5*5])
        src = self.fc1(src)
        # Now after the previous code, torch.Size([101, d_model])
        
        '''
        src = self.activation(self.fc0_input(src))

        h1_src = self.activation(self.fc1_input(src))
        h2_src = self.activation(self.fc2_input(h1_src) + src)

        h3_src = self.activation(self.fc3_input(h2_src))
        src    = self.activation(self.fc4_input(h3_src) + h2_src)
        '''
        

        #Now after the previous code, torch.Size([101, 1, d_model])
        # Now after the previous code, torch.Size([101, d_model])

        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)
        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # This is the one that I want to paly
        src2 = self.linear2(self.dropout(self.activation_1(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        src_final = self.linear3(src)

        return src_final, src2_attn_weight


class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000, scale=1.0):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
        self.scale = scale

    def forward(self, x):
        x = x + self.scale * self.pe[:x.size(0), :]

        return self.dropout(x)


# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayer_V(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="elu"):
        super(TransformerEncoderLayer_V, self).__init__()

        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=5)
        self.pool  = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=5)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        self.size_fc1 = int(((64 - 5 + 1) / 2 - 5 + 1) / 2) **2 * 5
        self.fc1 = nn.Linear(self.size_fc1, d_model).to(self.device)
        
        # Define the posional encoding
        self.pos_encoder = PositionalEncoding_V(d_model, dropout=0.1, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, d_middle)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_middle, int(d_middle/2))
        self.linear3 = nn.Linear(int(d_middle/2), d_final)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model)
        #self.norm2 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        #self.dropout2 = Dropout(dropout)
        self.activation = choose_nonlinearity(activation)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],1,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(F.relu(self.conv1(src)))
        # Now after the previous code, torch.Size([101, 3, 30, 30]) # since, (64-kernel_size+1) / (max_pool_size)
        src = self.pool(F.relu(self.conv2(src)))
        # Now after the previous code, torch.Size([101, 5, 13, 13]) # since, (30-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*13*13])
        src = F.relu(self.fc1(src))
        # Now after the previous code, torch.Size([101, d_model])
        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])

        # Now, we arrive the attention part 
        src = self.pos_encoder(src)
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src2 = self.linear2(self.activation(self.linear1(src)))
        src3 = self.linear3(self.activation(src2))

        print('src3#1:',src3[0])
        print('src3#2:',src3[1])
        # Orignal implemetation
        #src = src + self.dropout2(src2)
        #src = self.norm2(src)
        # Direct output
        return src3, src2_attn_weight

class PositionalEncoding_V(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000, scale=1.0):
        super(PositionalEncoding_V, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
        self.scale = scale

    def forward(self, x):
        # self.pe is of size [T steps, 1, d_model]
        # max 0.1473
        x = x + self.scale * self.pe[:x.size(0), :]
        return self.dropout(x)

# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayerCategoricals(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu"):
        super(TransformerEncoderLayerCategoricals, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=4)
        self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        self.size_fc1 = (size_after_conv2_maxpool ** 2) * 4
        self.fc1 = nn.Linear(self.size_fc1, d_model).to(self.device)
        
        # Define the posional encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout=0.0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, d_model)
        self.linear2 = nn.Linear(d_model, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)

        self.softmax = nn.Softmax(dim=2)


        self.class_vel = torch.tensor([-7,7]).float().to(self.device)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],1,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(F.tanh(self.conv1(src)))
        # Now after the previous code, torch.Size([101, 3, 31, 31]) # since, (64-kernel_size+1) / (max_pool_size)
        src = self.pool(F.tanh(self.conv2(src)))
        # Now after the previous code, torch.Size([101, 4, 14, 14]) # since, (31-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*5*5])
        src = self.fc1(src)
        # Now after the previous code, torch.Size([101, d_model])
        
        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)


        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)

        src = src + self.dropout1(src2)
        src = self.norm1(src)

        src_logit = self.linear2(self.activation(self.linear1(src)))
        # Get the softmax of prediction
        src_prob = self.softmax(src_logit)

        # Predict the velocity by linear combination of the input prob.
        est_vel = torch.matmul(src_prob,self.class_vel)

        return est_vel, src2_attn_weight

# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayerCatPos(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu"):
        super(TransformerEncoderLayerCatPos, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=4)
        self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        self.size_fc1 = (size_after_conv2_maxpool ** 2) * 4
        self.fc1 = nn.Linear(self.size_fc1, d_model).to(self.device)
        
        # Define the posional encoding
        self.pos_encoder = PositionalEncodingCat(d_model, dropout=0.0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model*2, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model*2, d_model*2)
        self.linear2 = nn.Linear(d_model*2, d_model*2)
        self.linear3 = nn.Linear(d_model*2, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model*2)
        self.dropout1 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)

        self.softmax = nn.Softmax(dim=2)


        self.class_vel = torch.tensor([-7,7]).float().to(self.device)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],1,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(F.relu(self.conv1(src)))
        # Now after the previous code, torch.Size([101, 3, 31, 31]) # since, (64-kernel_size+1) / (max_pool_size)
        src = self.pool(F.relu(self.conv2(src)))
        # Now after the previous code, torch.Size([101, 4, 14, 14]) # since, (31-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*5*5])
        src = self.fc1(src)
        # Now after the previous code, torch.Size([101, d_model])
        
        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)


        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)

        src = src + self.dropout1(src2)
        src = self.norm1(src)

        src2 = self.linear2(self.dropout(self.activation_1(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        src_final = self.linear3(src)

        return src_final, src2_attn_weight

# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayerCategoricalsCatPos(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu"):
        super(TransformerEncoderLayerCategoricalsCatPos, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=4)
        self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        self.size_fc1 = (size_after_conv2_maxpool ** 2) * 4
        self.fc1 = nn.Linear(self.size_fc1, d_model).to(self.device)
        
        # Define the posional encoding
        self.pos_encoder = PositionalEncodingCat(d_model, dropout=0.0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model*2, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model*2, d_model*2)
        self.linear2 = nn.Linear(d_model*2, d_model*2)
        self.linear3 = nn.Linear(d_model*2, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model*2)
        self.dropout1 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)

        self.softmax = nn.Softmax(dim=2)


        self.class_vel = torch.tensor([-7,7]).float().to(self.device)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],1,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(F.relu(self.conv1(src)))
        # Now after the previous code, torch.Size([101, 3, 31, 31]) # since, (64-kernel_size+1) / (max_pool_size)
        src = self.pool(F.relu(self.conv2(src)))
        # Now after the previous code, torch.Size([101, 4, 14, 14]) # since, (31-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*5*5])
        src = self.fc1(src)
        # Now after the previous code, torch.Size([101, d_model])
        
        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)


        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)

        src = src + self.dropout1(src2)
        src = self.norm1(src)

        src2 = self.linear2(self.dropout(self.activation_1(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        src_logit = self.linear3(src)
        # Get the softmax of prediction
        src_prob = self.softmax(src_logit)

        # Predict the velocity by linear combination of the input prob.
        est_vel = torch.matmul(src_prob,self.class_vel)

        return est_vel, src2_attn_weight

class PositionalEncodingCat(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000, scale=1.0):
        super(PositionalEncodingCat, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
        self.scale = scale

    def forward(self, x):

        x = torch.cat((x, self.scale * self.pe[:x.size(0), :]),axis=2)

        return self.dropout(x)


# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayer_v2_CategoricalsCatPos(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu",pooling='max'):
        super(TransformerEncoderLayer_v2_CategoricalsCatPos, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=4)
        #self.conv3 = nn.Conv2d(in_channels=4, out_channels=5, kernel_size=5)
        if pooling == 'max':
            self.pool  = nn.MaxPool2d(2, 2)
        else:
            self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        #size_after_conv3_maxpool = int((size_after_conv2_maxpool - 5 + 1) / 2)
        self.size_fc1 = (size_after_conv2_maxpool ** 2) * 4

        self.fc1 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc2 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc3 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc4 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc5 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc6 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc7 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc8 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc9 = nn.Linear(self.size_fc1, 10).to(self.device)
        self.fc10 = nn.Linear(self.size_fc1, 10).to(self.device)
        d_model = 10 * 10

        #self.fc0_input = nn.Linear(64*64, d_model).to(self.device)
        #self.fc1_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc2_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc3_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc4_input = nn.Linear(d_model, d_model).to(self.device)


        # Define the posional encoding
        self.pos_encoder = PositionalEncodingCat(d_model, dropout=0.0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model*2, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model*2, d_model*2)
        self.linear2 = nn.Linear(d_model*2, d_model*2)
        self.linear3 = nn.Linear(d_model*2, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model*2)
        self.norm2 = nn.LayerNorm(d_model*2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)
        self.activation_1 = choose_nonlinearity(activation_1)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],1,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(self.activation(self.conv1(src)))
        # Now after the previous code, torch.Size([101, 3, 31, 31]) # since, (64-kernel_size+1) / (max_pool_size)
        src = self.pool(self.activation(self.conv2(src)))
        # Now after the previous code, torch.Size([101, 4, 14, 14]) # since, (31-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*5*5])        
        one_hot_1  = torch.nn.functional.gumbel_softmax(self.fc1(src), tau=0.5, hard=True)
        one_hot_2  = torch.nn.functional.gumbel_softmax(self.fc2(src), tau=0.5, hard=True)
        one_hot_3  = torch.nn.functional.gumbel_softmax(self.fc3(src), tau=0.5, hard=True)
        one_hot_4  = torch.nn.functional.gumbel_softmax(self.fc4(src), tau=0.5, hard=True)
        one_hot_5  = torch.nn.functional.gumbel_softmax(self.fc5(src), tau=0.5, hard=True)
        one_hot_6  = torch.nn.functional.gumbel_softmax(self.fc6(src), tau=0.5, hard=True)
        one_hot_7  = torch.nn.functional.gumbel_softmax(self.fc7(src), tau=0.5, hard=True)
        one_hot_8  = torch.nn.functional.gumbel_softmax(self.fc8(src), tau=0.5, hard=True)
        one_hot_9  = torch.nn.functional.gumbel_softmax(self.fc9(src), tau=0.5, hard=True)
        one_hot_10 = torch.nn.functional.gumbel_softmax(self.fc10(src),tau=0.5, hard=True)
        # Now after the previous code, each one_hot is torch.Size([101, 8])
        src = torch.cat((one_hot_1,one_hot_2,one_hot_3,one_hot_4,one_hot_5,\
                         one_hot_6,one_hot_7,one_hot_8,one_hot_9,one_hot_10),axis=-1)
        '''
        src = self.activation(self.fc0_input(src))

        h1_src = self.activation(self.fc1_input(src))
        h2_src = self.activation(self.fc2_input(h1_src) + src)

        h3_src = self.activation(self.fc3_input(h2_src))
        src    = self.activation(self.fc4_input(h3_src) + h2_src)
        '''
        

        #Now after the previous code, torch.Size([101, 1, d_model])
        # Now after the previous code, torch.Size([101, d_model])

        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)
        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # This is the one that I want to paly
        src2 = self.linear2(self.dropout(self.activation_1(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        src_final = self.linear3(src)

        return src_final, src2_attn_weight


# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayer_v2_CatPos_MSD(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu",pooling='max'):
        super(TransformerEncoderLayer_v2_CatPos_MSD, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        
        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=5, kernel_size=3)#in_channels=2 out_channels=3
        self.conv2 = nn.Conv2d(in_channels=5, out_channels=8, kernel_size=4)#in_channels=3 out_channels=4
        self.conv3 = nn.Conv2d(in_channels=8, out_channels=11, kernel_size=5)#in_channels=4 out_channels=5
        if pooling == 'max':
            self.pool  = nn.MaxPool2d(2, 2)
        else:
            self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        size_after_conv3_maxpool = int((size_after_conv2_maxpool - 5 + 1) / 2)
        self.size_fc1 = (size_after_conv3_maxpool ** 2) * 11
        self.fc1 = nn.Linear(self.size_fc1, d_model).to(self.device)

        #self.fc0_input = nn.Linear(64*64, d_model).to(self.device)
        #self.fc1_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc2_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc3_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc4_input = nn.Linear(d_model, d_model).to(self.device)


        # Define the posional encoding
        self.pos_encoder = PositionalEncodingCat(d_model, dropout=dropout, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model*2, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model*2, d_model*2)
        self.linear2 = nn.Linear(d_model*2, d_model*2)
        self.linear3 = nn.Linear(d_model*2, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model*2)
        self.norm2 = nn.LayerNorm(d_model*2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)
        self.activation_1 = choose_nonlinearity(activation_1)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],2,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(self.activation(self.conv1(src)))
        src = self.pool(self.activation(self.conv2(src)))
        src = self.pool(self.activation(self.conv3(src)))
        # Now after the previous code, torch.Size([101, 4, 14, 14]) # since, (31-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*5*5])
        src = self.fc1(src)
        # Now after the previous code, torch.Size([101, d_model])
        
        '''
        src = self.activation(self.fc0_input(src))

        h1_src = self.activation(self.fc1_input(src))
        h2_src = self.activation(self.fc2_input(h1_src) + src)

        h3_src = self.activation(self.fc3_input(h2_src))
        src    = self.activation(self.fc4_input(h3_src) + h2_src)
        '''
        

        #Now after the previous code, torch.Size([101, 1, d_model])
        # Now after the previous code, torch.Size([101, d_model])

        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)
        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # This is the one that I want to paly
        src2 = self.linear2(self.dropout(self.activation_1(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        src_final = self.linear3(src)

        return src_final, src2_attn_weight


# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayer_v2_CatPos_TB(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, pos_size=20,d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu",pooling='max'):
        super(TransformerEncoderLayer_v2_CatPos_TB, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        
        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)#in_channels=2 out_channels=3
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=4)#in_channels=3 out_channels=4
        self.conv3 = nn.Conv2d(in_channels=4, out_channels=5, kernel_size=5)#in_channels=4 out_channels=5
        if pooling == 'max':
            self.pool  = nn.MaxPool2d(2, 2)
        else:
            self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        size_after_conv3_maxpool = int((size_after_conv2_maxpool - 5 + 1) / 2)
        self.size_fc1 = (size_after_conv3_maxpool ** 2) * 5
        print('size of embedding before mlp:',self.size_fc1)
        self.fc1 = nn.Linear(self.size_fc1, d_model).to(self.device)

        #self.fc0_input = nn.Linear(64*64, d_model).to(self.device)
        #self.fc1_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc2_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc3_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc4_input = nn.Linear(d_model, d_model).to(self.device)


        # Define the posional encoding
        self.pos_encoder = PositionalEncodingCat(pos_size, dropout=0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model+pos_size, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model+pos_size, d_model+pos_size)
        self.linear2 = nn.Linear(d_model+pos_size, d_model+pos_size)
        self.linear3 = nn.Linear(d_model+pos_size, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model+pos_size)
        self.norm2 = nn.LayerNorm(d_model+pos_size)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)
        self.activation_1 = choose_nonlinearity(activation_1)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],1,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(self.activation(self.conv1(src)))
        src = self.pool(self.activation(self.conv2(src)))
        src = self.pool(self.activation(self.conv3(src)))
        # Now after the previous code, torch.Size([101, 4, 14, 14]) # since, (31-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*5*5])
        src = self.fc1(src)
        # Now after the previous code, torch.Size([101, d_model])
        
        '''
        src = self.activation(self.fc0_input(src))

        h1_src = self.activation(self.fc1_input(src))
        h2_src = self.activation(self.fc2_input(h1_src) + src)

        h3_src = self.activation(self.fc3_input(h2_src))
        src    = self.activation(self.fc4_input(h3_src) + h2_src)
        '''
        

        #Now after the previous code, torch.Size([101, 1, d_model])
        # Now after the previous code, torch.Size([101, d_model])

        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part
        src = self.pos_encoder(src)
        #print('src 0',src[0])
        #print('src 50',src[50])
        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        #print('src 0',src[0])
        #print('src 50',src[50])
        # This is the one that I want to paly
        src2 = self.linear2(self.dropout(self.activation_1(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        #print('src 0',src[0])
        #print('src 50',src[50])
        src_final = self.linear3(src)
        #print('src 0',src_final[0])
        #print('src 50',src_final[50])
        return src_final, src2_attn_weight

# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayer_v2_CatPos(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu",pooling='max'):
        super(TransformerEncoderLayer_v2_CatPos, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=4)
        #self.conv3 = nn.Conv2d(in_channels=4, out_channels=5, kernel_size=5)
        if pooling == 'max':
            self.pool  = nn.MaxPool2d(2, 2)
        else:
            self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        #size_after_conv3_maxpool = int((size_after_conv2_maxpool - 5 + 1) / 2)
        self.size_fc1 = (size_after_conv2_maxpool ** 2) * 4
        self.fc1 = nn.Linear(self.size_fc1, d_model).to(self.device)
        
        #self.fc0_input = nn.Linear(64*64, d_model).to(self.device)
        #self.fc1_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc2_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc3_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc4_input = nn.Linear(d_model, d_model).to(self.device)


        # Define the posional encoding
        self.pos_encoder = PositionalEncodingCat(d_model, dropout=0.0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model*2, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model*2, d_model*2)
        self.linear2 = nn.Linear(d_model*2, d_model*2)
        self.linear3 = nn.Linear(d_model*2, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model*2)
        self.norm2 = nn.LayerNorm(d_model*2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)
        self.activation_1 = choose_nonlinearity(activation_1)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],1,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(self.activation(self.conv1(src)))
        # Now after the previous code, torch.Size([101, 3, 31, 31]) # since, (64-kernel_size+1) / (max_pool_size)
        src = self.pool(self.activation(self.conv2(src)))
        # Now after the previous code, torch.Size([101, 4, 14, 14]) # since, (31-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*5*5])
        src = self.fc1(src)
        # Now after the previous code, torch.Size([101, d_model])
        
        '''
        src = self.activation(self.fc0_input(src))

        h1_src = self.activation(self.fc1_input(src))
        h2_src = self.activation(self.fc2_input(h1_src) + src)

        h3_src = self.activation(self.fc3_input(h2_src))
        src    = self.activation(self.fc4_input(h3_src) + h2_src)
        '''
        

        #Now after the previous code, torch.Size([101, 1, d_model])
        # Now after the previous code, torch.Size([101, d_model])

        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)
        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # This is the one that I want to paly
        src2 = self.linear2(self.dropout(self.activation_1(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        src_final = self.linear3(src)

        return src_final, src2_attn_weight


# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayer_v3_CatPos_2C(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu",pooling='max'):
        super(TransformerEncoderLayer_v3_CatPos_2C, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.fc1 = nn.Linear(64*64*2, 1000).to(self.device)
        self.fc2 = nn.Linear(1000, 100).to(self.device)
        self.fc3 = nn.Linear(100, d_model).to(self.device)
        
        # Define the posional encoding
        self.pos_encoder = PositionalEncodingCat(d_model, dropout=0.0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model*2, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model*2, d_model*2)
        self.linear2 = nn.Linear(d_model*2, d_model*2)
        self.linear3 = nn.Linear(d_model*2, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model*2)
        self.norm2 = nn.LayerNorm(d_model*2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)
        self.activation_1 = choose_nonlinearity(activation_1)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],2*64*64)
        src = self.activation(self.fc1(src))
        src = self.activation(self.fc2(src))
        src =                 self.fc3(src)

        #Now after the previous code, torch.Size([101, 1, d_model])
        # Now after the previous code, torch.Size([101, d_model])

        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)
        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # This is the one that I want to paly
        src2 = self.linear2(self.dropout(self.activation_1(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        src_final = self.linear3(src)

        return src_final, src2_attn_weight

# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayer_v2_CatPos_2C(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu",pooling='max'):
        super(TransformerEncoderLayer_v2_CatPos_2C, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=5, kernel_size=3)#in_channels=2 out_channels=3
        self.conv2 = nn.Conv2d(in_channels=5, out_channels=8, kernel_size=4)#in_channels=3 out_channels=4
        self.conv3 = nn.Conv2d(in_channels=8, out_channels=11, kernel_size=5)#in_channels=4 out_channels=5
        if pooling == 'max':
            self.pool  = nn.MaxPool2d(2, 2)
        else:
            self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        size_after_conv3_maxpool = int((size_after_conv2_maxpool - 5 + 1) / 2)
        self.size_fc1 = (size_after_conv3_maxpool ** 2) * 11
        self.fc1 = nn.Linear(self.size_fc1, d_model).to(self.device)
        
        #self.fc0_input = nn.Linear(64*64, d_model).to(self.device)
        #self.fc1_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc2_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc3_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc4_input = nn.Linear(d_model, d_model).to(self.device)


        # Define the posional encoding
        self.pos_encoder = PositionalEncodingCat(d_model, dropout=0.0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model*2, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model*2, d_model*2)
        self.linear2 = nn.Linear(d_model*2, d_model*2)
        self.linear3 = nn.Linear(d_model*2, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model*2)
        self.norm2 = nn.LayerNorm(d_model*2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)
        self.activation_1 = choose_nonlinearity(activation_1)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],2,64,64)
        src = self.pool(self.activation(self.conv1(src)))
        src = self.pool(self.activation(self.conv2(src)))
        src = self.pool(self.activation(self.conv3(src)))
        src = src.reshape(-1, self.size_fc1)
        src = self.fc1(src)

        #Now after the previous code, torch.Size([101, 1, d_model])
        # Now after the previous code, torch.Size([101, d_model])

        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)

        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # This is the one that I want to paly
        src2 = self.linear2(self.dropout(self.activation_1(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        src_final = self.linear3(src)

        return src_final, src2_attn_weight

# https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
class TransformerEncoderLayer_v2_CatPos_3C(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """
    def __init__(self, d_model=200, nhead=4, d_middle=2048, d_final=3, dropout=0.1, max_len=101, pos_en_scale=1.0, activation="tanh",activation_1="elu",pooling='max'):
        super(TransformerEncoderLayer_v2_CatPos_3C, self).__init__()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Define the CNN to reduce the size of the images
        # The input image is of size (T,1,64*64), where T is the steps, 1 is the batch size
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=4)
        #self.conv3 = nn.Conv2d(in_channels=4, out_channels=5, kernel_size=5)
        if pooling == 'max':
            self.pool  = nn.MaxPool2d(2, 2)
        else:
            self.pool  = nn.AvgPool2d(2, 2)
        # The last layer of CNN uses the linear transformation
        # Compute the size of the final layer
        size_after_conv1_maxpool = int((64 - 3 + 1) / 2)
        size_after_conv2_maxpool = int((size_after_conv1_maxpool - 4 + 1) / 2)
        #size_after_conv3_maxpool = int((size_after_conv2_maxpool - 5 + 1) / 2)
        self.size_fc1 = (size_after_conv2_maxpool ** 2) * 4
        self.fc1 = nn.Linear(self.size_fc1, d_model).to(self.device)
        
        #self.fc0_input = nn.Linear(64*64, d_model).to(self.device)
        #self.fc1_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc2_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc3_input = nn.Linear(d_model, d_model).to(self.device)
        #self.fc4_input = nn.Linear(d_model, d_model).to(self.device)


        # Define the posional encoding
        self.pos_encoder = PositionalEncodingCat(d_model, dropout=0.0, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn = nn.MultiheadAttention(d_model*2, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model*2, d_model*2)
        self.linear2 = nn.Linear(d_model*2, d_model*2)
        self.linear3 = nn.Linear(d_model*2, d_final)
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(d_model*2)
        self.norm2 = nn.LayerNorm(d_model*2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation   = choose_nonlinearity(activation)
        self.activation_1 = choose_nonlinearity(activation_1)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # Reshaping to be 64 by 64: from [T,1,64*64] to [T,1,64,64]
        # The 1 here is used to be the batch size. But now we assign it to be the channel
        src = src.view(src.shape[0],3,64,64)
        # Now, after the previous code, we get torch.Size([T, 1, 64, 64]) 
        src = self.pool(self.activation(self.conv1(src)))
        # Now after the previous code, torch.Size([101, 3, 31, 31]) # since, (64-kernel_size+1) / (max_pool_size)
        src = self.pool(self.activation(self.conv2(src)))
        # Now after the previous code, torch.Size([101, 4, 14, 14]) # since, (31-kernel_size+1) / (max_pool_size)
        src = src.view(-1, self.size_fc1)
        # Now after the previous code, torch.Size([101, 5*5*5])
        src = self.fc1(src)
        # Now after the previous code, torch.Size([101, d_model])


        
        '''
        src = self.activation(self.fc0_input(src))

        h1_src = self.activation(self.fc1_input(src))
        h2_src = self.activation(self.fc2_input(h1_src) + src)

        h3_src = self.activation(self.fc3_input(h2_src))
        src    = self.activation(self.fc4_input(h3_src) + h2_src)
        '''
        

        #Now after the previous code, torch.Size([101, 1, d_model])
        # Now after the previous code, torch.Size([101, d_model])

        src = src.unsqueeze(1)
        # Now after the previous code, torch.Size([101, 1, d_model])
        # Now, we arrive the attention part 
        src = self.pos_encoder(src)
        # src2 will be the same over the sequences since it attends to the same thing in the begining 
        src2, src2_attn_weight = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # This is the one that I want to paly
        src2 = self.linear2(self.dropout(self.activation_1(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        src_final = self.linear3(src)

        return src_final, src2_attn_weight