import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv_params = nn.Sequential(
                nn.Conv2d(1, 20, kernel_size=5),
                nn.MaxPool2d(2),
                nn.ReLU(),
                nn.Conv2d(20, 50, kernel_size=5),
                nn.Dropout2d(p=0.5),
                nn.MaxPool2d(2),
                nn.ReLU(),
                )
        self.fc_params = nn.Sequential(nn.Linear(50*4*4, 256), nn.ReLU(), nn.Dropout(p=0.5))
        self.__in_features = 256


    def forward(self, x,  mask=None):
        x = self.conv_params(x)
        z = torch.flatten(x, start_dim=-2,end_dim=-1)

        if mask is not None:
            mask = mask.unsqueeze(1).repeat(1, 16, 1).transpose(-2,-1)
            z = z * mask
        x = z.view(-1, 50 * 4 * 4)
        x = self.fc_params(x)
        return x, z

    def output_num(self):
        return self.__in_features

    def is_patch_based(self):
        return False




