import torch
import torch.nn as nn
import copy


def quant_rtn(w, 
              bits=2,
              pertensor=False,
              group_size=-1,
              fake_quant=True
              ):
    
    qmax = 2 ** (bits - 1) - 1
    qmin = - 2 ** (bits - 1)
    wshape = w.shape
    
    if pertensor:
        w_max, _ = torch.max(w)
        w_min, _ = torch.min(w)
        scales = (w_max - w_min) / (qmax - qmin)
        zeros = torch.clamp(torch.round(qmin - w_min / (scales + 1e-6)), max = qmax, min = qmin)
        w = torch.clamp(torch.round(w / (scales + 1e-6) + zeros), max = qmax,min = qmin)
    else:
        if group_size > 0:
            w = w.reshape(-1, group_size)
        w_min,_=torch.min(w,dim=-1,keepdim=True)
        w_max,_=torch.max(w,dim=-1,keepdim=True)
        scales = (w_max - w_min) / (qmax - qmin)
        zeros = torch.clamp(torch.round(qmin - w_min / (scales+1e-6)),max=qmax,min=qmin)
        w = torch.clamp(torch.round(w / (scales+1e-6) + zeros),max=qmax,min=qmin)

    if fake_quant:
        w = (w - zeros) * scales
        w = w.reshape(wshape)
        return w
    else:
        w = w.reshape(wshape)
        scales=scales.view(w.shape[0], -1)
        zeros=zeros.view(w.shape[0], -1)
        return w, scales, zeros


def sign(x):
    return torch.where(x >= 0, 1.0, -1.0).to(x.dtype)

def quant_residual(w, 
                   bits=2,
                   pertensor=False,
                   group_size=-1,
                   ):
    wshape = w.shape
    if pertensor:
        residual=copy.deepcopy(w)
        W=torch.zeros_like(w)
        for i in range(bits):
            q = torch.sign(residual) * (residual.abs().mean())
            residual -= q
            W += q
    else:
        if group_size > 0:
            w = w.reshape(-1, group_size)
        residual=copy.deepcopy(w)
        W=torch.zeros_like(w)
        for i in range(bits):
            q = sign(residual) * residual.abs().mean(dim=-1, keepdim=True)
            residual -= q
            W += q
    W = W.reshape(wshape)
    return W


class residual_quantize(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, bit, groupsize):
        return quant_residual(input, bits=bit, group_size=groupsize)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None, None


class rtn_quantize(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, bit, groupsize):
        return quant_rtn(input, bits=bit, group_size=groupsize)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None, None


class QLinear(nn.Linear):
    def __init__(self, mode='residual', bit=2, groupsize=-1, *args, **kwargs):
        super(QLinear, self).__init__(*args, **kwargs)
        self.mode = mode
        self.bit = bit
        self.groupsize = groupsize
        self.quant_mode = None
        
    def forward(self, x):
        if self.mode == 'rtn':
            w = rtn_quantize.apply(self.weight, self.bit, self.groupsize)
        elif self.mode == 'residual':
            w = residual_quantize.apply(self.weight, self.bit, self.groupsize)
        else:
            w = self.weight
        return nn.functional.linear(x, w, self.bias)


def replace_linear(model, mode, bit=2, groupsize=-1, info=False):
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            if info:
                print(f'replacing...{name}') 
            new_linear = QLinear(
                mode=mode,
                bit=bit,
                groupsize=groupsize,
                in_features=module.in_features,
                out_features=module.out_features,
                bias=(module.bias is not None)
            )

            new_linear.weight = module.weight

            if module.bias is not None:
                new_linear.bias = module.bias

            setattr(model, name, new_linear)
            torch.cuda.empty_cache()
        elif len(list(module.children())) > 0:
            replace_linear(module, mode, bit, groupsize, info)


def fake_quant(model, mode='residual', bit=2, groupsize=-1, info=False):
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            if info:
                print(f'fakequanting...{name}') 
            if mode=='residual':
                module.weight.data = quant_residual(module.weight.data, bits=bit,group_size=groupsize)
            if mode=='rtn':
                module.weight.data = quant_rtn(module.weight.data, bits=bit,group_size=groupsize)
            torch.cuda.empty_cache()
        elif len(list(module.children())) > 0:
            fake_quant(module, mode, bit, groupsize, info)
