from . losses import ClipLoss, ContrasLoss, SigLipLoss, SigLipLossTest

def create_model():
    pass

def create_loss(args):
    return ContrasLoss( )

def create_loss_gather(args):
    return ClipLoss(
        rank= args.train.dist.local_rank, #args.train.dist.node_rank, args.train.dist.local_rank
        world_size= args.world_size,
        # use_horovod=args.horovod,
    )

def create_multi_loss(args, type):
    if type == 'ClipLoss':
        return ClipLoss(
                rank= args.train.dist.local_rank, #args.train.dist.node_rank, args.train.dist.local_rank
                world_size= args.world_size,
                )
    elif type == 'SigLipLoss':
        return SigLipLoss(
                rank = args.train.dist.local_rank,
                world_size = args.world_size,
                weight = args.loss.positive_weight,
                 )
    else:
        print("unknown loss function.")
        exit(1)