"""
4 bit half integer grid (-15/2, -13/2, ..., 15/2)
"""
import torch
from torch import nn
import quiptools_cuda

from token_sublora.nn.quip.lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda


def get_grid():
    hintr = torch.arange(-8, 8) + 1 / 2
    return hintr.unsqueeze(-1)


_HI4B1C_CACHED = get_grid()
_HI4B1C_NORM_CACHED = torch.diag(_HI4B1C_CACHED @ _HI4B1C_CACHED.T)


class HI4B1C_codebook(nn.Module):

    def __init__(self, inference=False):
        super(HI4B1C_codebook, self).__init__()
        self.opt_scale = 2.97
        self.codesz = 1
        self.idx_dtype = torch.int32
        self.packsz = 8
        self.pack_out = False
        self.version = 0
        
        self.register_buffer('grid', _HI4B1C_CACHED)
        if not inference:
            self.register_buffer('grid_norm', _HI4B1C_NORM_CACHED)
            '''
            self.cuda()
            samples = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(1), torch.eye(1)).rsample([200000]).cuda()
            print(samples.shape)
            def fn_s(s):
                err = (self.quantize(samples*s, False)/s - samples).float().norm()**2
                err = err.cpu() / torch.numel(samples)
                return err.cpu()        
            import scipy
            print(scipy.optimize.minimize_scalar(fn_s, bounds=(0.1, 100)))
            exit()
            '''

    def round(self, X, grid, grid_norm):
        assert X.shape[-1] == self.codesz
        Xqidx = (2 * X @ grid.T - grid_norm).argmax(-1)
        return grid[Xqidx], Xqidx

    def quantize(self, X, return_idx=True, **kwargs):
        vals, idx = self.round(X, self.grid, self.grid_norm)
        if not return_idx:
            return vals
        return vals, idx.to(self.idx_dtype)

    def maybe_pack_idxs(self, idxs):
        return \
            (idxs[:, 0::self.packsz] << 4*7) + \
            (idxs[:, 2::self.packsz] << 4*6) + \
            (idxs[:, 4::self.packsz] << 4*5) + \
            (idxs[:, 6::self.packsz] << 4*4) + \
            (idxs[:, 1::self.packsz] << 4*3) + \
            (idxs[:, 3::self.packsz] << 4*2) + \
            (idxs[:, 5::self.packsz] << 4*1) + \
            idxs[:, 7::self.packsz]

    def by_idxs(self, idxs, packed=False):
        if packed:
            idxs = idxs.repeat_interleave(self.packsz, dim=-1)
            idxs[:, 0::self.packsz] = (idxs[:, 0::self.packsz] >> 28) & 15
            idxs[:, 2::self.packsz] = (idxs[:, 2::self.packsz] >> 24) & 15
            idxs[:, 4::self.packsz] = (idxs[:, 4::self.packsz] >> 20) & 15
            idxs[:, 6::self.packsz] = (idxs[:, 6::self.packsz] >> 16) & 15
            idxs[:, 1::self.packsz] = (idxs[:, 1::self.packsz] >> 12) & 15
            idxs[:, 3::self.packsz] = (idxs[:, 3::self.packsz] >> 8) & 15
            idxs[:, 5::self.packsz] = (idxs[:, 5::self.packsz] >> 4) & 15
            idxs[:, 7::self.packsz] = idxs[:, 7::self.packsz] & 15

        return self.grid[idxs.int()]


class QuantizedHI4B1CLinear(nn.Module):

    def __init__(self, device):
        super().__init__()
        self.codebook = HI4B1C_codebook(inference=True).to(torch.float16).to(device)

    def maybe_unpack_idxs(self, idxs):
        return (idxs,)
        
    def forward(self,
                input,
                Qidxs_list,
                SU,
                SV,
                had_left,
                had_right,
                K_left,
                K_right,
                rank=-1,
                A=None,
                B=None,
                rescale_WH=False,
                scaleWH=None,
                packed=False,
                **kwargs):
        Qidxs = Qidxs_list[0]
        n, m = len(SU), len(SV)

        x = input.view(-1, n).to(torch.float32)
        if rescale_WH:
            x /= scaleWH
        x = x * SU
        x = matmul_hadUt_cuda(x, had_left, K_left)

        if rank > 0:
            Bx = x @ B.t().to(torch.float32)
            ABx = Bx @ A.t().to(torch.float32)

        num_scale = 1024
        x = x / num_scale
        x = x.to(torch.float16)

        if packed:
            W_decompressed = torch.zeros(m, n, dtype=torch.float16, device=x.device)
            quiptools_cuda.decompress_hi4b1c_packed(Qidxs, self.codebook.grid, W_decompressed)
        else:
            W_decompressed = self.codebook.by_idxs(Qidxs, packed=False).reshape(-1, n)

        z = x @ W_decompressed.t()

        x = z.to(torch.float32)
        x = x * num_scale

        if rank > 0:
            x = x + ABx.to(torch.float32)

        x = matmul_hadU_cuda(x, had_right, K_right)
        x = x * SV

        output = x.view(*input.shape[:-1], m)

        return output
