import torch.nn as nn
import torch.nn.functional as F


class QLinear(nn.Linear):
    """
        Quantized Linear module.
    """
    def __init__(self, *args, qfunc, quantize=True, **kwargs):
        """
        :param list args: args for :class:`torch.nn.Linear`.
        :param function qfunc: quantization function being applied to self.weight.
        :param bool quantize: Apply quantization function or perform full-precision operations.
        :param dict[str, any] kwargs: kwargs for :class:`torch.nn.Linear`.
        """
        super().__init__(*args, **kwargs)

        self._qfunc = qfunc
        self._quantize = quantize

    @property
    def quantize(self):
        return self._quantize

    @quantize.setter
    def quantize(self, val):
        self._quantize = val

    def forward(self, inputs):
        if not self._quantize:
            return F.linear(inputs, self.weight, self.bias)
        else:
            return F.linear(inputs, self._qfunc(self.weight), self.bias)
