from typing import Optional
import torch
from torch.autograd.function import InplaceFunction
import torch.nn as nn
from torch import Tensor


def get_rowwise_qparams(max_val, min_val, num_bits, signed, eps, symmetric, dep=False):
    max_val  = max_val.to(torch.float)
    min_val = min_val.to(torch.float)
    
    # add zero in quantiation, always
    max_val = torch.maximum(max_val, torch.zeros(max_val.shape, device=max_val.device))
    min_val = torch.minimum(min_val, torch.zeros(min_val.shape, device=min_val.device))

    qmin = -(2.0 ** (num_bits - 1)) if signed else 0.0
    qmax = qmin + 2.0 ** num_bits - 1

    # make to matrix
    eps = torch.ones(min_val.shape, device=max_val.device) * eps
    _qmin = torch.ones(min_val.shape, device=max_val.device) * qmin
    _qmax = torch.ones(max_val.shape, device=max_val.device) * qmax

    if symmetric:
        scale = 2 * torch.maximum(torch.abs(min_val), max_val) / (_qmax - _qmin).to(torch.float)
        scale = torch.max(scale, eps)
        zero_point = 0.0 if signed else (2.0 ** (num_bits - 1))
    else:
        scale = (max_val - min_val) / (_qmax - _qmin).to(torch.float)
        scale = torch.max(scale, eps)
        zero_point = _qmin - torch.round(min_val / scale)
        zero_point = torch.maximum(_qmin, zero_point)
        zero_point = torch.minimum(_qmax, zero_point)
        
    
    
    return qmin, qmax, zero_point, scale


class Quantize(InplaceFunction):
    @classmethod
    def forward(
        cls, ctx, input, max_val, min_val, num_bits, signed, eps, symmetric, ste, dep
    ):
        output = input.clone()

        # compute qparams
        qmin, qmax, zero_point, scale = get_rowwise_qparams(
            max_val, min_val, num_bits, signed, eps, symmetric, dep
        )
        # save stuff for backprop (if STE not enabled)
        ctx.STE = ste
        if not ste:
            ctx.save_for_backward(input)
            ctx.qmin = qmin
            ctx.qmax = qmax
            ctx.scale = scale
            ctx.zp = zero_point

        inv_scale = 1.0 / scale

        output.mul_(inv_scale).add_(zero_point)
        output.round_().clamp_(qmin, qmax)  # quantize
        output.add_(-zero_point).mul_(scale)  # dequantize

        return output, (qmin, qmax, scale, zero_point)

    @staticmethod
    def backward(ctx, grad_output, grad_params):

        if ctx.STE:
            return grad_output, None, None, None, None, None, None, None, None

        # Applying gradient clipping as described here:
        # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
        (input,) = ctx.saved_tensors

        mask = input.clone()
        inv_scale = 1.0 / ctx.scale
        mask.mul_(inv_scale).add_(ctx.zp).round_()

        # gradient clipping
        grad_input = grad_output.clone()
        grad_input[mask.ge(ctx.qmax)] = 0
        grad_input[mask.le(ctx.qmin)] = 0


        return grad_input, None, None, None, None, None, None, None, None


quantize = Quantize.apply

class TableQuantizer(nn.Module):
    """only perform quantization. scale updates are done outside the module, inside table"""

    def __init__(
        self,
        num_bits: int,
        signed: bool,
        use_ste: bool = False,
        symmetric: bool = False,
    ):
        super(TableQuantizer, self).__init__()
        self.num_bits = num_bits
        self.signed = signed
        self.symmetric = symmetric
        self.eps = torch.finfo(torch.float32).eps
        self.ste = use_ste
        
    def forward(self, input, max_val, min_val, dependency=False):
        
        #print(self.training) #true - true - false - false
        # make max_val, min_val as matrix the same size as input
        assert input.shape[0] == max_val.shape[0]
        assert input.shape[0] == min_val.shape[0]

        min_val = min_val.repeat([input.shape[1], 1]).transpose(0, 1)
        max_val = max_val.repeat([input.shape[1], 1]).transpose(0, 1)
        return quantize(
            input.clone(),
            max_val,
            min_val,
            self.num_bits,
            self.signed,
            self.eps,
            self.symmetric,
            self.ste,
            dependency
        )
