"""Registry of environment names."""


def _import_open_spiel(env_name):
    import pyspiel
    from env.open_spiel.base_env import OpenSpielEnv

    return OpenSpielEnv(pyspiel.load_game(env_name))


def _import_gfootball_qmix(cfg):
    from gym.spaces import Tuple
    from env.football.multi_agent_env import FootballPvEEnv
    env = FootballPvEEnv(**cfg)
    agent_list = [f"agent_{i}" for i in list(range(env.num_agents))]
    grouping = {"group_1": agent_list}
    obs_space = Tuple([env.observation_space for _ in agent_list])
    act_space = Tuple([env.action_space for _ in agent_list])
    return env.with_agent_groups(
        grouping, obs_space=obs_space, act_space=act_space,
    )


ENVIRONMENTS = {
    "connect_four": lambda _: _import_open_spiel("connect_four"),
    "markov_soccer": lambda _: _import_open_spiel("markov_soccer"),
    "gfootball_qmix": _import_gfootball_qmix,
}


def get_env_class(env: str) -> type:
    """Returns the class of a known environment given its name."""

    if env in ENVIRONMENTS:
        class_ = ENVIRONMENTS[env]
    else:
        raise Exception("Unknown environment {env}.")

    return class_
