# ls_ood_detect_cea/quantization/apply.py

import torch
import copy
import logging
import torch.nn as nn
from tqdm import tqdm

from .utils import replace_layers_with_quantized_versions
from .modules import InstanceAwareGroupQuantizer, QuantizedLinearLayer, QuantizedConv2d, EnhancedFakeQuantizer
from .calibration import calibrate_model, get_smoothing_stats, collect_activation_stats, calibrate_qwt 
from .training import run_contrastive_finetuning_loop, run_qvit_finetuning_loop, run_cosine_matching_loop, run_qvlm_veo_loop
import scripts.config as config

MODULE_LOGGER = logging.getLogger(__name__)

# UPDATED EXCLUSIONS for CoCa support
EXCLUSION_ZOO = {
    'siglip': ['visual.head', 'text.head', 'logit_bias', 'visual.proj', 
               'text.proj', 'text_projection', 'text.output'], 
    'convnext': ['visual.head', 'visual.proj'], 
    'resnet': ['visual.attnpool', 'visual.proj'],

    'coca': ['attn_pool', 'visual.proj', 'text.proj', 'text_decoder', 'text.decoder'],
    'open_clip': ['visual.proj', 'text.proj']
}
def get_exclusions(model_arch_name):
    model_arch_name = model_arch_name.lower()
    patterns = []
    
    # Check specific architectures first
    for key, val in EXCLUSION_ZOO.items():
        if key in model_arch_name:
            patterns.extend(val)
            
    # Generic OpenCLIP exclusions if nothing specific found, 
    # or ensure basic projections are always excluded
    if not patterns:
        patterns = ['visual.proj', 'text.proj']
    else:
        # Always exclude standard projections if not already caught
        patterns.extend(['visual.proj', 'text.proj'])
        
    return list(set(patterns))

def _cache_calibration_data(dataloader, num_batches=50):
    MODULE_LOGGER.info(f"Caching {num_batches} batches of calibration data...")
    cache = []
    if dataloader is None: return cache
    try:
        for i, batch in enumerate(tqdm(dataloader, desc="Caching Data", total=num_batches, leave=False)):
            if i >= num_batches: break
            cache.append((batch[0].cpu(), batch[1]))
    except Exception as e:
        MODULE_LOGGER.error(f"Error during data caching: {e}")
    return cache

def _apply_smoothing_to_model(model, scales_map):
    with torch.no_grad():
        for block_name, block in model.named_modules():
            if hasattr(block, "ln_2") and hasattr(block, "mlp"):
                linear_layer_name = f"{block_name}.mlp.c_fc"
                if linear_layer_name in scales_map:
                    scale = scales_map[linear_layer_name].to(config.TARGET_DEVICE)
                    block.ln_2.weight.div_(scale.view(-1))
                    if block.ln_2.bias is not None:
                        block.ln_2.bias.div_(scale.view(-1))
                    if hasattr(block.mlp.c_fc, "original_linear_layer"):
                        block.mlp.c_fc.original_linear_layer.weight.mul_(scale.view(1, -1))
            
            if hasattr(block, "ln_1") and hasattr(block, "attn"):
                def apply_to_proj(proj_module, s):
                    if hasattr(proj_module, "original_linear_layer"):
                        proj_module.original_linear_layer.weight.mul_(s.view(1, -1))
                    elif isinstance(proj_module, nn.Linear):
                        proj_module.weight.mul_(s.view(1, -1))

                q_name = f"{block_name}.attn.q_proj"
                if q_name in scales_map:
                    scale = scales_map[q_name].to(config.TARGET_DEVICE)
                    block.ln_1.weight.div_(scale.view(-1))
                    if block.ln_1.bias is not None:
                        block.ln_1.bias.div_(scale.view(-1))
                    mha = block.attn
                    apply_to_proj(mha.q_proj, scale)
                    apply_to_proj(mha.k_proj, scale)
                    apply_to_proj(mha.v_proj, scale)
    return model

def _freeze_observers(model):
    count = 0
    for module in model.modules():
        if isinstance(module, EnhancedFakeQuantizer) and not module.is_weight_quantizer:
            module.disable_observer_update()
            count += 1
    MODULE_LOGGER.info(f"Frozen observers for {count} activation quantizers.")
    return model

# ==============================================================================
# APPLY FUNCTIONS
# ==============================================================================

def apply_simple_ptq(fp32_model, *, calibration_dataloader, target_device, quantize_text=False, tokenizer=None, prompts=None, **kwargs):
    MODULE_LOGGER.info(f"--- Applying Simple PTQ (Text Quant: {quantize_text}) ---")
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    calib_cache = _cache_calibration_data(calibration_dataloader)
    if not calib_cache: return fp32_model
    q_model = replace_layers_with_quantized_versions(
        fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, method='ptq', **kwargs
    ).to(target_device)
    q_model = calibrate_model(q_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    return q_model.eval()

def apply_qwt_ptq(fp32_model, *, calibration_dataloader, target_device, quantize_text=False, tokenizer=None, prompts=None, **kwargs):
    MODULE_LOGGER.info(f"--- Applying QwT PTQ (Sequential Ridge Regression) ---")
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    calib_cache = _cache_calibration_data(calibration_dataloader, num_batches=kwargs.get('num_calibration_batches', 20))
    if not calib_cache: return fp32_model
    
    # FIX: Remove 'method'='ptq', passing kwargs allows config.QWT_KWARGS to set 'method'='qwt'
    q_model = replace_layers_with_quantized_versions(
        fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, **kwargs
    ).to(target_device)
    
    # 1. Standard Calibration (Min/Max scales)
    q_model = calibrate_model(q_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    
    # 2. QwT Sequential Calibration
    q_model = calibrate_qwt(q_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    
    return q_model.eval()

def apply_apq_ptq(fp32_model, *, calibration_dataloader, target_device, quantize_text=False, tokenizer=None, prompts=None, **kwargs):
    MODULE_LOGGER.info(f"--- Applying APQ-ViT PTQ ---")
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    calib_cache = _cache_calibration_data(calibration_dataloader)
    if not calib_cache: return fp32_model
    
    q_model = replace_layers_with_quantized_versions(
        fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, method='apq', **kwargs
    ).to(target_device)
    q_model = calibrate_model(q_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    return q_model.eval()

def apply_rotation_ptq(fp32_model, *, calibration_dataloader, target_device, quantize_text=False, tokenizer=None, prompts=None, **kwargs):
    MODULE_LOGGER.info(f"--- Applying Rotation PTQ (SpinQuant) ---")
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    calib_cache = _cache_calibration_data(calibration_dataloader)
    if not calib_cache: return fp32_model
    
    q_model = replace_layers_with_quantized_versions(
        fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, method='rotation', **kwargs
    ).to(target_device)
    q_model = calibrate_model(q_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    return q_model.eval()

def apply_outlier_aware_ptq(fp32_model, *, calibration_dataloader, target_device, quantize_text=False, tokenizer=None, prompts=None, **kwargs):
    MODULE_LOGGER.info(f"--- Applying OutlierAware PTQ ---")
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    calib_cache = _cache_calibration_data(calibration_dataloader)
    if not calib_cache: return fp32_model
    
    q_model = replace_layers_with_quantized_versions(
        fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, method='outlier', **kwargs
    ).to(target_device)
    q_model = calibrate_model(q_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    return q_model.eval()

def apply_smoothquant(fp32_model, *, calibration_dataloader, target_device, quantize_text=False, tokenizer=None, prompts=None, **kwargs):
    MODULE_LOGGER.info(f"--- Applying SmoothQuant PTQ ---")
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    calib_cache = _cache_calibration_data(calibration_dataloader)
    if not calib_cache: return fp32_model
    q_model = replace_layers_with_quantized_versions(
        fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, method='ptq', **kwargs
    ).to(target_device)
    alpha = kwargs.get('alpha', 0.5)
    
    scales_map = get_smoothing_stats(q_model, calib_cache, target_device, alpha, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    q_model = _apply_smoothing_to_model(q_model, scales_map)
    q_model = calibrate_model(q_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    return q_model.eval()

def apply_igq_vit(fp32_model, *, calibration_dataloader, target_device, quantize_text=False, tokenizer=None, prompts=None, **kwargs):
    MODULE_LOGGER.info(f"--- Applying IGQ-ViT PTQ ---")
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    calib_cache = _cache_calibration_data(calibration_dataloader)
    if not calib_cache: return fp32_model
    q_model = replace_layers_with_quantized_versions(
        fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, method='igq', **kwargs
    ).to(target_device)
    act_maxes_map = collect_activation_stats(q_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    for name, module in q_model.named_modules():
        if isinstance(module, QuantizedLinearLayer) and isinstance(module.activation_quantizer, InstanceAwareGroupQuantizer):
            if name in act_maxes_map:
                module.activation_quantizer.calibrate(act_maxes_map[name])
    return q_model.eval()

def apply_rotation_lsq(fp32_model, *, training_dataloader, calibration_dataloader, target_device, tokenizer, prompts, teacher, quantize_text=False, **kwargs):
    MODULE_LOGGER.info(f"--- Applying Rotation + LSQ (Geometric Init + Fine-tuning) ---")
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    calib_cache = _cache_calibration_data(calibration_dataloader)
    if not calib_cache: return fp32_model
    
    if 'lsq_learning_rate' not in kwargs:
        raise ValueError("Rotation+LSQ requires 'lsq_learning_rate' (use config.LSQ_KWARGS).")

    s_model = replace_layers_with_quantized_versions(
        fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, 
        method='rotation_lsq', **kwargs
    ).to(target_device)

    s_model = calibrate_model(s_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    
    return run_contrastive_finetuning_loop(
        s_model, teacher, tokenizer, prompts, training_dataloader, target_device,
        kwargs.get('total_steps', 100), kwargs.get('learning_rate'), "Rotation+LSQ",
        kwargs.get('lsq_learning_rate'),
        main_loss_weight=kwargs.get('main_loss_weight', 1.0), distill_weight=kwargs.get('distill_weight', 0.0)
    )

def apply_quantization_aware_training(fp32_model, *, training_dataloader, calibration_dataloader, target_device, tokenizer, prompts, teacher, quantize_text=False, **kwargs):
    MODULE_LOGGER.info(f"--- Applying Basic QAT ---")
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    calib_cache = _cache_calibration_data(calibration_dataloader)
    if not calib_cache: return fp32_model
    s_model = replace_layers_with_quantized_versions(
        fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, method='basic_qat', **kwargs
    ).to(target_device)
    s_model = calibrate_model(s_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    s_model = _freeze_observers(s_model)
    return run_contrastive_finetuning_loop(
        s_model, teacher, tokenizer, prompts, training_dataloader, target_device,
        kwargs.get('total_steps', 100), kwargs.get('learning_rate', 1e-6), "QAT",
        main_loss_weight=kwargs.get('main_loss_weight', 1.0), distill_weight=kwargs.get('distill_weight', 0.0)
    )

def apply_learned_step_size_quantization(fp32_model, *, training_dataloader, calibration_dataloader, target_device, tokenizer, prompts, teacher, quantize_text=False, **kwargs):
    MODULE_LOGGER.info(f"--- Applying LSQ ---")
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    calib_cache = _cache_calibration_data(calibration_dataloader)
    if not calib_cache: return fp32_model
    
    if 'lsq_learning_rate' not in kwargs:
        raise ValueError("LSQ requires 'lsq_learning_rate' to be specified in kwargs.")
    if 'learning_rate' not in kwargs:
        raise ValueError("LSQ requires 'learning_rate' for other parameters.")

    s_model = replace_layers_with_quantized_versions(
        fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, method='lsq', **kwargs
    ).to(target_device)
    s_model = calibrate_model(s_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    
    return run_contrastive_finetuning_loop(
        s_model, teacher, tokenizer, prompts, training_dataloader, target_device,
        kwargs.get('total_steps', 100), kwargs.get('learning_rate'), "LSQ",
        kwargs.get('lsq_learning_rate'),
        main_loss_weight=kwargs.get('main_loss_weight', 1.0), distill_weight=kwargs.get('distill_weight', 0.0)
    )

def apply_qat_lora(fp32_model, *, training_dataloader, calibration_dataloader, target_device, tokenizer, prompts, teacher, quantize_text=False, **kwargs):
    MODULE_LOGGER.info(f"--- Applying QAT-LoRA ---")
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    calib_cache = _cache_calibration_data(calibration_dataloader)
    if not calib_cache: return fp32_model
    s_model = replace_layers_with_quantized_versions(
        fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, method='qat_lora', **kwargs
    ).to(target_device)
    s_model = calibrate_model(s_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    s_model = _freeze_observers(s_model)
    return run_contrastive_finetuning_loop(
        s_model, teacher, tokenizer, prompts, training_dataloader, target_device,
        kwargs.get('total_steps', 100), kwargs.get('learning_rate', 1e-6), "QAT-LoRA",
        main_loss_weight=1.0, distill_weight=0.0
    )

def apply_qvit(fp32_model, *, training_dataloader, calibration_dataloader, target_device, tokenizer, prompts, teacher, quantize_text=False, **kwargs):
    MODULE_LOGGER.info(f"--- Applying Q-ViT ---")
    calib_cache = _cache_calibration_data(calibration_dataloader)
    if not calib_cache: return fp32_model
    teacher_model = teacher if teacher is not None else copy.deepcopy(fp32_model).to(target_device)
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    q_model = replace_layers_with_quantized_versions(
        fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, method='qvit', **kwargs
    ).to(target_device)
    q_model = calibrate_model(q_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    q_model = _freeze_observers(q_model)
    return run_qvit_finetuning_loop(
        q_model, teacher_model, tokenizer, prompts, training_dataloader, target_device, 
        kwargs.get('total_steps', 100), kwargs.get('learning_rate', 1e-6), "Q-ViT",
        main_loss_weight=1.0, distill_weight=1.0
    )

def apply_cosine_qat(fp32_model, *, training_dataloader, calibration_dataloader, target_device, tokenizer, prompts, teacher, quantize_text=False, **kwargs):
    MODULE_LOGGER.info(f"--- Applying Hybrid Cosine QAT ---")
    if teacher is None: raise ValueError("CosQAT requires a teacher model.")
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    s_model = replace_layers_with_quantized_versions(
        fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, method='qat', **kwargs
    ).to(target_device)
    calib_cache = _cache_calibration_data(calibration_dataloader)
    if not calib_cache: return fp32_model
    s_model = calibrate_model(s_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    return run_cosine_matching_loop(
        student=s_model, teacher=teacher, tokenizer=tokenizer, prompts=prompts, dataloader=training_dataloader, 
        target_device=target_device, total_steps=kwargs.get('total_steps', 100), lr=kwargs.get('learning_rate', 1e-5), 
        name="CosQAT", contrastive_weight=kwargs.get('contrastive_weight', 1.0), warmup_pct=kwargs.get('warmup_pct', 0.5) 
    )

def apply_qvlm_ptq(fp32_model, *, calibration_dataloader, target_device, quantize_text=False, tokenizer=None, prompts=None, **kwargs):
    MODULE_LOGGER.info(f"--- Applying Q-VLM (PTQ + Visual Encoder Optimization) ---")
    
    # 1. Setup
    exclusions = get_exclusions(config.MODEL_CONFIG['arch'])
    calib_cache = _cache_calibration_data(calibration_dataloader)
    if not calib_cache: return fp32_model
    
    # 2. Replace layers with Fake Quantizers
    # Q-VLM works on the quantized graph.
    q_model = replace_layers_with_quantized_versions(
            fp32_model, quantize_text=quantize_text, exclude_patterns=exclusions, **kwargs).to(target_device)
    
    # 3. Initial Calibration (MinMax)
    # We need rough scales before we start optimizing weights.
    q_model = calibrate_model(q_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    
    # 4. Visual Encoder Optimization (VEO)
    # The paper's core contribution: Optimize weights to disentangle cross-layer dependency
    # and minimize entropy/quantization noise.
    # Note: We create a temporary Teacher for distillation.
    teacher_model = copy.deepcopy(fp32_model).to(target_device).eval()
    for p in teacher_model.parameters(): p.requires_grad = False

    # We use the calibration dataloader for this optimization (as per PTQ definition)
    # Convert cache back to iterable for the loop
    class CacheDataset(torch.utils.data.Dataset):
        def __init__(self, data): self.data = data
        def __len__(self): return len(self.data)
        def __getitem__(self, idx): return self.data[idx]
    
    veo_loader = torch.utils.data.DataLoader(CacheDataset(calib_cache), batch_size=None)

    q_model = run_qvlm_veo_loop(
        q_model, teacher_model, veo_loader, target_device,
        epochs=kwargs.get('epochs', 5),
        lr=kwargs.get('learning_rate', 1e-5),
        name="Q-VLM VEO",
        distill_weight=kwargs.get('distill_weight', 1.0),
        quant_error_weight=kwargs.get('quant_error_weight', 0.5)
    )
    
    # 5. Final Calibration
    # Scales might need slight adjustment after weights changed.
    q_model = calibrate_model(q_model, calib_cache, target_device, tokenizer=tokenizer, prompts=prompts, quantize_text=quantize_text)
    
    # Cleanup
    del teacher_model
    
    return q_model.eval()