# scripts/ood_evaluation.py

import logging
import numpy as np
import torch
import torch.nn.functional as F
from sklearn import metrics
from tqdm import tqdm
import scripts.config as config

MODULE_LOGGER = logging.getLogger(__name__)

def get_safe_logit_scale(model, device):
    if hasattr(model, 'logit_scale') and model.logit_scale is not None:
        try: return model.logit_scale.exp().to(device).clamp(max=100.0) 
        except Exception: pass
    return torch.tensor(100.0, device=device) 

def compute_ood_scores(img_feats, raw_cos_id, logit_scale, text_feats_neg, text_feats_unk, method):
    """
    Computes OOD scores.
    raw_cos_id: (Batch, Num_ID_Classes) cosine similarity matrix.
    """
    
    if method in ["MSP", "Energy", "Entropy"]:
        logits = logit_scale * raw_cos_id
        if method == "MSP":
            probs = F.softmax(logits, dim=-1)
            scores, _ = torch.max(probs, dim=-1)
            return scores
        elif method == "Energy":
            return torch.logsumexp(logits, dim=-1)
        elif method == "Entropy":
            probs = F.softmax(logits, dim=-1)
            entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)
            return -entropy 

    elif method == "MCM":
        target_scale = config.MCM_PARAMS["tau"]
        scaled_logits = raw_cos_id * logit_scale if target_scale is None else raw_cos_id / target_scale
        probs = F.softmax(scaled_logits, dim=-1)
        scores, _ = torch.max(probs, dim=-1)
        return scores

    elif method == "NegLabel":
        if text_feats_unk is None: 
            return torch.zeros(raw_cos_id.shape[0], device=raw_cos_id.device)
        
        scaling = logit_scale.clamp(min=1.0)
        scaled_cos_id = raw_cos_id * scaling
        max_scaled_cos_id, _ = torch.max(scaled_cos_id, dim=-1, keepdim=True)
        term_id = torch.exp(max_scaled_cos_id)

        raw_cos_unk = img_feats @ text_feats_unk.T 
        scaled_cos_unk = raw_cos_unk * scaling
        max_scaled_cos_neg, _ = torch.max(scaled_cos_unk, dim=-1, keepdim=True)
        term_neg = torch.exp(max_scaled_cos_neg)
        
        scores = term_id / (term_id + term_neg + 1e-8)
        return scores.squeeze()

    elif method == "CLIPN-A":
        if text_feats_neg is None: 
             return torch.zeros(raw_cos_id.shape[0], device=raw_cos_id.device) 
        
        logits_id = logit_scale * raw_cos_id 
        probs_id = F.softmax(logits_id, dim=-1)
        max_prob_id, _ = torch.max(probs_id, dim=-1)
        
        raw_cos_rej = img_feats @ text_feats_neg.T 
        logit_rej = logit_scale * raw_cos_rej 
        max_cos_id, _ = torch.max(raw_cos_id, dim=-1, keepdim=True)
        logit_max_id = logit_scale * max_cos_id 
        
        logits_rej_model = torch.cat([logit_rej, logit_max_id], dim=-1)
        probs_rej_model = F.softmax(logits_rej_model, dim=-1)
        conf_rej_class = probs_rej_model[:, 0]
        
        scores = max_prob_id - conf_rej_class
        return scores
    
    elif method == "EOE":
        if text_feats_unk is None: 
            return torch.zeros(raw_cos_id.shape[0], device=raw_cos_id.device)
        
        text_feats_total = torch.cat([text_feats_neg if text_feats_neg is not None else text_feats_unk, text_feats_unk], dim=0) 
        # Note: EOE usually concatenates ID vectors with Outlier vectors.
        # But we already have raw_cos_id. We need raw_cos_outlier.
        
        raw_cos_outlier = img_feats @ text_feats_unk.T
        logits_id = logit_scale * raw_cos_id
        logits_outlier = logit_scale * raw_cos_outlier
        
        logits_total = torch.cat([logits_id, logits_outlier], dim=-1)
        probs_total = F.softmax(logits_total, dim=-1)
        
        K = raw_cos_id.shape[1]
        max_prob_id, _ = torch.max(probs_total[:, :K], dim=-1)
        max_prob_outlier, _ = torch.max(probs_total[:, K:], dim=-1)
        
        beta = config.EOE_PARAMS.get("beta", 0.1)
        scores = max_prob_id - beta * max_prob_outlier
        return scores.squeeze()

    return torch.zeros(raw_cos_id.shape[0], device=raw_cos_id.device)

def calculate_ood_metrics(id_scores, ood_scores):
    id_scores = np.array(id_scores)
    ood_scores = np.array(ood_scores)
    
    # Check for NaN
    if np.isnan(id_scores).any() or np.isnan(ood_scores).any():
        return {"AUROC": 0.5, "FPR95": 1.0, "ID_Mean": 0.0, "OOD_Mean": 0.0}

    y_true = np.concatenate([np.ones_like(id_scores), np.zeros_like(ood_scores)])
    y_scores = np.concatenate([id_scores, ood_scores])
    
    try:
        auroc = metrics.roc_auc_score(y_true, y_scores)
        # Handle flipped scores (where lower score = ID) by inverting
        if auroc < 0.5: 
            auroc = 1.0 - auroc
            y_scores = 1.0 - y_scores 
            
        fpr, tpr, _ = metrics.roc_curve(y_true, y_scores)
        target_tpr = 0.95
        # Find FPR at TPR >= 0.95
        idx = np.argmax(tpr >= target_tpr)
        fpr95 = fpr[idx]
    except Exception as e: 
        print(f"Metric calc error: {e}")
        auroc = 0.5
        fpr95 = 1.0
        
    return {
        "AUROC": float(auroc), 
        "FPR95": float(fpr95), 
        "ID_Mean": float(np.mean(id_scores)), 
        "ID_Std": float(np.std(id_scores)), 
        "OOD_Mean": float(np.mean(ood_scores)), 
        "OOD_Std": float(np.std(ood_scores))
    }

@torch.no_grad()
def run_ood_benchmark_eval(model, id_loader, ood_loader, text_feats_id, text_feats_neg, text_feats_unk, device):
    """
    Runs evaluation.
    Returns:
       - results_dict: {Method: {metrics...}}
       - id_accuracy: float (0.0 to 1.0)
    """
    model.eval()
    logit_scale = get_safe_logit_scale(model, device)
    
    tf_id = text_feats_id.to(device)
    tf_neg = text_feats_neg.to(device) if text_feats_neg is not None else None
    tf_unk = text_feats_unk.to(device) if text_feats_unk is not None else None
    
    all_scores = {m: {"id": [], "ood": []} for m in config.OOD_METHODS} 

    # --- 1. ID Loop (Calculate Scores + Accuracy) ---
    correct_id = 0
    total_id = 0
    
    for images, labels in tqdm(id_loader, desc="Scoring ID + Acc", leave=False):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            img_feats = F.normalize(model.encode_image(images), dim=-1)
            
            # Pre-calc cosine sim with ID classes (needed for Acc and most OOD methods)
            raw_cos_id = img_feats @ tf_id.T
            
            # -- ACCURACY CALCULATION --
            # Standard CLIP inference: Logits = scale * (Image . Text)
            logits = logit_scale * raw_cos_id
            preds = logits.argmax(dim=-1)
            correct_id += (preds == labels).sum().item()
            total_id += labels.size(0)
            
            # -- OOD SCORING --
            for method in config.OOD_METHODS:
                # Pass pre-calculated raw_cos_id to save matmul ops
                scores = compute_ood_scores(img_feats, raw_cos_id, logit_scale, tf_neg, tf_unk, method)
                all_scores[method]["id"].extend(scores.float().cpu().numpy())

    id_accuracy = correct_id / total_id if total_id > 0 else 0.0

    # --- 2. OOD Loop (Calculate Scores Only) ---
    for images, _ in tqdm(ood_loader, desc="Scoring OOD", leave=False):
        images = images.to(device, non_blocking=True)
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            img_feats = F.normalize(model.encode_image(images), dim=-1)
            raw_cos_id = img_feats @ tf_id.T # Need ID sims for "Max Logit" etc.
            
            for method in config.OOD_METHODS:
                scores = compute_ood_scores(img_feats, raw_cos_id, logit_scale, tf_neg, tf_unk, method)
                all_scores[method]["ood"].extend(scores.float().cpu().numpy())

    # --- 3. Compute Metrics ---
    results_dict = {m: calculate_ood_metrics(all_scores[m]["id"], all_scores[m]["ood"]) for m in config.OOD_METHODS}
    
    return results_dict, id_accuracy