
import torch

from .base import BaseScore


class THR(BaseScore):
    """
    Threshold conformal predictors (Sadinle et al., 2016).
    paper : https://arxiv.org/abs/1609.00451.
    
    :param score_type: a transformation on logits. Default: "softmax". Optional: "softmax", "Identity", "log_softmax" or "log".
    """

    def __init__(self, transform):
        self.transform = transform

    def __call__(self, logits, label=None):
        assert len(logits.shape) <= 2, "dimension of logits are at most 2."
        if label is not None:
            assert logits.device == label.device, f"Expected logits and label to be on the same device, but found at least two devices, {logits.device} and {label.device}!"
        if len(logits.shape) == 1:
            logits = logits.unsqueeze(0)
        probs = self.transform(logits, dim=-1)
        if label is None:
            return self.__calculate_all_label(probs)
        else:
            return self.__calculate_single_label(probs, label)

    def __calculate_single_label(self, probs, label):
        return 1 - probs[torch.arange(probs.shape[0], device=probs.device), label]

    def __calculate_all_label(self, probs):
        return 1 - probs
