# -*- coding: utf-8 -*-
import torch.nn as nn


__all__ = ['AlexNetBN', 'alexnet_bn']


class AlexNetBN(nn.Module):
    """An AlexNet with Batch Normalization.
    It uses the architecture from the original paper.
    """

    def __init__(self, num_classes=1000):
        super(AlexNetBN, self).__init__()
        # define functions.
        self.features = nn.Sequential(
            # conv layer 1.
            nn.Conv2d(
                in_channels=3, out_channels=96, kernel_size=11, stride=4,
                padding=0),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=96),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # conv layer 2.
            nn.Conv2d(
                in_channels=96, out_channels=256, kernel_size=5,
                padding=2, groups=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=192),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # conv layer 3.
            nn.Conv2d(
                in_channels=256, out_channels=384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=384),
            nn.Conv2d(
                in_channels=384, out_channels=384, kernel_size=3,
                padding=1, groups=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=384),
            # conv layer 4.
            nn.Conv2d(
                in_channels=384, out_channels=256, kernel_size=3,
                padding=1, groups=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=256)
        )
        self.classifier = nn.Sequential(
            nn.Linear(in_features=256 * 6 * 6, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(in_features=4096, out_features=num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 256 * 6 * 6)
        x = self.classifier(x)
        return x


def alexnet_bn(pretrained=False, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = AlexNetBN(**kwargs)
    return model
