import torch
from torch import nn
import numpy as np
# new package
from braindecode.models import EEGNetv4

class EEGNet(nn.Module):
    """ The EEGNet model

    This is a really small model ~3k parameters.

    This is from the Braindecode package:
        https://github.com/braindecode/braindecode
    
    Args:
        dataset (Multi_Domain_Dataset): dataset that we will be training on
        model_hparams (dict): The hyperparameters for the model.

    Attributes:
        input_size (int): The size of the inputs to the model (for a single time step).
        output_size (int): The size of the outputs of the model (number of classes).
        seq_len (int): The length of the sequences.
    """
    
    def __init__(self, dataset, model_hparams):
        super(EEGNet, self).__init__()

        # Save stuff
        self.device = model_hparams['device']
        self.input_size = np.prod(dataset.INPUT_SHAPE)
        self.output_size = dataset.OUTPUT_SIZE
        self.seq_len = dataset.SEQ_LEN

        scale = 1
        self.model = EEGNetv4(
            self.input_size,
            self.output_size,
            input_window_samples=self.seq_len,
            final_conv_length='auto',
            F1=8*scale,
            D=2*scale,
            F2=16*scale*scale, #usually set to F1*D (?)
            kernel_length=64*scale,
            third_kernel_size=(8, 4),
            drop_prob=0.05,
        )

        self.classifier = nn.Sequential(
            self.model.conv_classifier,
            self.model.permute_back
        )
        del self.model.conv_classifier
        del self.model.softmax
        del self.model.permute_back
        del self.model.squeeze
        
    def forward(self, input):

        # Forward pass
        features = self.model(input.permute((0, 2, 1)))
        out = self.classify(features)

        # Remove all extra dimension and Add the time prediction dimension
        out, features = torch.flatten(out, start_dim=1), torch.flatten(features, start_dim=1)
        out, features = out.unsqueeze(1), features.unsqueeze(1)

        return out, features   

    def classify(self, features):
        features = features.reshape(features.shape[0], 16, 1, 23)
        output = self.classifier(features)
        output = torch.flatten(output, start_dim=1).unsqueeze(1)
        return output

    def get_classifier_network(self):
        return self.classifier