from typing import Tuple, Union

import torch.nn as nn
import torch
import torch.nn.functional as F

from .util.quantization import weight_quantize_fn, act_quantization


class QuantConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size: Union[int, Tuple], stride: Union[int, Tuple] = 1,
                 padding: Union[int, Tuple] = 0, dilation: Union[int, Tuple] = 1, groups: int = 1, bias=False, bit=4,
                 act_bit=4):
        super(QuantConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
                                          bias)
        self.layer_type = 'QuantConv2d'
        self.bit = bit
        self.act_bit = act_bit
        self.weight_quant = weight_quantize_fn(w_bit=self.bit)
        self.act_alq = act_quantization(self.act_bit)
        self.act_alpha = torch.nn.Parameter(torch.tensor(8.0), requires_grad=True)

    def forward(self, x):
        weight_q = self.weight_quant(self.weight)
        x = self.act_alq(x, self.act_alpha)
        return F.conv2d(x, weight_q, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def show_params(self):
        wgt_alpha = round(self.weight_quant.wgt_alpha.data.item(), 3)
        act_alpha = round(self.act_alpha.data.item(), 3)
        print('clipping threshold weight alpha: {:2f}, activation alpha: {:2f}'.format(wgt_alpha, act_alpha))

    def change_bit(self, bit, act_bit):
        self.bit = bit
        self.act_bit = act_bit
        self.weight_quant = weight_quantize_fn(w_bit=self.bit)
        self.act_alq = act_quantization(self.act_bit)
