import torch.nn as nn
from losses import *

def get_loss_config(args):
    if args.dataset == 'cifar100':
        num_classes = 100
    elif args.dataset == 'webvision':
        num_classes = 50
    elif args.dataset == 'clothing1m':
        num_classes = 14
    elif args.dataset == 'cifar10':
        num_classes = 10

    if args.loss == 'VCE':
        return VCELoss(a=args.para, scale=args.beta)
    if args.loss == 'VEL':
        return VELoss(a=args.para, scale=args.beta, num_classes=num_classes)
    if args.loss == 'VMSE':
        return VMSELoss(a=args.para, scale=args.beta, num_classes=num_classes)
 
    if args.loss == 'NCEandVCE':
        return NCEandVCE(alpha=args.alpha, beta=args.beta, a=args.para, num_classes=num_classes)
    if args.loss == 'NCEandVEL':
        return NCEandVEL(alpha=args.alpha, beta=args.beta, a=args.para, num_classes=num_classes)
    if args.loss == 'NCEandVMSE':
        return NCEandVMSE(alpha=args.alpha, beta=args.beta, a=args.para, num_classes=num_classes)


 
