import torch

from tabicl.config.config_run import ConfigRun
from tabicl.core.enums import ModelName
from tabicl.core.trainer import Trainer
from tabicl.core.trainer_finetune import TrainerFinetune
from tabicl.models.autogluon.trainer import Trainer as TrainerAutoGluon


def get_trainer(cfg: ConfigRun, model: torch.nn.Module, n_classes: int):

    match cfg.model_name:
        case ModelName.FT_TRANSFORMER:
            return Trainer(cfg, model, n_classes)
        case ModelName.AUTOGLUON:
            return TrainerAutoGluon(cfg, model, n_classes)
        case ModelName.TABPFN | ModelName.FOUNDATION | ModelName.FOUNDATION_FLASH | ModelName.TAB2D:
            return TrainerFinetune(cfg, model, n_classes)
        case _:
            raise NotImplementedError(f"Model {cfg.model_name} not implemented yet")