import torch.nn as nn
import torch.nn.functional as F

class ImageNetShallowMLP(nn.Module):
    def __init__(self):
        super(ImageNetShallowMLP, self).__init__()
        self.flatten = nn.Flatten()
        self.layers = nn.Sequential(
            nn.Linear(150528, 8192),
            nn.BatchNorm1d(8192),
            nn.ReLU(),
            nn.Linear(8192, 8192),
            nn.BatchNorm1d(8192),
            nn.ReLU(),
            nn.Linear(8192, 8192),
            nn.BatchNorm1d(8192),
            nn.ReLU(),
            nn.Linear(8192, 1000),
        )

    def forward(self, x):
        x = self.flatten(x)
        x = self.layers(x)
        return x

class ImageNetNarrowMLP(nn.Module):
    def __init__(self):
        super(ImageNetNarrowMLP, self).__init__()
        self.flatten = nn.Flatten()
        self.initial_layer = nn.Sequential(
            nn.Linear(150528, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU()
        )
        self.intermediate_layers = nn.Sequential(
            *[nn.Sequential(
                nn.Linear(2048 if i == 0 else 1024, 1024),
                nn.BatchNorm1d(1024),
                nn.ReLU()
              ) for i in range(47)]
        )
        self.output_layer = nn.Sequential(
            nn.Linear(1024, 1000),
        )

    def forward(self, x):
        x = self.flatten(x)
        x = self.initial_layer(x)
        x = self.intermediate_layers(x)
        x = self.output_layer(x)
        return x