from torch import nn
from losses.cross_entry_smooth import CrossEntropyWithLabelSmooth
from losses.conface import CosFaceLoss
from losses.triplet_loss import TripletLoss, ContrastiveLoss4
from losses.cicl import ContrastiveLoss
def build_losses(config, num_train_pids):
    # Build identity classification loss
    if config.LOSS.CLA_LOSS == 'crossentropy':
        criterion_cla = nn.CrossEntropyLoss()
    elif config.LOSS.CLA_LOSS == 'crossentropylabelsmooth':
        criterion_cla = CrossEntropyWithLabelSmooth()
    elif config.LOSS.CLA_LOSS == 'cosface':
        criterion_cla = CosFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M)
    else:
        raise KeyError("Invalid classification loss: '{}'".format(config.LOSS.CLA_LOSS))
    criterion_pair = TripletLoss(margin=0.3)
    criterion_mm = ContrastiveLoss4(num_pids=num_train_pids, feat_dim=config.MODEL.FEATURE_DIM,margin=0.3,
                    momentum=config.LOSS.MOMENTUM, scale=config.LOSS.CLA_S, epsilon=config.LOSS.EPSILON)
    criterion_cicl = ContrastiveLoss(margin=0.3)
    criterion_clothes = CosFaceLoss(scale=config.LOSS.CLA_S, margin=0)

    return criterion_cla, criterion_mm, criterion_clothes, criterion_cicl, criterion_pair
