"""Registry of algorithm names."""


def _import_ppo():
    from agents.ppo import PPOTrainer, PPO_DEFAULT_CONFIG

    return PPOTrainer, PPO_DEFAULT_CONFIG


def _import_ppo_com():
    from agents.ppo.communication import PPOComTrainer
    from agents.ppo import PPO_DEFAULT_CONFIG

    return PPOComTrainer, PPO_DEFAULT_CONFIG


def _import_ppo_hrl():
    from agents.ppo.hierarchical import PPOHRLTrainer, PPO_HRL_DEFAULT_CONFIG

    return PPOHRLTrainer, PPO_HRL_DEFAULT_CONFIG


def _import_ppo_curriculum():
    from agents.curriculum import PPOCurriculumTrainer, PPO_CURRICULUM_DEFAULT_CONFIG

    return PPOCurriculumTrainer, PPO_CURRICULUM_DEFAULT_CONFIG


def _import_ppo_com_curriculum():
    from agents.curriculum import PPOComCurriculumTrainer, PPO_CURRICULUM_DEFAULT_CONFIG

    return PPOComCurriculumTrainer, PPO_CURRICULUM_DEFAULT_CONFIG


def _import_ppo_hrl_curriculum():
    from agents.curriculum import PPOHRLCurriculumTrainer, PPO_HRL_CURRICULUM_DEFAULT_CONFIG

    return PPOHRLCurriculumTrainer, PPO_HRL_CURRICULUM_DEFAULT_CONFIG


def _import_league():
    from agents.league import LeagueTrainer, LEAGUE_DEFAULT_CONFIG

    return LeagueTrainer, LEAGUE_DEFAULT_CONFIG


def _import_population_entropy():
    from agents.league.population_entropy import PopulationEntropyTrainer
    from agents.league import LEAGUE_DEFAULT_CONFIG

    return PopulationEntropyTrainer, LEAGUE_DEFAULT_CONFIG


def _import_qmix():
    from agents.qmix import QMixTrainer, QMIX_DEFAULT_CONFIG

    return QMixTrainer, QMIX_DEFAULT_CONFIG


ALGORITHMS = {
    "PPO": _import_ppo,
    "League": _import_league,
    "PPO-com": _import_ppo_com,
    "PPO-hrl": _import_ppo_hrl,
    "PPO-curriculum": _import_ppo_curriculum,
    "PPO-com-curriculum": _import_ppo_com_curriculum,
    "PPO-hrl-curriculum": _import_ppo_hrl_curriculum,
    "QMIX": _import_qmix,
    "PopulationEntropy": _import_population_entropy,
}


def get_trainer_class(alg: str, return_config=False):
    """Returns the class of a known Trainer given its name."""

    if alg in ALGORITHMS:
        class_, config = ALGORITHMS[alg]()
    # elif alg in CONTRIBUTED_ALGORITHMS:
    #     class_, config = CONTRIBUTED_ALGORITHMS[alg]()
    elif alg == "script":
        from ray.tune import script_runner

        class_, config = script_runner.ScriptRunner, {}
    else:
        raise Exception(f"Unknown algorithm {alg}.")

    if return_config:
        return class_, config
    return class_
