import torch

from ..layers.compression import QuantizSimple
from ..layers.init_parameters import init_conv, init_batchnorm, init_FC
from ..layers.layers import AsynchronousGenericLayer, AsynchronousFinal


class AVGFlattenFullyConnectedCE(AsynchronousFinal):
    def __init__(self, n_in, n_out, *args, **kwargs):
        super(AVGFlattenFullyConnectedCE, self).__init__(*args, **kwargs)
        weight_conv, bias_conv = init_FC(n_out, n_in)
        self._register_parameters('weight_conv', weight_conv)
        self._register_parameters('bias_conv', bias_conv)

    def loss(self, x, y):
        return torch.nn.functional.cross_entropy(x, y)

    def local_f(self, x, weight, bias, training):
        x = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
        x = x.flatten(start_dim=1)
        x = torch.nn.functional.linear(x, weight, bias)
        return x


class ConvBNReLUMax(AsynchronousGenericLayer):
    def __init__(self, n_in, n_out, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, eps_bn=1e-05,
                 momentum_bn=0.1, max_pool=True, *args, **kwargs):
        super(ConvBNReLUMax, self).__init__(*args, **kwargs)
        self.stride = stride
        self.padding = padding
        self.momentum_bn = momentum_bn
        self.eps_bn = eps_bn
        self.max_pool = max_pool

        weight_conv, bias_conv = init_conv(n_out, n_in, kernel_size)
        weight_bn, bias_bn = init_batchnorm(n_out)
        running_var, running_mean = init_batchnorm(n_out)

        self._register_parameters('weight_conv', weight_conv)
        self._register_parameters('bias_conv', bias_conv)
        self._register_parameters('weight_bn', weight_bn)
        self._register_parameters('bias_bn', bias_bn)
        self._register_buffers('running_mean', running_mean)
        self._register_buffers('running_var', running_var)

    def local_f(self, x, weight_conv, bias_conv, weight_bn, bias_bn, running_mean, running_var, training):
        y = torch.nn.functional.conv2d(x, weight_conv, bias=bias_conv, stride=self.stride, padding=self.padding)
        x = torch.nn.functional.batch_norm(y, running_mean, running_var, weight=weight_bn, bias=bias_bn,
                                           training=training, momentum=self.momentum_bn, eps=self.eps_bn)
        x = torch.nn.functional.relu(x, inplace=True)
        if self.max_pool:
            x = torch.nn.functional.max_pool2d(x, 2)
        return x


class BottomBottleneck(AsynchronousGenericLayer):
    def __init__(self, n_in, n_out, downsample=False, kernel_size=3, stride=1, padding=0, dilation=1, groups=1,
                 eps_bn=1e-05, momentum_bn=0.1, max_pool=True, *args, **kwargs):
        super(BottomBottleneck, self).__init__(*args, **kwargs)
        self.stride = stride
        self.padding = padding
        self.momentum_bn = momentum_bn
        self.eps_bn = eps_bn
        self.max_pool = max_pool
        self.downsample = downsample

        weight_conv, bias_conv = init_conv(n_out, n_in, kernel_size)
        weight_bn, bias_bn = init_batchnorm(n_out)
        running_var, running_mean = init_batchnorm(n_out)
        running_var_ds, running_mean_ds = init_batchnorm(n_out)
        weight_conv_ds, bias_conv_ds = init_conv(n_out, n_in, 1)
        weight_bn_ds, bias_bn_ds = init_batchnorm(n_out)

        self._register_parameters('weight_conv', weight_conv)
        self._register_parameters('bias_conv', bias_conv)
        self._register_parameters('weight_bn', weight_bn)
        self._register_parameters('bias_bn', bias_bn)
        self._register_buffers('running_mean', running_mean)
        self._register_buffers('running_var', running_var)
        self._register_parameters('weight_conv_ds', weight_conv_ds)
        self._register_parameters('bias_conv_ds', bias_conv_ds)
        self._register_parameters('weight_bn_ds', weight_bn_ds)
        self._register_parameters('bias_bn_ds', bias_bn_ds)
        self._register_buffers('running_mean_ds', running_mean_ds)
        self._register_buffers('running_var_ds', running_var_ds)

    def local_f(self, x, weight_conv, bias_conv, weight_bn, bias_bn, weight_conv_ds, bias_conv_ds, weight_bn_ds,
                bias_bn_ds, running_mean, running_var, running_mean_ds, running_var_ds, training):
        y = torch.nn.functional.conv2d(x, weight_conv, bias=bias_conv, stride=self.stride, padding=self.padding)
        y = torch.nn.functional.batch_norm(y, running_mean, running_var, weight=weight_bn, bias=bias_bn,
                                           training=training, momentum=self.momentum_bn, eps=self.eps_bn)
        y = torch.nn.functional.relu(y, inplace=True)
        if self.downsample:
            x_ds = torch.nn.functional.conv2d(x, weight_conv_ds, bias=bias_conv_ds, stride=2, padding=0)
            x_ds = torch.nn.functional.batch_norm(x_ds, running_mean_ds, running_var_ds, weight=weight_bn_ds,
                                                  bias=bias_bn_ds, training=training, momentum=self.momentum_bn,
                                                  eps=self.eps_bn)
        else:
            x_ds = x
        return y, x_ds


class TopBottleneck(AsynchronousGenericLayer):
    def __init__(self, n_in, n_out, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, eps_bn=1e-05,
                 momentum_bn=0.1, max_pool=True, *args, **kwargs):
        super(TopBottleneck, self).__init__(*args, **kwargs)
        self.stride = stride
        self.padding = padding
        self.momentum_bn = momentum_bn
        self.eps_bn = eps_bn
        self.max_pool = max_pool

        weight_conv, bias_conv = init_conv(n_out, n_in, kernel_size)
        weight_bn, bias_bn = init_batchnorm(n_out)
        running_var, running_mean = init_batchnorm(n_out)

        self._register_parameters('weight_conv', weight_conv)
        self._register_parameters('bias_conv', bias_conv)
        self._register_parameters('weight_bn', weight_bn)
        self._register_parameters('bias_bn', bias_bn)
        self._register_buffers('running_mean', running_mean)
        self._register_buffers('running_var', running_var)

    def local_f(self, x, x_ds, weight_conv, bias_conv, weight_bn, bias_bn, running_mean, running_var, training):
        y = torch.nn.functional.conv2d(x, weight_conv, bias=bias_conv, stride=self.stride, padding=self.padding)
        x = torch.nn.functional.batch_norm(y, running_mean, running_var, weight=weight_bn, bias=bias_bn,
                                           training=training, momentum=self.momentum_bn, eps=self.eps_bn)
        x += x_ds
        x = torch.nn.functional.relu(x, inplace=True)
        return x


def make_layers_resnet18(dataset, nclass=10, last_bn_zero_init=False, store_param=True, store_vjp=False,
                         quantizer=QuantizSimple, accumulation_steps=1, accumulation_averaging=False):
    layers = []
    in_channels = 3
    inplanes = 64
    channels = [64, 128, 256, 512]
    if dataset == 'imagenet':
        kernel_size, stride, padding, max_pool = 7, 2, 3, True
    else:
        kernel_size, stride, padding, max_pool = 3, 1, 1, False

    layers += [ConvBNReLUMax(in_channels, channels[0], kernel_size=kernel_size, padding=padding, stride=stride,
                             max_pool=max_pool, first_layer=True, store_param=store_param, store_vjp=store_vjp,
                             quantizer=quantizer, accumulation_steps=accumulation_steps,
                             accumulation_averaging=accumulation_averaging)]
    x = inplanes
    for c in channels:
        if x != c:
            layers += [BottomBottleneck(x, c, stride=2, downsample=True, padding=1, store_param=store_param,
                                        store_vjp=store_vjp, quantizer=quantizer,
                                        accumulation_steps=accumulation_steps,
                                        accumulation_averaging=accumulation_averaging)]
        else:
            layers += [
                BottomBottleneck(x, c, padding=1, store_param=store_param, store_vjp=store_vjp, quantizer=quantizer,
                                 accumulation_steps=accumulation_steps, accumulation_averaging=accumulation_averaging)]
        layer = TopBottleneck(c, c, padding=1, store_param=store_param, store_vjp=store_vjp, quantizer=quantizer,
                              accumulation_steps=accumulation_steps, accumulation_averaging=accumulation_averaging)
        if last_bn_zero_init:
            name = 'weight_bn'
            setattr(layer, name + '_forward', torch.zeros_like(getattr(layer, name + '_forward')))
            setattr(layer, name + '_backward', torch.zeros_like(getattr(layer, name + '_backward')))
        layers += [layer]

        layers += [BottomBottleneck(c, c, padding=1, store_param=store_param, store_vjp=store_vjp, quantizer=quantizer,
                                    accumulation_steps=accumulation_steps,
                                    accumulation_averaging=accumulation_averaging)]
        layer = TopBottleneck(c, c, padding=1, store_param=store_param, store_vjp=store_vjp, quantizer=quantizer,
                              accumulation_steps=accumulation_steps, accumulation_averaging=accumulation_averaging)
        if last_bn_zero_init:
            name = 'weight_bn'
            setattr(layer, name + '_forward', torch.zeros_like(getattr(layer, name + '_forward')))
            setattr(layer, name + '_backward', torch.zeros_like(getattr(layer, name + '_backward')))
        layers += [layer]
        x = c

    # Need avg pooling
    layers += [AVGFlattenFullyConnectedCE(512, nclass, store_param=store_param, store_vjp=store_vjp,
                                          quantizer=quantizer, accumulation_steps=accumulation_steps,
                                          accumulation_averaging=accumulation_averaging)]
    return layers


def make_layers_resnet34(dataset, nclass=10, store_param=True, store_vjp=False, quantizer=QuantizSimple,
                         accumulation_steps=1, accumulation_averaging=False):
    layers = []
    in_channels = 3
    inplanes = 64
    depth = 0
    s = 1
    channels = [64, 64, 64,
                128, 128, 128, 128,
                256, 256, 256, 256, 256, 256,
                512, 512, 512]
    if dataset == 'imagenet':
        kernel_size, stride, padding, max_pool = 7, 2, 3, True
    else:
        kernel_size, stride, padding, max_pool = 3, 1, 1, False
    layers += [ConvBNReLUMax(in_channels, channels[0], kernel_size=kernel_size, padding=padding, stride=stride,
                             max_pool=max_pool, first_layer=True, store_param=store_param, store_vjp=store_vjp,
                             quantizer=quantizer, accumulation_steps=accumulation_steps,
                             accumulation_averaging=accumulation_averaging)]

    x = inplanes
    for c in channels:
        if x != c:
            layers += [BottomBottleneck(x, c, stride=2, downsample=True, padding=1, store_param=store_param,
                                        store_vjp=store_vjp, quantizer=quantizer,
                                        accumulation_steps=accumulation_steps,
                                        accumulation_averaging=accumulation_averaging)]
        else:
            layers += [
                BottomBottleneck(x, c, padding=1, store_param=store_param, store_vjp=store_vjp, quantizer=quantizer,
                                 accumulation_steps=accumulation_steps, accumulation_averaging=accumulation_averaging)]
        layers += [TopBottleneck(c, c, padding=1, store_param=store_param, store_vjp=store_vjp, quantizer=quantizer,
                                 accumulation_steps=accumulation_steps, accumulation_averaging=accumulation_averaging)]
        x = c

    # Need avg pooling
    layers += [
        AVGFlattenFullyConnectedCE(512, nclass, store_param=store_param, store_vjp=store_vjp, quantizer=quantizer,
                                   accumulation_steps=accumulation_steps,
                                   accumulation_averaging=accumulation_averaging)]
    return layers
