import torch
import torch.nn as nn
ARCH_TYPE = "nofred" # "pool", "strconv", "nofred"

class ConvEmbedder(nn.Module):
    """
    edgeEEGNet
    """
    def __init__(self, F1=8, D=2, C=64, T=1280, p_dropout=0.5, reg_rate=0.25,
                 activation='relu', constrain_w=False, dropout_type='TimeDropout2D',
                 permuted_flatten=False):
        """
        F1:           Number of spectral filters
        D:            Number of spacial filters (per spectral filter), F2 = F1 * D
        F2:           Number or None. If None, then F2 = F1 * D
        C:            Number of EEG channels
        T:            Number of time samples
        N:            Number of classes
        p_dropout:    Dropout Probability
        reg_rate:     Regularization (L1) of the final linear layer (fc)
                      This parameter is ignored when constrain_w is not asserted
        activation:   string, either 'elu' or 'relu'
        constrain_w:  bool, if True, constrain weights of spatial convolution and final fc-layer
        dropout_type: string, either 'dropout', 'SpatialDropout2d' or 'TimeDropout2D'
        permuted_flatten: bool, if True, use the permuted flatten to make the model_old keras compliant
        """
        super(ConvEmbedder, self).__init__()

        # check the activation input
        activation = activation.lower()
        assert activation in ['elu', 'relu']

        # store local values
        self.F1, self.D, self.C, self.T = (F1, D, C, T)
        self.p_dropout, self.reg_rate, self.activation = (p_dropout, reg_rate, activation)
        self.constrain_w, self.dropout_type = (constrain_w, dropout_type)


        # Block 1
        self.conv1 = torch.nn.Conv2d(1, D * F1, (C, 1), bias=False)
        self.batch_norm1 = torch.nn.BatchNorm2d(D * F1, momentum=0.01, eps=0.001)
        self.conv2_pad = torch.nn.ZeroPad2d((31, 32, 0, 0))

        if ARCH_TYPE == "pool" or ARCH_TYPE == "nofred": ## Simple Feature Reduction
            self.conv2 = torch.nn.Conv2d(D * F1, D* F1, (1, 64), groups=D * F1, bias=False)
            self.pool1 = torch.nn.AvgPool2d((1, 8))
        if ARCH_TYPE == "strconv": ## Convolutional Feature Reduction
            self.conv2 = torch.nn.Conv2d(D * F1, D* F1, (1, 64), stride=(1,8), groups=D * F1, bias=False) # with dilation -> replaces self.pool1

        self.batch_norm2 = torch.nn.BatchNorm2d(D * F1, momentum=0.01, eps=0.001)
        self.activation1 = torch.nn.ELU(inplace=True) if activation == 'elu' else torch.nn.ReLU(inplace=True)
        # self.dropout1 = dropout(p=p_dropout)
        self.dropout1 = torch.nn.Dropout(p=p_dropout)

    def forward(self, x):

        # reshape vector from (s, C, T) to (s, 1, C, T)
        x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])

        # input dimensions: (s, 1, C, T)

        # Block 1
        x = self.conv1(x)            # output dim: (s, F1, C, T-1)
        x = self.batch_norm1(x)
        x = self.conv2_pad(x)
        x = self.conv2(x)            # output dim: (s, D * F1, 1, T-1)
        x = self.batch_norm2(x)
        x = self.activation1(x)
        
        if ARCH_TYPE == "pool": ## Simple Feature Reduction
            x = self.pool1(x)            # output dim: (s, D * F1, 1, T // 8), NEEDED FOR FREQUENCY REDUCTION

        x = self.dropout1(x)

        # print(f'shape before reshape: {x.shape}')
        # reshape vector from (s, D * F1, 1, T // 8) to (s, 1, D * F1, T // 8)
        # x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[3])
        x = x.reshape(x.shape[0], x.shape[1], x.shape[3])
        # x = x.transpose(1,2)
        # print(f'shape after reshape: {x.shape}')

        return x


