






import torch

from torchcp.classification.score.aps import APS


class Margin(APS):
    def __init__(self, score_type="softmax"):
        super().__init__(score_type)

    def _calculate_single_label(self, probs, label):
        row_indices = torch.arange(probs.size(0), device=probs.device)
        target_prob = probs[row_indices, label].clone()
        probs[row_indices, label] = -1

        
        largest_probs_ex_correct_labels = torch.max(probs, dim=-1).values
        return largest_probs_ex_correct_labels - target_prob

    def _calculate_all_label(self, probs):
        batch_size, num_labels = probs.shape

        values, indices = torch.topk(probs, k=2, dim=1)

        max_values = values[:, 0].unsqueeze(1).expand(-1, num_labels)
        second_max_values = values[:, 1].unsqueeze(1).expand(-1, num_labels)
        max_indices = indices[:, 0].unsqueeze(1).expand(-1, num_labels)
        position_indices = torch.arange(num_labels).expand(batch_size, -1).to(probs.device)

        selected_values = torch.where(position_indices == max_indices,
                                      second_max_values,
                                      max_values)

        scores = selected_values - probs
        return scores
