import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from .utils.Parallel_Conv1d_achieved_by_mask import Parallel_Conv1d_achieved_by_mask

class Flexible_1D_CNN(nn.Module):
    def __init__(self,layer_parameter_list, n_class, out_put_feature = False):
        super(Flexible_1D_CNN, self).__init__()
        self.out_put_feature = out_put_feature
        
        self.layer_parameter_list = layer_parameter_list
        
        layer_list =[]
        for layer_parameter in layer_parameter_list:
            out_channels = sum([kernel[1] for kernel in layer_parameter])
            layer_list.append(Parallel_Conv1d_achieved_by_mask(layer_parameter))
            layer_list.append(nn.BatchNorm1d(num_features=out_channels))
            layer_list.append(nn.ReLU())
        self.layers =  nn.Sequential(*layer_list)
        self.averagepool = nn.AdaptiveAvgPool1d(1)
        
        out_put_channel_numebr = 0
        for final_layer_parameters in layer_parameter_list[-1]:
            out_put_channel_numebr = out_put_channel_numebr+ final_layer_parameters[1] 
            
        self.hidden = nn.Linear(out_put_channel_numebr, n_class)

    def forward(self, X):
        
        X = self.layers(X)

        X = self.averagepool(X)
        X = X.squeeze_(-1)

        if not self.out_put_feature:
            X = self.hidden(X)
        return X
