import torch
import torch.nn as nn

try:
    from any_precision_ext import matmul_kbit, dequant_kbit
except:
    matmul_kbit, dequant_kbit = None, None


class AnyPrecisionLinear_3456(nn.Module):
    def __init__(self, in_features, out_features, supported_bits, bias=True, precisions=None, device=None,
                 dtype=None, jl_path=None, err_mode="full", err_lin_param=None,
                 targ_path=None, my_name=None, my_layer=None, prefill_as_decode=False, maxmem=6):
        super().__init__()
        if dequant_kbit is None or matmul_kbit is None:
            raise ModuleNotFoundError('Please install any precision CUDA kernel extension from modules/kernels.')
        if precisions is None:
            precisions = supported_bits
        if not isinstance(precisions, list):
            raise RuntimeError('supported_bits must be a list of integers.')

        self.in_features = in_features
        self.out_features = out_features
        self.precisions = precisions
        self.precision = max(self.precisions)
        self.supported_bits = supported_bits

        self.register_buffer(
            'qweight',
            torch.empty((max(supported_bits), out_features, in_features // 32), dtype=torch.int32, device=device)
        )

        for bit in supported_bits:
            self.register_buffer(
                f'lut{bit}',
                torch.empty((out_features, 2 ** bit), dtype=dtype, device=device)
            )

        if bias:
            self.register_buffer(
                "bias",
                torch.empty((out_features,), dtype=dtype, device=device)
            )
        else:
            self.bias = None

        self.err_mode = err_mode

        if jl_path is not None:
            if (err_mode == "full" or err_mode == "prev" or err_mode == "intra"):
                self.e = torch.load(jl_path, weights_only=False).to(dtype).to("cuda:0")

            elif err_mode == "lin" or err_mode == "prevlin" or err_mode == "intralin":
                self.lin_slope = err_lin_param[0]
                self.lin_inter = err_lin_param[1]

            elif err_mode is None or err_mode == "mqdecode" or err_mode == "oracle":
                pass

            elif err_mode == "random":
                pass

            else:
                raise RuntimeError(f"Unknown Error Mode : {err_mode}")

        self.jl_path = jl_path

        if targ_path:
            if err_mode == "mqdecode":
                self.mqd_prec = torch.load(targ_path, weights_only=False) # precision int
            elif err_mode == "mq":
                self.mq_prec = torch.load(targ_path, weights_only=False) # precision int
            else:
                self.bit_l, self.bit_h, self.targ = torch.load(targ_path, weights_only=False) # (b_l, b_h, targ)
                self.targ = self.targ.item()

        self.targ_path = targ_path
        self.comp_count = {2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0}
        
        self.mother_layer = None
        self.my_name = my_name
        self.my_layer = my_layer
        self.mother_ln = None

        self.prefill_as_decode = prefill_as_decode
        self.maxmem = maxmem

    def prune_precisions(self):
        self.qweight = self.qweight[:max(self.precisions)]
        for bit in self.supported_bits:
            if bit not in self.precisions:
                delattr(self, f'lut{bit}')

    def forward(self, x, **kwargs):
        w_bits = self.precision
        
        if self.jl_path is not None:
            if hasattr(self, 'e'):
                if self.e.device != x.device:
                    self.e = self.e.to(x.device)
                if self.e.dtype != x.dtype:
                    self.e = self.e.to(x.dtype)

        if x.numel() // x.shape[-1] > 1:
            if self.err_mode is None:
                w_bits = self.precision
            else:
                w_bits = self.maxmem
                # w_bits = max(self.precisions)

            if self.prefill_as_decode:
                if self.err_mode == "mq":
                    w_bits = self.mq_prec
                    self.comp_count[w_bits] += 1
                elif self.err_mode is not None:
                    w_bits = self.bit_l
                elif self.err_mode is None:
                    w_bits = self.precision
                    self.comp_count[w_bits] += 1

                weight = dequant_kbit(self.qweight, self._buffers[f'lut{w_bits}'], w_bits)
                y = torch.matmul(x, weight.T)

                if self.err_mode == "full":
                    if self.jl_path:
                        nowe = x @ self.e.T
                        if self.targ_path:
                            nowtarg = self.targ
                            mask = nowe.norm(dim=-1) > nowtarg

                            w_bits=self.bit_h
                            weight = dequant_kbit(self.qweight, self._buffers[f'lut{w_bits}'], w_bits)
                            y4 = torch.matmul(x, weight.T)
                            y[mask] = y4[mask]

                            self.comp_count[self.bit_h] += mask.count_nonzero().item()
                            self.comp_count[self.bit_l] += (mask.numel() - mask.count_nonzero()).item()
                
                elif self.err_mode == "random":
                    nowe = torch.rand((x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
                    if self.targ_path:
                        nowtarg = self.targ
                        mask = nowe > nowtarg

                        w_bits=self.bit_h
                        weight = dequant_kbit(self.qweight, self._buffers[f'lut{w_bits}'], w_bits)
                        y4 = torch.matmul(x, weight.T)
                        y[mask] = y4[mask]

                        self.comp_count[self.bit_h] += mask.count_nonzero().item()
                        self.comp_count[self.bit_l] += (mask.numel() - mask.count_nonzero()).item()

                elif self.err_mode == "oracle":
                    w_h = dequant_kbit(self.qweight, self._buffers[f'lut{self.bit_h}'], self.bit_h)
                    w_l = weight
                    e = w_h - w_l
                    nowe = x @ e.T
                    if self.targ_path:
                        nowtarg = self.targ
                        mask = nowe.norm(dim=-1) > nowtarg

                        w_bits=self.bit_h
                        y4 = torch.matmul(x, w_h.T)
                        y[mask] = y4[mask]

                        self.comp_count[self.bit_h] += mask.count_nonzero().item()
                        self.comp_count[self.bit_l] += (mask.numel() - mask.count_nonzero()).item()
                    del w_h, w_l, e

                elif self.err_mode == "lin":
                    nowe_norm = x.norm(dim=-1) * self.lin_slope + self.lin_inter
                    if self.targ_path:
                        nowtarg = self.targ
                        mask = nowe_norm > nowtarg

                        w_bits=self.bit_h
                        weight = dequant_kbit(self.qweight, self._buffers[f'lut{w_bits}'], w_bits)
                        y4 = torch.matmul(x, weight.T)
                        y[mask] = y4[mask]

                        self.comp_count[self.bit_h] += mask.count_nonzero().item()
                        self.comp_count[self.bit_l] += (mask.numel() - mask.count_nonzero()).item()

                elif self.err_mode == "prev":
                    nowe = self.mother_ln(self.mother_layer.inter_x) @ self.e.T
                    if self.targ_path:
                        mask = nowe.norm(dim=-1) > self.targ

                        w_bits=self.bit_h
                        weight = dequant_kbit(self.qweight, self._buffers[f'lut{w_bits}'], w_bits)
                        y4 = torch.matmul(x, weight.T)
                        y[mask] = y4[mask]

                        self.comp_count[self.bit_h] += mask.count_nonzero().item()
                        self.comp_count[self.bit_l] += (mask.numel() - mask.count_nonzero()).item()

                elif self.err_mode == "intra":
                    nowe = self.mother_ln(self.mother_layer.intra_x) @ self.e.T
                    if self.targ_path:
                        mask = nowe.norm(dim=-1) > self.targ

                        w_bits=self.bit_h
                        weight = dequant_kbit(self.qweight, self._buffers[f'lut{w_bits}'], w_bits)
                        y4 = torch.matmul(x, weight.T)
                        y[mask] = y4[mask]

                        self.comp_count[self.bit_h] += mask.count_nonzero().item()
                        self.comp_count[self.bit_l] += (mask.numel() - mask.count_nonzero()).item()
                
                elif self.err_mode == "prevlin":
                    nowe_norm = self.mother_ln(self.mother_layer.inter_x).norm(dim=-1) * self.lin_slope + self.lin_inter
                    if self.targ_path:
                        mask = nowe_norm > self.targ

                        w_bits=self.bit_h
                        weight = dequant_kbit(self.qweight, self._buffers[f'lut{w_bits}'], w_bits)
                        y4 = torch.matmul(x, weight.T)
                        y[mask] = y4[mask]

                        self.comp_count[self.bit_h] += mask.count_nonzero().item()
                        self.comp_count[self.bit_l] += (mask.numel() - mask.count_nonzero()).item()

                elif self.err_mode == "intralin":
                    nowe_norm = self.mother_ln(self.mother_layer.intra_x).norm(dim=-1) * self.lin_slope + self.lin_inter
                    if self.targ_path:
                        mask = nowe_norm > self.targ

                        w_bits=self.bit_h
                        weight = dequant_kbit(self.qweight, self._buffers[f'lut{w_bits}'], w_bits)
                        y4 = torch.matmul(x, weight.T)
                        y[mask] = y4[mask]

                        self.comp_count[self.bit_h] += mask.count_nonzero().item()
                        self.comp_count[self.bit_l] += (mask.numel() - mask.count_nonzero()).item()

                elif self.err_mode == "mq":
                    pass
                elif self.err_mode is not None:
                    raise RuntimeError(f"Unknown Error Mode during prefill: {self.err_mode}")
            else:
                weight = dequant_kbit(self.qweight, self._buffers[f'lut{w_bits}'], w_bits)
                y = torch.matmul(x, weight.T)

        else:
            # For generation target err
            if self.my_layer == 6 and self.my_name == "k_proj":
                pass

            if self.err_mode == "full":
                if self.jl_path:
                    nowe = x @ self.e.T
                    if self.targ_path:
                        nowtarg = self.targ
                        nowerr = nowe.norm()

                        if nowerr > nowtarg:
                            w_bits = self.bit_h
                        else:
                            w_bits = self.bit_l

            elif self.err_mode == "mqdecode":
                w_bits = self.mqd_prec

            elif self.err_mode == "oracle":
                w_bits = self.bit_l # compute err after y

            elif self.err_mode == "lin":
                nowe_norm = x.norm() * self.lin_slope + self.lin_inter
                if self.targ_path:
                    if nowe_norm > self.targ:
                        w_bits = self.bit_h
                    else:
                        w_bits = self.bit_l

            elif self.err_mode == "prev":
                nowe = self.mother_ln(self.mother_layer.inter_x) @ self.e.T
                if self.targ_path:
                    nowtarg = self.targ
                    nowerr = nowe.norm()

                    if nowerr > nowtarg:
                        w_bits = self.bit_h
                    else:
                        w_bits = self.bit_l

            elif self.err_mode == "intra":
                nowe = self.mother_ln(self.mother_layer.intra_x) @ self.e.T
                if self.targ_path:
                    nowtarg = self.targ
                    nowerr = nowe.norm()

                    if nowerr > nowtarg:
                        w_bits = self.bit_h
                    else:
                        w_bits = self.bit_l
            
            elif self.err_mode == "prevlin":
                nowe_norm = self.mother_ln(self.mother_layer.inter_x).norm() * self.lin_slope + self.lin_inter
                if self.targ_path:
                    nowtarg = self.targ
                    nowerr = nowe_norm

                    if nowerr > nowtarg:
                        w_bits = self.bit_h
                    else:
                        w_bits = self.bit_l

            elif self.err_mode == "intralin":
                nowe_norm = self.mother_ln(self.mother_layer.intra_x).norm() * self.lin_slope + self.lin_inter
                if self.targ_path:
                    nowtarg = self.targ
                    nowerr = nowe_norm

                    if nowerr > nowtarg:
                        w_bits = self.bit_h
                    else:
                        w_bits = self.bit_l
            
            elif self.err_mode == "decode":
                w_bits = self.precision

            elif self.err_mode is not None:
                raise RuntimeError(f"Unknown Error Mode during decoding: {self.err_mode}")

            torch.cuda.set_device(0)
            y = matmul_kbit(x, self.qweight, self._buffers[f'lut{w_bits}'], w_bits)

            if self.err_mode == "oracle":
                y_h = matmul_kbit(x, self.qweight, self._buffers[f'lut{self.bit_h}'], self.bit_h)
                nowe = (y_h - y).norm()
                if nowe > self.targ:
                    y = y_h
                    w_bits = self.bit_h

            self.comp_count[w_bits] += 1

        if self.bias is not None:
            y += self.bias

        return y

    def set_precision(self, precision):
        # if precision not in self.precisions:
        #     raise RuntimeError(f"{self.precisions}-bit precisions are supported but {precision}-bit was specified.")

        self.precision = precision

    def extra_repr(self) -> str:
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'
