import torch
import torch.nn as nn
import math

# a = t.rand([2,10])
# b = t.rand([10])
# print(a)
# print(a.mean(dim=0))

prob = torch.Tensor([[0.05,0.1,0.85]])
dist_p = torch.distributions.Categorical(prob).entropy()
print(dist_p.mean())

def entropy(input_tensor):
    return (-input_tensor * torch.log(input_tensor)).sum()

print(entropy(prob))