import torch.nn as nn
def get_criterion(args):
    if args.dataset in ['MNIST','CIFAR10','CIFAR100','ImageNet']:
        criterion = nn.CrossEntropyLoss().cuda()
    return criterion