






import torch

from torchcp.classification.score.base import BaseScore


class LAC(BaseScore):

    def __init__(self, score_type="softmax"):
        super().__init__()

        self.score_type = score_type

        if callable(score_type):
            self.transform = score_type
        else:
            if score_type == "identity":
                self.transform = lambda x: x
            elif score_type == "softmax":
                self.transform = lambda x: torch.softmax(x, dim=-1)
            elif score_type == "log_softmax":
                self.transform = lambda x: torch.log_softmax(x, dim=-1)
            elif score_type == "log":
                self.transform = lambda x: torch.log(x)
            else:
                raise ValueError(
                    f"Score type '{score_type}' is not implemented. Options are 'softmax', 'identity', 'log_softmax', 'log', or a callable function.")


    def __call__(self, logits, label=None):
        """
        Calculate non-conformity scores for logits.

        Args:
            logits (torch.Tensor): The logits output from the model.
            label (torch.Tensor, optional): The ground truth label. Default is None.

        Returns:
            torch.Tensor: The non-conformity scores.
        """

        if len(logits.shape) > 2:
            raise ValueError("dimension of logits are at most 2.")

        if len(logits.shape) == 1:
            logits = logits.unsqueeze(0)
        probs = self.transform(logits)
        if label is None:
            return self._calculate_all_label(probs)
        else:
            return self._calculate_single_label(probs, label)


    def _calculate_single_label(self, probs, label):
        """
        Calculate non-conformity score for a single label.

        Args:
            probs (torch.Tensor): The prediction probabilities.
            label (torch.Tensor): The ground truth label.

        Returns:
            torch.Tensor: The non-conformity score for the given label.
        """
        if (self.prior is not None) and (self.score_type == "softmax"):
            scores = -probs[torch.arange(probs.shape[0], device=probs.device), label] / self.prior[label]
        else:
            scores = -probs[torch.arange(probs.shape[0], device=probs.device), label]
        return scores
    
    
    def _calculate_all_label(self, probs):
        """
        Calculate non-conformity scores for all labels.

        Args:
            probs (torch.Tensor): The prediction probabilities.

        Returns:
            torch.Tensor: The non-conformity scores.
        """
        if (self.prior is not None) and (self.score_type == "softmax"):
            scores = -probs / self.prior
        else:
            scores = -probs
        return scores


class EnergyLAC(BaseScore):

    def __init__(self, score_type="identity", temp_cal=1.0, temp_e=1.0, ent=False):
        super().__init__()

        self.score_type = score_type
        self.temp_cal = temp_cal
        self.temp_e = temp_e
        self.ent = ent

        if callable(score_type):
            self.transform = score_type
        else:
            if score_type == "identity":
                self.transform = lambda x: x
            elif score_type == "softmax":
                self.transform = lambda x: torch.softmax(x, dim=-1)
            elif score_type == "log_softmax":
                self.transform = lambda x: torch.log_softmax(x, dim=-1)
            elif score_type == "log":
                self.transform = lambda x: torch.log(x)
            else:
                raise ValueError(
                    f"Score type '{score_type}' is not implemented. Options are 'softmax', 'identity', 'log_softmax', 'log', or a callable function.")

    def __call__(self, logits, label=None):
        """
        Calculate non-conformity scores for logits.

        Args:
            logits (torch.Tensor): The logits output from the model.
            label (torch.Tensor, optional): The ground truth label. Default is None.

        Returns:
            torch.Tensor: The non-conformity scores.
        """

        if len(logits.shape) > 2:
            raise ValueError("dimension of logits are at most 2.")

        if len(logits.shape) == 1:
            logits = logits.unsqueeze(0)
        probs = self.transform(logits)
        if label is None:
            return self._calculate_all_label(probs)
        else:
            return self._calculate_single_label(probs, label)

    def _calculate_single_label(self, probs, label):
        """
        Calculate non-conformity score for a single label.

        Args:
            probs (torch.Tensor): The prediction probabilities.
            label (torch.Tensor): The ground truth label.

        Returns:
            torch.Tensor: The non-conformity score for the given label.
        """
        energy = -torch.logsumexp(probs*self.temp_cal/self.temp_e, dim=-1)
        if self.score_type == "identity":
            probs = torch.softmax(probs, dim=-1)
        
        
        if self.ent:
            entropy = -torch.distributions.Categorical(probs).entropy()
            return (probs[torch.arange(probs.shape[0], device=probs.device), label]) * entropy
        else:
            return (probs[torch.arange(probs.shape[0], device=probs.device), label]) / energy


    def _calculate_all_label(self, probs):
        """
        Calculate non-conformity scores for all labels.

        Args:
            probs (torch.Tensor): The prediction probabilities.

        Returns:
            torch.Tensor: The non-conformity scores.
        """
        energy = -torch.logsumexp(probs*self.temp_cal/self.temp_e, dim=-1, keepdim=True)
        if self.score_type == "identity":
            probs = torch.softmax(probs, dim=-1)

        if self.ent:
            entropy = -torch.distributions.Categorical(probs).entropy().unsqueeze(-1)
            return (probs) * entropy
        else:
            return (probs) / energy
