from .model_based_joint_trainer import ModelBasedJointTrainer
from .model_based_trainer import ModelBasedTrainer
from .model_free_trainer import ModelFreeTrainer
from .treeqn_trainer import TreeQNTrainer


def get_trainer(args):
    if args.trainer == ModelFreeTrainer.__name__:
        return ModelFreeTrainer(args)
    elif args.trainer == ModelBasedJointTrainer.__name__:
        return ModelBasedJointTrainer(args)
    elif args.trainer == ModelBasedTrainer.__name__:
        return ModelBasedTrainer(args)
    elif args.trainer == TreeQNTrainer.__name__:
        return TreeQNTrainer(args)
    else:
        print(f"Invalid trainer passed! Trainer: {args.trainer}")
        exit(0)
