import torch
import torch.nn as nn
import math
import random

try:
    from any_precision_ext import matmul_kbit, dequant_kbit
except:
    matmul_kbit, dequant_kbit = None, None


class deq_gemm(torch.autograd.Function):

    @staticmethod
    def forward(ctx, qweight: torch.Tensor, bl: torch.Tensor, bh: torch.Tensor, x: torch.Tensor, b_l: int, b_h: int):
        ctx.qweight = qweight
        ctx.bl = bl
        ctx.bh = bh
        ctx.b_l = b_l
        ctx.b_h = b_h
        with torch.no_grad():
            w3 = dequant_kbit(qweight, bl, b_l)
            y3 = torch.matmul(x, w3.T)
            w4 = dequant_kbit(qweight, bh, b_h)
            y4 = torch.matmul(x, w4.T)
            del w3, w4

        return y3, y4

    @staticmethod
    def backward(ctx, dy3, dy4):
        qweight = ctx.qweight
        bl = ctx.bl
        bh = ctx.bh
        b_l = ctx.b_l
        b_h = ctx.b_h
        with torch.no_grad():
            w3 = dequant_kbit(qweight, bl, b_l)
            dx3 = dy3 @ w3
            w4 = dequant_kbit(qweight, bh, b_h)
            dx4 = dy4 @ w4
            dx = dx3+dx4
            del w3, w4
        
        return None, None, None, dx, None, None

class AnyPrecisionLinear_train_whole3456(nn.Module):
    def __init__(self, in_features, out_features, supported_bits, bias=True, precisions=None, device=None,
                 dtype=None, th_init=0.5, maxmem=6):
        super().__init__()
        if dequant_kbit is None or matmul_kbit is None:
            raise ModuleNotFoundError('Please install any precision CUDA kernel extension from modules/kernels.')
        if precisions is None:
            precisions = supported_bits
        if not isinstance(precisions, list):
            raise RuntimeError('supported_bits must be a list of integers.')
        # if dtype is not None and dtype != torch.float16:
        #     raise RuntimeError('Only float16 is supported for now.')

        self.dtype = dtype

        self.in_features = in_features
        self.out_features = out_features
        self.precisions = precisions
        self.precision = min(self.precisions)
        self.supported_bits = supported_bits

        self.b_l = self.precisions
        self.b_h = max(self.precisions)
        self.maxmem = maxmem

        self.register_buffer(
            'qweight',
            torch.empty((max(supported_bits), out_features, in_features // 32), dtype=torch.int32, device=device)
        )

        for bit in supported_bits:
            self.register_buffer(
                f'lut{bit}',
                torch.empty((out_features, 2 ** bit), dtype=self.dtype, device=device)
            )

        if bias:
            self.register_buffer(
                "bias",
                torch.empty((out_features,), dtype=self.dtype, device=device)
            )
        else:
            self.bias = None

        self.th_inv = math.log(th_init / (1 - th_init))
        self.device = device
        self.sigmoid = torch.nn.Sigmoid()

    def create_th(self):
        self.th = torch.nn.Parameter(torch.tensor(self.th_inv, device=self.qweight.device, requires_grad=True))

    def prune_precisions(self):
        self.qweight = self.qweight[:max(self.precisions)]
        for bit in self.supported_bits:
            if bit not in self.precisions:
                delattr(self, f'lut{bit}')

    def forward(self, x, **kwargs):
        func = deq_gemm.apply

        if self.maxmem == 6:
            p = self.sigmoid(self.th)*3 -1.5
            if p < -0.5:
                y3, y4 = func(self.qweight, self._buffers[f'lut{3}'], self._buffers[f'lut{4}'], x, 3, 4)
                th = -p-0.5
                y = y3 * th + y4 * (1-th)
            elif p < 0.5:
                y4, y5 = func(self.qweight, self._buffers[f'lut{4}'], self._buffers[f'lut{5}'], x, 4, 5)
                th = -p+0.5
                y = y4 * th + y5 * (1-th)
            else:
                y5, y6 = func(self.qweight, self._buffers[f'lut{5}'], self._buffers[f'lut{6}'], x, 5, 6)
                th = -p+1.5
                y = y5 * th + y6 * (1-th)
        elif self.maxmem == 5:
            p = self.sigmoid(self.th)*2 -1
            if p < 0:
                y3, y4 = func(self.qweight, self._buffers[f'lut{3}'], self._buffers[f'lut{4}'], x, 3, 4)
                th = -p
                y = y3 * th + y4 * (1-th)
            else:
                y4, y5 = func(self.qweight, self._buffers[f'lut{4}'], self._buffers[f'lut{5}'], x, 4, 5)
                th = 1-p
                y = y4 * th + y5 * (1-th)
        elif self.maxmem == 4:
            p = self.sigmoid(self.th)
            y3, y4 = func(self.qweight, self._buffers[f'lut{3}'], self._buffers[f'lut{4}'], x, 3, 4)
            th = 1-p
            y = y3 * th + y4 * (1-th)
        elif self.maxmem == 3:
            y, _ = func(self.qweight, self._buffers[f'lut{3}'], self._buffers[f'lut{4}'], x, 3, 4)
        else:
            raise RuntimeError(f"Unknown maxmem {self.maxmem}")

        
        if self.bias is not None:
            y += self.bias

        return y

    def set_precision(self, precision):
        self.precision = precision
    
    def set_precision_dual(self, b_l, b_h):
        self.b_l = b_l
        self.b_h = b_h

    def extra_repr(self) -> str:
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'