from .comlib import *

def accuracy_unreduced(out, target):
    _, predicted = torch.max(out, 1)
    return (predicted == target).float()

def accuracy_(reduction,out, target):
    unreduced=accuracy_unreduced(out, target)
    if reduction=="mean":
        return torch.mean(unreduced)
    if reduction=="sum":
        return torch.sum(unreduced)

def accuracy(reduction):
    if reduction=="none":
        return accuracy_unreduced
    else:
        return lambda out, target: accuracy_(reduction,out, target)

# '''
# cross entropy: H(p,q)=\sum_i p_i\log{q_i}
# '''
def crossEntropyLoss():
    # '''
    # each point in batch:
    # x \in R^C (N,C)
    # label \in [0,C-1] (N)or(N,C)
    # https://docs.pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
    # '''
    return nn.CrossEntropyLoss()

def crossEntropyLossUnreduced():
    return nn.CrossEntropyLoss(reduction="none")

def twoClassCrossEntropyLoss():
    '''
    https://docs.pytorch.org/docs/stable/generated/torch.nn.BCELoss.html#torch.nn.BCELoss
    '''
    return nn.BCELoss()

'''
kl div
'''
def klDiv():
    return nn.KLDivLoss()

def out_norm(out,target):
    # print(out.shape)
    return torch.mean(torch.norm(out,dim=1))