import torch


class diff_aps(object):
    """
    Adaptive Prediction Sets (Romano et al., 2020)
    paper :https://proceedings.neurips.cc/paper/2020/file/244edd7e85dc81602b7615cd705545f5-Paper.pdf
    """

    def __call__(self, logits, label=None, softmax=True, random=True):
        assert len(logits.shape) <= 2, "The dimension of logits must be less than 2."
        if len(logits.shape) == 1:
            logits = logits.unsqueeze(0)

        if softmax:
            probs = torch.softmax(logits, dim=-1)
        else:
            probs = logits

        if label is None:
            return self._calculate_all_label(probs, random)
        else:
            return self._calculate_single_label(probs, label, random)

    def _calculate_all_label(self, probs, random=True):
        indices, ordered, cumsum = self._sort_sum(probs)
        if random:
            U = torch.rand(probs.shape, device=probs.device)
            ordered_scores = cumsum - ordered * U
        else:
            ordered_scores = cumsum
        _, sorted_indices = torch.sort(indices, descending=False, dim=-1)
        scores = ordered_scores.gather(dim=-1, index=sorted_indices)
        return scores

    def _sort_sum(self, probs):
        # ordered: the ordered probabilities in descending order
        # indices: the rank of ordered probabilities in descending order
        # cumsum: the accumulation of sorted probabilities
        ordered, indices = torch.sort(probs, dim=-1, descending=True)
        cumsum = torch.cumsum(ordered, dim=-1)
        return indices, ordered, cumsum

    def _calculate_single_label(self, probs, label, random=True):
        indices, ordered, cumsum = self._sort_sum(probs)
        idx = torch.where(indices == label.view(-1, 1))
        idx_minus_one = (idx[0], idx[1] - 1)
        if random:
            U = torch.rand(indices.shape[0], device=probs.device)
            scores_first_rank = U * cumsum[idx]
            scores_usual = U * ordered[idx] + cumsum[idx_minus_one]
            scores = torch.where(idx[1] == 0, scores_first_rank, scores_usual)
        else:
            scores = cumsum[range(cumsum.shape[0]), label]
        return scores