"""
    MobileNetV2 for ImageNet-1K, implemented in PyTorch.
    Original paper: 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' https://arxiv.org/abs/1801.04381.
    Adapted from https://github.com/osmr/imgclsmob/blob/master/pytorch/pytorchcv/models/mobilenetv2.py
"""

import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.nn.modules import module
from Layers.layers import conv1x1, conv1x1_block, conv3x3_block, dwconv3x3_block, Conv2d, BatchNorm2d, sparse_initialize, PaddedIdentityAdd, Linear


class LinearBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expansion, remove_exp_conv, groups):
        super(LinearBottleneck, self).__init__()
        self.residual = (in_channels == out_channels) and (stride == 1)
        # self.residual = True
        self.stride = stride
        mid_channels = in_channels * 6 if expansion else in_channels
        self.use_exp_conv = (expansion or (not remove_exp_conv))
        self.groups = groups

        self.conv1 = conv1x1_block(in_channels=in_channels, out_channels=mid_channels, activation="relu6", groups=groups, shuffle=True)
        self.conv2 = dwconv3x3_block(in_channels=mid_channels, out_channels=mid_channels, stride=stride, activation="relu6")
        self.conv3 = conv1x1_block(in_channels=mid_channels, out_channels=out_channels, activation='relu6', groups=groups)
        if self.residual:
            self.identity = PaddedIdentityAdd()

    def forward(self, x):
        if self.residual:
            identity = x
        assert self.use_exp_conv
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        if self.residual:
            # if self.stride != 1:
            #     identity = F.avg_pool2d(identity, self.stride)
            x = self.identity(x, identity)
        return x


class MobileNetV2(nn.Module):
    def __init__(self,
                 channels,
                 init_block_channels,
                 final_block_channels,
                 remove_exp_conv,
                 in_channels=3,
                 in_size=(224, 224),
                 num_classes=1000,
                 groups=1):
        super(MobileNetV2, self).__init__()
        self.in_size = in_size
        self.num_classes = num_classes

        self.features = nn.Sequential()
        self.features.add_module("init_block", conv3x3_block(in_channels=in_channels, out_channels=init_block_channels, stride=2, activation="relu6"))
        in_channels = init_block_channels
        for i, channels_per_stage in enumerate(channels):
            stage = nn.Sequential()
            for j, out_channels in enumerate(channels_per_stage):
                stride = 2 if (j == 0) and (i != 0) else 1
                expansion = (i != 0) or (j != 0)
                stage.add_module(
                    "unit{}".format(j + 1),
                    LinearBottleneck(in_channels=in_channels,
                                     out_channels=out_channels,
                                     stride=stride,
                                     expansion=expansion,
                                     remove_exp_conv=remove_exp_conv,
                                     groups=groups))
                in_channels = out_channels
            self.features.add_module("stage{}".format(i + 1), stage)
        self.features.add_module("final_block", conv1x1_block(in_channels=in_channels, out_channels=final_block_channels, activation="relu6"))
        in_channels = final_block_channels
        self.features.add_module("final_pool", nn.AvgPool2d(kernel_size=7, stride=1))

        # self.output = conv1x1(in_channels=in_channels,
        #                       out_channels=num_classes,
        #                       bias=False)
        self.output = Linear(in_features=in_channels, out_features=num_classes, bias=False)

        self._initialize_weights(sparse_init=False)
        self._set_prune_types()

    def _initialize_weights(self, sparse_init=True):
        if sparse_init:
            sparse_initialize(self)
        else:
            for name, module in self.named_modules():
                if isinstance(module, Conv2d):
                    init.kaiming_uniform_(module.weight)
                    if module.bias is not None:
                        init.constant_(module.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.output(x)
        return x

    def clean(self):
        prev_out = 32
        for n, m in self.named_modules():
            if isinstance(m, LinearBottleneck):
                if not m.residual and prev_out is not None:
                    in_channels = prev_out
                    out_channels = m.conv1.conv.out_channels
                    activation = None if not hasattr(m.conv1, 'activ') else m.conv1.activ
                    m.conv1.__init__(in_channels, out_channels, m.conv1.conv.kernel_size, m.conv1.conv.stride, m.conv1.conv.groups,
                                     m.conv1.conv.padding, m.conv1.conv.bias, activation, m.conv1.shuffle)
                    prev_out = m.conv3.conv.out_channels
                else:
                    prev_out = m.identity.out_channels

        with torch.no_grad():
            self.cuda()
            self._initialize_weights()
            self(torch.randn(64, 3, 224, 224).cuda())

    def _set_prune_types(self):
        self.pruned_types = {}
        unit_list = []
        for name, module in self.named_modules():
            if isinstance(module, LinearBottleneck):
                unit_list.append(module)
        length = len(unit_list)
        for index, unit in enumerate(unit_list):
            if unit.residual:
                self.pruned_types[unit.conv1] = 'in'
            else:
                self.pruned_types[unit.conv1] = 'nonresidual_in'
            self.pruned_types[unit.conv3] = 'vgg_out'
        self.pruned_types[self.features.final_block] = 'only_out'


def get_mobilenetv2(groups, width_scale, remove_exp_conv=False, depth_factor=1, **kwargs):
    """
    Create MobileNetV2 model with specific parameters.
    Parameters:
    """

    init_block_channels = 32
    final_block_channels = 1280
    layers = [1, 2, 3, 4, 3, 3, 1]
    downsample = [0, 1, 1, 1, 0, 1, 0]
    channels_per_layers = [16, 24, 32, 64, 96, 160, 320]

    layers = [int(math.ceil(l * depth_factor)) for l in layers]

    from functools import reduce
    channels = reduce(lambda x, y: x + [[y[0]] * y[1]]
                      if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], zip(channels_per_layers, layers, downsample), [[]])

    if width_scale != 1.0:
        channels = [[int(cij * width_scale) for cij in ci] for ci in channels]
        init_block_channels = int(init_block_channels * width_scale)
        if width_scale > 1.0:
            final_block_channels = int(final_block_channels * width_scale)

    net = MobileNetV2(channels=channels,
                      init_block_channels=init_block_channels,
                      final_block_channels=final_block_channels,
                      remove_exp_conv=remove_exp_conv,
                      groups=groups,
                      **kwargs)

    return net


def mobilenetV2(groups=1, width_factor=1.0, **kwargs):
    """
    1.0 MobileNetV2-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,'
    https://arxiv.org/abs/1801.04381.
    """
    return get_mobilenetv2(groups=groups, width_scale=width_factor, **kwargs)


def _test():
    import torch

    pretrained = False

    models = [mobilenetV2]

    for model in models:

        net = model(groups=1)
        x = torch.ones(1, 3, 224, 224)
        y = net(x)
        t = torch.Tensor([1]).long()
        loss = nn.CrossEntropyLoss()(y, t)
        loss.backward()
        assert (tuple(y.size()) == (1, 1000))


if __name__ == "__main__":
    _test()
