import torch
from torch.autograd import Function


class FakeQuantNF(Function):
    """
        This :class:`torch.autograd.Function` implements n-bit quantization following TensorFlow Lite quantization scheme.
     The backward pass is the Straight-Through Estimator.

    """

    @staticmethod
    def forward(ctx, inputs, min_, max_, nbits):
        """
        :param torch.Tensor inputs: Inputs.
        :param torch.Tensor min_: Represent inputs' minimum value statistic.
        :param torch.Tensor max_: Represent inputs' maximum value statistic.
        :param int nbits: Number of bits used for representing values.
        """
        # As we want 0 to be exactly representable, we need to find an interval
        # where 0 in float is an exact integer. We therefore adapt the [min,max] range
        # in a way that 0 is exactly representable. However due to some numerical error, it might
        # be mapped back to a very small zero like 1e-8 but this error is acceptable.
        s = (max_ - min_) / (2**nbits - 1.0) # initial scaling factor
        zero_point = torch.round(-min_/s) # initial 'true' zero quantized
        nudged_min = -zero_point*s # min quantized value (0) shifted by zero quantized
        nudged_max = ((2**nbits - 1.0) - zero_point)*s # max quantized value (2**nbits - 1.0) shifted by zero quantized
        s = (nudged_max - nudged_min) / (2**nbits - 1.0) # new scaling factor (based on new min and max)
        clamped = torch.min(torch.max(inputs, nudged_min), nudged_max)
        ctx.save_for_backward(inputs, nudged_min, nudged_max)
        nudged_zero_point = torch.round(-nudged_min/s) # new 'true' zero quantized
        return (torch.round((clamped - nudged_min) / s) - nudged_zero_point) * s

    @staticmethod
    def backward(ctx, grad_output):
        inputs, nudged_min, nudged_max = ctx.saved_tensors
        grad_output[inputs < nudged_min] = 0
        grad_output[inputs > nudged_max] = 0
        return grad_output, None, None, None


fakequantn = FakeQuantNF.apply
