import torch

from ..layers.compression import QuantizSimple
from ..layers.init_parameters import init_conv, init_batchnorm, init_FC
from ..layers.layers import AsynchronousGenericLayer, AsynchronousFinal


class AVGFlattenFullyConnectedCE(AsynchronousFinal):
    def __init__(self, n_in, n_out, *args, **kwargs):
        super(AVGFlattenFullyConnectedCE, self).__init__(*args, **kwargs)
        weight_conv, bias_conv = init_FC(n_out, n_in)
        self._register_parameters('weight_conv', weight_conv)
        self._register_parameters('bias_conv', bias_conv)

    def loss(self, x, y):
        return torch.nn.functional.cross_entropy(x, y)

    def local_f(self, x, tilde_x, weight, bias, training):
        z = torch.cat([x, tilde_x], dim=1)
        z = torch.nn.functional.adaptive_avg_pool2d(z, (1, 1))
        z = z.flatten(start_dim=1)
        z = torch.nn.functional.linear(z, weight, bias)
        return z


class ConvBNReLUMax(AsynchronousGenericLayer):
    def __init__(self, n_in, n_out, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, eps_bn=1e-05,
                 momentum_bn=0.1, max_pool=True, *args, **kwargs):
        super(ConvBNReLUMax, self).__init__(*args, **kwargs)
        self.stride = stride
        self.padding = padding
        self.momentum_bn = momentum_bn
        self.eps_bn = eps_bn
        self.max_pool = max_pool

        weight_conv, bias_conv = init_conv(n_out, n_in, kernel_size)
        weight_bn, bias_bn = init_batchnorm(n_out)
        running_var, running_mean = init_batchnorm(n_out)

        self._register_parameters('weight_conv', weight_conv)
        self._register_parameters('bias_conv', bias_conv)
        self._register_parameters('weight_bn', weight_bn)
        self._register_parameters('bias_bn', bias_bn)
        self._register_buffers('running_mean', running_mean)
        self._register_buffers('running_var', running_var)

    def local_f(self, x, weight_conv, bias_conv, weight_bn, bias_bn, running_mean, running_var, training):
        y = torch.nn.functional.conv2d(x, weight_conv, bias=bias_conv, stride=self.stride, padding=self.padding)
        x = torch.nn.functional.batch_norm(y, running_mean, running_var, weight=weight_bn, bias=bias_bn,
                                           training=training, momentum=self.momentum_bn, eps=self.eps_bn)
        x = torch.nn.functional.relu(x, inplace=True)
        if self.max_pool:
            x = torch.nn.functional.max_pool2d(x, 2)
        x, tilde_x = torch.split(x, x.size(1) // 2, dim=1)
        return x, tilde_x


class BasicBlock(AsynchronousGenericLayer):
    def __init__(self, n_in, n_out, downsample=False, kernel_size=3, stride=1, padding=0, dilation=1, groups=1,
                 eps_bn=1e-05, momentum_bn=0.1, max_pool=True, *args, **kwargs):
        super(BasicBlock, self).__init__(*args, **kwargs)
        n_in = n_in // 2
        n_out = n_out // 2
        self.stride = stride
        self.padding = padding
        self.momentum_bn = momentum_bn
        self.eps_bn = eps_bn
        self.max_pool = max_pool
        self.downsample = downsample

        weight_conv_1, bias_conv_1 = init_conv(n_out, n_in, kernel_size)
        weight_bn_1, bias_bn_1 = init_batchnorm(n_out)
        running_var_1, running_mean_1 = init_batchnorm(n_out)
        running_var_ds, running_mean_ds = init_batchnorm(n_out)
        weight_conv_ds, bias_conv_ds = init_conv(n_out, n_in, 1)
        weight_bn_ds, bias_bn_ds = init_batchnorm(n_out)

        self._register_parameters('weight_conv_1', weight_conv_1)
        self._register_parameters('bias_conv_1', bias_conv_1)
        self._register_parameters('weight_bn_1', weight_bn_1)
        self._register_parameters('bias_bn_1', bias_bn_1)
        self._register_buffers('running_mean', running_mean_1)
        self._register_buffers('running_var_1', running_var_1)
        self._register_parameters('weight_conv_ds', weight_conv_ds)
        self._register_parameters('bias_conv_ds', bias_conv_ds)
        self._register_parameters('weight_bn_ds', weight_bn_ds)
        self._register_parameters('bias_bn_ds', bias_bn_ds)
        self._register_buffers('running_mean_ds', running_mean_ds)
        self._register_buffers('running_var_ds', running_var_ds)

        weight_conv_2, bias_conv_2 = init_conv(n_out, n_out, kernel_size)
        weight_bn_2, bias_bn_2 = init_batchnorm(n_out)
        running_var_2, running_mean_2 = init_batchnorm(n_out)

        self._register_parameters('weight_conv_2', weight_conv_2)
        self._register_parameters('bias_conv_2', bias_conv_2)
        self._register_parameters('weight_bn_2', weight_bn_2)
        self._register_parameters('bias_bn_2', bias_bn_2)
        self._register_buffers('running_mean_2', running_mean_2)
        self._register_buffers('running_var_2', running_var_2)

    def local_f(self, x, tilde_x, weight_conv_1, bias_conv_1, weight_bn_1, bias_bn_1,
                weight_conv_ds, bias_conv_ds, weight_bn_ds, bias_bn_ds,
                weight_conv_2, bias_conv_2, weight_bn_2, bias_bn_2,
                running_mean_1, running_var_1, running_mean_ds, running_var_ds, running_mean_2, running_var_2,
                training):
        y = torch.nn.functional.conv2d(x, weight_conv_1, bias=bias_conv_1, stride=self.stride, padding=self.padding)
        y = torch.nn.functional.batch_norm(y, running_mean_1, running_var_1, weight=weight_bn_1, bias=bias_bn_1,
                                           training=training, momentum=self.momentum_bn, eps=self.eps_bn)
        y_ = torch.nn.functional.relu(y, inplace=True)
        if self.downsample:
            tilde_x_ds = torch.nn.functional.conv2d(tilde_x, weight_conv_ds, bias=bias_conv_ds, stride=2, padding=0)
            tilde_x_ds = torch.nn.functional.batch_norm(tilde_x_ds, running_mean_ds, running_var_ds,
                                                        weight=weight_bn_ds, bias=bias_bn_ds, training=training,
                                                        momentum=self.momentum_bn, eps=self.eps_bn)
            x_ds = y
        else:
            tilde_x_ds = tilde_x
            x_ds = x

        y = torch.nn.functional.conv2d(y_, weight_conv_2, bias=bias_conv_2, stride=1, padding=self.padding)
        y = torch.nn.functional.batch_norm(y, running_mean_2, running_var_2, weight=weight_bn_2, bias=bias_bn_2,
                                           training=training, momentum=self.momentum_bn, eps=self.eps_bn)
        y = torch.nn.functional.relu(y, inplace=True)
        y = y + tilde_x_ds
        tilde_y = x_ds
        return y, tilde_y


def make_layers_revnet18(dataset, nclass=10, store_param=True, store_vjp=False, quantizer=QuantizSimple,
                         accumulation_steps=1, accumulation_averaging=False):
    layers = []
    in_channels = 3
    inplanes = 64
    channels = [64, 128, 256, 512]
    if dataset == 'imagenet':
        kernel_size, stride, padding, max_pool = 7, 2, 3, True
    else:
        kernel_size, stride, padding, max_pool = 3, 1, 1, False

    layers += [ConvBNReLUMax(in_channels, channels[0], kernel_size=kernel_size, padding=padding,
                             stride=stride, max_pool=max_pool, first_layer=True, store_param=store_param,
                             store_vjp=store_vjp, quantizer=quantizer, accumulation_steps=accumulation_steps,
                             accumulation_averaging=accumulation_averaging)]
    x = inplanes
    for c in channels:
        if x != c:
            layers += [
                BasicBlock(x, c, stride=2, downsample=True, padding=1, store_param=store_param, store_vjp=store_vjp,
                           quantizer=quantizer, accumulation_steps=accumulation_steps,
                           accumulation_averaging=accumulation_averaging)]
        else:
            layers += [BasicBlock(x, c, padding=1, store_param=store_param, store_vjp=store_vjp, quantizer=quantizer,
                                  accumulation_steps=accumulation_steps, accumulation_averaging=accumulation_averaging)]
        layers += [BasicBlock(c, c, padding=1, store_param=store_param, store_vjp=store_vjp, quantizer=quantizer,
                              accumulation_steps=accumulation_steps, accumulation_averaging=accumulation_averaging)]
        x = c

    # Need avg pooling
    layers += [AVGFlattenFullyConnectedCE(512, nclass, quantizer=quantizer, accumulation_steps=accumulation_steps,
                                          accumulation_averaging=accumulation_averaging)]
    return layers


def make_layers_revnet34(dataset, nclass=10, store_param=True, store_vjp=False, quantizer=QuantizSimple,
                         accumulation_steps=1, accumulation_averaging=False):
    layers = []
    in_channels = 3
    inplanes = 64
    channels = [64, 64, 64,
                128, 128, 128, 128,
                256, 256, 256, 256, 256, 256,
                512, 512, 512]
    if dataset == 'imagenet':
        kernel_size, stride, padding, max_pool = 7, 2, 3, True
    else:
        kernel_size, stride, padding, max_pool = 3, 1, 1, False
    layers += [ConvBNReLUMax(in_channels, channels[0], kernel_size=kernel_size, padding=padding, stride=stride,
                             max_pool=max_pool, first_layer=True, store_param=store_param, store_vjp=store_vjp,
                             quantizer=quantizer, accumulation_steps=accumulation_steps,
                             accumulation_averaging=accumulation_averaging)]

    x = inplanes
    for c in channels:
        if x != c:
            layers += [
                BasicBlock(x, c, stride=2, downsample=True, padding=1, store_param=store_param, store_vjp=store_vjp,
                           quantizer=quantizer, accumulation_steps=accumulation_steps,
                           accumulation_averaging=accumulation_averaging)]
        else:
            layers += [BasicBlock(x, c, padding=1, store_param=store_param, store_vjp=store_vjp, quantizer=quantizer,
                                  accumulation_steps=accumulation_steps, accumulation_averaging=accumulation_averaging)]
        x = c

    # Need avg pooling
    layers += [AVGFlattenFullyConnectedCE(512, nclass, quantizer=quantizer, accumulation_steps=accumulation_steps,
                                          accumulation_averaging=accumulation_averaging)]
    return layers
