# evaluation/metrics.py
"""
Evaluation metrics used in the paper: CLIPScore, Neg-ACC scaffolding, and placeholders for FID/LPIPS.

Note: FID and LPIPS require external libraries; this module provides hooks and fallback warnings.
"""
import numpy as np
import torch
import torch.nn.functional as F

def clip_score_from_embeddings(img_embs, text_embs):
    """
    img_embs: torch.Tensor [N, D] (L2-normalized)
    text_embs: torch.Tensor [N, D] (matching order)
    Returns mean cosine similarity (CLIPScore)
    """
    if not isinstance(img_embs, torch.Tensor):
        img_embs = torch.tensor(img_embs)
    if not isinstance(text_embs, torch.Tensor):
        text_embs = torch.tensor(text_embs)
    sim = (img_embs * text_embs).sum(dim=1)
    return float(sim.mean().item())

def neg_acc_metric(img_embs, text_embs, negative_text_embs, threshold=0.2):
    """
    Compute a simple Neg-ACC proxy:
      For each image i, compute max_{neg_k} cosine(img_emb_i, neg_text_emb_k).
      If that max is <= threshold, the image is 'not negative' -> counts as safe.
    Returns fraction of safe samples.
    """
    # img_embs: [N, D], negative_text_embs: [K, D]
    sims = torch.matmul(img_embs, negative_text_embs.t())  # [N, K]
    max_sim, _ = sims.max(dim=1)
    safe = (max_sim <= threshold).float().mean().item()
    return float(safe)

# Placeholders for FID / LPIPS
def compute_fid(real_images, fake_images):
    """
    Placeholder. For real FID use `torch_fidelity` or `pytorch-fid`.
    This function should return scalar FID.
    """
    raise NotImplementedError("FID computation requires external package (pytorch-fid or torch_fidelity).")

def compute_lpips(imgs_a, imgs_b):
    """
    Placeholder. For LPIPS use `lpips` package.
    """
    raise NotImplementedError("LPIPS requires lpips package.")
