import torch.nn as nn
import torch.nn.functional as F


class QConv2d(nn.Conv2d):
    """
        Quantized Conv2d module.
    """
    def __init__(self, *args, qfunc, quantize=True, **kwargs):
        """
        :param list args: args for :class:`torch.nn.Conv2d`.
        :param function qfunc: quantization function being applied to self.weight.
        :param bool quantize: Apply quantization function or perform full-precision operations.
        :param dict[str, any] kwargs: kwargs for :class:`torch.nn.Conv2d`.
        """
        super().__init__(*args, **kwargs)

        self._qfunc = qfunc
        self._quantize = quantize

    @property
    def quantize(self):
        return self._quantize

    @quantize.setter
    def quantize(self, val):
        self._quantize = val

    def forward(self, inputs):
        if not self._quantize:
            return F.conv2d(inputs, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

        else:
            return F.conv2d(inputs, self._qfunc(self.weight), self.bias, self.stride,
                            self.padding, self.dilation, self.groups)
