"""
MobileNetV2 for CIFAR in PyTorch.
See the paper "Inverted Residuals and Linear Bottlenecks:
Mobile Networks for Classification, Detection and Segmentation" for more details.
"""
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from dbq import DBQQuantizer
from quantized_modules import TernaryConv2d, TernaryLinear, weights_init
from wage_initializer import wage_init_, get_scale
from wage_quantizer import WAGEQuantizer


class Block(nn.Module):
    """expand + depthwise + pointwise"""

    def __init__(self, in_planes, out_planes, expansion, stride,
                 expansion_layer, expansion_bn_layer,
                 depthwise_layer, depthwise_bn_layer,
                 pointwise_layer, pointwise_bn_layer,
                 shortcut_layer, shortcut_bn_layer,
                 activation):
        super(Block, self).__init__()
        self.stride = stride

        planes = expansion * in_planes
        self.conv1 = expansion_layer(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = expansion_bn_layer(planes)
        self.conv2 = depthwise_layer(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
        self.bn2 = depthwise_bn_layer(planes)
        self.conv3 = pointwise_layer(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = pointwise_bn_layer(out_planes)
        self.activation = activation

        self.shortcut = nn.Sequential()
        if stride == 1 and in_planes != out_planes:
            self.shortcut = nn.Sequential(
                shortcut_layer(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
                shortcut_bn_layer(out_planes),
            )

    def coroutine(self, x):
        out = (yield x, self.conv1)
        out = self.activation(self.bn1(self.conv1(out)))
        out = (yield out, self.conv2)
        out = self.activation(self.bn2(self.conv2(out)))
        out = (yield out, self.conv3)
        out = self.bn3(self.conv3(out))
        if len(self.shortcut) > 0:
            x = (yield x, self.shortcut[0])
        out = out + self.shortcut(x) if self.stride == 1 else out
        return out

    def forward(self, x):
        coro = self.coroutine(x)
        out = None
        while True:
            try:
                out, _ = coro.send(out)
            except StopIteration as ex:
                return ex.value


class MobileNetV2(nn.Module):
    # (expansion, out_planes, num_blocks, stride)
    cfg = [(1, 16, 1, 1),
           (6, 24, 2, 1),  # NOTE: change stride 2 -> 1 for CIFAR10
           (6, 32, 3, 2),
           (6, 64, 4, 2),
           (6, 96, 3, 1),
           (6, 160, 3, 2),
           (6, 320, 1, 1)]

    def __init__(self, num_classes=10,
                 first_layer_type: Any = nn.Conv2d, first_bn_layer_type: Any = nn.BatchNorm2d,
                 expansion_layer: Any = nn.Conv2d, expansion_bn_layer: Any = nn.BatchNorm2d,
                 depthwise_layer: Any = nn.Conv2d, depthwise_bn_layer: Any = nn.BatchNorm2d,
                 pointwise_layer: Any = nn.Conv2d, pointwise_bn_layer: Any = nn.BatchNorm2d,
                 shortcut_layer: Any = nn.Conv2d, shortcut_bn_layer: Any = nn.BatchNorm2d,
                 conv2_layer: Any = nn.Conv2d, conv2_bn_layer: Any = nn.BatchNorm2d,
                 last_layer_type: Any = nn.Linear,
                 decompose=None, final_bias_trick=True):
        super(MobileNetV2, self).__init__()

        self.activation = F.relu

        # NOTE: change conv1 stride 2 -> 1 for CIFAR10
        self.conv1 = first_layer_type(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = first_bn_layer_type(32)

        self.layers = self._make_layers(
            32,
            expansion_layer, expansion_bn_layer,
            depthwise_layer, depthwise_bn_layer,
            pointwise_layer, pointwise_bn_layer,
            shortcut_layer, shortcut_bn_layer,
            self.activation
        )

        self.conv2 = conv2_layer(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = conv2_bn_layer(1280)

        self.final_bias_trick = final_bias_trick
        if decompose is None:
            if final_bias_trick:
                self.linear = last_layer_type(1280 + 1, num_classes, bias=False)
            else:
                self.linear = last_layer_type(1280, num_classes, bias=True)
        else:
            self.linear = nn.Sequential(
                last_layer_type(1280 + 1, decompose, bias=False),
                last_layer_type(decompose, num_classes, bias=False)
            )

        self.apply(weights_init)

    def _make_layers(self, in_planes,
                     expansion_layer, expansion_bn_layer,
                     depthwise_layer, depthwise_bn_layer,
                     pointwise_layer, pointwise_bn_layer,
                     shortcut_layer, shortcut_bn_layer,
                     activation):
        layers = []
        for expansion, out_planes, num_blocks, stride in self.cfg:
            strides = [stride] + [1] * (num_blocks - 1)
            for stride in strides:
                layers.append(Block(
                    in_planes, out_planes, expansion, stride,
                    expansion_layer, expansion_bn_layer,
                    depthwise_layer, depthwise_bn_layer,
                    pointwise_layer, pointwise_bn_layer,
                    shortcut_layer, shortcut_bn_layer,
                    activation
                ))
                in_planes = out_planes
        return nn.Sequential(*layers)

    def coroutine(self, x):
        out = (yield x, self.conv1)
        out = F.relu(self.bn1(self.conv1(out)))
        for layer in self.layers:
            out = (yield from layer.coroutine(out))
        out = (yield out, self.conv2)
        out = F.relu(self.bn2(self.conv2(out)))
        # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
        out = F.avg_pool2d(out, 4)
        out = (yield out, self.linear)
        out = out.view(out.size(0), -1)
        if self.final_bias_trick:
            out = torch.cat([out, out.new_ones((out.size(0), 1))], dim=1)
        out = self.linear(out)
        return out

    def forward(self, x):
        coro = self.coroutine(x)
        out = None
        while True:
            try:
                out, _ = coro.send(out)
            except StopIteration as ex:
                return ex.value

    def inputs_for(self, layer, x):
        coro = self.coroutine(x)
        out = None
        while True:
            out, cur_layer = coro.send(out)
            if cur_layer is layer:
                return out

    def dependent_layers(self, layer_name: str, include_depthwise=False):
        if layer_name == 'linear':
            return ['conv2']
        if layer_name == 'conv2':
            return ['layers.16.shortcut.0', 'layers.16.conv3']
        if layer_name.startswith('layers.'):
            idx = int(layer_name.split('.')[1])
            if layer_name.split('.')[2] == 'conv3':
                if include_depthwise:
                    return [f'layers.{idx}.conv2']
                else:
                    return [f'layers.{idx}.conv1']
            if layer_name.split('.')[2] == 'conv2':
                assert include_depthwise
                return [f'layers.{idx}.conv1']
            if layer_name.split('.')[2] in ['conv1', 'shortcut']:
                if idx == 0:
                    return ['conv1']
                prev_layer = self.layers[idx - 1]
                depends = [f'layers.{idx-1}.conv3']
                if prev_layer.stride == 1:
                    if len(prev_layer.shortcut) > 0:
                        depends.append(f'layers.{idx-1}.shortcut.0')
                    else:
                        depends.extend(self.dependent_layers(f'layers.{idx-1}.conv1'))
                return depends
        if layer_name == 'conv1':
            return []
        assert False, f"No such layer {layer_name}"


class WeightStandardizedConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(WeightStandardizedConv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride,
            padding, dilation, groups, bias
        )

    def forward(self, x):
        weight = self.weight
        weight_mean = weight.mean(dim=[1, 2, 3], keepdim=True)
        weight = weight - weight_mean
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) #+ 1e-5
        weight = weight / std.expand_as(weight)
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)


def ws_mobilenetv2(num_classes=10, decompose=None):
    return MobileNetV2(
        num_classes=num_classes,
        first_layer_type=WeightStandardizedConv2d,
        expansion_layer=WeightStandardizedConv2d,
        depthwise_layer=WeightStandardizedConv2d,
        pointwise_layer=WeightStandardizedConv2d,
        shortcut_layer=WeightStandardizedConv2d,
        conv2_layer=WeightStandardizedConv2d,
        decompose=decompose
    )


class ScaledConv2d(nn.Conv2d):
    def __init__(self, in_channels: int, out_channels: int,
                 kernel_size, stride=1, padding=0, dilation=1, groups: int = 1,
                 bias: bool = True, padding_mode: str = 'zeros', device=None, dtype=None,
                 wl_weight=2):
        super(ScaledConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                                           padding, dilation, groups, bias, padding_mode, device, dtype)
        self.scale = 1. / get_scale(self.weight, wl_weight, factor=1.0)

    def forward(self, x):
        return F.conv2d(x, self.weight * self.scale,
                        self.bias, self.stride, self.padding,
                        self.dilation, self.groups)


class ScaledLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None,
                 wl_weight=2):
        super(ScaledLinear, self).__init__(in_features, out_features, bias, device, dtype)
        self.scale = 1. / get_scale(self.weight, wl_weight, factor=1.0)

    def forward(self, x):
        return F.linear(x, self.weight * self.scale, self.bias)


def simpleq_mobilenetv2(
        num_classes=10,
        decompose=None,
        wl_weight=2,
        wl_activate=8,
):
    scaled_conv = partial(ScaledConv2d, wl_weight=wl_weight)
    scaled_linear = partial(ScaledLinear, wl_weight=wl_weight)
    bn_layer = (lambda x: WAGEQuantizer(wl_activate, -1))
    model = MobileNetV2(
        num_classes=num_classes, decompose=decompose,
        first_layer_type=scaled_conv, first_bn_layer_type=bn_layer,
        expansion_layer=scaled_conv, expansion_bn_layer=bn_layer,
        depthwise_layer=scaled_conv, depthwise_bn_layer=bn_layer,
        pointwise_layer=scaled_conv, pointwise_bn_layer=bn_layer,
        shortcut_layer=scaled_conv, shortcut_bn_layer=bn_layer,
        conv2_layer=scaled_conv, conv2_bn_layer=bn_layer,
        last_layer_type=scaled_linear
    )
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            wage_init_(m.weight, wl_weight, factor=1.0)
            assert m.bias is None, name
    return model


class QuantizedActivations(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, c, bitwidth):
        m = 2 ** bitwidth - 1
        result = torch.clip(torch.round(x / c * m), 0, m)
        return result / m * c

    @staticmethod
    def backward(ctx, g):
        return g, None, None


class QuantizedBatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True, device=None, dtype=None, bitwidth=8):
        super(QuantizedBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine,
            track_running_stats, device, dtype)
        self.bitwidth = bitwidth

    def forward(self, x):
        c = torch.max(self.weight * 6 + self.bias)
        result = super(nn.BatchNorm2d, self).forward(x)
        return QuantizedActivations.apply(result, c, self.bitwidth)


def dbq_mobilenetv2(
        num_classes=10,
        num_branches=2, num_branches_first=2, num_branches_last=2,
        quantizer_type=DBQQuantizer,
        quantize_first_layer=False,
        quantize_last_layer=False,
        quantize_expansion_layer=True,
        quantize_depthwise_layer=False,
        quantize_pointwise_layer=True,
        quantize_shortcut_layer=False,
        quantize_conv2_layer=True,
        quantize_activations=False,
        decompose=None,
        gen_matrix_every_step=False
):
    if quantize_first_layer:
        first_layer_type = partial(
            TernaryConv2d, num_branches=num_branches_first, quantizer_type=quantizer_type,
            gen_matrix_every_step=gen_matrix_every_step
        )
        first_bn_layer_type = QuantizedBatchNorm2d if quantize_activations else nn.BatchNorm2d
    else:
        first_layer_type = nn.Conv2d
        first_bn_layer_type = nn.BatchNorm2d

    if quantize_expansion_layer:
        expansion_layer = partial(
            TernaryConv2d, num_branches=num_branches, quantizer_type=quantizer_type,
            gen_matrix_every_step=gen_matrix_every_step
        )
        expansion_bn_layer = QuantizedBatchNorm2d if quantize_activations else nn.BatchNorm2d
    else:
        expansion_layer = nn.Conv2d
        expansion_bn_layer = nn.BatchNorm2d

    if quantize_depthwise_layer:
        depthwise_layer = partial(
            TernaryConv2d, num_branches=num_branches, quantizer_type=quantizer_type,
            gen_matrix_every_step=gen_matrix_every_step
        )
        depthwise_bn_layer = QuantizedBatchNorm2d if quantize_activations else nn.BatchNorm2d
    else:
        depthwise_layer = nn.Conv2d
        depthwise_bn_layer = nn.BatchNorm2d

    if quantize_pointwise_layer:
        pointwise_layer = partial(
            TernaryConv2d, num_branches=num_branches, quantizer_type=quantizer_type,
            gen_matrix_every_step=gen_matrix_every_step
        )
        pointwise_bn_layer = QuantizedBatchNorm2d if quantize_activations else nn.BatchNorm2d
    else:
        pointwise_layer = nn.Conv2d
        pointwise_bn_layer = nn.BatchNorm2d

    if quantize_shortcut_layer:
        shortcut_layer = partial(
            TernaryConv2d, num_branches=num_branches, quantizer_type=quantizer_type,
            gen_matrix_every_step=gen_matrix_every_step
        )
        shortcut_bn_layer = QuantizedBatchNorm2d if quantize_activations else nn.BatchNorm2d
    else:
        shortcut_layer = nn.Conv2d
        shortcut_bn_layer = nn.BatchNorm2d

    if quantize_conv2_layer:
        conv2_layer = partial(
            TernaryConv2d, num_branches=num_branches, quantizer_type=quantizer_type,
            gen_matrix_every_step=gen_matrix_every_step
        )
        conv2_bn_layer = QuantizedBatchNorm2d if quantize_activations else nn.BatchNorm2d
    else:
        conv2_layer = nn.Conv2d
        conv2_bn_layer = nn.BatchNorm2d

    last_layer_type = partial(
        TernaryLinear, num_branches=num_branches_last, quantizer_type=quantizer_type,
        gen_matrix_every_step=gen_matrix_every_step
    ) if quantize_last_layer else nn.Linear

    return MobileNetV2(
        num_classes=num_classes,
        first_layer_type=first_layer_type, first_bn_layer_type=first_bn_layer_type,
        expansion_layer=expansion_layer, expansion_bn_layer=expansion_bn_layer,
        depthwise_layer=depthwise_layer, depthwise_bn_layer=depthwise_bn_layer,
        pointwise_layer=pointwise_layer, pointwise_bn_layer=pointwise_bn_layer,
        shortcut_layer=shortcut_layer, shortcut_bn_layer=shortcut_bn_layer,
        conv2_layer=conv2_layer, conv2_bn_layer=conv2_bn_layer,
        last_layer_type=last_layer_type,
        decompose=decompose
    )
