import torch



def Getloss(stragety):
    if stragety=="MSE":
        MSE = torch.nn.MSELoss(reduce=True, size_average=True)
        return MSE
    if stragety=="CrossEntropy":
        Cross=torch.nn.CrossEntropyLoss()
        return Cross
    if stragety=="NLL":
        NLL = torch.nn.NLLLoss(reduce=True, size_average=True)
        return NLL
    if stragety=="CrossEntropy_new":
        Cross=CrossEntropyLoss_new()
        return Cross
    if stragety=="CosineSimilarity":
        cos = CosineSimilarityLoss()
        return cos
    if stragety=="SmoothL1":
        L1 = torch.nn.SmoothL1Loss()
        return L1


class CrossEntropyLoss_new(torch.nn.Module):
    def __init__(self, type='mean'):
        super(CrossEntropyLoss_new, self).__init__()
        self.type = type

    def forward(self, x, y):
        logx = torch.log(x)
        logxy = -logx.mul(y)
        loss = torch.sum(logxy, dim=1)
        if self.type == 'mean':
            return torch.mean(loss)
        elif self.type == 'sum':
            return torch.sum(loss)

class CosineSimilarityLoss(torch.nn.Module):
    def __init__(self, dim=1, type='mean'):
        super(CosineSimilarityLoss, self).__init__()
        self.cos = torch.nn.CosineSimilarity(dim=dim, eps=1e-8)
        self.type = type
    
    def forward(self, x, y):
        if self.type == 'mean':
            return -torch.mean(self.cos(x,y))
        elif self.type == 'sum':
            return -torch.sum(self.cos(x,y))
    
        


