import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import numpy as np

# (128,2) means conv planes=128, conv stride=2, by default conv stride=1
cfg = [(32, 2)] + [(64, 1), (128, 2), (128, 1), (256, 2), (256, 1),
                   (512, 2), (512, 1), (512, 1), (512, 1), (512, 1),
                   (512, 1), (1024, 2), (1024, 1)]


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return QuantizedConv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return QuantizedConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class firstconv3x3(nn.Module):
    def __init__(self, inp, oup, stride):
        super(firstconv3x3, self).__init__()

        self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(oup)

    def coroutine(self, x):
        x = (yield x, self.conv1)
        out = self.conv1(x)
        out = self.bn1(out)

        return out

    def forward(self, x):
        coro = self.coroutine(x)
        out = None
        while True:
            try:
                out, _ = coro.send(out)
            except StopIteration as ex:
                return ex.value


class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        ctx.save_for_backward(x)
        return torch.sign(x)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0)


class QuantizedConv2d(nn.Conv2d):
    def __init__(
            self, in_channels: int, out_channels: int, kernel_size,
            stride=1, padding=0, dilation=1, groups: int = 1,
            bias: bool = True, padding_mode: str = 'zeros',
            device=None, dtype=None):
        super(QuantizedConv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride,
            padding, dilation, groups, bias, padding_mode, device, dtype)
        self.mode = 'original'

    def set_mode(self, mode, param=None):
        self.mode = mode
        if mode == 'original':
            self.weight = param
            self.weight.data.copy_(self.alpha.data * torch.sign(self.scores.data))
            self.weight.grad = None
            del self.scores, self.alpha
        else:
            alpha = self.weight.data.abs().mean([1, 2, 3], keepdim=True)
            self.scores = nn.Parameter(self.weight.data / alpha)
            self.alpha = nn.Parameter(alpha)
            weight = self.weight
            del self.weight
            return weight

    def forward(self, x):
        if self.mode == 'original':
            w = self.weight
        else:
            w = Sign.apply(self.scores) * self.alpha

        return F.conv2d(x, w,
                        self.bias, self.stride, self.padding,
                        self.dilation, self.groups)


class BinaryActivation(nn.Module):
    def __init__(self):
        super(BinaryActivation, self).__init__()

    def forward(self, x):
        return Sign.apply(x)


class LearnableBias(nn.Module):
    def __init__(self, out_chn):
        super(LearnableBias, self).__init__()
        self.bias = nn.Parameter(torch.zeros(1, out_chn, 1, 1), requires_grad=True)

    def forward(self, x):
        out = x + self.bias.expand_as(x)
        return out


class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(BasicBlock, self).__init__()
        norm_layer = nn.BatchNorm2d

        self.move11 = LearnableBias(inplanes)
        self.binary_3x3 = conv3x3(inplanes, inplanes, stride=stride)
        self.bn1 = norm_layer(inplanes)

        self.move12 = LearnableBias(inplanes)
        self.prelu1 = nn.PReLU(inplanes)
        self.move13 = LearnableBias(inplanes)

        self.move21 = LearnableBias(inplanes)

        if inplanes == planes:
            self.binary_pw = conv1x1(inplanes, planes)
            self.bn2 = norm_layer(planes)
        else:
            self.binary_pw_down1 = conv1x1(inplanes, inplanes)
            self.binary_pw_down2 = conv1x1(inplanes, inplanes)
            self.bn2_1 = norm_layer(inplanes)
            self.bn2_2 = norm_layer(inplanes)

        self.move22 = LearnableBias(planes)
        self.prelu2 = nn.PReLU(planes)
        self.move23 = LearnableBias(planes)

        self.binary_activation = BinaryActivation()
        self.stride = stride
        self.inplanes = inplanes
        self.planes = planes

        self.quantize_down = True

        if self.inplanes != self.planes:
            self.pooling = nn.AvgPool2d(2, 2)

    def coroutine(self, x):
        out1 = self.move11(x)

        #out1 = (yield out1, self.binary_3x3)
        out1 = self.binary_activation(out1)
        out1 = (yield out1, self.binary_3x3)
        out1 = self.binary_3x3(out1)
        out1 = self.bn1(out1)

        if self.stride == 2:
            x = self.pooling(x)

        out1 = x + out1

        out1 = self.move12(out1)
        out1 = self.prelu1(out1)
        out1 = self.move13(out1)

        out2 = self.move21(out1)

        if self.inplanes == self.planes:
            #out2 = (yield out2, self.binary_pw)
            out2 = self.binary_activation(out2)
            out2 = (yield out2, self.binary_pw)
            out2 = self.binary_pw(out2)
            out2 = self.bn2(out2)
            out2 += out1

        else:
            assert self.planes == self.inplanes * 2

            #out2 = (yield out2, self.binary_pw_down1)
            if self.quantize_down:
                out2 = self.binary_activation(out2)
            out2 = (yield out2, self.binary_pw_down1)
            out2_1 = self.binary_pw_down1(out2)
            out2_2 = self.binary_pw_down2(out2)
            out2_1 = self.bn2_1(out2_1)
            out2_2 = self.bn2_2(out2_2)
            out2_1 += out1
            out2_2 += out1
            out2 = torch.cat([out2_1, out2_2], dim=1)

        out2 = self.move22(out2)
        out2 = self.prelu2(out2)
        out2 = self.move23(out2)

        return out2

    def forward(self, x):
        coro = self.coroutine(x)
        out = None
        while True:
            try:
                out, _ = coro.send(out)
            except StopIteration as ex:
                return ex.value


class reactnet(nn.Module):
    def __init__(self, num_classes=1000):
        super(reactnet, self).__init__()
        self.feature = nn.ModuleList()
        for i in range(len(cfg)):
            if i == 0:
                self.feature.append(firstconv3x3(3, cfg[i][0], cfg[i][1]))
            else:
                self.feature.append(BasicBlock(cfg[i - 1][0], cfg[i][0], cfg[i][1]))
        self.pool1 = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(1024, num_classes)

    def coroutine(self, x):
        for block in self.feature:
            x = (yield from block.coroutine(x))

        x = self.pool1(x)
        x = x.view(x.size(0), -1)

        x = (yield x, self.fc)
        x = self.fc(x)

        return x

    def forward(self, x):
        coro = self.coroutine(x)
        out = None
        while True:
            try:
                out, _ = coro.send(out)
            except StopIteration as ex:
                return ex.value

    def dependent_layers(self, layer_name):
        if layer_name == 'fc':
            return ['feature.13.binary_pw', *self.dependent_layers('feature.13.binary_pw')]
        if layer_name.startswith('feature.'):
            idx = int(layer_name.split('.')[1])
            if layer_name.split('.')[2] == 'binary_pw':
                depends = [f'feature.{idx}.binary_3x3']
                depends.extend(self.dependent_layers(f'feature.{idx}.binary_3x3'))
                return depends
            if layer_name.split('.')[2] == 'binary_pw_down':
                depends = [f'feature.{idx}.binary_3x3']
                depends.extend(self.dependent_layers(f'feature.{idx}.binary_3x3'))
                return depends
            if layer_name.split('.')[2] == 'binary_3x3':
                if idx == 1:
                    return []
                    # return ['feature.0.conv1']
                prev_layer = self.feature[idx - 1]
                depends = []
                if prev_layer.inplanes == prev_layer.planes:
                    depends.append(f'feature.{idx-1}.binary_pw')
                    depends.extend(self.dependent_layers(f'feature.{idx-1}.binary_pw'))
                else:
                    depends.append(f'feature.{idx-1}.binary_pw_down')
                    depends.extend(self.dependent_layers(f'feature.{idx-1}.binary_pw_down'))
                return depends
        if layer_name == 'feature.0.conv1':
            return []
        assert False, layer_name

    def inputs_for(self, layer, x):
        coro = self.coroutine(x)
        out = None
        while True:
            out, cur_layer = coro.send(out)
            if cur_layer is layer:
                return out


def test():
    net = reactnet()
    x = torch.randn(1,3,32,32)
    y = net(x)
    print(y.size())


if __name__ == "__main__":
    test()
