import math
from typing import Tuple

import torch
from torch import nn
from torch.utils.cpp_extension import load
from typing_extensions import Annotated, Doc

pack_cuda = load(
    name="cuda_packing",
    sources=["./3bit_pack.cu"],
    verbose=True,
)


class Int3Linear(nn.Module):

    def __init__(
        self,
        weights: torch.Tensor,
        scales: torch.Tensor,
        zeros: torch.Tensor,
        zero_point: bool,
        naive: bool = False,
    ) -> None:

        super().__init__()

        self.naive = naive
        if self.naive:
            pweights, self.n_cols = self.naive_pack(weights)
            self.register_buffer("weights", pweights)
            pzeros, self.n_cols_zeros = self.pack(zeros)
            self.register_buffer("zeros", pzeros)
        else:
            self.register_buffer("weights", self.pack(weights))
            self.register_buffer("zeros", self.pack(zeros))

        self.register_buffer("scales", scales)
        self.zero_point = zero_point

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # out_shape = x.shape[:-1] + (self.linear_layer.out_features,)
        input_dtype = x.dtype
        if input_dtype != torch.float16:
            x = x.half()

        # Unpack into integer weights
        if self.naive:
            iweights = self.unpack(self.weights, self.n_cols)
            izeros = self.unpack(self.zeros, self.n_cols_zeros)
        else:
            iweights = self.unpack(self.weights)
            izeros = self.unpack(self.zeros)

        # Dequantize and do forward pass
        dequant = self.pseudo_dequantize(iweights, self.scales, izeros)
        out = (x @ dequant.T).to(dtype=input_dtype)

        # out = out.reshape(out_shape)
        if len(out.shape) == 2:
            out = out.unsqueeze(0)

        return out

    def pseudo_dequantize(
        self, w: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor = None
    ) -> torch.Tensor:

        # get repeated count
        repeat_count = w.shape[-1] // scales.shape[-1]
        scales = scales.repeat_interleave(repeat_count, dim=1)

        # dequantize
        if self.zero_point:
            zeros = zeros.repeat_interleave(repeat_count, dim=1)
            w = (w - zeros) * scales
        else:
            w = w * scales

        return w

    def pack(self, x: torch.Tensor, cuda=True) -> torch.Tensor:
        if cuda:
            return pack_cuda.packbits_cuda(x.to(torch.uint8))

        # Must be divisible by 32 to be packed
        x = x.t()
        assert x.shape[1] % 32 == 0

        x = x.to(torch.int8)

        # num elements in each int 32 = 10
        pack_size = 32 // 3

        packed = torch.zeros(
            (x.shape[0], int(x.shape[1] / (32 / 3))), dtype=torch.int32, device=x.device
        )
        # Number of int3 cols that fit exactly into all int32 cols
        n_full = int(packed.shape[1] * pack_size)
        for idx in range(pack_size):
            packed |= x[:, idx:n_full:pack_size].to(torch.int32) << (3 * idx)

        # Col 1 (12)
        packed[:, ::3] |= (x[:, n_full::2].to(torch.int32) >> 1) << 30
        # Col 2 (3|4)
        packed[:, 1::3] |= (x[:, n_full::2].to(torch.int32) & 0x1) << 31
        packed[:, 1::3] |= (x[:, (n_full + 1) :: 2].to(torch.int32) >> 2) << 30
        # Col 3 (56)
        packed[:, 2::3] |= (x[:, (n_full + 1) :: 2].to(torch.int32) & 0x3) << 30

        return packed

    def unpack(self, x: torch.Tensor, cuda=True) -> torch.Tensor:
        if cuda:
            return pack_cuda.unpackbits_cuda(x).half()

        pack_size = 32 // 3

        unpacked = torch.zeros(
            (x.shape[0], int(x.shape[1] * (32 / 3))),
            dtype=torch.float16,
            device=x.device,
        )

        # Number of int3 cols that fit exactly into all int32 cols
        n_full = int(x.shape[1] * pack_size)

        # Unpack columns by index in packed int (grouped)
        for idx in range(pack_size):
            unpacked[:, idx:n_full:pack_size] = ((x >> (3 * idx)) & 0x7).half()

        # Col 1
        unpacked[:, n_full::2] = (
            ((x[:, ::3] & 0x40000000) >> 29)
            | ((x[:, ::3] < 0).to(torch.int32) << 2)
            | (x[:, 1::3] < 0).to(torch.int32)
        )
        # Col 2
        unpacked[:, (n_full + 1) :: 2] = (
            ((x[:, 1::3] >> 28) & 0x4)
            | ((x[:, 2::3] < 0).to(torch.int32) << 1)
            | ((x[:, 2::3] & 0x40000000) >> 30)
        )

        return unpacked.t()

    def naive_pack(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
        x = x.to(torch.int8)

        # num elements in each int 32 = 10
        pack_size = 32 // 3

        # Pack columns by index in int (grouped)
        packed = torch.zeros(
            (x.shape[0], math.ceil(x.shape[1] / pack_size)),
            dtype=torch.int32,
            device=x.device,
        )
        for idx in range(min(pack_size, x.shape[1])):
            mask = torch.arange(idx, x.shape[1], pack_size)
            packed[:, : len(mask)] |= x[:, mask].to(torch.int32) << (3 * idx)

        return packed, x.shape[1]

    def naive_unpack(
        self,
        x: torch.Tensor,
        n_cols: Annotated[int, Doc("Number of columns in original matrix")],
    ) -> torch.Tensor:
        pack_size = 32 // 3

        unpacked = torch.zeros(
            (x.shape[0], n_cols), dtype=torch.float16, device=x.device
        )

        # Unpack columns by index in packed int (grouped)
        for idx in range(min(pack_size, n_cols)):
            mask = torch.arange(idx, n_cols, pack_size)
            unpacked[:, mask] = ((x[:, : len(mask)] >> (3 * idx)) & 0x7).half()

        return unpacked
