import torch

from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.config.config_run import ConfigRun
from tabicl.core.enums import LossName, Task
from tabicl.core.losses import CrossEntropyLossExtraBatch


def get_loss(cfg: ConfigRun):

    match (cfg.task, cfg.hyperparams['regression_loss']):
        case (Task.REGRESSION, LossName.MSE.name):
            return torch.nn.MSELoss()
        case (Task.REGRESSION, LossName.MAE.name):
            return torch.nn.L1Loss()
        case (Task.CLASSIFICATION, _):
            return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
        case (_, _):
            raise ValueError(f"Unsupported task {cfg.task} and (regression) loss {cfg.hyperparams['regression_loss']}")
        

def get_loss_pretrain(cfg: ConfigPretrain):

    match (cfg.data.task, cfg.optim.regression_loss):
        case (Task.REGRESSION, LossName.MSE):
            return torch.nn.MSELoss()
        case (Task.REGRESSION, LossName.MAE):
            return torch.nn.L1Loss()
        case (Task.CLASSIFICATION, _):
            return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
        case (_, _):
            raise ValueError(f"Unsupported task {cfg.data.task} and (regression) loss {cfg.optim.regression_loss}")