from .base_trainer import BaseTrainer
from .mvgrl import *
from .mvgrl_multi_aug import MVGRL_SVIAug_Pareto_Trainer
from .mvgrl_spectral import MVGRL_Spectral_Trainer
from .utils import *


def choose_trainer(model_name: str) -> BaseTrainer:
    if model_name == 'mvgrl':
        return MVGRLTrainer()
    if model_name == 'mvgrl_sviaug':
        return MVGRL_SVIAugTrainer()
    if model_name == 'mvgrl_sviaug_spectral':
        return MVGRL_Spectral_Trainer()
    if model_name == 'mvgrl_sviaug_spectral_shuf':
        return MVGRL_Spectral_Trainer()
    if model_name == 'mvgrl_sviaug_pareto':
        return MVGRL_SVIAug_Pareto_Trainer()
    if model_name == 'mvgrl_loss':
        return MVGRL_LossTrainer()
    else:
        raise ValueError(f'Unknown model name: {model_name}')
