# RL algorithms
from .sac_agent import SACAgent
from .ppo_agent import PPOAgent
from .ddpg_agent import DDPGAgent
from .dqn_agent import DQNAgent

# IL algorithms
from .bc_agent import BCAgent
from .mt_bc_agent import MTBCAgent
from .gail_agent import GAILAgent
from .dac_agent import DACAgent
from .airl_agent import AIRLAgent
from .acgail_agent import ACGAILAgent
from .gail_agent_v2 import GAILAgentV2
from .reachable_gail_agent import ReachableGAILAgent
from .reachable_gail_agent_v0 import ReachableGAILAgentV0
from .iqlearn_agent import IQLearnAgent
from .sqil_agent import SQILAgent
from .prox_agent import ProxAgent

RL_ALGOS = {
    "sac": SACAgent,
    "ppo": PPOAgent,
    "ddpg": DDPGAgent,
    "td3": DDPGAgent,
    "dqn": DQNAgent,
}


IL_ALGOS = {
    "bc": BCAgent,
    "mt-bc": MTBCAgent,
    "gail": GAILAgent,
    "dac": DACAgent,
    "airl": AIRLAgent,
    "acgail": ACGAILAgent,
    "gail-v2": GAILAgentV2,
    "iqlearn": IQLearnAgent,
    "reachable_gail": ReachableGAILAgent,
    "reachable_gail-v0": ReachableGAILAgentV0,
    "sqil": SQILAgent,
    "prox": ProxAgent,
}


def get_agent_by_name(algo):
    """
    Returns RL or IL agent.
    """
    if algo in RL_ALGOS:
        return RL_ALGOS[algo]
    elif algo in IL_ALGOS:
        return IL_ALGOS[algo]
    else:
        raise ValueError("--algo %s is not supported" % algo)
