import math
import torch
import torch.nn as nn
from functools import partial

class ScaledActivation(nn.Module):
    def __init__(self, module, scales):
        super().__init__()
        self.act = module
        self.scales = nn.Parameter(scales.data)

    def forward(self, x):
        return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
    
class WQLinear(nn.Module):
    def __init__(self, w_bit, a_bit, group_size, in_features, out_features, bias, dev, alpha=0.15):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.w_bit = w_bit
        self.a_bit = a_bit
        self.alpha = alpha
        self.group_size = group_size if group_size != -1 else in_features

        assert self.in_features % self.group_size == 0

        self.act_quant = quantize_activation_per_token_absmax
     
        self.act_quant = partial(quantize_activation_per_token_absmax, alpha=alpha)
        self.register_buffer(
            "qweight",
            torch.zeros(
                (
                    out_features,
                    in_features
                ),
                dtype=torch.float16,
                device=dev,
            ),
        )

        if bias:
            self.register_buffer(
                "bias", torch.zeros((out_features), dtype=torch.float16, device=dev)
            )
        else:
            self.bias = None
    @classmethod
    def from_linear(
        cls, linear, w_bit, a_bit, group_size, alpha=0.15
    ):
        awq_linear = cls(
            w_bit,
            a_bit,
            group_size,
            linear.in_features,
            linear.out_features,
            linear.bias is not None,
            linear.weight.device,
            alpha=alpha
        )
        if linear.bias is not None:
            awq_linear.bias = linear.bias.clone()
        awq_linear.qweight = quantize_weight_per_channel_absmax(linear.weight, n_bits=w_bit, group_size=group_size)

        return awq_linear

    @torch.no_grad()
    def forward(self, x):
        inputs = self.act_quant(x)
        out = torch.functional.F.linear(inputs, self.qweight, self.bias)
        return out

@torch.no_grad()
def quantize_activation_per_token_absmax(t, n_bits=8, alpha=None):
    n_bits = 8
    q_max = 2 ** (n_bits - 1) - 1
    c_m = t.abs().max(dim=-2)[0].pow(1-alpha).clamp_(min=1e-5)
    r_m = t.abs().max(dim=-1, keepdim=True)[0].pow(alpha).div_(q_max).clamp_(min=1e-5)
    t.div_(r_m).div_(c_m).clamp_(min=-q_max, max=q_max).round_().mul_(c_m).mul_(r_m)
    return t

@torch.no_grad()
def quantize_weight_per_channel_absmax(w, n_bits=4, group_size=None):
    w_shape = w.shape
    if group_size is not None:
        w = w.view(-1, group_size)
    n_bits=8
    q_max = 2 ** (n_bits - 1) - 1
    scales = w.abs().max(dim=-1, keepdim=True)[0]
    scales.clamp_(min=1e-5).div_(q_max)
    w.div_(scales).clamp_(min=-q_max, max=q_max).round_().mul_(scales)
    
    w = w.view(w_shape[0], w_shape[1])

    return w 
