"""
MobileNetV2 for CIFAR in PyTorch.
See the paper "Inverted Residuals and Linear Bottlenecks:
Mobile Networks for Classification, Detection and Segmentation" for more details.
"""
import math
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from dbq import DBQQuantizer
from quantized_modules import TernaryConv2d, TernaryLinear, weights_init
from wage_initializer import wage_init_, get_scale
from wage_quantizer import WAGEQuantizer


class BinAct(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return torch.sign(x)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0), None


class LearnableBias(nn.Module):
    def __init__(self, out_chn):
        super(LearnableBias, self).__init__()
        self.bias = nn.Parameter(torch.zeros(1, out_chn, 1, 1), requires_grad=True)

    def forward(self, x):
        out = x + self.bias.expand_as(x)
        return out


class Block(nn.Module):
    """expand + depthwise + pointwise"""

    def __init__(self, in_planes, out_planes, expansion, stride,
                 expansion_layer, expansion_bn_layer,
                 depthwise_layer, depthwise_bn_layer,
                 pointwise_layer, pointwise_bn_layer,
                 shortcut_layer, shortcut_bn_layer):
        super(Block, self).__init__()
        self.stride = stride

        planes = expansion * in_planes
        self.move11 = LearnableBias(in_planes)
        self.conv1 = expansion_layer(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = expansion_bn_layer(planes)
        self.move12 = LearnableBias(planes)
        self.act1 = nn.PReLU(planes)
        self.move13 = LearnableBias(planes)

        #self.move21 = LearnableBias(planes)
        self.conv2 = depthwise_layer(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
        self.bn2 = depthwise_bn_layer(planes)
        #self.move22 = LearnableBias(planes)
        self.act2 = nn.PReLU(planes)
        #self.move23 = LearnableBias(planes)

        self.move31 = LearnableBias(planes)
        self.conv3 = pointwise_layer(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = pointwise_bn_layer(out_planes)

    def coroutine(self, x):
        out = BinAct.apply(self.move11(x))
        out = (yield out, self.conv1)
        out = self.conv1(out)
        out = self.bn1(out)

        # skip connection
        reps = math.ceil(out.shape[1] / x.shape[1])
        shortcut_x = torch.tile(x, (1, reps, 1, 1))[:, :out.shape[1]]
        out = out + shortcut_x

        out = self.move12(out)
        out = self.act1(out)
        out = self.move13(out)

        shortcut_x = out
        #out = self.move21(out)
        #out = binact(out)
        out = (yield out, self.conv2)
        out = self.conv2(out)
        out = self.bn2(out)

        # skip connection
        if self.stride == 2:
            shortcut_x = F.avg_pool2d(shortcut_x, 2, 2)
        out = out + shortcut_x

        #out = self.move22(out)
        out = self.act2(out)
        #out = self.move23(out)

        out = BinAct.apply(self.move31(out))
        out = (yield out, self.conv3)
        out = self.bn3(self.conv3(out))

        # skip connection
        shortcut_x = x
        if self.stride == 2:
            shortcut_x = F.avg_pool2d(shortcut_x, 2, 2)
        if shortcut_x.shape[1] != out.shape[1]:
            reps = math.ceil(out.shape[1] / shortcut_x.shape[1])
            shortcut_x = torch.tile(shortcut_x, (1, reps, 1, 1))[:, :out.shape[1]]
        out = out + shortcut_x
        return out

    def forward(self, x):
        coro = self.coroutine(x)
        out = None
        while True:
            try:
                out, _ = coro.send(out)
            except StopIteration as ex:
                return ex.value


class MobileNetV2(nn.Module):
    # (expansion, out_planes, num_blocks, stride)
    cfg = [(1, 16, 1, 1),
           (6, 24, 2, 1),  # NOTE: change stride 2 -> 1 for CIFAR10
           (6, 32, 3, 2),
           (6, 64, 4, 2),
           (6, 96, 3, 1),
           (6, 160, 3, 2),
           (6, 320, 1, 1)]

    def __init__(self, num_classes=10,
                 first_layer_type: Any = nn.Conv2d, first_bn_layer_type: Any = nn.BatchNorm2d,
                 expansion_layer: Any = nn.Conv2d, expansion_bn_layer: Any = nn.BatchNorm2d,
                 depthwise_layer: Any = nn.Conv2d, depthwise_bn_layer: Any = nn.BatchNorm2d,
                 pointwise_layer: Any = nn.Conv2d, pointwise_bn_layer: Any = nn.BatchNorm2d,
                 shortcut_layer: Any = nn.Conv2d, shortcut_bn_layer: Any = nn.BatchNorm2d,
                 conv2_layer: Any = nn.Conv2d, conv2_bn_layer: Any = nn.BatchNorm2d,
                 last_layer_type: Any = nn.Linear,
                 decompose=None, final_bias_trick=True):
        super(MobileNetV2, self).__init__()

        # NOTE: change conv1 stride 2 -> 1 for CIFAR10
        self.conv1 = first_layer_type(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = first_bn_layer_type(32)

        self.layers = self._make_layers(
            32,
            expansion_layer, expansion_bn_layer,
            depthwise_layer, depthwise_bn_layer,
            pointwise_layer, pointwise_bn_layer,
            shortcut_layer, shortcut_bn_layer
        )

        self.conv2 = conv2_layer(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = conv2_bn_layer(1280)

        self.final_bias_trick = final_bias_trick
        if decompose is None:
            if final_bias_trick:
                self.linear = last_layer_type(1280 + 1, num_classes, bias=False)
            else:
                self.linear = last_layer_type(1280, num_classes, bias=True)
        else:
            self.linear = nn.Sequential(
                last_layer_type(1280 + 1, decompose, bias=False),
                last_layer_type(decompose, num_classes, bias=False)
            )

        self.apply(weights_init)

    def _make_layers(self, in_planes,
                     expansion_layer, expansion_bn_layer,
                     depthwise_layer, depthwise_bn_layer,
                     pointwise_layer, pointwise_bn_layer,
                     shortcut_layer, shortcut_bn_layer):
        layers = []
        for expansion, out_planes, num_blocks, stride in self.cfg:
            strides = [stride] + [1] * (num_blocks - 1)
            for stride in strides:
                layers.append(Block(
                    in_planes, out_planes, expansion, stride,
                    expansion_layer, expansion_bn_layer,
                    depthwise_layer, depthwise_bn_layer,
                    pointwise_layer, pointwise_bn_layer,
                    shortcut_layer, shortcut_bn_layer,
                ))
                in_planes = out_planes
        return nn.Sequential(*layers)

    def coroutine(self, x):
        out = (yield x, self.conv1)
        out = F.relu(self.bn1(self.conv1(out)))
        for layer in self.layers:
            out = (yield from layer.coroutine(out))
        out = (yield out, self.conv2)
        out = F.relu(self.bn2(self.conv2(out)))
        # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
        out = F.avg_pool2d(out, 4)
        out = (yield out, self.linear)
        out = out.view(out.size(0), -1)
        if self.final_bias_trick:
            out = torch.cat([out, out.new_ones((out.size(0), 1))], dim=1)
        out = self.linear(out)
        return out

    def forward(self, x):
        coro = self.coroutine(x)
        out = None
        while True:
            try:
                out, _ = coro.send(out)
            except StopIteration as ex:
                return ex.value

    def inputs_for(self, layer, x):
        coro = self.coroutine(x)
        out = None
        while True:
            out, cur_layer = coro.send(out)
            if cur_layer is layer:
                return out

    def dependent_layers(self, layer_name: str, include_depthwise=False):
        if layer_name == 'linear':
            return ['conv2']
        if layer_name == 'conv2':
            return ['layers.16.shortcut.0', 'layers.16.conv3']
        if layer_name.startswith('layers.'):
            idx = int(layer_name.split('.')[1])
            if layer_name.split('.')[2] == 'conv3':
                if include_depthwise:
                    return [f'layers.{idx}.conv2']
                else:
                    return [f'layers.{idx}.conv1']
            if layer_name.split('.')[2] == 'conv2':
                assert include_depthwise
                return [f'layers.{idx}.conv1']
            if layer_name.split('.')[2] in ['conv1', 'shortcut']:
                if idx == 0:
                    return ['conv1']
                prev_layer = self.layers[idx - 1]
                depends = [f'layers.{idx-1}.conv3']
                if prev_layer.stride == 1:
                    if len(prev_layer.shortcut) > 0:
                        depends.append(f'layers.{idx-1}.shortcut.0')
                    else:
                        depends.extend(self.dependent_layers(f'layers.{idx-1}.conv1'))
                return depends
        if layer_name == 'conv1':
            return []
        assert False, f"No such layer {layer_name}"


class WeightStandardizedConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(WeightStandardizedConv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride,
            padding, dilation, groups, bias
        )

    def forward(self, x):
        weight = self.weight
        weight_mean = weight.mean(dim=[1, 2, 3], keepdim=True)
        weight = weight - weight_mean
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) #+ 1e-5
        weight = weight / std.expand_as(weight)
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)


def ws_mobilenetv2(num_classes=10, decompose=None):
    return MobileNetV2(
        num_classes=num_classes,
        first_layer_type=WeightStandardizedConv2d,
        expansion_layer=WeightStandardizedConv2d,
        depthwise_layer=WeightStandardizedConv2d,
        pointwise_layer=WeightStandardizedConv2d,
        shortcut_layer=WeightStandardizedConv2d,
        conv2_layer=WeightStandardizedConv2d,
        decompose=decompose
    )


class ScaledConv2d(nn.Conv2d):
    def __init__(self, in_channels: int, out_channels: int,
                 kernel_size, stride=1, padding=0, dilation=1, groups: int = 1,
                 bias: bool = True, padding_mode: str = 'zeros', device=None, dtype=None,
                 wl_weight=2):
        super(ScaledConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                                           padding, dilation, groups, bias, padding_mode, device, dtype)
        self.scale = 1. / get_scale(self.weight, wl_weight, factor=1.0)

    def forward(self, x):
        return F.conv2d(x, self.weight * self.scale,
                        self.bias, self.stride, self.padding,
                        self.dilation, self.groups)


class ScaledLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None,
                 wl_weight=2):
        super(ScaledLinear, self).__init__(in_features, out_features, bias, device, dtype)
        self.scale = 1. / get_scale(self.weight, wl_weight, factor=1.0)

    def forward(self, x):
        return F.linear(x, self.weight * self.scale, self.bias)


def simpleq_mobilenetv2(
        num_classes=10,
        decompose=None,
        wl_weight=2,
        wl_activate=8,
):
    scaled_conv = partial(ScaledConv2d, wl_weight=wl_weight)
    scaled_linear = partial(ScaledLinear, wl_weight=wl_weight)
    bn_layer = (lambda x: WAGEQuantizer(wl_activate, -1))
    model = MobileNetV2(
        num_classes=num_classes, decompose=decompose,
        first_layer_type=scaled_conv, first_bn_layer_type=bn_layer,
        expansion_layer=scaled_conv, expansion_bn_layer=bn_layer,
        depthwise_layer=scaled_conv, depthwise_bn_layer=bn_layer,
        pointwise_layer=scaled_conv, pointwise_bn_layer=bn_layer,
        shortcut_layer=scaled_conv, shortcut_bn_layer=bn_layer,
        conv2_layer=scaled_conv, conv2_bn_layer=bn_layer,
        last_layer_type=scaled_linear
    )
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            wage_init_(m.weight, wl_weight, factor=1.0)
            assert m.bias is None, name
    return model


class QuantizedActivations(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, c, bitwidth):
        m = 2 ** bitwidth - 1
        result = torch.clip(torch.round(x / c * m), 0, m)
        return result / m * c

    @staticmethod
    def backward(ctx, g):
        return g, None, None


class QuantizedBatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True, device=None, dtype=None, bitwidth=8):
        super(QuantizedBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine,
            track_running_stats, device, dtype)
        self.bitwidth = bitwidth

    def forward(self, x):
        c = torch.max(self.weight * 6 + self.bias)
        result = super(nn.BatchNorm2d, self).forward(x)
        return QuantizedActivations.apply(result, c, self.bitwidth)


def dbq_mobilenetv2(
        num_classes=10,
        num_branches=2, num_branches_first=2, num_branches_last=2,
        quantizer_type=DBQQuantizer,
        quantize_first_layer=False,
        quantize_last_layer=False,
        quantize_expansion_layer=True,
        quantize_depthwise_layer=False,
        quantize_pointwise_layer=True,
        quantize_shortcut_layer=False,
        quantize_conv2_layer=True,
        quantize_activations=False,
        decompose=None,
        gen_matrix_every_step=False
):
    if quantize_first_layer:
        first_layer_type = partial(
            TernaryConv2d, num_branches=num_branches_first, quantizer_type=quantizer_type,
            gen_matrix_every_step=gen_matrix_every_step
        )
        first_bn_layer_type = QuantizedBatchNorm2d if quantize_activations else nn.BatchNorm2d
    else:
        first_layer_type = nn.Conv2d
        first_bn_layer_type = nn.BatchNorm2d

    if quantize_expansion_layer:
        expansion_layer = partial(
            TernaryConv2d, num_branches=num_branches, quantizer_type=quantizer_type,
            gen_matrix_every_step=gen_matrix_every_step
        )
        expansion_bn_layer = QuantizedBatchNorm2d if quantize_activations else nn.BatchNorm2d
    else:
        expansion_layer = nn.Conv2d
        expansion_bn_layer = nn.BatchNorm2d

    if quantize_depthwise_layer:
        depthwise_layer = partial(
            TernaryConv2d, num_branches=num_branches, quantizer_type=quantizer_type,
            gen_matrix_every_step=gen_matrix_every_step
        )
        depthwise_bn_layer = QuantizedBatchNorm2d if quantize_activations else nn.BatchNorm2d
    else:
        depthwise_layer = nn.Conv2d
        depthwise_bn_layer = nn.BatchNorm2d

    if quantize_pointwise_layer:
        pointwise_layer = partial(
            TernaryConv2d, num_branches=num_branches, quantizer_type=quantizer_type,
            gen_matrix_every_step=gen_matrix_every_step
        )
        pointwise_bn_layer = QuantizedBatchNorm2d if quantize_activations else nn.BatchNorm2d
    else:
        pointwise_layer = nn.Conv2d
        pointwise_bn_layer = nn.BatchNorm2d

    if quantize_shortcut_layer:
        shortcut_layer = partial(
            TernaryConv2d, num_branches=num_branches, quantizer_type=quantizer_type,
            gen_matrix_every_step=gen_matrix_every_step
        )
        shortcut_bn_layer = QuantizedBatchNorm2d if quantize_activations else nn.BatchNorm2d
    else:
        shortcut_layer = nn.Conv2d
        shortcut_bn_layer = nn.BatchNorm2d

    if quantize_conv2_layer:
        conv2_layer = partial(
            TernaryConv2d, num_branches=num_branches, quantizer_type=quantizer_type,
            gen_matrix_every_step=gen_matrix_every_step
        )
        conv2_bn_layer = QuantizedBatchNorm2d if quantize_activations else nn.BatchNorm2d
    else:
        conv2_layer = nn.Conv2d
        conv2_bn_layer = nn.BatchNorm2d

    last_layer_type = partial(
        TernaryLinear, num_branches=num_branches_last, quantizer_type=quantizer_type,
        gen_matrix_every_step=gen_matrix_every_step
    ) if quantize_last_layer else nn.Linear

    return MobileNetV2(
        num_classes=num_classes,
        first_layer_type=first_layer_type, first_bn_layer_type=first_bn_layer_type,
        expansion_layer=expansion_layer, expansion_bn_layer=expansion_bn_layer,
        depthwise_layer=depthwise_layer, depthwise_bn_layer=depthwise_bn_layer,
        pointwise_layer=pointwise_layer, pointwise_bn_layer=pointwise_bn_layer,
        shortcut_layer=shortcut_layer, shortcut_bn_layer=shortcut_bn_layer,
        conv2_layer=conv2_layer, conv2_bn_layer=conv2_bn_layer,
        last_layer_type=last_layer_type,
        decompose=decompose
    )
