import torch

from tabicl.config.config_run import ConfigRun
from tabicl.core.enums import Task
from tabicl.core.losses import CrossEntropyLossExtraBatch


def get_loss(cfg: ConfigRun):

    match cfg.task:
        case Task.REGRESSION:
            return torch.nn.MSELoss()
        case Task.CLASSIFICATION:
            return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
