import torch
import sys

sys.path.append(".")
from src.tools.sharpness_tools.utils import get_device
from torch.cuda.amp import autocast

def shannon_entropy(model, data_loader):
    res = 0.0
    device = get_device(model)
    total_num = 0
    for inputs, _, _ in data_loader:
        inputs = inputs.to(device)
        total_num += len(inputs)

        with torch.no_grad(), autocast():
            outputs = model(inputs)
            res += torch.sum(torch.nn.Softmax(1)(outputs) * torch.nn.LogSoftmax(1)(outputs))

    return (-res / total_num).item()
