from typing import Any, Dict

from .base_loss import CELoss, Loss
from .code_reconstruct_loss import CodebookReconstructLoss
from .dark_kg_loss import DKG_Loss


__REGISTERED_LOSS__ = {
    "ce_loss": CELoss,
    "dkg_loss": DKG_Loss
}


def get_loss_fn(loss_cfg: Dict[str, Any], **kwargs) -> Loss:
    name = loss_cfg["name"]
    cfg = loss_cfg.get("loss_cfg", dict())
    return __REGISTERED_LOSS__[name](**cfg, **kwargs)

