class GPTQ:
    def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None):
        self.lock = threading.Lock()
        # ...
        # at END
        self._bpdq_candidate_cache = {}

# NOTE NOTE NOTE BPDQ NOTE NOTE NOTE
    def _get_bpdq_candidates(self, k: int, device: torch.device, dtype: torch.dtype):
        """
        Return:
          candidate_vectors:   [2^k, k] uint8
          candidate_vectors_f: [2^k, k] dtype (float16/float32)
        Cached per (device, k, dtype).
        """
        if k <= 0:
            return None, None

        key = (device, int(k), dtype)
        cached = self._bpdq_candidate_cache.get(key, None)
        if cached is not None:
            return cached

        # Small tensors; safe to build once and reuse.
        values = torch.tensor([0, 1], device=device, dtype=torch.uint8)

        # -mi
        if k == 1:
            candidate_vectors = values.view(-1, 1)  # [2, 1]
        else:
            candidate_vectors = torch.cartesian_prod(*([values] * k))  # [2^k, k]
        # candidate_vectors = torch.cartesian_prod(*([values] * k))  # [2^k, k], uint8
        # -ix


        candidate_vectors_f = candidate_vectors.to(dtype=dtype)

        self._bpdq_candidate_cache[key] = (candidate_vectors, candidate_vectors_f)
        return candidate_vectors, candidate_vectors_f

    @torch.inference_mode()
    def _pack_lutgemm_int32(self, B_batch: torch.Tensor):
        Bits, M, GroupSize = B_batch.shape
        
        if GroupSize % 32 != 0:
            raise ValueError(f"LUT-GEMM requires GroupSize multiple of 32, got {GroupSize}")

        # B_int = B_batch.round().to(torch.int32).contiguous()
        B_int = B_batch.round().clamp(0, 1).to(torch.int32).contiguous()

        # 2. Reshape
        # [Bits, M, GroupSize] -> [Bits, M, K_Tiles, 32]
        B_reshaped = B_int.reshape(Bits, M, GroupSize // 32, 32)

        # 3. Permute -> [K_Tiles, Bits, M, 32]
        B_permuted = B_reshaped.permute(2, 0, 1, 3)

        shifts = torch.arange(32, device=B_batch.device, dtype=torch.int32).view(1, 1, 1, 32)
        
        packed_int32 = (B_permuted << shifts).sum(dim=-1, dtype=torch.int32)

        return packed_int32.contiguous()
    
    @torch.inference_mode()
    def _unpack_lutgemm_int32(self, packed_int32: torch.Tensor, original_k: int):
        """
        Input: [K/32, Bits, M] (int32)
        Output: [Bits, M, K] (float16)
        """
        K_tiles, Bits, M = packed_int32.shape
        
        shifts = torch.arange(32, device=packed_int32.device, dtype=torch.int32).view(1, 1, 1, 32)
        packed_expanded = packed_int32.unsqueeze(-1)
        unpacked_bits = (packed_expanded >> shifts) & 1 # [K_tiles, Bits, M, 32]
        
        unpacked_permuted = unpacked_bits.permute(1, 2, 0, 3) # [Bits, M, K_tiles, 32]
        
        B_restored = unpacked_permuted.contiguous().reshape(Bits, M, K_tiles * 32)
        
        if B_restored.shape[2] > original_k:
            B_restored = B_restored[:, :, :original_k]
            
        return B_restored.to(torch.float16)

    @torch.inference_mode()
    def _debug_verify_lutgemm(self, Q, c, B, group_size):
        """
        Q: [M, K]
        B: [K_tiles, Bits, M] int32
        c: [Groups, Bits+1, M] float16
        """
        log.info(f"--- [DEBUG] Verifying LUT-GEMM Packing for {self.name} ---")
        try:
            if Q is None or c is None or B is None:
                log.error("Q/c/B is None, cannot verify LUT-GEMM packing.")
                return

            if B.dim() != 3:
                log.error(f"B must be 3D [K_tiles, Bits, M], got shape={tuple(B.shape)}")
                return
            K_tiles_b, Bits_b, M_b = B.shape

            if Q.dim() != 2:
                log.error(f"Q must be 2D, got shape={tuple(Q.shape)}")
                return

            if Q.shape[0] == M_b:
                # Q already [M, K]
                pass
            elif Q.shape[1] == M_b:
                # Q is [K, M] -> transpose back to [M, K]
                log.warning("Detected transposed Q ([K, M]); transposing back to [M, K] for verification...")
                Q = Q.t().contiguous()
            else:
                log.error(f"Dimension mismatch: Q shape={tuple(Q.shape)} cannot match B.M={M_b}")
                return

            Rows, K = Q.shape  # Rows == M
            M = Rows

            expected_k_tiles = (K + 31) // 32
            if (K_tiles_b == expected_k_tiles) and (M_b == M):
                # ok
                K_tiles = K_tiles_b
                Bits = Bits_b
            elif (B.shape[0] == M) and (B.shape[2] == expected_k_tiles):
                log.warning("Detected transposed B ([M, Bits, K_tiles]); permuting back to [K_tiles, Bits, M]...")
                B = B.permute(2, 1, 0).contiguous()
                K_tiles, Bits, M_b2 = B.shape
                if M_b2 != M:
                    log.error(f"After permuting B, still mismatch: B.M={M_b2} vs Q.M={M}")
                    return
            else:
                log.error(
                    f"Dimension mismatch: B shape={tuple(B.shape)} not compatible with Q.M={M}, Q.K={K} "
                    f"(expected K_tiles={expected_k_tiles})."
                )
                return

            log.info(f" > Unpacking int32 weights (B dtype: {B.dtype})...")
            B_unpacked = self._unpack_lutgemm_int32(B, K).to(torch.float32)  # [Bits, M, K]

            log.info(f" > Preparing coefficients (c dtype: {c.dtype})...")
            c_fp32 = c.to(torch.float32)
            alpha = c_fp32[:, :-1, :]   # [Groups, Bits, M]
            bias  = c_fp32[:, -1, :]    # [Groups, M]

            needed_groups = (K + group_size - 1) // group_size
            if alpha.shape[0] < needed_groups:
                log.error(f"Not enough groups in c: c.groups={alpha.shape[0]} < needed_groups={needed_groups} (K={K}, group_size={group_size})")
                return

            col_group = torch.arange(K, device=Q.device, dtype=torch.long) // group_size  # [K]
            # alpha_k: [K, Bits, M] -> [Bits, M, K]
            alpha_k = alpha.index_select(0, col_group).permute(1, 2, 0).contiguous()
            # bias_k:  [K, M] -> [M, K]
            bias_k  = bias.index_select(0, col_group).t().contiguous()

            log.info(" > Reconstructing Q in float32...")
            weighted_sum = (B_unpacked * alpha_k).sum(dim=0)  # [M, K]
            Q_recon = weighted_sum + bias_k

            Q_fp32 = Q.to(torch.float32)
            diff = (Q_fp32 - Q_recon).abs()
            max_diff = diff.max()
            mean_diff = diff.mean()

            log.info(f"  > Q shape: {tuple(Q.shape)}")
            log.info(f"  > B shape: {tuple(B.shape)} (int32)")
            log.info(f"  > c shape: {tuple(c.shape)} (float16)")
            log.info(f"  > Max Error (fp32): {max_diff.item():.8f}")
            log.info(f"  > Mean Error (fp32): {mean_diff.item():.8f}")

            if mean_diff < 1e-3:
                log.info("  > [SUCCESS] LUT-GEMM packing verified.")
            elif mean_diff < 0.1:
                log.warning("  > [WARNING] Reconstruction acceptable but high error (likely quantization noise).")
            else:
                log.error("  > [FAILURE] Reconstruction mismatch significantly (Check permutations/shapes/group indexing).")

        except Exception as e:
            log.error(f"--- [DEBUG] Verification FAILED: {e} ---")
            import traceback
            traceback.print_exc()
# NOTE NOTE NOTE BPDQ NOTE NOTE NOTE



    @torch.inference_mode()
    def quantize(
            self,
            blocksize=128,
    ):
# NOTE BPDQ
        if self.qcfg.bpdq_flag:
            gs = int(self.qcfg.group_size)
            if blocksize < gs:
                blocksize = gs
# NOTE BPDQ


# ...


# NOTE BPDQ
        c_list = []
        B_list = []
        c = None
        B = None
        inv_final = None
# NOTE BPDQ
        if self.qcfg.static_groups:
            import copy


        # Use simplified loop when mock_quantization is active
        if self.qcfg.mock_quantization or (self.fail_safe and self.fwd_counter == 0):
# NOTE BPDQ
            if self.qcfg.bpdq_flag:
                raise RuntimeError("BPDQ does not support mock_quantization / fail_safe first pass.")
# NOTE BPDQ
        else:
            # Original heavy loop for normal quantization

# NOTE NOTE NOTE BPDQ NOTE NOTE NOTE

            dtype_store = torch.float32
            dtype_acc   = torch.float32  
            for i1 in range(0, self.columns, blocksize):
                i2 = min(i1 + blocksize, self.columns)
                count = i2 - i1


                # W1 = W[:, i1:i2].clone()
                W1 = W[:, i1:i2].to(dtype_store).clone() 
                Q1 = torch.zeros_like(W1)
                if self.qcfg.bpdq_flag:  # keep impact minimal
                    W1_acc = W[:, i1:i2].to(dtype_acc).clone()  

                Err1 = torch.zeros(W1.shape, device=W1.device, dtype=dtype_acc)     
                Losses1 = torch.zeros(W1.shape, device=W1.device, dtype=dtype_acc) 
                # Err1 = torch.zeros_like(W1)
                # Losses1 = torch.zeros_like(W1)

                if Hinv is not None:
                    Hinv1 = Hinv[i1:i2, i1:i2].to(dtype_acc) 
                    d_diag = Hinv1.diagonal()

                if self.qcfg.bpdq_flag:
                    if Hinv is None:
                        raise RuntimeError("BPDQ requires a non-None Hessian inverse (Hinv).")
                    assert self.qcfg.group_size in (32, 64, 128, 256), "BPDQ scheme requires group_size in {32,64,128, 256}."
                    assert self.columns % self.qcfg.group_size == 0, "columns must be multiple of group_size for BPDQ."
                    assert blocksize % self.qcfg.group_size == 0, "blocksize must be multiple of group_size in BPDQ scheme."



                    # candidate_vectors:   [ncand, k] uint8(0/1)
                    # candidate_vectors_f: [ncand, k] float32
                    candidate_vectors, candidate_vectors_f = self._get_bpdq_candidates(k=self.qcfg.msb_num, device=W1.device, dtype=dtype_acc,)
                    group_size = self.qcfg.group_size
                    for g0 in range(0, count, group_size):
                        g1 = g0 + group_size
                        # self.quantizer.find_params(W1[:, g0:g1], weight=True) 
                        self.quantizer.find_params(W[:, i1 + g0 : i1 + g1], weight=True)
                        curr_scale = self.quantizer.scale.to(dtype_acc)
                        curr_zero  = self.quantizer.zero.to(dtype_acc)

                        Wg_f = W1_acc[:, g0:g1].contiguous().clone()
                        W_tail_init = W1_acc[:, g0:count].contiguous().clone()
                        U_rows = torch.triu(Hinv1[g0:g1, g0:count])            # [G, tail_len]
                        d_g = d_diag[g0:g1]                                    # [G]


                        bpdq_results = self.quantizer.quantize_bpdq(Wg_f, curr_scale, curr_zero, W_tail_init, U_rows, d_g,
                                                                    bplane_bits=self.qcfg.bits, msb_num=self.qcfg.msb_num, 
                                                                    alpha=self.qcfg.alpha,
                                                                    n_iters=self.qcfg.n_iters, candidate_vectors=candidate_vectors, candidate_vectors_f=candidate_vectors_f,
                                                                    dtype_store=dtype_store, dtype_acc=dtype_acc,U_g=torch.triu(Hinv1[g0:g1, g0:g1]))
                        W1_acc[:, g0:count] = bpdq_results["W_tail"]
                        Q1[:, g0:g1]        = bpdq_results["Q_store"].to(Q1.dtype)
                        W1[:, g0:g1]        = bpdq_results["Q_store"].to(W1.dtype)
                        Err1[:, g0:g1]      = bpdq_results["S"]
                        Losses1[:, g0:g1]   = bpdq_results["S"].pow(2)
                        B_final = bpdq_results["B_final"]
                        c_bpdq  = bpdq_results["c_store"]

                        if (c_bpdq is not None) and (B_final is not None):
                            perm_group_idx = (i1 + g0) // group_size
                            orig_group_idx = int(global_perm[perm_group_idx].item())
                            lp = local_perms[orig_group_idx].to(dtype=torch.long)
                            inv_lp = torch.empty_like(lp)
                            inv_lp[lp] = torch.arange(group_size, device=lp.device, dtype=torch.long)

                            B_store = B_final.index_select(dim=2, index=inv_lp)
                            B_packed_chunk = self._pack_lutgemm_int32(B_store)
                            B_list.append(B_packed_chunk)
                            c_list.append(c_bpdq.permute(2, 1, 0))

                else:
                    for i in range(count):
                        w = W1[:, i]
                        if Hinv is not None:
                            d = Hinv1[i, i]

                        if self.qcfg.group_size != -1:
                            if not self.qcfg.static_groups:
                                if (i1 + i) % self.qcfg.group_size == 0:
                                    self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + self.qcfg.group_size)], weight=True)

                                if ((i1 + i) // self.qcfg.group_size) - now_idx == -1:
                                    scale.append(self.quantizer.scale)
                                    zero.append(self.quantizer.zero)
                                    now_idx += 1
                            else:
                                idx = i1 + i
                                if self.qcfg.desc_act:
                                    idx = perm[idx]

                                self.quantizer = groups[idx // self.qcfg.group_size]

                        q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
                        Q1[:, i] = q
                        if Hinv is not None:
                            Losses1[:, i] = (w - q) ** 2 / d**2
                            err1 = (w - q) / d  
                            W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
                            Err1[:, i] = err1

                Q[:, i1:i2] = Q1
                if Hinv is not None:
                    Losses[:, i1:i2] = (Losses1 / 2).to(Losses.dtype)     

                    corr_out = Err1.matmul(Hinv[i1:i2, i2:].to(dtype_acc))    
                    W[:, i2:] = (W[:, i2:].to(dtype_acc) - corr_out).to(W.dtype) 
                    # W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:].float())
# NOTE NOTE NOTE BPDQ NOTE NOTE NOTE



        if Hinv is not None:
            del Hinv
            if self.nsamples != 0:
                avg_loss = torch.sum(Losses).item() / self.nsamples

                if math.isnan(avg_loss):
                    print("Losses sum item:", torch.sum(Losses).item())
                    if self.fail_safe:
                        log.info(f"Quantization: Failed due to `NaN` loss for `{self.name}`, use mock quantization retry for `{self.name}`")
                        self.qcfg.mock_quantization = True
                        return self.quantize(blocksize=blocksize)
                    else:
                        raise ValueError(f"Quantization: Failed due to `NaN` loss for `{self.name}`, please try increasing calibration data samples or enable fail_safe=True")
            else:
                if self.fail_safe:
                    log.warn(f"Quantization: Module `{self.name}` -> using fail safe mode. Please check if calibration data is sufficient.")
                else:
                    log.warn(f"Quantization: `{self.name}` is not activated due to model inference logic (MoE)")
                avg_loss = 999999999
        else:
            avg_loss = 999999999

        del Losses
        del self.H

# NOTE NOTE NOTE BPDQ NOTE NOTE NOTE

        if self.qcfg.bpdq_flag and self.qcfg.act_group_aware:
            inv_global_perm_list = invert_perm(global_perm).tolist()
            c_list = [c_list[i] for i in inv_global_perm_list]
            B_list = [B_list[i] for i in inv_global_perm_list]
            B = torch.cat(B_list, dim=0) 
            c = torch.cat(c_list, dim=0).to(torch.float16)
            log.info(f"Final B dtype: {B.dtype}, c dtype: {c.dtype}") 


        group_size = self.qcfg.group_size if self.qcfg.group_size != -1 else self.columns
        if self.qcfg.static_groups and self.qcfg.desc_act:
            g_idx = [perm[i] // group_size for i in range(self.columns)]
        else:
            g_idx = [i // group_size for i in range(self.columns)]
        g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)

        if self.qcfg.desc_act:
            Q = Q[:, invperm]
            g_idx = g_idx[invperm]

        elif self.qcfg.bpdq_flag and self.qcfg.act_group_aware:
            inv_final = invert_perm(final_perm)
            Q = Q[:, inv_final]

        elif self.qcfg.act_group_aware:
            inv_final = invert_perm(final_perm)
            Q = Q[:, inv_final]
            inv_global_perm = invert_perm(global_perm)
            inv_global_perm_list = inv_global_perm.tolist()
            temp_scale = [scale[i] for i in inv_global_perm_list]
            scale = temp_scale
            temp_zero = [zero[i] for i in inv_global_perm_list]
            zero = temp_zero

        if self._tp_pad_cols:
            valid_cols = self._original_columns
            Q = Q[:, :valid_cols]
            g_idx = g_idx[:valid_cols]

        if isinstance(self.module, transformers.Conv1D):
            Q = Q.t()

        if Q.shape != self.module.weight.shape:
            Q = Q.reshape(self.module.weight.shape).to(self.module.weight.dtype)
        else:
            Q = Q.to(self.module.weight.dtype)

        if scale == []:
            scale.append(self.quantizer.scale)
            zero.append(self.quantizer.zero)
        scale = torch.cat(scale, dim=1)
        zero = torch.cat(zero, dim=1)

        if self._tp_pad_cols:
            valid_cols = self._original_columns
            scale = self.truncate_last_dim(scale, valid_cols)
            zero = self.truncate_last_dim(zero, valid_cols)

        Q = Q.to(device=self.module.weight.data.device, non_blocking=False)
        duration = time.time() - start

        # self._debug_verify_lutgemm(Q, c, B, self.qcfg.group_size)
        if self.qcfg.bpdq_flag:
            scale = torch.ones((Q.shape[0], 1), device=Q.device, dtype=torch.float16)
            zero  = torch.zeros((Q.shape[0], 1), device=Q.device, dtype=torch.float16) 

            g_idx = torch.zeros_like(g_idx, dtype=torch.int32, device=Q.device)            

            return Q, scale, zero, g_idx, duration, avg_loss, damp, self.nsamples, c, B     
        else:
            return Q, scale, zero, g_idx, duration, avg_loss, damp, self.nsamples

# NOTE NOTE NOTE BPDQ NOTE NOTE NOTE







