# ls_ood_detect_cea/quantization/modules.py

import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Dict, Optional, Tuple, List, Union

try:
    from sklearn.cluster import KMeans
except ImportError:
    KMeans = None

# --- Helpers ---
def get_bit_config(quant_cfg, default=8):
    w_bits = quant_cfg.get('weight_bits', quant_cfg.get('number_of_bits', default))
    a_bits = quant_cfg.get('act_bits', quant_cfg.get('number_of_bits', default))
    return w_bits, a_bits

# --- LSQ Implementation (Updated from Reference) ---
class LSQ(torch.autograd.Function):
    """
    Implements the LSQ (Learned Step-size Quantization) logic as a custom
    autograd function.
    """

    @staticmethod
    def forward(
        ctx: Any,
        x: torch.Tensor,
        scale: torch.Tensor,
        q_min: int,
        q_max: int
    ) -> torch.Tensor:
        x_div_s = x / scale
        x_quant = torch.clamp(x_div_s, q_min, q_max).round()
        x_dequant = x_quant * scale
        
        ctx.save_for_backward(x, scale, x_quant)
        ctx.q_min, ctx.q_max = q_min, q_max
        return x_dequant

    @staticmethod
    def backward(
        ctx: Any, 
        grad_output: torch.Tensor
    ) -> Tuple[Optional[torch.Tensor], ...]:
        x, scale, x_quant = ctx.saved_tensors
        q_min, q_max = ctx.q_min, ctx.q_max

        # Gradient for input 'x' (Straight-Through Estimator)
        x_div_s = x / scale
        in_range_mask = (x_div_s >= q_min) & (x_div_s <= q_max)
        grad_x = torch.where(
            in_range_mask, grad_output, torch.zeros_like(grad_output)
        )

        # Gradient for the learnable 'scale'
        grad_scale_in_range = grad_output * (x_quant - x_div_s)
        grad_scale_out_of_range = grad_output * x_quant
        grad_scale_map = torch.where(
            in_range_mask, grad_scale_in_range, grad_scale_out_of_range
        )
        
        # Ensure correct reduction for broadcasting
        dims_to_sum = [
            i for i, (dx, ds) in enumerate(zip(x.shape, scale.shape)) if dx != ds
        ]
        if x.ndim > scale.ndim:
            dims_to_sum.extend(range(scale.ndim, x.ndim))
        
        grad_scale = grad_scale_map.sum(
            dim=tuple(set(dims_to_sum))
        ).reshape(scale.shape)
        
        # LSQ normalization term (Fixed as per reference)
        grad_scale /= math.sqrt(x.numel() * q_max)

        return grad_x, grad_scale, None, None

# --- Enhanced Fake Quantizer (Updated from Reference) ---
class EnhancedFakeQuantizer(nn.Module):
    """
    A versatile fake quantizer module for weights and activations.
    """
    def __init__(
        self,
        bits: int = 8,
        observer_momentum: float = 0.1,
        is_weight_quantizer: bool = False,
        per_channel: bool = False,
        num_channels: Optional[int] = None,
        learnable_scale: bool = False
    ):
        super().__init__()
        self.bits = bits
        self.is_weight_quantizer = is_weight_quantizer
        self.per_channel = per_channel and self.is_weight_quantizer
        # Reference logic: learnable scale primarily for activations
        self.learnable_scale = learnable_scale and not self.is_weight_quantizer

        if self.is_weight_quantizer:
            self.num_channels = num_channels if self.per_channel else 1
        else: # Activation quantizer
            self.num_channels = 1
            self.observer_momentum = observer_momentum
            self.register_buffer('min_val', torch.full((1,), float('inf')))
            self.register_buffer('max_val', torch.full((1,), float('-inf')))
            
            if self.learnable_scale:
                self.scale = nn.Parameter(torch.ones(1))
            else:
                self.register_buffer('scale', torch.ones(1))
            
            self.register_buffer('initialized', torch.tensor(False, dtype=torch.bool))
            self.calibration_mode = False

        # Compatibility for existing observers from previous implementation
        self.observer_enabled = True

    def disable_observer_update(self):
        self.observer_enabled = False

    def _get_qmin_qmax(self) -> Tuple[int, int]:
        q_min = -(2**(self.bits - 1))
        q_max = (2**(self.bits - 1)) - 1
        return q_min, q_max

    @torch.no_grad()
    def update_observer_stats(self, x: torch.Tensor) -> None:
        x_detached = x.detach().float()
        current_min = torch.min(x_detached)
        current_max = torch.max(x_detached)
        
        if not self.initialized.item():
            self.min_val.copy_(current_min)
            self.max_val.copy_(current_max)
            self.initialized.fill_(True)
        else:
            self.min_val.mul_(1 - self.observer_momentum).add_(
                current_min * self.observer_momentum
            )
            self.max_val.mul_(1 - self.observer_momentum).add_(
                current_max * self.observer_momentum
            )

    @torch.no_grad()
    def update_qparams(self) -> None:
        if self.learnable_scale:
            return
        max_abs = torch.max(torch.abs(self.min_val), torch.abs(self.max_val))
        _, q_max = self._get_qmin_qmax()
        self.scale.data.copy_((max_abs / q_max).clamp(min=1e-8))

    @torch.no_grad()
    def init_learnable_scale(self) -> None:
        if not self.learnable_scale or not self.initialized.item():
            return
        max_abs = torch.max(torch.abs(self.min_val), torch.abs(self.max_val))
        _, q_max = self._get_qmin_qmax()
        self.scale.data.copy_((max_abs / q_max).clamp(min=1e-8))

    # Added for compatibility with existing modules.py structure
    @torch.no_grad()
    def init_weight_scale(self, x):
        if self.is_weight_quantizer and not self.learnable_scale:
            # We perform immediate calculation for weights as they are often static in PTQ
            _, q_max = self._get_qmin_qmax()
            if self.per_channel:
                dims_to_reduce = tuple(range(1, x.ndim))
                max_abs = x.abs().amax(dim=dims_to_reduce, keepdim=True)
            else:
                max_abs = x.abs().max()
            # Note: In this reference impl, weights don't usually use a stored 'scale' buffer 
            # for PTQ, they calc on fly. But for consistency with external calls:
            pass

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.bits is None:
            return x

        q_min, q_max = self._get_qmin_qmax()
        
        if self.is_weight_quantizer:
            if self.per_channel:
                dims_to_reduce = tuple(range(1, x.ndim))
                max_abs = x.abs().amax(dim=dims_to_reduce, keepdim=True)
            else:
                max_abs = x.abs().max()
            
            current_scale = (max_abs / q_max).clamp(min=1e-8)
            x_dequant = torch.clamp(torch.round(x / current_scale), q_min, q_max) * current_scale
            # Use Straight-Through Estimator (STE) for gradients
            return x + (x_dequant - x).detach()

        # Logic for activation quantization
        should_update = self.calibration_mode or (self.training and not self.learnable_scale and self.observer_enabled)
        if should_update:
            self.update_observer_stats(x)
        
        if self.training and not self.learnable_scale and self.observer_enabled:
            self.update_qparams()
        
        if not self.initialized.item():
            return x
        
        if self.learnable_scale and self.training:
            # Use LSQ autograd function for training
            return LSQ.apply(x, self.scale.clamp(min=1e-8), q_min, q_max)
        
        # Standard PTQ/QAT inference path
        current_scale = self.scale.float()
        x_dequant = torch.clamp(torch.round(x / current_scale), q_min, q_max) * current_scale
        return x + (x_dequant - x).detach()

# --- Quantized Linear Layer (Updated) ---
class QuantizedLinearLayer(nn.Module):
    def __init__(self, original_linear_layer: nn.Linear, quant_cfg: Dict[str, Any]):
        super().__init__()
        self.original_linear_layer = original_linear_layer
        self.quant_cfg = quant_cfg
        
        # --- CONFIGURATION FLAGS ---
        method = quant_cfg.get('method', 'ptq')
        is_lsq = 'lsq' in method
        is_lora = 'lora' in method
        self.is_qwt = 'qwt' in method
        is_igq = 'igq' in method
        is_rotation = 'rotation' in method
        is_outlier_aware = 'outlier' in method

        w_bits, a_bits = get_bit_config(quant_cfg)

        # --- WEIGHT QUANTIZER ---
        self.weight_quantizer = EnhancedFakeQuantizer(
            bits=w_bits,
            is_weight_quantizer=True,
            per_channel=True,
            num_channels=original_linear_layer.out_features,
            learnable_scale=is_lsq
        )
        
        # --- ACTIVATION QUANTIZER ---
        base_act_quant = EnhancedFakeQuantizer(
            bits=a_bits,
            is_weight_quantizer=False,
            learnable_scale=is_lsq
        )

        if is_igq:
            self.activation_quantizer = InstanceAwareGroupQuantizer(
                bits=a_bits, num_channels=original_linear_layer.in_features, num_groups=quant_cfg.get('num_groups', 8)
            )
        elif is_outlier_aware:
            self.activation_quantizer = OutlierAwareFakeQuantizer(
                base_act_quant, percentile=quant_cfg.get('outlier_percentile', 0.01)
            )
        else:
            self.activation_quantizer = base_act_quant

        # --- QwT SETUP ---
        if self.is_qwt:
            self.qwt_compensation = nn.Linear(
                original_linear_layer.in_features, 
                original_linear_layer.out_features, 
                bias=True
            )
            nn.init.zeros_(self.qwt_compensation.weight)
            nn.init.zeros_(self.qwt_compensation.bias)

        self.register_buffer('bias_correction', torch.zeros(original_linear_layer.out_features))
        self.use_bias_correction = False
        self.register_buffer('smoothing_scale', None)

        # --- ROTATION ---
        self.is_rotation = is_rotation
        self.register_buffer('rotation_matrix', None)
        if self.is_rotation:
            self._init_rotation()

        # --- LoRA ---
        self.is_lora = is_lora
        if self.is_lora:
            self._init_lora(quant_cfg)

    # --- EXPOSE ORIGINAL WEIGHTS (Safe properties) ---
    @property
    def weight(self):
        return self.original_linear_layer.weight

    @property
    def bias(self):
        return self.original_linear_layer.bias

    # --- SIGLIP COMPATIBILITY FIXES (Safe: Only triggered on missing attrs/ops) ---
    
    def __getattr__(self, name):
        """Pass through attributes like 'in_features' to the inner layer."""
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.original_linear_layer, name)

    def __rmatmul__(self, x):
        """
        Handle 'x @ layer'.
        This mirrors the quantization logic of 'forward' but for the @ operator.
        """
        # 1. Apply Rotation (if active)
        if self.is_rotation and self.rotation_matrix is not None:
            rot = self.rotation_matrix.to(x.device, dtype=x.dtype)
            x = x @ rot
            
        # 2. Apply Smoothing (if active)
        if self.smoothing_scale is not None:
            x = x / self.smoothing_scale

        # 3. Quantize Activation (CRITICAL: Must match forward behavior)
        qx = self.activation_quantizer(x)
        
        # 4. Quantize Weight
        qw = self.weight_quantizer(self.original_linear_layer.weight)
        
        # 5. Perform Operation (Transposing weight to match x @ W.T standard)
        # Note: @ operator implies NO BIAS addition.
        return x @ qw.T

    # --- STANDARD METHODS (Unchanged) ---

    def _init_rotation(self):
        in_f = self.original_linear_layer.in_features
        weight = self.original_linear_layer.weight
        device = weight.device
        dtype = weight.dtype
        R = torch.randn(in_f, in_f, device=device, dtype=dtype)
        q, _ = torch.linalg.qr(R)
        self.register_buffer('rotation_matrix', q.detach())
        with torch.no_grad():
            self.original_linear_layer.weight.data = self.original_linear_layer.weight.data @ self.rotation_matrix

    def _init_lora(self, quant_cfg):
        self.original_linear_layer.weight.requires_grad = False
        if self.original_linear_layer.bias is not None:
            self.original_linear_layer.bias.requires_grad = False
        lora_rank = quant_cfg.get('lora_rank', 4)
        lora_alpha = quant_cfg.get('lora_alpha', 4)
        self.lora_scaling = lora_alpha / lora_rank
        in_f = self.original_linear_layer.in_features
        out_f = self.original_linear_layer.out_features
        self.lora_A = nn.Parameter(torch.Tensor(lora_rank, in_f))
        self.lora_B = nn.Parameter(torch.Tensor(out_f, lora_rank))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def enable_bias_correction(self):
        self.use_bias_correction = True

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        input_for_lora = x
        
        # 1. Pre-processing
        if self.is_rotation and self.rotation_matrix is not None:
            rot = self.rotation_matrix.to(x.device, dtype=x.dtype)
            x = x @ rot
        if self.smoothing_scale is not None:
            x = x / self.smoothing_scale

        # 2. Quantization
        qx = self.activation_quantizer(x)
        qw = self.weight_quantizer(self.original_linear_layer.weight)
        
        # 3. Main Linear Op
        main_path = F.linear(qx, qw, self.original_linear_layer.bias)
        
        # 4. Post-processing
        if self.is_qwt:
            compensation = self.qwt_compensation(qx)
            main_path = main_path + compensation
        if self.use_bias_correction:
            main_path = main_path + self.bias_correction.to(main_path.device)
        if self.is_lora:
            lora_path = (input_for_lora @ self.lora_A.T @ self.lora_B.T) * self.lora_scaling
            return main_path + lora_path
            
        return main_path

class QuantizedConv2d(nn.Module):
    def __init__(self, original_conv_layer, quant_cfg):
        super().__init__()
        self.original_conv_layer = original_conv_layer
        self.quant_cfg = quant_cfg
        w_bits, a_bits = get_bit_config(quant_cfg)
        is_lsq = 'lsq' in quant_cfg.get('method', '')

        self.weight_quantizer = EnhancedFakeQuantizer(
            bits=w_bits, is_weight_quantizer=True, per_channel=True, 
            num_channels=original_conv_layer.out_channels, learnable_scale=is_lsq
        )
        
        self.activation_quantizer = EnhancedFakeQuantizer(
            bits=a_bits, is_weight_quantizer=False, learnable_scale=is_lsq
        )

    # --- Properties to allow external access to weight/bias ---
    @property
    def weight(self):
        return self.original_conv_layer.weight

    @property
    def bias(self):
        return self.original_conv_layer.bias

    def forward(self, x):
        qx = self.activation_quantizer(x)
        qw = self.weight_quantizer(self.original_conv_layer.weight)
        return self.original_conv_layer._conv_forward(qx, qw, self.original_conv_layer.bias)

class QuantizableMultiheadAttention(nn.Module):
    def __init__(self, n_embd: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.q_proj = nn.Linear(n_embd, n_embd)
        self.k_proj = nn.Linear(n_embd, n_embd)
        self.v_proj = nn.Linear(n_embd, n_embd)
        self.out_proj = nn.Linear(n_embd, n_embd)
        # Hooks support
        self.q_for_dgd = None
        self.k_for_dgd = None

    @classmethod
    def from_multihead_attention(cls, original_mha: nn.MultiheadAttention) -> 'QuantizableMultiheadAttention':
        n_embd, n_head = original_mha.embed_dim, original_mha.num_heads
        new_mha = cls(n_embd, n_head)

        # --- ROBUST WEIGHT LOADING FOR CoCa / OpenCLIP ---
        # CoCa and some OpenCLIP models separate Q/K/V weights, causing in_proj_weight to be None.
        # We check both locations to ensure compatibility.
        
        q_w, k_w, v_w = None, None, None
        q_b, k_b, v_b = None, None, None

        if original_mha.in_proj_weight is not None:
            q_w, k_w, v_w = original_mha.in_proj_weight.chunk(3)
        else:
            # Fallback for separate weights (CoCa / standard PyTorch with separate inputs)
            q_w = getattr(original_mha, 'q_proj_weight', None)
            k_w = getattr(original_mha, 'k_proj_weight', None)
            v_w = getattr(original_mha, 'v_proj_weight', None)

        if original_mha.in_proj_bias is not None:
            q_b, k_b, v_b = original_mha.in_proj_bias.chunk(3)
        else:
             # Fallback for separate biases
            q_b = getattr(original_mha, 'q_proj_bias', None)
            k_b = getattr(original_mha, 'k_proj_bias', None)
            v_b = getattr(original_mha, 'v_proj_bias', None)

        # Assign Weights
        if q_w is not None: new_mha.q_proj.weight.data.copy_(q_w)
        if k_w is not None: new_mha.k_proj.weight.data.copy_(k_w)
        if v_w is not None: new_mha.v_proj.weight.data.copy_(v_w)
        
        # Assign Biases (if they exist)
        if q_b is not None: new_mha.q_proj.bias.data.copy_(q_b)
        else: new_mha.q_proj.bias.data.zero_()
            
        if k_b is not None: new_mha.k_proj.bias.data.copy_(k_b)
        else: new_mha.k_proj.bias.data.zero_()
            
        if v_b is not None: new_mha.v_proj.bias.data.copy_(v_b)
        else: new_mha.v_proj.bias.data.zero_()

        new_mha.out_proj.weight.data.copy_(original_mha.out_proj.weight)
        if original_mha.out_proj.bias is not None:
            new_mha.out_proj.bias.data.copy_(original_mha.out_proj.bias)
        else:
            new_mha.out_proj.bias.data.zero_()
            
        return new_mha

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, **kwargs: Any) -> Tuple[torch.Tensor, None]:
        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)
        
        # Capture for Teacher DGD Hooks (Distillation)
        self.q_for_dgd = q
        self.k_for_dgd = k
        
        B, L, C = q.shape
        q = q.view(B, L, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, L, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, L, self.n_head, C // self.n_head).transpose(1, 2)
        
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        
        if attn_mask is not None:
            # --- FIX: Trigger ONLY for CoCa-style fused masks (Batch*Heads) ---
            # Standard CLIP masks are (L, S) or (B, 1, L, S), so they skip this.
            # CoCa masks are (B*H, L, S) which crashes broadcasting against (B, H, L, S).
            if attn_mask.dim() == 3 and attn_mask.shape[0] == B * self.n_head:
                attn_mask = attn_mask.view(B, self.n_head, attn_mask.shape[1], attn_mask.shape[2])
            # ------------------------------------------------------------------
            scores += attn_mask
            
        y = torch.matmul(F.softmax(scores, dim=-1), v)
        y = y.transpose(1, 2).reshape(B, L, C)
        
        return self.out_proj(y), None

# --- Existing Experimental Modules (Kept for compatibility) ---
class PowerOfTwoFakeQuantizer(nn.Module):
    def __init__(self, bits=4):
        super().__init__()
        self.bits = bits
        num_levels = 2 ** bits
        self.register_buffer('pot_levels', torch.tensor([2**(-i) for i in range(num_levels)]))

    def forward(self, x):
        x_clamped = x.clamp(1e-9, 1.0)
        shape = x.shape
        x_flat = x_clamped.flatten().unsqueeze(-1)
        levels = self.pot_levels.to(x.device).unsqueeze(0)
        dists = torch.abs(x_flat - levels)
        min_indices = torch.argmin(dists, dim=1)
        x_quant = self.pot_levels[min_indices].view(shape)
        return x + (x_quant - x).detach()

class OutlierAwareFakeQuantizer(nn.Module):
    def __init__(self, base_quantizer, percentile=0.01):
        super().__init__()
        self.base_quantizer = base_quantizer
        self.percentile = percentile

    def forward(self, x):
        if not self.training:
            numel = x.numel()
            k_total = int(numel * self.percentile)
            
            if k_total == 0: 
                return self.base_quantizer(x)

            # --- OPTIMIZATION: STRIDED SAMPLING ---
            # Instead of sorting 50M+ elements (slow), we sample ~5000 elements.
            # This provides a statistically accurate threshold in microseconds.
            
            target_sample_size = 5000
            if numel > target_sample_size:
                step = numel // target_sample_size
                # Take a strided sample to get good distribution cover
                # FIX: Use .flatten() instead of .view(-1) to handle non-contiguous tensors (SigLIP/EVA)
                flat_sample = x.flatten()[::step].abs()
                
                # Calculate k for the sample size
                k_sample = int(flat_sample.numel() * self.percentile)
                if k_sample < 1: k_sample = 1
                
                # Calculate threshold on the small sample
                threshold = torch.kthvalue(flat_sample, flat_sample.numel() - k_sample + 1).values
            else:
                # Fallback for very small layers
                # FIX: Use .flatten() here as well
                flat_abs = x.abs().flatten()
                threshold = torch.kthvalue(flat_abs, numel - k_total + 1).values

            # --- END OPTIMIZATION ---

            # Create mask (Fast boolean op)
            mask_outlier = x.abs() > threshold
            
            # Quantize
            x_quant = self.base_quantizer(x)
            
            # Merge: Keep outliers FP16, use Quantized for the rest
            return torch.where(mask_outlier, x, x_quant)
        else:
            return self.base_quantizer(x)

class InstanceAwareGroupQuantizer(nn.Module):
    def __init__(self, bits=8, num_channels=None, num_groups=8):
        super().__init__()
        self.bits = bits
        self.num_channels = num_channels
        self.num_groups = num_groups
        self.register_buffer('group_scales', torch.zeros(self.num_groups))
        self.register_buffer('channel_groups', torch.zeros(self.num_channels, dtype=torch.long))
        self.register_buffer('calibrated', torch.tensor(False))

    def _get_qmin_qmax(self):
        q_min = -(2**(self.bits - 1))
        q_max = (2**(self.bits - 1)) - 1
        return q_min, q_max

    @torch.no_grad()
    def calibrate(self, act_maxes):
        if KMeans is None: raise ImportError("scikit-learn is required for IGQ.")
        if act_maxes.numel() != self.num_channels: return
        act_maxes_np = act_maxes.cpu().numpy().reshape(-1, 1)
        n_clusters = min(self.num_groups, len(act_maxes_np))
        kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init=10).fit(act_maxes_np)
        self.channel_groups.copy_(torch.from_numpy(kmeans.labels_).long())
        _, q_max = self._get_qmin_qmax()
        for i in range(self.num_groups):
            channels_in_group_mask = (self.channel_groups.cpu() == i)
            if torch.any(channels_in_group_mask):
                max_val_in_group = act_maxes[channels_in_group_mask].max()
                self.group_scales[i] = (max_val_in_group / q_max).clamp(min=1e-8)
        self.calibrated.fill_(True)

    def forward(self, x):
        if not self.calibrated or self.bits is None: return x
        q_min, q_max = self._get_qmin_qmax()
        per_channel_scales = self.group_scales[self.channel_groups].to(x.device)
        scale_shape = [1] * (x.ndim - 1) + [-1]
        current_scale = per_channel_scales.view(scale_shape)
        x_dequant = torch.clamp(torch.round(x / current_scale), q_min, q_max) * current_scale
        return x + (x_dequant - x).detach()

class IRM(nn.Module):
    def __init__(self):
        super().__init__()
        self.gamma = nn.Parameter(torch.tensor(1.0))
        self.beta = nn.Parameter(torch.tensor(0.0))
        self.eps = 1e-6
    def forward(self, x):
        mean = x.mean()
        std = x.std()
        x_rectified = (x - mean + self.beta) / (self.gamma * (std + self.eps))
        return x_rectified

class QViTMultiheadAttention(nn.Module):
    def __init__(self, original_mha, quant_cfg):
        super().__init__()
        self.n_embd = original_mha.embed_dim
        self.n_head = original_mha.num_heads
        
        # --- ROBUST WEIGHT LOADING FOR CoCa / OpenCLIP (QViT Version) ---
        q_w, k_w, v_w = None, None, None
        q_b, k_b, v_b = None, None, None

        if original_mha.in_proj_weight is not None:
            q_w, k_w, v_w = original_mha.in_proj_weight.chunk(3)
        else:
            q_w = getattr(original_mha, 'q_proj_weight', None)
            k_w = getattr(original_mha, 'k_proj_weight', None)
            v_w = getattr(original_mha, 'v_proj_weight', None)

        if original_mha.in_proj_bias is not None:
            q_b, k_b, v_b = original_mha.in_proj_bias.chunk(3)
        else:
            q_b = getattr(original_mha, 'q_proj_bias', None)
            k_b = getattr(original_mha, 'k_proj_bias', None)
            v_b = getattr(original_mha, 'v_proj_bias', None)
        
        self.q_proj_orig = nn.Linear(self.n_embd, self.n_embd)
        self.k_proj_orig = nn.Linear(self.n_embd, self.n_embd)
        self.v_proj_orig = nn.Linear(self.n_embd, self.n_embd)
        
        # Load weights/biases with defaults
        if q_w is not None: self.q_proj_orig.weight.data.copy_(q_w)
        if q_b is not None: self.q_proj_orig.bias.data.copy_(q_b)
        else: self.q_proj_orig.bias.data.zero_()

        if k_w is not None: self.k_proj_orig.weight.data.copy_(k_w)
        if k_b is not None: self.k_proj_orig.bias.data.copy_(k_b)
        else: self.k_proj_orig.bias.data.zero_()

        if v_w is not None: self.v_proj_orig.weight.data.copy_(v_w)
        if v_b is not None: self.v_proj_orig.bias.data.copy_(v_b)
        else: self.v_proj_orig.bias.data.zero_()

        w_bits, a_bits = get_bit_config(quant_cfg)

        self.q_weight_quantizer = EnhancedFakeQuantizer(bits=w_bits, is_weight_quantizer=True, per_channel=True, num_channels=self.n_embd)
        self.k_weight_quantizer = EnhancedFakeQuantizer(bits=w_bits, is_weight_quantizer=True, per_channel=True, num_channels=self.n_embd)
        self.v_weight_quantizer = EnhancedFakeQuantizer(bits=w_bits, is_weight_quantizer=True, per_channel=True, num_channels=self.n_embd)
        
        self.q_in_quantizer = EnhancedFakeQuantizer(bits=a_bits, is_weight_quantizer=False)
        self.k_in_quantizer = EnhancedFakeQuantizer(bits=a_bits, is_weight_quantizer=False)
        self.v_in_quantizer = EnhancedFakeQuantizer(bits=a_bits, is_weight_quantizer=False)

        self.irm_q = IRM()
        self.irm_k = IRM()

        self.q_act_quantizer = EnhancedFakeQuantizer(bits=a_bits, is_weight_quantizer=False)
        self.k_act_quantizer = EnhancedFakeQuantizer(bits=a_bits, is_weight_quantizer=False)
        self.v_act_quantizer = EnhancedFakeQuantizer(bits=a_bits, is_weight_quantizer=False)
        
        is_apq = 'apq' in quant_cfg.get('method', '')
        if is_apq:
            self.attn_score_quantizer = PowerOfTwoFakeQuantizer(bits=a_bits)
        else:
            self.attn_score_quantizer = EnhancedFakeQuantizer(bits=a_bits, is_weight_quantizer=False)

        self.out_proj = QuantizedLinearLayer(copy.deepcopy(original_mha.out_proj), quant_cfg)
        self.q_for_dgd = None
        self.k_for_dgd = None

    def forward(self, query, key, value, attn_mask=None, **kwargs):
        qw_q = self.q_weight_quantizer(self.q_proj_orig.weight)
        qw_k = self.k_weight_quantizer(self.k_proj_orig.weight)
        qw_v = self.v_weight_quantizer(self.v_proj_orig.weight)

        q_in = self.q_in_quantizer(query)
        k_in = self.k_in_quantizer(key)
        v_in = self.v_in_quantizer(value)

        q = F.linear(q_in, qw_q, self.q_proj_orig.bias)
        k = F.linear(k_in, qw_k, self.k_proj_orig.bias)
        v = F.linear(v_in, qw_v, self.v_proj_orig.bias)
        
        # Capture projected Q/K for DGD Loss (post-linear, pre-attention)
        self.q_for_dgd = q
        self.k_for_dgd = k

        q_rect = self.irm_q(q)
        k_rect = self.irm_k(k)
        
        qq = self.q_act_quantizer(q_rect)
        qk = self.k_act_quantizer(k_rect)
        qv = self.v_act_quantizer(v)

        B, L, C = qq.shape
        qq = qq.view(B, L, self.n_head, C // self.n_head).transpose(1, 2)
        qk = qk.view(B, L, self.n_head, C // self.n_head).transpose(1, 2)
        qv = qv.view(B, L, self.n_head, C // self.n_head).transpose(1, 2)
        
        scores = torch.matmul(qq, qk.transpose(-2, -1)) / math.sqrt(qq.size(-1))
        if attn_mask is not None:
            scores += attn_mask
        
        quantized_softmax_scores = self.attn_score_quantizer(F.softmax(scores, dim=-1))
        
        y = torch.matmul(quantized_softmax_scores, qv)
        y = y.transpose(1, 2).reshape(B, L, C)
        
        return self.out_proj(y), None