"""Functions of logits to take gradients with respect to."""
import dataclasses
from typing import Callable, Optional

import torch

###############################################################################


@dataclasses.dataclass
class LogitFunctionInput:
    """Example-specific input to a function of logits to compute gradients of."""

    # shape = [n_classes]
    log_probs: torch.Tensor
    # shape = [], dtype=torch.int32
    label: Optional[torch.Tensor]


LogitFunctionType = Callable[[LogitFunctionInput], torch.Tensor]


###############################################################################


def cross_entropy_loss_logits_fn(logit_fn_input: LogitFunctionInput) -> torch.Tensor:
    return -logit_fn_input.log_probs[logit_fn_input.label]


def cross_entropy_top_prediction_logits_fn(logit_fn_input: LogitFunctionInput) -> torch.Tensor:
    return -logit_fn_input.log_probs[logit_fn_input.log_probs.argmax()]
