






from abc import ABC, abstractmethod
import torch


class BaseScore(ABC):
    """
    Abstract base class for all score functions.
    """

    

    def __init__(self) -> None:
        self.prior = None
        self.device = None
        pass

    def set_prior(self, prior, device):
        self.prior = torch.tensor(list(prior.values()), device=device)
        self.device = device

    @abstractmethod
    def __call__(self, logits, labels=None):
        """Virtual method to compute scores for a data pair (x,y).

        Args:
            probs (torch.Tensor): The prediction probabilities.
            label (torch.Tensor): The ground truth label.
        """
        raise NotImplementedError
