import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .Conv1d_Layer_with_Mask import Conv1d_Layer_with_Mask

class Masked_dilation_inception_layer(nn.Module):
    def __init__(self,layer_parameter_list, relu_or_not_at_last_layer = True):
        super(Masked_dilation_inception_layer, self).__init__()
        self.layer_parameter_list = layer_parameter_list
        self.layer_list = []
        self.relu_or_not_at_last_layer = relu_or_not_at_last_layer
        
        
        for i in range(len(layer_parameter_list)):
            if i!= len(layer_parameter_list)-1:
                using_relu = True
            else:
                using_relu = self.relu_or_not_at_last_layer
                
            layer = Conv1d_Layer_with_Mask(layer_parameter_list[i],use_relu = using_relu)
            self.layer_list.append(layer)
        
        self.net = nn.Sequential(*self.layer_list)
            

    def forward(self, X):
        
        X = self.net(X)

        return X