import torch.nn as nn
import math
from collections import OrderedDict


__all__ = ['MobileNetV2']


def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


def conv_3x3_bn(inp, oup, stride):
    return nn.Sequential(OrderedDict([
        ('conv', nn.Conv2d(inp, oup, 3, stride, 1, bias=False)),
        ('bn', nn.BatchNorm2d(oup)),
        ('relu6', nn.ReLU6(inplace=True))
    ]))


def conv_1x1_bn(inp, oup):
    return nn.Sequential(OrderedDict([
        ('conv', nn.Conv2d(inp, oup, 1, 1, 0, bias=False)),
        ('bn', nn.BatchNorm2d(oup)),
        ('relu6', nn.ReLU6(inplace=True))
    ]))


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.identity = stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(OrderedDict([
                ('conv1', nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False)),
                ('bn1', nn.BatchNorm2d(hidden_dim)),
                ('relu6', nn.ReLU6(inplace=True)),
                ('conv2', nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False)),
                ('bn2', nn.BatchNorm2d(oup)),
            ]))
        else:
            self.conv = nn.Sequential(OrderedDict([
                ('conv1', nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False)),
                ('bn1', nn.BatchNorm2d(hidden_dim)),
                ('relu6_1', nn.ReLU6(inplace=True)),
                ('conv2', nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False)),
                ('bn2', nn.BatchNorm2d(hidden_dim)),
                ('relu6_2', nn.ReLU6(inplace=True)),
                ('conv3', nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False)),
                ('bn3', nn.BatchNorm2d(oup)),
            ]))

    def forward(self, x):
        if self.identity:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, dataset, num_classes=200, width_mult=1.):
        super(MobileNetV2, self).__init__()
        # setting of inverted residual blocks
        self.cfgs = [
            # t, c, n, s
            [1,  16, 1, 1],
            [6,  24, 2, 2],
            [6,  32, 3, 2],
            [6,  64, 4, 2],
            [6,  96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        # building first layer
        input_channel = _make_divisible(32 * width_mult, 4 if width_mult == 0.1 else 8)
        # layers = [conv_3x3_bn(3, input_channel, 2)]
        layers = [conv_3x3_bn(3, input_channel, 1)] # 64 input -> change stride to 1 (original: 1)
        # building inverted residual blocks
        block = InvertedResidual
        for t, c, n, s in self.cfgs:
            output_channel = _make_divisible(c * width_mult, 4 if width_mult == 0.1 else 8)
            for i in range(n):
                layers.append(block(input_channel, output_channel, s if i == 0 else 1, t))
                input_channel = output_channel
        self.features = nn.Sequential(*layers)
        # building last several layers
        output_channel = _make_divisible(1280 * width_mult, 4 if width_mult == 0.1 else 8) if width_mult > 1.0 else 1280
        self.conv = conv_1x1_bn(input_channel, output_channel)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(output_channel, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.conv(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x