






import torch

from torchcp.classification.score.lac import LAC, EnergyLAC


class APS(LAC):

    def __init__(self, score_type="softmax", randomized=True):
        super().__init__(score_type)
        self.randomized = randomized

    def _calculate_all_label(self, probs):
        """
        Calculate non-conformity scores for all labels.

        Args:
            probs (torch.Tensor): The prediction probabilities.

        Returns:
            torch.Tensor: The non-conformity scores.
        """
        if probs.dim() == 1 or probs.dim() > 2:
            raise ValueError("Input probabilities must be 2D.")
        indices, ordered, cumsum = self._sort_sum(probs)
        if self.randomized:
            U = torch.rand(probs.shape, device=probs.device)
        else:
            U = torch.zeros_like(probs)

        ordered_scores = cumsum - ordered * U
        _, sorted_indices = torch.sort(indices, descending=False, dim=-1)
        scores = ordered_scores.gather(dim=-1, index=sorted_indices)
        if (self.prior is not None) and self.score_type == "softmax":
            scores = scores * self.prior
        return scores

    def _sort_sum(self, probs):
        """
        Sort probabilities and calculate cumulative sum.

        Args:
            probs (torch.Tensor): The prediction probabilities.

        Returns:
            tuple: A tuple containing:
                - indices (torch.Tensor): The rank of ordered probabilities in descending order.
                - ordered (torch.Tensor): The ordered probabilities in descending order.
                - cumsum (torch.Tensor): The accumulation of sorted probabilities.
        """
        ordered, indices = torch.sort(probs, dim=-1, descending=True)
        cumsum = torch.cumsum(ordered, dim=-1)
        return indices, ordered, cumsum

    def _calculate_single_label(self, probs, label):
        """
        Calculate non-conformity score for a single label.

        Args:
            probs (torch.Tensor): The prediction probabilities.
            label (torch.Tensor): The ground truth label.

        Returns:
            torch.Tensor: The non-conformity score for the given label.
        """
        indices, ordered, cumsum = self._sort_sum(probs)
        if self.randomized:
            U = torch.rand(indices.shape[0], device=probs.device)
        else:
            U = torch.zeros(indices.shape[0], device=probs.device)

        idx = torch.where(indices == label.view(-1, 1))
        scores = cumsum[idx] - U * ordered[idx]
        if (self.prior is not None) and self.score_type == "softmax":
            scores = scores * self.prior[label]
        return scores


class EnergyAPS(EnergyLAC):

    def __init__(self, score_type="identity", randomized=True, temp_cal=1.0, temp_e=1.0, ent=False):
        super().__init__(score_type, temp_cal, temp_e, ent)
        self.randomized = randomized


    def _calculate_all_label(self, probs):
        """
        Calculate non-conformity scores for all labels.

        Args:
            probs (torch.Tensor): The prediction probabilities.
            scale_energy (str, optional): The type of energy normalization. Default is 'one_norm'. 
                - 'one_norm': Normalize energy to 1.
                - 'max_norm': Normalize energy to max value.

        Returns:
            torch.Tensor: The non-conformity scores.
        """
        if probs.dim() == 1 or probs.dim() > 2:
            raise ValueError("Input probabilities must be 2D.")
        energy = torch.logsumexp(probs*self.temp_cal/self.temp_e, dim=-1)
        if self.score_type == "identity":
            probs = torch.softmax(probs, dim=-1) 
        indices, ordered, cumsum = self._sort_sum(probs)
        if self.randomized:
            U = torch.rand(probs.shape, device=probs.device)
        else:
            U = torch.zeros_like(probs)
        ordered_scores = cumsum - ordered * U 
        _, sorted_indices = torch.sort(indices, descending=False, dim=-1)
        scores = ordered_scores.gather(dim=-1, index=sorted_indices)
        

        if self.ent:
            entropy = torch.distributions.Categorical(probs).entropy()
            scores = scores / entropy.view(-1, 1)
        else:
            scores = scores * energy.view(-1, 1)
        
        return scores

    def _sort_sum(self, probs):
        """
        Sort probabilities and calculate cumulative sum.

        Args:
            probs (torch.Tensor): The prediction probabilities.

        Returns:
            tuple: A tuple containing:
                - indices (torch.Tensor): The rank of ordered probabilities in descending order.
                - ordered (torch.Tensor): The ordered probabilities in descending order.
                - cumsum (torch.Tensor): The accumulation of sorted probabilities.
        """
        ordered, indices = torch.sort(probs, dim=-1, descending=True)
        cumsum = torch.cumsum(ordered, dim=-1)
        return indices, ordered, cumsum

    def _calculate_single_label(self, probs, label):
        """
        Calculate non-conformity score for a single label.

        Args:
            probs (torch.Tensor): The prediction probabilities.
            label (torch.Tensor): The ground truth label.

        Returns:
            torch.Tensor: The non-conformity score for the given label.
        """
        energy = torch.logsumexp(probs*self.temp_cal/self.temp_e, dim=-1)
        if self.score_type == "identity":
            probs = torch.softmax(probs, dim=-1) 
        indices, ordered, cumsum = self._sort_sum(probs)
        if self.randomized:
            U = torch.rand(indices.shape[0], device=probs.device)
        else:
            U = torch.zeros(indices.shape[0], device=probs.device)

        idx = torch.where(indices == label.view(-1, 1))
        scores = cumsum[idx] - U * ordered[idx] 
        
        
        if self.ent:
            entropy = torch.distributions.Categorical(probs).entropy()
            scores = scores / entropy.view(-1)
        else:
            scores = scores * energy.view(-1)
        
        return scores
