import torch
from torch import Tensor


def softsort(scores: Tensor, tau: float = 1.0, hard: bool = False, pow: float = 1.0):
    """
    scores: elements to be sorted. Typical shape: batch_size x n
    """
    scores = scores.unsqueeze(-1)
    sorted = scores.sort(descending=True, dim=1)[0]
    pairwise_diff = (scores.mT - sorted).abs().pow(pow).neg() / tau
    pairwise_diff = torch.where(scores.mT == sorted, 0, pairwise_diff)
    P_hat = pairwise_diff.softmax(-1)

    if hard:
        P = torch.zeros_like(P_hat, device=P_hat.device)
        P.scatter_(-1, P_hat.argmax(-1, keepdim=True), value=1)
        P_hat = (P - P_hat).detach() + P_hat
    return P_hat
