






import torch

from torchcp.classification.score.aps import APS, EnergyAPS


class SAPS(APS):

    def __init__(self, score_type="softmax", randomized=True, weight=0.2):
        super().__init__(score_type, randomized)
        if weight <= 0:
            raise ValueError("The parameter 'weight' must be a positive value.")
        if not isinstance(randomized, bool):
            raise ValueError("The parameter 'randomized' must be a boolean.")
        self.weight = weight

    def _calculate_all_label(self, probs):
        indices, ordered, cumsum = self._sort_sum(probs)
        ordered[:, 1:] = self.weight
        cumsum = torch.cumsum(ordered, dim=-1)
        if self.randomized:
            U = torch.rand(probs.shape, device=probs.device)
        else:
            U = torch.zeros_like(probs)
        ordered_scores = cumsum - ordered * U
        _, sorted_indices = torch.sort(indices, descending=False, dim=-1)
        scores = ordered_scores.gather(dim=-1, index=sorted_indices)
        if (self.prior is not None) and self.score_type == "softmax":
            scores = scores * self.prior
        return scores

    def _calculate_single_label(self, probs, label):
        indices, ordered, cumsum = self._sort_sum(probs)
        if self.randomized:
            U = torch.rand(indices.shape[0], device=probs.device)
        else:
            U = torch.zeros(indices.shape[0], device=probs.device)
        idx = torch.where(indices == label.view(-1, 1))
        scores_first_rank = U * cumsum[idx]
        scores_usual = self.weight * (idx[1] - U) + ordered[:, 0]
        scores = torch.where(idx[1] == 0, scores_first_rank, scores_usual)
        if (self.prior is not None) and self.score_type == "softmax":
            scores = scores * self.prior[label]
        return scores

class EnergySAPS(EnergyAPS):


    def __init__(self, score_type="identity", randomized=True, weight=0.2, temp_cal=1.0, temp_e=1.0, ent=False):
        super().__init__(score_type, randomized, temp_cal, temp_e, ent)
        if weight <= 0:
            raise ValueError("The parameter 'weight' must be a positive value.")
        if not isinstance(randomized, bool):
            raise ValueError("The parameter 'randomized' must be a boolean.")
        self.weight = weight

    def _calculate_all_label(self, probs):
        energy = torch.logsumexp(probs*self.temp_cal/self.temp_e, dim=-1)
        if self.score_type == "identity":
            probs = torch.softmax(probs, dim=-1) 
        indices, ordered, cumsum = self._sort_sum(probs)
        ordered[:, 1:] = self.weight
        cumsum = torch.cumsum(ordered, dim=-1)
        if self.randomized:
            U = torch.rand(probs.shape, device=probs.device)
        else:
            U = torch.zeros_like(probs)
        ordered_scores = cumsum - ordered * U 
        _, sorted_indices = torch.sort(indices, descending=False, dim=-1)
        scores = ordered_scores.gather(dim=-1, index=sorted_indices)
        
        
        if self.ent:
            entropy = torch.distributions.Categorical(probs).entropy()
            scores = scores / entropy.view(-1, 1)
        else:
            scores = scores * energy.view(-1, 1)
        
        return scores

    def _calculate_single_label(self, probs, label):
        energy = torch.logsumexp(probs*self.temp_cal/self.temp_e, dim=-1)
        if self.score_type == "identity":
            probs = torch.softmax(probs, dim=-1) 
        
        indices, ordered, cumsum = self._sort_sum(probs)
        
        if self.randomized:
            U = torch.rand(indices.shape[0], device=probs.device)
        else:
            U = torch.zeros(indices.shape[0], device=probs.device)
        
        idx = torch.where(indices == label.view(-1, 1))
        scores_first_rank = U * cumsum[idx]
        scores_usual = self.weight * (idx[1] - U) + ordered[:, 0]
        scores = torch.where(idx[1] == 0, scores_first_rank, scores_usual)
        
        
        if self.ent:
            entropy = torch.distributions.Categorical(probs).entropy()
            scores = scores / entropy.view(-1)
        else:
            scores = scores * energy.view(-1)
        
        return scores