# Code adapted from https://github.com/deeplearning-wisc/MCM/blob/main/utils/detection_util.py

import torch
from torch import Tensor, IntTensor

from ..clip import CLIP


__all__ = [
    "clip_score",
]


def clip_score(
    model: CLIP,
    images: Tensor,  # images or image_features
    normal_prompts: IntTensor | list[IntTensor] | list[str] | list[list[str]],                 # texts or text_features
    anomaly_prompts: IntTensor | list[IntTensor] | list[str] | list[list[str]] | None = None,  # texts or text_features
    *,
    temperature: float | None = 1.0,
) -> Tensor:
    image_features = model.encode_image(images) if images.dim() == 4 else images
    normal_text_features = model.encode_text(normal_prompts) if isinstance(normal_prompts, list) else normal_prompts

    if anomaly_prompts is None:
        normal_outputs = model.similarity(image_features, normal_text_features, softmax=False)
        scores = -torch.amax(normal_outputs, dim=1)  # MCM scoring
    else:
        anomaly_text_features = model.encode_text(anomaly_prompts) if isinstance(anomaly_prompts, list) else anomaly_prompts
        text_features = torch.cat((normal_text_features, anomaly_text_features), dim=0)
        normal_and_anomaly_outputs = model.similarity(image_features, text_features, temperature=temperature)
        anomaly_outputs = normal_and_anomaly_outputs[:, normal_text_features.size(0):]
        scores = torch.sum(anomaly_outputs, dim=1)  # ZOC scoring

    return scores
