import torch.nn as nn
from .fakequantn import FakeQuantN

class FakeQuantNWrap(nn.Module):
    """
    FakeQuantNWrap is just an utility module to wrap a full-precision output and map
    it to its quantized version.
    """
    def __init__(self, module, nbits, **kwargs):
        """
        :param torch.nn.Module module: A module from which we want to quantize its output.
        :param int nbits: Number of bits used for representing output values.
        """
        super().__init__()

        self.module = module
        self.quantn = FakeQuantN(nbits=nbits, **kwargs)

    def forward(self, *args, **kwargs):
        return self.quantn(self.module(*args, **kwargs))

class FakeQuant8Wrap(FakeQuantNWrap):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, nbits=8, **kwargs)
