from typing import Any

import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import numpy as np
import torch.distributed as dist

# (128,2) means conv planes=128, conv stride=2, by default conv stride=1
cfg = [(32, 1)] + [(64, 1), (128, 2), (128, 1), (256, 2), (256, 1),
                   (512, 2), (512, 1), (512, 1), (512, 1), (512, 1),
                   (512, 1), (1024, 2), (1024, 1)]


class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        ctx.save_for_backward(x)
        return torch.sign(x)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0)


@torch.jit.script
def local_reparam_backward(grad, es2, V_x, x, ks: int, stride: int, padding: int):
    es2 = es2[None, :, None, None]  # (1, in, 1, 1)

    grad_s = torch.einsum(
        'bwh,biwh->i',
        grad,
        -2 * V_x ** 2 * es2  # (B, in, H, W)
    )

    tmp = 2 * es2 * V_x  # (B, in, H, W)
    tmp = grad[:, None] * tmp  # (B, in, H, W)

    x_unf = F.unfold(x, ks, stride=stride, padding=padding)  # (B, in, H*W)
    x_unf = x_unf.reshape_as(tmp)  # (B, in, H, W)

    grad_V = torch.einsum(
        'biwh,bjwh->ij',
        tmp, x_unf
    )  # (in, in)

    return grad_s, grad_V


class LocalReparam(torch.autograd.Function):
    @staticmethod
    def forward(ctx, log_s, V, x,
                in_channels: int, ks: int, stride: int, padding: int):
        es2 = torch.exp(-2 * log_s)  # (in)

        V_ = V.reshape(-1, in_channels, ks, ks)  # (in, ch, ks, ks)
        V_x = F.conv2d(x, V_, stride=stride, padding=padding)  # (B, in, H, W)

        ctx.ks = ks
        ctx.stride = stride
        ctx.padding = padding

        ctx.save_for_backward(es2, V_x, x)

        return torch.einsum('i,biwh->bwh', es2, V_x ** 2)  # (B, H, W)

    @staticmethod
    def backward(ctx, grad):
        es2, V_x, x = ctx.saved_tensors

        grad_s, grad_V = local_reparam_backward(
            grad, es2, V_x, x, ctx.ks, ctx.stride, ctx.padding)

        return grad_s, grad_V, None, None, None, None, None


class QuantizedConv2d(nn.Conv2d):
    def __init__(
            self, in_channels: int, out_channels: int, kernel_size,
            stride=1, padding=0, dilation=1, groups: int = 1,
            bias: bool = True, padding_mode: str = 'zeros',
            device=None, dtype=None):
        super(QuantizedConv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride,
            padding, dilation, groups, bias, padding_mode, device, dtype)
        self.mode = 'original'
        self.inject = False

    def set_mode(self, mode, param=None):
        self.mode = mode
        if mode == 'original':
            self.weight = param
            self.weight.data.copy_(self.alpha.data * torch.sign(self.scores.data))
            self.weight.grad = None
            del self.scores, self.alpha
        elif 'quantize' in self.mode:
            alpha = self.weight.data.abs().mean([1, 2, 3], keepdim=True)
            self.scores = nn.Parameter(self.weight.data / alpha)
            self.alpha = nn.Parameter(alpha)  # (out, 1, 1, 1)
            weight = self.weight
            del self.weight
            return weight

    def set_noise_injection(self, inject: bool, original_w=None, log_s=None, V=None, alpha=None, scores=None):
        if inject:
            self.inject = True
            self.original_w = original_w
            self.log_s = log_s
            self.V = V
            self.noise_scale = 0.01
            alpha, scores = self.alpha, self.scores
            del self.alpha, self.scores
            self.alpha = alpha.data
            return alpha, scores
        else:
            self.inject = False
            del self.original_w
            del self.log_s
            del self.V
            self.alpha = alpha
            self.scores = scores

    def forward(self, x):
        if not self.inject or not self.training:
            if self.mode == 'original':
                w = self.weight
            elif 'quantize' in self.mode:
                w = Sign.apply(self.scores) * self.alpha
            else:
                assert False

            wx = F.conv2d(
                x, w, self.bias, self.stride, self.padding,
                self.dilation, self.groups)
            return wx

        else:
            w = self.original_w

            if False:
                noise = torch.randn_like(w).reshape(w.shape[0], -1)
                noise = ((self.V * torch.exp(-self.log_s)[None, :]) @ noise.T).T.reshape_as(w)
                noise_scale = self.alpha * 0.753 * self.noise_scale
                w = w + noise * noise_scale

                return F.conv2d(
                    x, w, self.bias, self.stride, self.padding,
                    self.dilation, self.groups)

            else:
                wx = F.conv2d(
                    x, w, self.bias, self.stride, self.padding,
                    self.dilation, self.groups)

                t = LocalReparam.apply(
                    self.log_s, self.V, x,
                    self.in_channels, self.kernel_size[0], self.stride[0], self.padding[0]
                )
                t = t.unsqueeze(1)  # (B, 1, W, H)

                alpha = self.alpha.data.flatten()[None, :, None, None]
                std = torch.sqrt(t) * alpha * 0.753  # (B, out, W, H)
                # print(wx.norm(2).item(), std.norm(2).item())

                noise = torch.randn_like(wx)  # (B, out, W, H)
                return wx + noise * std * self.noise_scale


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return QuantizedConv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return QuantizedConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class firstconv3x3(nn.Module):
    def __init__(self, inp, oup, stride):
        super(firstconv3x3, self).__init__()

        self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(oup)

    def coroutine(self, x):
        x = (yield x, self.conv1)
        out = self.conv1(x)
        out = self.bn1(out)

        return out

    def forward(self, x):
        coro = self.coroutine(x)
        out = None
        while True:
            try:
                out, _ = coro.send(out)
            except StopIteration as ex:
                return ex.value


class BinAct(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return torch.sign(x)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0), None


class BinaryActivation(nn.Module):
    def __init__(self):
        super(BinaryActivation, self).__init__()

    def forward(self, x):
        return BinAct.apply(x)


class LearnableBias(nn.Module):
    def __init__(self, out_chn):
        super(LearnableBias, self).__init__()
        self.bias = nn.Parameter(torch.zeros(1, out_chn, 1, 1), requires_grad=True)

    def forward(self, x):
        out = x + self.bias.expand_as(x)
        return out


class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(BasicBlock, self).__init__()
        norm_layer = nn.BatchNorm2d

        self.move11 = LearnableBias(inplanes)
        self.binary_3x3 = conv3x3(inplanes, inplanes, stride=stride)
        self.bn1 = norm_layer(inplanes)

        self.move12 = LearnableBias(inplanes)
        self.prelu1 = nn.PReLU(inplanes)
        self.move13 = LearnableBias(inplanes)

        self.move21 = LearnableBias(inplanes)

        if inplanes == planes:
            self.binary_pw = conv1x1(inplanes, planes)
            self.bn2 = norm_layer(planes)
        else:
            self.binary_pw_down = conv1x1(inplanes, inplanes * 2)
            self.bn2_1 = norm_layer(inplanes)
            self.bn2_2 = norm_layer(inplanes)

        self.move22 = LearnableBias(planes)
        self.prelu2 = nn.PReLU(planes)
        self.move23 = LearnableBias(planes)

        self.binary_activation = BinaryActivation()
        self.stride = stride
        self.inplanes = inplanes
        self.planes = planes

        if self.inplanes != self.planes:
            self.pooling = nn.AvgPool2d(2, 2)

    def coroutine(self, x):
        out1 = self.move11(x)

        #out1 = (yield out1, self.binary_3x3)
        out1 = self.binary_activation(out1)
        out1 = (yield out1, self.binary_3x3)
        out1 = self.binary_3x3(out1)
        out1 = self.bn1(out1)

        if self.stride == 2:
            x = self.pooling(x)

        out1 = x + out1

        out1 = self.move12(out1)
        out1 = self.prelu1(out1)
        out1 = self.move13(out1)

        out2 = self.move21(out1)

        if self.inplanes == self.planes:
            #out2 = (yield out2, self.binary_pw)
            out2 = self.binary_activation(out2)
            out2 = (yield out2, self.binary_pw)
            out2 = self.binary_pw(out2)
            out2 = self.bn2(out2)
            out2 += out1

        else:
            assert self.planes == self.inplanes * 2

            #out2 = (yield out2, self.binary_pw_down1)
            out2 = self.binary_activation(out2)
            out2 = (yield out2, self.binary_pw_down)
            out2 = self.binary_pw_down(out2)
            out2_1, out2_2 = torch.split(out2, [self.inplanes, self.inplanes], dim=1)
            out2_1 = self.bn2_1(out2_1)
            out2_2 = self.bn2_2(out2_2)
            out2_1 += out1
            out2_2 += out1
            out2 = torch.cat([out2_1, out2_2], dim=1)

        out2 = self.move22(out2)
        out2 = self.prelu2(out2)
        out2 = self.move23(out2)

        return out2

    def forward(self, x):
        coro = self.coroutine(x)
        out = None
        while True:
            try:
                out, _ = coro.send(out)
            except StopIteration as ex:
                return ex.value


class reactnet(nn.Module):
    def __init__(self, num_classes=100):
        super(reactnet, self).__init__()
        self.feature = nn.ModuleList()
        for i in range(len(cfg)):
            if i == 0:
                self.feature.append(firstconv3x3(3, cfg[i][0], 1))
            else:
                self.feature.append(BasicBlock(cfg[i - 1][0], cfg[i][0], cfg[i][1]))
        self.fc = nn.Linear(1024, num_classes)

    def coroutine(self, x):
        for block in self.feature:
            x = (yield from block.coroutine(x))

        x = F.avg_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = (yield x, self.fc)
        x = self.fc(x)

        return x

    def forward(self, x):
        coro = self.coroutine(x)
        out = None
        while True:
            try:
                out, _ = coro.send(out)
            except StopIteration as ex:
                return ex.value

    def dependent_layers(self, layer_name):
        if layer_name == 'fc':
            return ['feature.13.binary_pw', *self.dependent_layers('feature.13.binary_pw')]
        if layer_name.startswith('feature.'):
            idx = int(layer_name.split('.')[1])
            if layer_name.split('.')[2] == 'binary_pw':
                depends = [f'feature.{idx}.binary_3x3']
                depends.extend(self.dependent_layers(f'feature.{idx}.binary_3x3'))
                return depends
            if layer_name.split('.')[2] == 'binary_pw_down':
                depends = [f'feature.{idx}.binary_3x3']
                depends.extend(self.dependent_layers(f'feature.{idx}.binary_3x3'))
                return depends
            if layer_name.split('.')[2] == 'binary_3x3':
                if idx == 1:
                    return []
                    # return ['feature.0.conv1']
                prev_layer = self.feature[idx - 1]
                depends = []
                if prev_layer.inplanes == prev_layer.planes:
                    depends.append(f'feature.{idx-1}.binary_pw')
                    depends.extend(self.dependent_layers(f'feature.{idx-1}.binary_pw'))
                else:
                    depends.append(f'feature.{idx-1}.binary_pw_down')
                    depends.extend(self.dependent_layers(f'feature.{idx-1}.binary_pw_down'))
                return depends
        if layer_name == 'feature.0.conv1':
            return []
        assert False, layer_name

    def inputs_for(self, layer, x):
        coro = self.coroutine(x)
        out = None
        while True:
            out, cur_layer = coro.send(out)
            if cur_layer is layer:
                return out


def test():
    net = reactnet()
    x = torch.randn(1,3,32,32)
    y = net(x)
    print(y.size())


if __name__ == "__main__":
    test()
