from typing import Tuple, Union

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from .wavelet import DwtCompress, inverse_wavelet_transform_init, create_filter, DwtCompressDense

DEVICE = 'cuda:0'


def weight_quantization(b):
    def uniform_quant(x, b):
        xdiv = x.mul((2 ** b - 1))
        xhard = xdiv.round().div(2 ** b - 1)
        return xhard

    class _pq(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, alpha):
            input.div_(alpha)  # weights are first divided by alpha
            input_c = input.clamp(min=-1, max=1)  # then clipped to [-1,1]
            sign = input_c.sign()
            input_abs = input_c.abs()
            input_q = uniform_quant(input_abs, b).mul(sign)
            ctx.save_for_backward(input, input_q)
            input_q = input_q.mul(alpha)  # rescale to the original range
            return input_q

        @staticmethod
        def backward(ctx, grad_output):
            grad_input = grad_output.clone()  # grad for weights will not be clipped
            input, input_q = ctx.saved_tensors
            i = (input.abs() > 1.).float()
            sign = input.sign()
            grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum()
            # grad_input = grad_input * (1 - i)
            return grad_input, grad_alpha

    return _pq().apply


class weight_quantize_fn(nn.Module):
    def __init__(self, w_bit):
        super(weight_quantize_fn, self).__init__()
        # assert (w_bit <= 5 and w_bit > 0) or w_bit == 32
        assert w_bit > 0 or w_bit == 32
        self.w_bit = w_bit - 1
        self.weight_q = weight_quantization(b=self.w_bit)
        self.register_parameter('wgt_alpha', Parameter(torch.tensor(3.0), requires_grad=True))

    def forward(self, weight):
        if self.w_bit == 32:
            weight_q = weight
        else:
            mean = weight.data.mean()
            std = weight.data.std()
            weight = weight.add(-mean).div(std)  # weights normalization
            weight_q = self.weight_q(weight, self.wgt_alpha)
        return weight_q


def act_quantization(b, signed=False):
    def uniform_quant(x, b=3):
        # floor(v + [0, 1)) | round(v + [-.5, .5))
        xdiv = x.mul(2 ** b - 1)
        xhard = xdiv.round().div(2 ** b - 1)
        return xhard

    class _uq(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, alpha):
            input = input.div(alpha)
            input_c = input.clamp(min=-1, max=1) if signed else input.clamp(max=1)
            input_q = uniform_quant(input_c, b)
            ctx.save_for_backward(input, input_q)
            input_q = input_q.mul(alpha)
            return input_q

        @staticmethod
        def backward(ctx, grad_output):
            grad_input = grad_output.clone()
            input, input_q = ctx.saved_tensors
            i = (input.abs() > 1.).float()
            sign = input.sign()
            grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum()
            grad_input = grad_input * (1 - i)
            return grad_input, grad_alpha

    return _uq().apply


class QuantConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size: Union[int, Tuple], stride: Union[int, Tuple] = 1,
                 padding: Union[int, Tuple] = 0, dilation: Union[int, Tuple] = 1, groups: int = 1, bias=False, bit=4,
                 act_bit=4, signed=False):
        super(QuantConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
                                          bias)
        self.layer_type = 'QuantConv2d'
        self.bit = bit
        self.act_bit = act_bit
        self.signed = signed
        self.weight_quant = weight_quantize_fn(w_bit=self.bit)
        self.act_alq = act_quantization(self.act_bit - 1 if self.signed else self.act_bit, self.signed)
        self.act_alpha = torch.nn.Parameter(torch.tensor(8.0), requires_grad=True)

    def update_act_bit(self, act_bit):
        self.act_bit = act_bit
        self.act_alq = act_quantization(self.act_bit - 1 if self.signed else self.act_bit, self.signed)

    def forward(self, x):
        weight_q = self.weight_quant(self.weight)
        x = self.act_alq(x, self.act_alpha)
        return F.conv2d(x, weight_q, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def show_params(self):
        wgt_alpha = round(self.weight_quant.wgt_alpha.data.item(), 3)
        act_alpha = round(self.act_alpha.data.item(), 3)
        print('clipping threshold weight alpha: {:2f}, activation alpha: {:2f}'.format(wgt_alpha, act_alpha))

    def extra_repr(self):
        return f"{super(QuantConv2d, self).extra_repr()}, Weight bits: {self.bit}, Activation bits: {self.act_bit}"


class DwtDenseQuantConv2d1x1(nn.Conv2d):
    def __init__(self, in_channels: int,
                 out_channels: int,
                 level: int,
                 compression: float,
                 weight_bit: int,
                 act_bit: int,
                 stride: Union[int, Tuple] = 1,
                 padding: Union[int, Tuple] = 0,
                 dilation: Union[int, Tuple] = 1,
                 groups: int = 1,
                 bias: bool = False):
        super(DwtDenseQuantConv2d1x1, self).__init__(in_channels, out_channels, 1, stride, padding, dilation, groups, bias)
        self.layer_type = 'QuantConv2d'
        self.weight_bit = weight_bit
        self.act_bit = act_bit
        self.weight_quant = weight_quantize_fn(w_bit=self.weight_bit)
        self.act_alq = act_quantization(self.act_bit - 1, signed=True)  # after wavelet there are negative values
        self.act_alpha = torch.nn.Parameter(torch.tensor(8.0), requires_grad=True)
        self.level = level
        self.compression = compression
        self.wt_quant = DwtCompressDense(in_size=in_channels, level=self.level, compress_rate=self.compression, wave='db1')
        self.iwt_weight = nn.Parameter(create_filter(wave='db1', in_size=out_channels), requires_grad=False)
        # iwt_weight = create_filter(wave='db1', in_size=out_channels).to(DEVICE)
        self.iwt = inverse_wavelet_transform_init(weight=self.iwt_weight, in_size=out_channels, level=self.level)

    def update_compression(self, compression: float, level: int):
        self.compression = compression
        self.level = level
        self.wt_quant = DwtCompressDense(in_size=self.in_channels, level=self.level, compress_rate=self.compression, wave='db1')
        self.iwt = inverse_wavelet_transform_init(weight=self.iwt_weight, in_size=self.out_channels, level=self.level)

    def forward(self, x):
        weight_q = self.weight_quant(self.weight)
        topk = self.wt_quant(x)
        topk = self.act_alq(topk, self.act_alpha)
        topk = F.conv2d(topk, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups)
        out = self.iwt(topk)
        return out

    def show_params(self):
        wgt_alpha = round(self.weight_quant.wgt_alpha.data.item(), 3)
        act_alpha = round(self.act_alpha.data.item(), 3)
        print('clipping threshold weight alpha: {:2f}, activation alpha: {:2f}'.format(wgt_alpha, act_alpha))

    def extra_repr(self):
        return f"{super(DwtDenseQuantConv2d1x1, self).extra_repr()}, Weight bits: {self.weight_bit}, Activation bits: {self.act_bit}, Wavelet Compression: {self.compression}, Wavelet Levels: {self.level}"


class DwtQuantConv2d1x1(nn.Conv1d):
    def __init__(self, in_channels: int,
                 out_channels: int,
                 level: int,
                 compression: float,
                 weight_bit: int,
                 act_bit: int,
                 stride: Union[int, Tuple] = 1,
                 padding: Union[int, Tuple] = 0,
                 dilation: Union[int, Tuple] = 1,
                 groups: int = 1,
                 bias: bool = False):
        super(DwtQuantConv2d1x1, self).__init__(in_channels, out_channels, 1, stride, padding, dilation, groups, bias)
        self.layer_type = 'QuantConv2d'
        self.weight_bit = weight_bit
        self.act_bit = act_bit
        self.weight_quant = weight_quantize_fn(w_bit=self.weight_bit)
        self.act_alq = act_quantization(self.act_bit - 1, signed=True)  # after wavelet there are negative values
        self.act_alpha = torch.nn.Parameter(torch.tensor(8.0), requires_grad=True)
        self.level = level
        self.compression = compression
        self.wt_quant = DwtCompress(in_size=in_channels, level=self.level, compress_rate=self.compression, wave='db1')
        self.iwt_weight = nn.Parameter(create_filter(wave='db1', in_size=out_channels), requires_grad=False)
        # iwt_weight = create_filter(wave='db1', in_size=out_channels).to(DEVICE)
        self.iwt = inverse_wavelet_transform_init(weight=self.iwt_weight, in_size=out_channels, level=self.level)

    def update_compression(self, compression: float, level: int):
        self.compression = compression
        self.level = level
        self.wt_quant = DwtCompress(in_size=self.in_channels, level=self.level, compress_rate=self.compression, wave='db1')
        self.iwt = inverse_wavelet_transform_init(weight=self.iwt_weight, in_size=self.out_channels, level=self.level)

    def forward(self, x):
        weight_q = self.weight_quant(self.weight)
        topk, indexes = self.wt_quant(x)
        topk = self.act_alq(topk, self.act_alpha)
        topk = F.conv1d(topk, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups)
        x = self.iwt_decompress(x.shape, indexes, topk)
        return x

    def iwt_decompress(self, shape, indexes, topk):
        b, c, h, w = shape
        indexes = indexes.repeat(1, self.out_channels, 1)
        x = torch.zeros(size=(b, self.out_channels, h * w), requires_grad=True, device=DEVICE)
        x = x.scatter(dim=2, index=indexes, src=topk)
        x = x.reshape((b, self.out_channels, h, w))
        x = self.iwt(x)
        return x

    def show_params(self):
        wgt_alpha = round(self.weight_quant.wgt_alpha.data.item(), 3)
        act_alpha = round(self.act_alpha.data.item(), 3)
        print('clipping threshold weight alpha: {:2f}, activation alpha: {:2f}'.format(wgt_alpha, act_alpha))

    def extra_repr(self):
        return f"{super(DwtQuantConv2d1x1, self).extra_repr()}, Weight bits: {self.weight_bit}, Activation bits: {self.act_bit}, Wavelet Compression: {self.compression}, Wavelet Levels: {self.level}"


# 8-bit quantization for the first and the last layer
class first_conv(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size: Union[int, Tuple], stride: Union[int, Tuple] = 1,
                 padding: Union[int, Tuple] = 0, dilation: Union[int, Tuple] = 1,
                 groups=1, bias=False):
        super(first_conv, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
                                         bias)
        self.layer_type = 'FConv2d'

    def forward(self, x):
        max = self.weight.data.max()
        weight_q = self.weight.div(max).mul(127).round().div(127).mul(max)
        weight_q = (weight_q - self.weight).detach() + self.weight
        return F.conv2d(x, weight_q, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)


class last_fc(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(last_fc, self).__init__(in_features, out_features, bias)
        self.layer_type = 'LFC'

    def forward(self, x):
        max = self.weight.data.max()
        weight_q = self.weight.div(max).mul(127).round().div(127).mul(max)
        weight_q = (weight_q - self.weight).detach() + self.weight
        return F.linear(x, weight_q, self.bias)


class DwtQuantConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, level=1,
                 bias=False):
        super(DwtQuantConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
                                             bias)
        self.layer_type = 'QuantConv2d'
        self.bit = 4
        self.weight_quant = weight_quantize_fn(w_bit=self.bit)
        self.act_alq = act_quantization(self.bit)
        self.act_alpha = torch.nn.Parameter(torch.tensor(8.0), requires_grad=True)
        self.wt_quant = DwtCompress(in_size=in_channels, level=level, wave='db1', compress_rate=0.125)
        iwt_weight = create_filter(wave='db1', in_size=out_channels).to(DEVICE)
        self.iwt = inverse_wavelet_transform_init(weight=iwt_weight, in_size=out_channels, level=level)

    def forward(self, x):
        weight_q = self.weight_quant(self.weight)
        x = self.wt_quant(x)
        x = self.act_alq(x, self.act_alpha)
        x = self.iwt(x)
        return F.conv2d(x, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups)

    def show_params(self):
        wgt_alpha = round(self.weight_quant.wgt_alpha.data.item(), 3)
        act_alpha = round(self.act_alpha.data.item(), 3)
        print('clipping threshold weight alpha: {:2f}, activation alpha: {:2f}'.format(wgt_alpha, act_alpha))
