from .dece_trainer import CausalDeCETrainer
from .diverse_trainer import DivTrainer
from ...train.sft.trainer import CustomSeq2SeqTrainer

from .workflow import run_sft


__all__ = ["run_sft"]

TRAINERS = {
    "dece":CausalDeCETrainer,
    "decouple":DivTrainer,
    "sft": CustomSeq2SeqTrainer,
}

def select_trainer(trainer_name: str):
    if trainer_name not in TRAINERS:
        raise ValueError(f"Trainer {trainer_name} not found.")
    return TRAINERS[trainer_name]