import torch.nn as nn

from .amsoftmax import AMSoftmax as AMS
from .arcface import Arcface as ARC
from .dual_focal_loss import Dual_Focal_loss as DF
from .label_smooth import LabelSmoothSoftmaxCE as LSS
from .sphereface import AngleLoss as AGL
# from .triplet_loss import TripletLoss as TL
from .dsam_loss import DSAMLoss as DSAM
from .softmax_nn import softmax as SMN
from .modified_softmax import MSoftmax as MSM
from .amm_loss import AMMLoss as AMM
from .jsd_loss import KLLoss as KL
from .jsd_loss import JSDivLoss as JS
from .jsd_loss import CELoss as CE
from .msa_loss import MSmoothArcLoss as MSA
from torch.nn import HuberLoss as HB
from torch.nn import MSELoss as MSE

__all_loss = {'amsoftmax': AMS, 'arcface': ARC, 'dualfocal': DF, 'labelsmooth': LSS,
              'angleloss': AGL, 'softmax_nn': SMN,  # 'triplet': TL, 
              'dsam': DSAM, 'msoftmax': MSM, 'ammloss': AMM, 'KLloss': KL, 'JSloss': JS,
              'msaloss': MSA, 'CEloss': CE, 'MSEloss': MSE, 'Huberloss': HB}


def get_loss_func(loss_name, **kwargs):
    if loss_name not in __all_loss:
        raise KeyError("Unknown Loss Function:", loss_name)
    return __all_loss[loss_name](**kwargs)
