import torch
from .functions.fakequantn import fakequantn
from .qlinear import QLinear


class FakeQuantNLinear(QLinear):
    """
        FakeQuantN Linear module.
    """
    def __init__(self, *args, nbits, quantize=False, **kwargs):
        """
        :param list args: args for :class:`antgine.modules.quantization.qconv2d.QConv2d`.
        :param int nbits: number of quantization bits.
        :param dict[str, any] kwargs: kwargs for :class:`antgine.modules.quantization.qconv2d.QConv2d`.
        """
        super().__init__(*args, quantize=quantize, qfunc=lambda x: fakequantn(x, torch.min(torch.tensor(0.).type_as(x).to(x.device), torch.min(x).to(x.device)),
                                                           torch.max(torch.tensor(0.).type_as(x).to(x.device), torch.max(x).to(x.device)), self._nbits), **kwargs)
        self._nbits = nbits

    @property
    def nbits(self):
        return self._nbits

class FakeQuant8Linear(FakeQuantNLinear):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, nbits=8, **kwargs)
