import torch
from .functions.fakequantn import fakequantn
from .qconv2d import QConv2d


class FakeQuantNConv2d(QConv2d):
    """
        FakeQuantN Conv2d module.
    """
    def __init__(self, *args, nbits, quantize=False, **kwargs):
        """
        :param list args: args for :class:`antgine.modules.quantization.qconv2d.QConv2d`.
        :param int nbits: number of quantization bits.
        :param dict[str, any] kwargs: kwargs for :class:`antgine.modules.quantization.qconv2d.QConv2d`.
        """
        super().__init__(*args, quantize=quantize, qfunc=lambda x: fakequantn(x, torch.min(torch.tensor(0.).type_as(x).to(x.device),
                                                                        torch.min(x).to(x.device)),
                                                           torch.max(torch.tensor(0.).type_as(x).to(x.device),
                                                                     torch.max(x).to(x.device)), self._nbits), **kwargs)
        self._nbits = nbits

    @property
    def nbits(self):
        return self._nbits

class FakeQuant8Conv2d(FakeQuantNConv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, nbits=8, **kwargs)
