from typing import Optional
import torch
from torch.autograd.function import InplaceFunction
import torch.nn as nn

def get_rowwise_qparams(max_val, min_val, num_bits, signed, eps, symmetric):
    max_val  = max_val.to(torch.float)
    min_val = min_val.to(torch.float)

    # add zero in quantiation, always
    max_val = torch.maximum(max_val, torch.zeros(max_val.shape, device=max_val.device))
    min_val = torch.minimum(min_val, torch.zeros(min_val.shape, device=min_val.device))


    qmin = -(2.0 ** (num_bits - 1)) if signed else 0.0
    qmax = qmin + 2.0 ** num_bits - 1

    # make to matrix
    eps = torch.ones(min_val.shape, device=max_val.device) * eps
    _qmin = torch.ones(min_val.shape, device=max_val.device) * qmin
    _qmax = torch.ones(max_val.shape, device=max_val.device) * qmax

    if symmetric:
        scale = 2 * torch.max(abs(min_val), max_val) / (_qmax - _qmin)
        zero_point = 0.0 if signed else (2.0 ** (num_bits - 1))
    else:
        scale = (max_val - min_val) / (_qmax - _qmin).to(torch.float)
        scale = torch.max(scale, eps)
        zero_point = _qmin - torch.round(min_val / scale)
        zero_point = torch.maximum(_qmin, zero_point)
        zero_point = torch.minimum(_qmax, zero_point)
        

    return qmin, qmax, zero_point, scale

class Quantize(InplaceFunction):
    @classmethod
    def forward(
        cls, ctx, input, max_val, min_val, num_bits, signed, eps, symmetric, ste
    ):
        output = input.clone()

        # compute qparams
        qmin, qmax, zero_point, scale = get_rowwise_qparams(
            max_val, min_val, num_bits, signed, eps, symmetric,
        )

        # save stuff for backprop (if STE not enabled)
        ctx.STE = ste
        if not ste:
            ctx.save_for_backward(input)
            ctx.qmin = qmin
            ctx.qmax = qmax
            ctx.scale = scale
            ctx.zp = zero_point

        inv_scale = 1.0 / scale

        output.mul_(inv_scale).add_(zero_point)
        output.round_().clamp_(qmin, qmax)  # quantize
        output.add_(-zero_point).mul_(scale)  # dequantize

        return output

    @staticmethod
    def backward(ctx, grad_output):

        if ctx.STE:
            return grad_output, None, None, None, None, None, None, None

        # Applying gradient clipping as described here:
        # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
        (input,) = ctx.saved_tensors

        mask = input.clone()
        inv_scale = 1.0 / ctx.scale
        mask.mul_(inv_scale).add_(ctx.zp).round_()

        # gradient clipping
        grad_input = grad_output.clone()
        grad_input[mask.ge(ctx.qmax)] = 0
        grad_input[mask.le(ctx.qmin)] = 0

        return grad_input, None, None, None, None, None, None, None


quantize = Quantize.apply


class RowwiseQuantizer(nn.Module):
    """Allows for rowwise or columnwise integer uniform (symmetric or asymmetric/affine) quantization."""

    def __init__(
        self,
        num_bits: int,
        signed: bool,
        use_ste: bool = False,
        symmetric: bool = True,
        use_momentum: bool = False,
        momentum: float = 0.01,
        columnwise: bool = False,
        percentile: float = -1, 
    ):
        super(RowwiseQuantizer, self).__init__()
        self.register_buffer("min_val", torch.tensor([]))
        self.register_buffer("max_val", torch.tensor([]))
        self.columnwise = columnwise
        self.num_bits = num_bits
        self.signed = signed
        self.symmetric = symmetric
        self.eps = torch.finfo(torch.float32).eps
        self.percentile = percentile

        self.ste = use_ste
        self.momentum_min_max = use_momentum
        self.momentum = momentum

        self.update = True

        self.floating_point = False
        if self.num_bits == 32:
            self.floating_point = True


    def update_ranges(self, input):

        # updating min/max ranges
        min_val = self.min_val
        max_val = self.max_val

        current_min, _ = torch.min(input, dim=-1) # rowwise min values
        current_max, _ = torch.max(input, dim=-1) # rowwise max values

        if self.percentile > 0:
            ranges = torch.stack([current_min, current_max]).transpose(0, 1)
            q = torch.tensor([self.percentile, 1-self.percentile])
            current_min, current_max = torch.quantile(ranges, q.to(ranges.device), dim=1)

        if min_val.numel() == 0 or max_val.numel() == 0:
            min_val = current_min
            max_val = current_max
        else:
            if self.momentum_min_max:
                min_val = min_val + self.momentum * (current_min - min_val)
                max_val = max_val + self.momentum * (current_max - max_val)
            else:
                # Range update equivalent to PyTorch's MinMaxObserver
                # https://github.com/pytorch/pytorch/blob/9e5e5a7d9628f988a928969d09ff2bffe362c08c/torch/quantization/observer.py#L398
                min_val = torch.min(current_min, min_val)
                max_val = torch.max(current_max, max_val)

        self.min_val = min_val
        self.max_val = max_val

    def freeze_quantization_parameters(self):
        self.update = False
    
    def unfreeze_quantization_parameters(self):
        self.update = True

    def forward(self, input, *args, **kwargs):
        
        if self.floating_point:
            return input
        
        shape = input.shape if input.dim() > 2 else None
        if input.dim() > 2:
            input = input.view(shape[0], -1)
        if self.columnwise:
            input = torch.t(input)
        if self.training or self.update:
            self.update_ranges(input.detach())

        out = quantize(
            input,
            self.max_val.repeat([input.shape[1], 1]).transpose(0, 1),
            self.min_val.repeat([input.shape[1], 1]).transpose(0, 1),
            self.num_bits,
            self.signed,
            self.eps,
            self.symmetric,
            self.ste,
        )
        
        if self.columnwise:
            out = torch.t(out)
        if shape is not None:
            out = out.view(shape)

        return out
        
