from k_level_policy_gradients.src.agents.setup_dqn_agent import setup_dqn_agent
from k_level_policy_gradients.src.agents.setup_gru_dqn_agent import setup_gru_dqn_agent
from k_level_policy_gradients.src.agents.setup_dqn_continuous_agent import (
    setup_dqn_continuous_agent,
)
from k_level_policy_gradients.src.agents.setup_ddpg_agent import setup_ddpg_agent
from k_level_policy_gradients.src.agents.setup_discrete_ddpg_agent import (
    setup_discrete_ddpg_agent,
)
from k_level_policy_gradients.src.agents.setup_gru_discrete_ddpg_agent import (
    setup_gru_discrete_ddpg_agent,
)
from k_level_policy_gradients.src.agents.setup_maddpg import setup_maddpg_agent
from k_level_policy_gradients.src.agents.setup_kmaddpg import setup_kmaddpg_agent
from k_level_policy_gradients.src.agents.setup_maddpg_discrete_agent import (
    setup_maddpg_discrete_agent,
)
from k_level_policy_gradients.src.agents.setup_qmix_agent import setup_qmix_agent
from k_level_policy_gradients.src.agents.setup_comix_agent import setup_comix_agent
from k_level_policy_gradients.src.agents.setup_facmac_agent import setup_facmac_agent
from k_level_policy_gradients.src.agents.setup_kfacmac_agent import setup_kfacmac_agent
from k_level_policy_gradients.src.agents.setup_facmac_continuous_agent import (
    setup_facmac_continuous_agent,
)
from k_level_policy_gradients.src.agents.setup_kfacmac_continuous_agent import (
    setup_kfacmac_continuous_agent,
)


def setup_agent(agent, mdp_info, idx_agent, **kwargs):
    match agent:
        case "dqn":
            return setup_dqn_agent(mdp_info, idx_agent, **kwargs)
        case "gru_dqn":
            return setup_gru_dqn_agent(mdp_info, idx_agent, **kwargs)
        case "dqn_continuous":
            return setup_dqn_continuous_agent(mdp_info, idx_agent, **kwargs)
        case "ddpg":
            return setup_ddpg_agent(mdp_info, idx_agent, **kwargs)
        case "discrete_ddpg":
            return setup_discrete_ddpg_agent(mdp_info, idx_agent, **kwargs)
        case "gru_discrete_ddpg":
            return setup_gru_discrete_ddpg_agent(mdp_info, idx_agent, **kwargs)
        case "maddpg":
            return setup_maddpg_agent(mdp_info, idx_agent, **kwargs)
        case "kmaddpg":
            return setup_kmaddpg_agent(mdp_info, idx_agent, **kwargs)
        case "maddpg_discrete":
            return setup_maddpg_discrete_agent(mdp_info, idx_agent, **kwargs)
        case "qmix":
            return setup_qmix_agent(mdp_info, idx_agent, **kwargs)
        case "comix":
            return setup_comix_agent(mdp_info, idx_agent, **kwargs)
        case "facmac":
            return setup_facmac_agent(mdp_info, idx_agent, **kwargs)
        case "kfacmac":
            return setup_kfacmac_agent(mdp_info, idx_agent, **kwargs)
        case "facmac_continuous":
            return setup_facmac_continuous_agent(mdp_info, idx_agent, **kwargs)
        case "kfacmac_continuous":
            return setup_kfacmac_continuous_agent(mdp_info, idx_agent, **kwargs)
        case _:
            raise ValueError(f"Agent {agent} not supported.")
