import torch
import torch.nn as nn
ARCH_TYPE = "nofred" # "pool", "strconv", "nofred"

class ConvEmbedderVision(nn.Module):
    """
    edgeEEGNet
    """
    def __init__(self, F1=8, D=2, C=64, T=1280, p_dropout=0.5, reg_rate=0.25,
                 activation='relu', constrain_w=False, dropout_type='TimeDropout2D',
                 permuted_flatten=False):
        """
        F1:           Number of spectral filters
        D:            Number of spacial filters (per spectral filter), F2 = F1 * D
        F2:           Number or None. If None, then F2 = F1 * D
        C:            Number of EEG channels
        T:            Number of time samples
        N:            Number of classes
        p_dropout:    Dropout Probability
        reg_rate:     Regularization (L1) of the final linear layer (fc)
                      This parameter is ignored when constrain_w is not asserted
        activation:   string, either 'elu' or 'relu'
        constrain_w:  bool, if True, constrain weights of spatial convolution and final fc-layer
        dropout_type: string, either 'dropout', 'SpatialDropout2d' or 'TimeDropout2D'
        permuted_flatten: bool, if True, use the permuted flatten to make the model_old keras compliant
        """
        super(ConvEmbedderVision, self).__init__()

        # check the activation input
        activation = activation.lower()
        assert activation in ['elu', 'relu']

        # store local values
        self.F1, self.D, self.C, self.T = (F1, D, C, T)
        self.p_dropout, self.reg_rate, self.activation = (p_dropout, reg_rate, activation)
        self.constrain_w, self.dropout_type = (constrain_w, dropout_type)


        # Block 1
        self.conv1 = torch.nn.Conv2d(1, F1, (C, 1), bias=False, padding='same')
        self.upsample = torch.nn.Upsample(1280)
        self.batch_norm1 = torch.nn.BatchNorm2d(F1, momentum=0.01, eps=0.001)
        self.activation1 = torch.nn.ELU(inplace=True) if activation == 'elu' else torch.nn.ReLU(inplace=True)
        # self.dropout1 = dropout(p=p_dropout)
        self.dropout1 = torch.nn.Dropout(p=p_dropout)
        self.linear = torch.nn.Linear(64*16*4,16)

    def forward(self, x):

        # reshape vector from (s, C, T) to (s, 1, C, T)
        x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])

        # input dimensions: (s, 1, C, T)

        # Block 1
        x.transpose(2,3)
        x = self.conv1(x)            # output dim: (s, F1, C, T)
        x = self.batch_norm1(x)
        x = self.activation1(x)



        # x = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3])
        x = x.transpose(2,3)
        x = x.transpose(1,3)


        x_ = torch.zeros(x.shape[0], 80, 2, x.shape[1], x.shape[2]//80, x.shape[3]//2).to(x.device)
        for time in range(80):
            for freq in range(2):
                x_[:, time, freq, :, :, :] = x[:, :, time*16:(time+1)*16, freq*4:(freq+1)*4]
        #apply linear layer
        x_ = x_.reshape(x.shape[0], 80, 2, 64*16*4)
        x = self.linear(x_)
        x = x.reshape(x.shape[0], 160, 16)
        return x


