import torch
import torch.nn.functional as F

@torch.enable_grad()
def entropy_minimization_loss(x):
    return (-(x.softmax(1) * x.log_softmax(1)).sum(1)).mean(0)

    # Step 3: Calculate mean entropy (loss)
    # loss = torch.mean(entropy)
    # return loss
