#%%
import torch

class Flatten(torch.nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        # out = input.view(batch_size,-1)
        out = input.contiguous().view(batch_size, -1)
        return out
    

class AlexNet(torch.nn.Module):
    def __init__(self, output_dim,device = 'cpu'):
        super().__init__()
        self.device = device
        self.layer = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels = 3,out_channels = 64,kernel_size= 3,stride =  1, padding = 1,bias = False),  # in_channels, out_channels, kernel_size, stride, padding
            torch.nn.BatchNorm2d(64,momentum=0.9),
            torch.nn.MaxPool2d(2), 
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 192, 3, padding=1,bias = False),
            torch.nn.BatchNorm2d(192,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(192, 384, 3, padding=1,bias = False),
            torch.nn.BatchNorm2d(384,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(384, 256, 3, padding=1),
            torch.nn.BatchNorm2d(256,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256, 256, 3, padding=1,bias = False),
            torch.nn.BatchNorm2d(256,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            Flatten(),
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(256, output_dim)
        )

    def forward(self, x):
        return self.layer(x)
    
def alexnet():

    return AlexNet(10)