#coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F

def Entropy(input_):
    bs = input_.size(0)
    epsilon = 1e-5
    entropy = -input_ * torch.log(input_ + epsilon)
    entropy = torch.mean(torch.sum(entropy, dim=1))
    return entropy 

def Entropylogits(input,redu='mean'):
    input_=F.softmax(input,dim=1)
    bs = input_.size(0)
    epsilon = 1e-5
    entropy = -input_ * torch.log(input_ + epsilon)
    if redu=='mean':
        entropy = torch.mean(torch.sum(entropy, dim=1))
    elif redu=='None':
        entropy=torch.sum(entropy, dim=1)
    return entropy     
