from federatedscope.register import register_criterion


def call_my_criterion(type, device):
    try:
        import torch.nn as nn
    except ImportError:
        nn = None
        criterion = None

    if type == 'mycriterion':
        if nn is not None:
            criterion = nn.CrossEntropyLoss().to(device)
        return criterion


register_criterion('mycriterion', call_my_criterion)
