# ls_ood_detect_cea/quantization/utils.py

import copy
import logging
import torch.nn as nn
from torch.nn import MultiheadAttention
from .modules import QuantizedLinearLayer, QuantizedConv2d, QuantizableMultiheadAttention

MODULE_LOGGER = logging.getLogger(__name__)

def replace_layers_with_quantized_versions(model, quantize_text=False, exclude_patterns=None, **quant_cfg):
    if exclude_patterns is None:
        exclude_patterns = []

    q_model = copy.deepcopy(model)
    
    # 1. Handle MultiheadAttention
    for name, module in model.named_modules():
        if isinstance(module, MultiheadAttention):
            if _should_skip(name, exclude_patterns): continue
            
            is_visual = 'visual' in name
            is_text = 'transformer' in name or 'text' in name
            should_quantize = is_visual or (quantize_text and is_text)
            
            if should_quantize:
                # Replace with Quantizable version (FP32 but split q/k/v for QViT/APQ)
                _replace_module(q_model, name, QuantizableMultiheadAttention.from_multihead_attention(module))

    # 2. Replace Linear and Conv2d layers
    modules_to_replace = {}
    for name, mod in q_model.named_modules():
        if _should_skip(name, exclude_patterns):
            MODULE_LOGGER.debug(f"Skipping quantization for: {name}")
            continue

        is_linear = isinstance(mod, nn.Linear)
        is_conv = isinstance(mod, nn.Conv2d)
        
        if is_linear or is_conv:
            is_visual = 'visual' in name or 'backbone' in name
            is_text = 'transformer' in name or 'text' in name or 'bert' in name
            
            should_quantize = is_visual or (quantize_text and is_text)
            
            if should_quantize:
                modules_to_replace[name] = mod

    for name, module in modules_to_replace.items():
        if isinstance(module, nn.Linear):
            quant_layer = QuantizedLinearLayer(copy.deepcopy(module), quant_cfg)
        elif isinstance(module, nn.Conv2d):
            # Rotation PTQ is usually not applicable to Conv2d in this context
            # We fallback to standard QuantizedConv2d unless method implies something else
            quant_layer = QuantizedConv2d(copy.deepcopy(module), quant_cfg)
        
        _replace_module(q_model, name, quant_layer)
        
    return q_model

def _should_skip(name, patterns):
    for p in patterns:
        if p in name:
            return True
    return False

def _replace_module(root_model, path, new_module):
    parent_name, child_name = path.rsplit('.', 1)
    parent = root_model.get_submodule(parent_name)
    setattr(parent, child_name, new_module)