"""
    MobileNetV2 for ImageNet-1K, implemented in PyTorch.
    Original paper: 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,' https://arxiv.org/abs/1801.04381.
    Adapted from https://github.com/osmr/imgclsmob/blob/master/pytorch/pytorchcv/models/mobilenetv2.py
"""

import os
import math
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn.modules import module
from Layers.layers import conv1x1, conv1x1_block, conv3x3_block, dwconv3x3_block, Conv2d, sparse_initialize, Linear
from Models.imagenet_mobilenetv2 import LinearBottleneck


class MobileNetV2(nn.Module):
    def __init__(self,
                 channels,
                 init_block_channels,
                 final_block_channels,
                 remove_exp_conv,
                 in_channels=3,
                 in_size=(32, 32),
                 num_classes=10,
                 groups=1):
        super(MobileNetV2, self).__init__()
        self.in_size = in_size
        self.num_classes = num_classes

        self.features = nn.Sequential()
        self.features.add_module("init_block", conv3x3_block(in_channels=in_channels, out_channels=init_block_channels, stride=1, activation="relu6"))
        in_channels = init_block_channels
        for i, channels_per_stage in enumerate(channels):
            stage = nn.Sequential()
            for j, out_channels in enumerate(channels_per_stage):
                stride = 2 if (j == 0) and (i != 0) else 1
                expansion = (i != 0) or (j != 0)
                stage.add_module(
                    "unit{}".format(j + 1),
                    LinearBottleneck(in_channels=in_channels,
                                     out_channels=out_channels,
                                     stride=stride,
                                     expansion=expansion,
                                     remove_exp_conv=remove_exp_conv,
                                     groups=groups))
                in_channels = out_channels
            self.features.add_module("stage{}".format(i + 1), stage)
        self.features.add_module("final_block", conv1x1_block(in_channels=in_channels, out_channels=final_block_channels, activation="relu6"))
        in_channels = final_block_channels
        self.features.add_module("final_pool", nn.AvgPool2d(kernel_size=4, stride=1))

        # self.output = conv1x1(in_channels=in_channels,
        #                       out_channels=num_classes,
        #                       bias=False)
        self.output = Linear(in_features=in_channels, out_features=num_classes, bias=False)

        self._initialize_weights()
        self._set_prune_types()

    def _initialize_weights(self, sparse_init=True):
        if sparse_init:
            sparse_initialize(self)
        else:
            for name, module in self.named_modules():
                if isinstance(module, Conv2d):
                    init.kaiming_uniform_(module.weight)
                    if module.bias is not None:
                        init.constant_(module.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.output(x)
        return x

    def _set_prune_types(self):
        self.pruned_types = {}
        unit_list = []
        for name, module in self.named_modules():
            if isinstance(module, LinearBottleneck):
                unit_list.append(module)
        length = len(unit_list)
        for index, unit in enumerate(unit_list):
            if unit.residual:
                self.pruned_types[unit.conv1] = 'in'
            else:
                self.pruned_types[unit.conv1] = 'nonresidual_in'
            self.pruned_types[unit.conv3] = 'vgg_out'
        self.pruned_types[self.features.final_block] = 'only_out'

    def clean(self):
        prev_out = 32
        for n, m in self.named_modules():
            if isinstance(m, LinearBottleneck):
                if not m.residual and prev_out is not None:
                    in_channels = prev_out
                    out_channels = m.conv1.conv.out_channels
                    activation = None if not hasattr(m.conv1, 'activ') else m.conv1.activ
                    m.conv1.__init__(in_channels, out_channels, m.conv1.conv.kernel_size, m.conv1.conv.stride, m.conv1.conv.groups,
                                     m.conv1.conv.padding, m.conv1.conv.bias, activation, m.conv1.shuffle)
                    prev_out = m.conv3.conv.out_channels
                else:
                    prev_out = m.identity.out_channels
        with torch.no_grad():
            self.cuda()
            self._initialize_weights()
            self(torch.randn(64, 3, 32, 32).cuda())


def get_mobilenetv2(groups, width_scale, remove_exp_conv=False, depth_factor=1, **kwargs):
    """
    Create MobileNetV2 model with specific parameters.
    Parameters:
    """

    init_block_channels = 32
    final_block_channels = 1280
    layers = [1, 2, 3, 4, 3, 3, 1]
    # removed the second downsample
    downsample = [0, 0, 1, 1, 0, 1, 0]
    channels_per_layers = [16, 24, 32, 64, 96, 160, 320]

    layers = [int(math.ceil(l * depth_factor)) for l in layers]

    from functools import reduce
    channels = reduce(lambda x, y: x + [[y[0]] * y[1]]
                      if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], zip(channels_per_layers, layers, downsample), [[]])

    if width_scale != 1.0:
        channels = [[int(cij * width_scale) for cij in ci] for ci in channels]
        init_block_channels = int(init_block_channels * width_scale)
        if width_scale > 1.0:
            final_block_channels = int(final_block_channels * width_scale)

    net = MobileNetV2(channels=channels,
                      init_block_channels=init_block_channels,
                      final_block_channels=final_block_channels,
                      remove_exp_conv=remove_exp_conv,
                      groups=groups,
                      **kwargs)

    return net


def mobilenetV2(groups=1, width_factor=1.0, **kwargs):
    """
    1.0 MobileNetV2-224 model from 'MobileNetV2: Inverted Residuals and Linear Bottlenecks,'
    https://arxiv.org/abs/1801.04381.
    """
    return get_mobilenetv2(groups=groups, width_scale=width_factor, **kwargs)


def _test():
    import torch

    pretrained = False

    models = [mobilenetV2]

    for model in models:

        net = model(groups=1)
        x = torch.ones(1, 3, 32, 32)
        y = net(x)
        t = torch.Tensor([1]).long()
        loss = nn.CrossEntropyLoss()(y, t)
        loss.backward()
        print(y.shape)
        assert (tuple(y.size()) == (1, 10))


if __name__ == "__main__":
    _test()
