import numpy as np
import torch
import torch.nn as nn
import math
import os, sys
# sys.path.append(os.path.dirname(__file__))
import cmpq_cuda


def round_to_nearest_pole_sim(w, poles):
    """
    w: weight values (1d vector)
    poles: tuple of values

    Round the numbers in w to the nearest value in poles.
    """
    stack = []
    for c in poles:
        diff = (w - c).abs()
        stack.append(diff)
    diff = torch.stack(stack)
    idx = diff.argmin(axis=0)
    aug = 0
    for i, c in enumerate(poles):
        aug += (idx == i) * c
    return aug


# drop-in layer replacement class
class QuantLinearLUT(nn.Module):
    def __init__(
            self,
            bits,
            infeatures,
            outfeatures,
            bias,
            include_sparse=False,
            numvals=0,
            topX=0,
            balanced=False,
            num_nonzero_per_thread=10,
    ):
        super().__init__()
        if not torch.all((bits >= 2) & (bits <= 4)):
            raise NotImplementedError("Bits must be in the range [2, 4].")
        self.infeatures = infeatures
        self.outfeatures = outfeatures
        self.bits = bits

        self.register_buffer(
            "qweight2",
            torch.zeros(((self.bits == 2).sum().item(), outfeatures // 32 * 2), dtype=torch.int32),
        )
        self.register_buffer(
            "qweight3",
            torch.zeros(((self.bits == 3).sum().item(), outfeatures // 32 * 3), dtype=torch.int32),
        )
        self.register_buffer(
            "qweight4",
            torch.zeros(((self.bits == 4).sum().item(), outfeatures // 32 * 4), dtype=torch.int32),
        )

        self.register_buffer(
            "bitalloc",
            bits,
        )
        if bias:
            self.include_bias = True
            self.register_buffer("bias", torch.zeros((outfeatures)))
        else:
            self.include_bias = False
            self.bias = None
        self.register_buffer(
            "lookup_table",
            torch.zeros((infeatures, 2 ** int(torch.max(bits).item())), dtype=torch.float32),
        )

        self.include_sparse = include_sparse
        self.numvals = numvals
        self.topX = topX
        if numvals > 0:
            self.register_buffer(
                "rows", torch.zeros(outfeatures + 1, dtype=torch.int32)
            )
            self.register_buffer("cols", torch.zeros(numvals, dtype=torch.int32))
            self.register_buffer("vals", torch.zeros(numvals, dtype=torch.float32))

            #print("self.rows: ", self.rows)
        if topX > 0:
            self.register_buffer(
                "full_rows", torch.zeros((infeatures, topX), dtype=torch.float32)
            )
            self.register_buffer(
                "full_row_indices", torch.zeros(topX, dtype=torch.int32)
            )

        self.balanced = balanced

        if include_sparse and balanced and numvals > 0:
            print("use num_nonzero_per_thread")
            self.num_threads = int(
                (numvals + num_nonzero_per_thread - 1) / num_nonzero_per_thread
            )
            self.num_threads = 128 * math.ceil(
                self.num_threads / 128
            )  # round up to nearest factor of blocksize = 128
            self.register_buffer(
                "startrows", torch.zeros(self.num_threads, dtype=torch.int32)
            )
            print("self.num_threads : ", self.num_threads)

    def pack2(self, linear, lookup_table, include_sparse, num_nonzero_per_thread=-1):
        if self.include_bias:  # linear.bias is not None:
            self.bias = linear.bias.clone()

        # self.lookup_table = lookup_table.float()
        lut, _, outliers = lookup_table

        # handle dense matrix
        intweight = linear.weight.data.clone()

        if include_sparse:
            outliers = outliers.to_dense()

        # get zero mapping
        num_channels = len(lut)
        #num_channels = intweight.shape[1]
        for channel in range(num_channels):
            centroid, indices = lut[channel][0]  # last 0 is for group 0
            intweight[:, channel] = torch.from_numpy(indices)
            #print(self.lookup_table[channel].shape, torch.from_numpy(centroid).shape)
            self.lookup_table[channel][:len(centroid)] = torch.from_numpy(centroid)

            if include_sparse:
                zero_mapping = round_to_nearest_pole_sim(torch.zeros(1), centroid)
                nonzero_vals = torch.nonzero(outliers[:, channel])

                outliers_channel = outliers[:, channel]
                outliers_channel[nonzero_vals] -= zero_mapping
                outliers[:, channel] = outliers_channel

        if include_sparse:
            outliers = outliers.to_sparse(layout=torch.sparse_csr)

            # save sparse matrix (already in CSR)
            self.register_buffer("rows", outliers.crow_indices().to(torch.int32))
            self.register_buffer("cols", outliers.col_indices().to(torch.int32))
            self.register_buffer("vals", outliers.values().to(torch.float32))

            # self.balanced
            if self.balanced:
                self.numvals = self.vals.shape[0]
                print("self.numvals: ", self.numvals)
                print("self.rows: ", self.rows.shape[0])

                self.num_threads = int(
                    (self.numvals + num_nonzero_per_thread - 1)
                    / num_nonzero_per_thread
                )
                self.num_threads = 128 * math.ceil(
                    self.num_threads / 128
                )  # round up to nearest factor of blocksize = 128

                nnz_per_thread = int(
                    (self.numvals + self.num_threads - 1) / self.num_threads
                )
                start_rows = torch.zeros(self.num_threads, dtype=torch.int32)

                print("self.num_threads: ", self.num_threads)
                print("nnz_per_thread: ", nnz_per_thread)

                minidx = 0
                for i in range(0, self.num_threads):
                    tmpmin = minidx
                    for j in range(minidx, self.outfeatures):
                        if nnz_per_thread * i > self.numvals:
                            start_rows[i] = -1
                            break
                        elif self.rows[j] < nnz_per_thread * i:
                            start_rows[i] = j
                            tmpmin = j
                        else:
                            break
                    minidx = tmpmin

                self.register_buffer("startrows", start_rows)

        intweight = intweight.to(torch.int)
        intweight = intweight.t().contiguous()
        intweight = intweight.numpy().astype(np.uint32)
        #qweight_parts = {2: None, 3: None, 4: None}

        for bit_width in [2,3,4]:
            print(self.bits.shape, intweight.shape)
            selected_cols = intweight[self.bits==bit_width]

            qweight = np.zeros(
                (selected_cols.shape[0], selected_cols.shape[1] // 32 * bit_width), dtype=np.uint32)
            print(self.bits.shape,selected_cols.shape,qweight.shape)
            i = 0
            col = 0
            while col < qweight.shape[1]:
                if bit_width in [2, 4, 8]:
                    for j in range(i, i + (32 // bit_width)):
                        qweight[:,col] |= selected_cols[:,j] << (bit_width * (j - i))
                    i += 32 //bit_width
                    col += 1
                elif bit_width == 3:
                    for j in range(i, i + 10):
                        qweight[:,col] |= selected_cols[:,j] << (3 * (j - i))
                    i += 10
                    qweight[:, col] |= selected_cols[:,i] << 30
                    col += 1
                    qweight[:,col] |= (selected_cols[:,i] >> 2) & 1
                    i += 1
                    for j in range(i, i + 10):
                        qweight[:,col] |= selected_cols[:,j] << (3 * (j - i) + 1)
                    i += 10
                    qweight[:, col] |= selected_cols[:,i] << 31
                    col += 1
                    qweight[:,col] |= (selected_cols[:,i] >> 1) & 0x3
                    i += 1
                    for j in range(i, i + 10):
                        qweight[:,col] |= selected_cols[:,j] << (3 * (j - i) + 2)
                    i += 10
                    col += 1
                else:
                    raise NotImplementedError("Only 2,3,4,8 bits are supported.")
            qweight = qweight.astype(np.int32)

            if bit_width==2:
                self.qweight2 = torch.from_numpy(qweight)
            elif bit_width==3:
                self.qweight3 = torch.from_numpy(qweight)
            elif bit_width==4:
                self.qweight4 = torch.from_numpy(qweight)
            else:
                raise NotImplementedError("Only 2,3,4,8 bits are supported.")


    def forward(self, x):

        if not x.is_contiguous():
            x = x.contiguous()

        out_shape = x.shape[:-1] + (self.outfeatures,)
        x = x.reshape(-1, x.shape[-1])
        result = torch.zeros(
                (x.shape[0], self.outfeatures), device="cuda", dtype=torch.float32
            ).cuda()
        dtype = x.dtype
        x = x.float()

        if self.include_sparse:
            cmpq_cuda.vecquant_spmv(self.rows, self.cols, self.vals, x, result, self.outfeatures)


        for bit_width in [2,3,4]:
            x_selected = x[..., self.bitalloc == bit_width]
            x_selected = x_selected.reshape(-1, x_selected.shape[-1])
            x_selected = x_selected.float()
            lut=self.lookup_table[self.bitalloc == bit_width,:2 ** bit_width]
            x_selected = x_selected.contiguous()
            lut = lut.contiguous()
            if bit_width == 2:
                cmpq_cuda.vecquant2matmul_nuq_perchannel_batched(
                    x_selected,
                    self.qweight2,
                    result,
                    lut,
                )

            elif bit_width == 4:
                cmpq_cuda.vecquant4matmul_nuq_perchannel_batched(
                    x_selected,
                    self.qweight4,
                    result,
                    lut,
                )
            elif bit_width == 3:
                cmpq_cuda.vecquant3matmul_nuq_perchannel_batched(
                    x_selected,
                    self.qweight3,
                    result,
                    lut,
                )
        result = result.to(dtype)
        result = result.reshape(out_shape)
        result = result + self.bias if self.bias is not None else result

        return result


def make_quant_lut(
        module,
        names,
        bits,
        name="",
        include_sparse=False,
        numvals=None,
        topX=0,
        balanced=False,
        num_nonzero_per_thread=10,
):
    if isinstance(module, QuantLinearLUT):
        return
    for attr in dir(module):
        tmp = getattr(module, attr)
        name1 = name + "." + attr if name != "" else attr
        if name1 in names:
            if numvals is not None:
                #print("name1 ", name1)
                num = numvals[name1]
            else:
                num = 0
            _,bitseq,_ = names[name1]
            delattr(module, attr)
            setattr(
                module,
                attr,
                QuantLinearLUT(
                    bitseq,
                    tmp.in_features,
                    tmp.out_features,
                    tmp.bias is not None,
                    include_sparse=include_sparse,
                    numvals=num,
                    topX=topX,
                    balanced=balanced,
                    num_nonzero_per_thread=num_nonzero_per_thread,
                ),
            )
    for name1, child in module.named_children():
        make_quant_lut(
            child,
            names,
            bits,
            name + "." + name1 if name != "" else name1,
            include_sparse=include_sparse,
            numvals=numvals,
            topX=topX,
            balanced=balanced,
            num_nonzero_per_thread=num_nonzero_per_thread,
        )
