import torch
import torch.nn as nn
import torch.nn.functional as F


class MobileNet(nn.Module):
    def __init__(self, in_channels=3, n_kernels=16, out_dim=10):
        super(MobileNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, n_kernels, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        self.conv2 = self.depthwise_separable_conv(n_kernels, 2 * n_kernels)
        self.conv3 = self.depthwise_separable_conv(2 * n_kernels, 4 * n_kernels)
        self.conv4 = self.depthwise_separable_conv(4 * n_kernels, 8 * n_kernels)

        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc1 = nn.Linear(8 * n_kernels, 2000)
        self.fc2 = nn.Linear(2000, 500)
        self.fc3 = nn.Linear(500, out_dim)

    def depthwise_separable_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))

        x = self.global_pool(x)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        o = F.relu(self.fc2(x))
        x = self.fc3(o)
        return x
