import torch.nn as nn





class TotalLoss(nn.Module):
    def __init__(self, temperature=0.5, margin=1.0):
        super(TotalLoss, self).__init__()
        self.margin = margin
        self.temperature = temperature
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self,  features,featuresOrginal, labels, image_path, device):

        features=features.squeeze(1)
        featuresOrginal = featuresOrginal.squeeze(1)

        Lce=self.criterion(features, labels)
        LceSRM = self.criterion(featuresOrginal, labels)
        Loss=0.1*Lce+LceSRM
        return Loss
