from typing import Union, Tuple

import torch
from torch import nn as nn
from torch.nn import functional as F

from .util.quantization import weight_quantize_fn, act_quantization
from .util.wavelet import DwtCompress, create_filter, inverse_wavelet_transform_init


class DwtQuantConv2d1x1(nn.Conv1d):
    def __init__(self, in_channels: int,
                 out_channels: int,
                 level: int,
                 compression: float,
                 weight_bit: int,
                 act_bit: int,
                 stride: Union[int, Tuple] = 1,
                 padding: Union[int, Tuple] = 0,
                 dilation: Union[int, Tuple] = 1,
                 groups: int = 1,
                 bias: bool = False):
        super(DwtQuantConv2d1x1, self).__init__(in_channels, out_channels, 1, stride, padding, dilation, groups, bias)
        self.layer_type = 'QuantConv2d'
        self.weight_bit = weight_bit
        self.act_bit = act_bit
        self.weight_quant = weight_quantize_fn(w_bit=self.weight_bit)
        self.act_alq = act_quantization(self.act_bit - 1, signed=True)  # after wavelet there are negative values
        self.act_alpha = torch.nn.Parameter(torch.tensor(8.0), requires_grad=True)
        self.wt_quant = DwtCompress(in_size=in_channels, level=level, compress_rate=compression, wave='db1')
        self.iwt_weight = nn.Parameter(create_filter(wave='db1', in_size=out_channels), requires_grad=False)
        self.iwt = inverse_wavelet_transform_init(weight=self.iwt_weight, in_size=out_channels, level=level)
        self.get_pad = lambda n: (2**level) - (n % (2**level))

    def forward(self, x):
        _, _, h, w = x.shape
        pads = (0, self.get_pad(h), 0, self.get_pad(w))
        x = F.pad(x, pads)
        weight_q = self.weight_quant(self.weight)
        topk, ids = self.wt_quant(x)
        topk = self.act_alq(topk, self.act_alpha)
        topk = F.conv1d(topk, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups)
        x = self.iwt_decompress(x.shape, ids, topk, x.device, pads)
        return x

    def iwt_decompress(self, shape, ids, topk, device, pads):
        b, c, h, w = shape
        ids = ids.repeat(1, self.out_channels, 1)
        x = torch.zeros(size=(b, self.out_channels, h * w), requires_grad=True, device=device)
        x = x.scatter(dim=2, index=ids, src=topk)
        x = x.reshape((b, self.out_channels, h, w))
        x = self.iwt(x)
        x = x[:, :, :h-pads[1], :w-pads[3]]  # remove pads
        return x

    def show_params(self):
        wgt_alpha = round(self.weight_quant.wgt_alpha.data.item(), 3)
        act_alpha = round(self.act_alpha.data.item(), 3)
        print('clipping threshold weight alpha: {:2f}, activation alpha: {:2f}'.format(wgt_alpha, act_alpha))

    def change_wt_params(self, compression, level):
        self.wt_quant = DwtCompress(in_size=self.in_channels, level=level, compress_rate=compression, wave='db1')
        self.iwt = inverse_wavelet_transform_init(weight=self.iwt_weight, in_size=self.out_channels, level=level)
        self.get_pad = lambda n: (2 ** level) - (n % (2 ** level))

    def change_bit(self, bit, act_bit):
        self.weight_bit = bit
        self.act_bit = act_bit
        self.weight_quant = weight_quantize_fn(w_bit=self.weight_bit)
        self.act_alq = act_quantization(self.act_bit - 1, signed=True)  # after wavelet there are negative values