# this build inception layer that can be calculated parrally.
# the method it uses it using zero mask to on a big convoluation 
# for example: kernel sizes 3 5 7 will be like 
#  0     0    value value value  0     0
#  0    value value value value value  0  
# value value value value value value value  
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.nn.parameter import Parameter
   
class SampaddingConv1D_BN(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,use_bias = True,use_batch_Norm = True):
        super(SampaddingConv1D_BN, self).__init__()
        self.use_bias = use_bias
        self.use_batch_Norm = use_batch_Norm
        self.padding = nn.ConstantPad1d((int((kernel_size-1)/2), int(kernel_size/2)), 0)
        self.conv1d = torch.nn.Conv1d(in_channels=in_channels, 
                                      out_channels=out_channels, 
                                      kernel_size=kernel_size,
                                      bias= self.use_bias
                                     )
        self.bn = nn.BatchNorm1d(num_features=out_channels)
        
    def forward(self, X):
        X = self.padding(X)
        X = self.conv1d(X)
        if self.use_batch_Norm:
            X = self.bn(X)
        return X
    
class Inception_Layer(nn.Module):
    def __init__(self,layer_parameters, use_bias = True, use_batch_Norm =True, use_relu =True): 
        super(Inception_Layer, self).__init__()
        self.conv_list = nn.ModuleList()
        self.use_bias = use_bias
        self.use_batch_Norm = use_batch_Norm
        self.use_relu = use_relu

        for i in layer_parameters:
            conv = SampaddingConv1D_BN(i[0],i[1],i[2],self.use_relu,self.use_batch_Norm)
            self.conv_list.append(conv)
    
    def forward(self, X):
        
        conv_result_list = []
        for conv in self.conv_list:
            conv_result = conv(X)
            conv_result_list.append(conv_result)
        if self.use_relu:    
            result = F.relu(torch.cat(tuple(conv_result_list), 1))
        else: 
            result =  torch.cat(tuple(conv_result_list), 1)
        return result
