# scripts/evaluation.py

import logging
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm import tqdm
import scripts.config as config

MODULE_LOGGER = logging.getLogger(__name__)

def tune_logit_scale(model, tuning_dataloader, tokenizer, class_names, device="cuda",
                     main_loss_weight=1.0, distill_weight=0.0, teacher_model=None,
                     template="a photo of a {}"):
    
    MODULE_LOGGER.info(f"--- Phase 2: Logit Tuning ---")
    if teacher_model is None: distill_weight = 0.0

    original_grad_states = {name: param.requires_grad for name, param in model.named_parameters()}
    
    model.train() 
    for name, param in model.named_parameters():
        param.requires_grad = ('logit_scale' in name)

    model.eval()
    all_class_features = None
    
    # Pre-compute target class features using AMP
    with torch.no_grad():
        # FIXED: Updated AMP syntax
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            try:
                if class_names:
                    all_texts = [template.format(c) for c in class_names]
                    text_feats_list = []
                    batch_size = 256
                    for i in range(0, len(all_texts), batch_size):
                        chunk = all_texts[i : i + batch_size]
                        tokenized = tokenizer(chunk).to(device)
                        feat = model.encode_text(tokenized)
                        feat = feat / feat.norm(dim=-1, keepdim=True)
                        text_feats_list.append(feat)
                    all_class_features = torch.cat(text_feats_list, dim=0).to(device)
            except Exception as e:
                MODULE_LOGGER.warning(f"Could not pre-compute class features: {e}")

    model.train() 
    
    optimizer = torch.optim.AdamW([model.logit_scale], lr=config.LOGIT_TUNE_LR)
    # FIXED: Updated AMP Scaler
    scaler = torch.amp.GradScaler('cuda', enabled=config.USE_AMP)
    
    tuning_iterator = iter(tuning_dataloader)
    pbar = tqdm(range(config.LOGIT_TUNE_STEPS), desc="Tuning", leave=False)
    
    for i in pbar:
        try:
            images, second_element = next(tuning_iterator)
        except StopIteration:
            tuning_iterator = iter(tuning_dataloader)
            images, second_element = next(tuning_iterator)
        
        images = images.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        
        # FIXED: Updated AMP syntax
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            student_img_raw = model.encode_image(images)
            student_img_norm = F.normalize(student_img_raw, dim=-1)
            
            loss_main = torch.tensor(0.0, device=device)
            if main_loss_weight > 0:
                if isinstance(second_element, torch.Tensor) and all_class_features is not None:
                    labels = second_element.to(device)
                    logits = model.logit_scale.exp() * student_img_norm @ all_class_features.t()
                    loss_main = F.cross_entropy(logits, labels)
                elif isinstance(second_element, (list, tuple)):
                    texts = list(second_element)
                    tokenized = tokenizer(texts).to(device)
                    student_txt = model.encode_text(tokenized)
                    student_txt_norm = F.normalize(student_txt, dim=-1)
                    logits_img = model.logit_scale.exp() * student_img_norm @ student_txt_norm.t()
                    labels_idx = torch.arange(len(images), device=device)
                    loss_main = (F.cross_entropy(logits_img, labels_idx) + F.cross_entropy(logits_img.t(), labels_idx)) / 2

            loss_distill = torch.tensor(0.0, device=device)
            if distill_weight > 0 and teacher_model:
                with torch.no_grad():
                    teacher_img = teacher_model.encode_image(images)
                loss_distill = F.mse_loss(student_img_raw, teacher_img)

            total_loss = main_loss_weight * loss_main + distill_weight * loss_distill
        
        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        
        pbar.set_postfix({"loss": f"{total_loss.item():.4f}", "s": f"{model.logit_scale.item():.4f}"})

    for name, param in model.named_parameters():
        param.requires_grad = original_grad_states.get(name, True)
    return model.eval()

@torch.no_grad()
def calculate_calibration_metrics(logits, labels, n_bins=15):
    logits = logits.float() 
    probs = F.softmax(logits, dim=1)
    confidences, predictions = torch.max(probs, 1)
    accuracies = predictions.eq(labels)
    ece = 0.0
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    for i in range(n_bins):
        in_bin = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i+1])
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin > 0:
            acc = accuracies[in_bin].float().mean()
            conf = confidences[in_bin].mean()
            ece += torch.abs(conf - acc) * prop_in_bin
    return ece.item()

@torch.no_grad()
def run_comprehensive_evaluation(model, fp32_model, loader, ensembled_text_features, device):
    model.to(device).eval()
    fp32_model.to(device).eval()
    
    logits_list, labels_list = [], []
    sims, dists = [], []
    
    ensembled_text_features = ensembled_text_features.to(device)
    
    for imgs, lbls in tqdm(loader, desc="Evaluation", leave=False):
        imgs = imgs.to(device, non_blocking=True)
        
        # FIXED: Updated AMP syntax
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            fp32_feat = F.normalize(fp32_model.encode_image(imgs), dim=-1)
            quant_feat = F.normalize(model.encode_image(imgs), dim=-1)
            
            sims.extend(F.cosine_similarity(fp32_feat, quant_feat).float().cpu().numpy())
            l = model.logit_scale.exp() * quant_feat @ ensembled_text_features.T
        
        logits_list.append(l.float().cpu())
        labels_list.append(lbls)
        
    logits = torch.cat(logits_list)
    labels = torch.cat(labels_list)
    
    acc = (logits.argmax(dim=-1) == labels).float().mean().item()
    ece = calculate_calibration_metrics(logits, labels)
    
    return {
        "Zero-Shot Accuracy": acc,
        "Avg. Cosine Similarity": np.mean(sims),
        "ECE": ece
    }