






import torch

from torchcp.classification.score.lac import LAC


class TOPK(LAC):


    def __init__(self, score_type="softmax", randomized=True):
        super().__init__(score_type)
        self.randomized = randomized

    def _calculate_all_label(self, probs):
        """Calculate scores for all labels using binary values
        
        Args:
            probs (torch.Tensor): the prediction probabilities
            
        Returns:
            torch.Tensor: the non-conformity scores
        """
        
        indices, _, cumsum = self._sort_sum(probs)

        if self.randomized:
            U = torch.rand(probs.shape, device=probs.device)
        else:
            U = torch.zeros_like(probs)

        ordered_scores = cumsum - U
        _, sorted_indices = torch.sort(indices, descending=False, dim=-1)
        scores = ordered_scores.gather(dim=-1, index=sorted_indices)

        return scores

    def _sort_sum(self, probs):
        """Sort values and return indices and cumulative sums
        """
        ordered, indices = torch.sort(probs, dim=-1, descending=True)
        ones = torch.ones_like(ordered)
        cumsum = torch.cumsum(ones, dim=-1)
        return indices, ones, cumsum

    def _calculate_single_label(self, probs, label):
        """Calculate score for a single label
        
        Args:
            probs (torch.Tensor): the prediction probabilities
            label (torch.Tensor): the true label
            
        Returns:
            torch.Tensor: the non-conformity scores
        """
        indices, ones, cumsum = self._sort_sum(probs)

        if self.randomized:
            U = torch.rand(indices.shape[0], device=probs.device)
        else:
            U = torch.zeros(indices.shape[0], device=probs.device)

        idx = torch.where(indices == label.view(-1, 1))
        scores = cumsum[idx] - U

        return scores
