
# quantization/training.py

import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import scripts.config as config
from .modules import QViTMultiheadAttention, QuantizableMultiheadAttention, EnhancedFakeQuantizer, QuantizedLinearLayer

MODULE_LOGGER = logging.getLogger(__name__)

def _process_contrastive_batch(student, batch, tokenizer, class_names, target_device):
    images, second_element = batch
    
    texts = None
    if isinstance(second_element, (list, tuple)) and len(second_element) > 0 and isinstance(second_element[0], str):
        texts = list(second_element)
        ground_truth = torch.arange(len(images), dtype=torch.long, device=target_device)
    elif isinstance(second_element, torch.Tensor):
        labels = second_element.cpu().numpy()
        texts = [config.CIFAR100_TEMPLATE.format(class_names[label]) for label in labels]
        ground_truth = torch.arange(len(images), dtype=torch.long, device=target_device)
    else:
        return torch.tensor(0.0, device=target_device)

    images = images.to(target_device, non_blocking=True)
    tokenized_texts = tokenizer(texts).to(target_device)

    image_features = F.normalize(student.encode_image(images), dim=-1)
    text_features = F.normalize(student.encode_text(tokenized_texts), dim=-1)

    logit_scale = student.logit_scale.exp()
    logits_per_image = logit_scale * image_features @ text_features.t()
    logits_per_text = logits_per_image.t()

    loss = (F.cross_entropy(logits_per_image, ground_truth) + F.cross_entropy(logits_per_text, ground_truth)) / 2
    return loss

def run_contrastive_finetuning_loop(student, teacher, tokenizer, class_names, dataloader, target_device, total_steps, lr, name, lsq_lr=None, main_loss_weight=1.0, distill_weight=0.0):
    student.to(target_device).train()
    if teacher: teacher.to(target_device).eval()
    
    student.logit_scale.requires_grad = True 
    
    # --- UPDATED OPTIMIZER SETUP (From Reference) ---
    lsq_scale_params = []
    other_params = []
    
    is_lsq_run = lsq_lr is not None

    for param_name, param in student.named_parameters():
        if not param.requires_grad: 
            continue
        
        # Updated grouping logic to match reference script
        if is_lsq_run and 'activation_quantizer.scale' in param_name:
            lsq_scale_params.append(param)
        else:
            other_params.append(param)
            
    if lsq_scale_params:
        MODULE_LOGGER.info(f"Found {len(lsq_scale_params)} LSQ scale parameters. Using separate LR: {lsq_lr}")
        param_groups = [
            {'params': other_params, 'lr': lr, 'name': 'other_params'},
            {'params': lsq_scale_params, 'lr': lsq_lr, 'name': 'lsq_scales'}
        ]
    else:
        param_groups = [{'params': other_params, 'lr': lr, 'name': 'all_params'}]
    
    if not any(len(pg['params']) > 0 for pg in param_groups):
        MODULE_LOGGER.warning(f"No trainable parameters for {name}, skipping training.")
        return student.eval()
    
    optimizer = torch.optim.AdamW(param_groups)
    optimizer.zero_grad(set_to_none=True)
    
    scaler = torch.amp.GradScaler('cuda', enabled=config.USE_AMP)
    
    accum_steps = config.GRAD_ACCUM_STEPS
    pbar = tqdm(range(total_steps), desc=f"{name} Steps")
    data_iterator = iter(dataloader)
    
    for step in pbar:
        try:
            batch = next(data_iterator)
        except StopIteration:
            data_iterator = iter(dataloader)
            batch = next(data_iterator)
            
        images, second_element = batch
        images = images.to(target_device, non_blocking=True)
        
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            loss_main = torch.tensor(0.0, device=target_device)
            if main_loss_weight > 0:
                loss_main = _process_contrastive_batch(student, batch, tokenizer, class_names, target_device)
            
            loss_distill = torch.tensor(0.0, device=target_device)
            if distill_weight > 0 and teacher:
                student_img_features = student.encode_image(images)
                with torch.no_grad():
                    teacher_img_features = teacher.encode_image(images)
                loss_distill = F.mse_loss(student_img_features, teacher_img_features.detach())
                
            total_loss = (main_loss_weight * loss_main + distill_weight * loss_distill) / accum_steps
            
        scaler.scale(total_loss).backward()
        
        if (step + 1) % accum_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            
        disp_loss = total_loss.item() * accum_steps
        pbar.set_postfix({"loss": f"{disp_loss:.4f}"})
        
    return student.eval()

def run_cosine_matching_loop(student, teacher, tokenizer, prompts, dataloader, target_device, total_steps, lr, name, contrastive_weight=1.0, warmup_pct=0.5):
    MODULE_LOGGER.info(f"Starting CosQAT (Target Contrastive W: {contrastive_weight}, Warmup: {warmup_pct*100}%)...")
    
    student.to(target_device).train()
    teacher.to(target_device).eval()
    student.logit_scale.requires_grad = False
    
    params = [p for p in student.parameters() if p.requires_grad]
    if not params: return student.eval()

    optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-7)
    scaler = torch.amp.GradScaler('cuda', enabled=config.USE_AMP)
    
    pbar = tqdm(range(total_steps), desc=f"{name}")
    data_iterator = iter(dataloader)
    target_ones = None 

    for step in pbar:
        try:
            batch = next(data_iterator)
        except StopIteration:
            data_iterator = iter(dataloader)
            batch = next(data_iterator)

        progress = step / total_steps
        alpha = 0.0 
        if progress >= warmup_pct:
            ramp_factor = (progress - warmup_pct) / (1.0 - warmup_pct)
            alpha = contrastive_weight * ramp_factor

        images, second_element = batch
        images = images.to(target_device, non_blocking=True)
        
        if target_ones is None or target_ones.size(0) != images.size(0):
            target_ones = torch.ones(images.size(0), device=target_device)

        texts = None
        if isinstance(second_element, (list, tuple)) and isinstance(second_element[0], str):
            texts = list(second_element)
        elif isinstance(second_element, torch.Tensor) and prompts is not None:
            labels = second_element.cpu().numpy()
            texts = [config.CIFAR100_TEMPLATE.format(prompts[l]) for l in labels]
        
        optimizer.zero_grad(set_to_none=True)
        
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            s_img_raw = student.encode_image(images)
            if texts is not None:
                tokenized = tokenizer(texts).to(target_device)
                s_txt_raw = student.encode_text(tokenized)
            else:
                s_txt_raw = None

            with torch.no_grad():
                t_img_raw = teacher.encode_image(images)
                t_txt_raw = teacher.encode_text(tokenized) if texts is not None else None

            loss_distill_img = F.cosine_embedding_loss(s_img_raw, t_img_raw, target_ones)
            
            loss_distill_txt = torch.tensor(0.0, device=target_device)
            if s_txt_raw is not None and t_txt_raw is not None:
                loss_distill_txt = F.cosine_embedding_loss(s_txt_raw, t_txt_raw, target_ones)

            loss_clip = torch.tensor(0.0, device=target_device)
            if alpha > 1e-4 and s_txt_raw is not None:
                img_norm = s_img_raw / s_img_raw.norm(dim=-1, keepdim=True)
                txt_norm = s_txt_raw / s_txt_raw.norm(dim=-1, keepdim=True)
                logit_scale = student.logit_scale.exp() 
                logits_per_image = logit_scale * img_norm @ txt_norm.t()
                loss_clip = (F.cross_entropy(logits_per_image, torch.arange(len(images), device=target_device)) + 
                             F.cross_entropy(logits_per_image.t(), torch.arange(len(images), device=target_device))) / 2

            total_loss = (loss_distill_img + loss_distill_txt) + (alpha * loss_clip)
        
        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        pbar.set_postfix({"L": f"{total_loss.item():.3f}", "α": f"{alpha:.2f}"})

    return student.eval()

def run_qvit_finetuning_loop(student, teacher, tokenizer, class_names, dataloader, target_device, total_steps, lr, name, main_loss_weight=1.0, distill_weight=1.0):
    if teacher is None and distill_weight > 0:
        distill_weight = 0.0
        
    student.to(target_device).train()
    student.logit_scale.requires_grad = True
    if teacher: teacher.to(target_device).eval()
    
    params_to_train = [p for p in student.parameters() if p.requires_grad]
    if not params_to_train: return student.eval()
    
    optimizer = torch.optim.AdamW(params_to_train, lr=lr)
    scaler = torch.amp.GradScaler('cuda', enabled=config.USE_AMP)
    accum_steps = config.GRAD_ACCUM_STEPS
    
    student_activations, teacher_activations = {}, {}
    def get_hook(storage, name, detach=False):
        def hook_fn(module, inputs, outputs):
            if hasattr(module, 'q_for_dgd') and module.q_for_dgd is not None:
                q, k = module.q_for_dgd, module.k_for_dgd
                if detach: q, k = q.detach(), k.detach()
                storage[name] = {'q': q, 'k': k}
        return hook_fn

    hooks = []
    if teacher:
        s_mha = {n for n, m in student.named_modules() if isinstance(m, QViTMultiheadAttention)}
        t_mha = {n for n, m in teacher.named_modules() if isinstance(m, QuantizableMultiheadAttention)}
        
        for name in s_mha.intersection(t_mha):
            hooks.append(student.get_submodule(name).register_forward_hook(get_hook(student_activations, name, False)))
            hooks.append(teacher.get_submodule(name).register_forward_hook(get_hook(teacher_activations, name, True)))

    pbar = tqdm(range(total_steps), desc=f"{name} Steps")
    data_iterator = iter(dataloader)
    optimizer.zero_grad(set_to_none=True)
    
    for step in pbar:
        try:
            batch = next(data_iterator)
        except StopIteration:
            data_iterator = iter(dataloader)
            batch = next(data_iterator)
            
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            loss_main = torch.tensor(0.0, device=target_device)
            if main_loss_weight > 0:
                loss_main = _process_contrastive_batch(student, batch, tokenizer, class_names, target_device)
            else:
                _ = student.encode_image(batch[0].to(target_device))
                
            loss_dgd = torch.tensor(0.0, device=target_device)
            if distill_weight > 0 and teacher:
                with torch.no_grad():
                    teacher.encode_image(batch[0].to(target_device))
                
                dgd_terms = []
                for name in student_activations:
                    if name in teacher_activations:
                        q_s, k_s = student_activations[name]['q'], student_activations[name]['k']
                        q_t, k_t = teacher_activations[name]['q'], teacher_activations[name]['k']
                        
                        attn_s = torch.bmm(q_s, k_s.transpose(1, 2))
                        attn_t = torch.bmm(q_t, k_t.transpose(1, 2))
                        
                        attn_s = F.normalize(attn_s, p=2, dim=-1)
                        attn_t = F.normalize(attn_t, p=2, dim=-1)
                        
                        dgd_terms.append(F.mse_loss(attn_s, attn_t))
                
                if dgd_terms:
                    loss_dgd = torch.stack(dgd_terms).mean()

            total_loss = (main_loss_weight * loss_main + distill_weight * loss_dgd) / accum_steps

        scaler.scale(total_loss).backward()
        
        if (step + 1) % accum_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            
        pbar.set_postfix({"loss": f"{total_loss.item()*accum_steps:.4f}"})
        
    for h in hooks: h.remove()
    return student.eval()
def run_qvlm_veo_loop(student, teacher, dataloader, target_device, epochs, lr, name, distill_weight=1.0, quant_error_weight=1.0):
    """
    Q-VLM Visual Encoder Optimization (VEO).
    Optimizes weights to minimize quantization error (proxy for entropy/cross-layer dependency)
    and feature distortion (distillation).
    """
    MODULE_LOGGER.info(f"Starting Q-VLM VEO (Epochs: {epochs}, LR: {lr})...")
    
    # 1. Setup: Student is the quantized model (in training mode to allow weight updates)
    # Teacher is FP32 frozen.
    student.to(target_device).train()
    teacher.to(target_device).eval()
    
    # Disable logit scale training, we focus on encoder weights
    if hasattr(student, 'logit_scale'):
        student.logit_scale.requires_grad = False

    # 2. Optimization Target: Visual Encoder Weights
    # We want to adjust weights so that W * x produces activations that are 'easier' to quantize.
    params = [p for p in student.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=lr)
    scaler = torch.amp.GradScaler('cuda', enabled=config.USE_AMP)

    # 3. Helper to compute quantization error (L_quant)
    # The paper minimizes activation entropy/discretization error.
    # We calculate || x_fake_quant - x_raw ||^2 for activations in quantized layers.
    def compute_quantization_loss(model):
        total_q_loss = 0.0
        count = 0
        for m in model.modules():
            if isinstance(m, QuantizedLinearLayer):
                # We assume the activation quantizer has just processed input.
                # However, since we can't easily hook internal variables without modifying modules.py significantly,
                # we rely on the fact that QuantizedLinearLayer forward pass typically looks like:
                # out = Linear(Q(x), Q(w)).
                # We approximate L_ent by penalizing the rounding error of weights,
                # as modifying weights is our only degree of freedom here.
                
                # Minimizing Weight Quantization Error: || W - Q(W) ||^2
                # This aligns the weights to the grid, reducing noise introduced downstream.
                w = m.original_linear_layer.weight
                w_q = m.weight_quantizer(w)
                total_q_loss += F.mse_loss(w, w_q)
                count += 1
        return total_q_loss / (count + 1e-8)

    for epoch in range(epochs):
        data_iterator = iter(dataloader)
        pbar = tqdm(range(len(dataloader)), desc=f"{name} Epoch {epoch+1}")
        
        for step in pbar:
            try:
                batch = next(data_iterator)
            except StopIteration:
                break
                
            images, _ = batch
            images = images.to(target_device, non_blocking=True)
            
            optimizer.zero_grad(set_to_none=True)
            
            with torch.amp.autocast('cuda', enabled=config.USE_AMP):
                # A. Forward Student (Quantized)
                student_features = student.encode_image(images)
                student_features = F.normalize(student_features, dim=-1)

                # B. Forward Teacher (FP32)
                with torch.no_grad():
                    teacher_features = teacher.encode_image(images)
                    teacher_features = F.normalize(teacher_features, dim=-1)
                
                # C. Loss Calculation
                # L_err: Feature Distillation (keep semantics)
                loss_distill = F.mse_loss(student_features, teacher_features)
                
                # L_ent / L_quant: Quantization Error 
                # (Optimizing weights to land closer to quantization levels)
                loss_quant = compute_quantization_loss(student)

                total_loss = (distill_weight * loss_distill) + (quant_error_weight * loss_quant)

            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            pbar.set_postfix({
                "L_tot": f"{total_loss.item():.4f}", 
                "L_dst": f"{loss_distill.item():.4f}",
                "L_q": f"{loss_quant.item():.4f}"
            })

    return student.eval()