import sys

import gymnasium as gym
import mo_gymnasium
import ale_py

from algorithms import DDPG_ALGO, DQN_ALGO, RAINBOW_ALGO, TD3_ALGO, NECSA_DDPG_ALGO, \
    NECSA_DQN_ALGO, NECSA_RAINBOW_ALGO, NECSA_TD3_ALGO, NECSA_ADV_TD3_ALGO, MPEC_ALGO, DQN_DISCRETE_ALGO, \
    NECSA_DQN_DISCRETE_ALGO, MPEC_DISCRETE_ALGO, ALGORITHMS
from ddpg import get_ddpg_argparser, get_necsa_ddpg_argparser, test_ddpg
from dqn import get_dqn_argparser, get_necsa_dqn_argparser, test_dqn
from dqn_discrete import get_dqn_discrete_argparser, get_necsa_dqn_discrete_argparser, \
    test_dqn_discrete
from mpec import get_mpec_argparser, test_mpec
from mpec_discrete import get_mpec_discrete_argparser, test_mpec_discrete
from rainbow import get_rainbow_argparser, get_necsa_rainbow_argparser, test_rainbow
from td3 import get_td3_argparser, get_necsa_td3_argparser, get_necsa_adv_td3_argparser, \
    test_td3
from utils.environment_util import ENV_ARGPARSER_MAP, get_env_category

gym.register_envs(ale_py)
gym.register_envs(mo_gymnasium)

from gymnasium.envs.registration import register
register(
    id='breakable-bottles-v0',
    entry_point='env_.task:CustomBreakableBottles',
)


if __name__ == "__main__":
    if len(sys.argv) < 2:
        algorithms = "\n".join(ALGORITHMS)
        print(f'Error: No algorithm specified. \nPick one of: \n{algorithms}')
        sys.exit(1)

    algorithm = sys.argv[1]

    parser_map = {
        DDPG_ALGO: get_ddpg_argparser,
        DQN_ALGO: get_dqn_argparser,
        RAINBOW_ALGO: get_rainbow_argparser,
        TD3_ALGO: get_td3_argparser,
        NECSA_DDPG_ALGO: get_necsa_ddpg_argparser,
        NECSA_DQN_ALGO: get_necsa_dqn_argparser,
        NECSA_RAINBOW_ALGO: get_necsa_rainbow_argparser,
        NECSA_TD3_ALGO: get_necsa_td3_argparser,
        NECSA_ADV_TD3_ALGO: get_necsa_adv_td3_argparser,
        MPEC_ALGO: get_mpec_argparser,
        DQN_DISCRETE_ALGO: get_dqn_discrete_argparser,
        NECSA_DQN_DISCRETE_ALGO: get_necsa_dqn_discrete_argparser,
        MPEC_DISCRETE_ALGO: get_mpec_discrete_argparser,
    }

    resume_path = None
    if '--resume-path' in sys.argv:
        resume_path_index = sys.argv.index('--resume-path') + 1
        resume_path = sys.argv[resume_path_index]

    if resume_path:
        task, algorithm, seed, experiment_name = resume_path.split('/')

        sys.argv = sys.argv[:1] + [algorithm] + sys.argv[1:]
        sys.argv.extend(['--task', task, '--seed', seed])

    if algorithm not in parser_map:
        print(f"Error: Unknown algorithm '{algorithm}'")
        sys.exit(1)

    task_index = sys.argv.index('--task') + 1
    task = sys.argv[task_index]

    env_argparser = ENV_ARGPARSER_MAP[get_env_category(task)]
    parser = parser_map[algorithm](env_argparser)
    args = parser.parse_args()

    runner_map = {
        DDPG_ALGO: test_ddpg,
        DQN_ALGO: test_dqn,
        RAINBOW_ALGO: test_rainbow,
        TD3_ALGO: test_td3,
        NECSA_DDPG_ALGO: test_ddpg,
        NECSA_DQN_ALGO: test_dqn,
        NECSA_RAINBOW_ALGO: test_rainbow,
        NECSA_TD3_ALGO: test_td3,
        NECSA_ADV_TD3_ALGO: test_td3,
        MPEC_ALGO: test_mpec,
        DQN_DISCRETE_ALGO: test_dqn_discrete,
        NECSA_DQN_DISCRETE_ALGO: test_dqn_discrete,
        MPEC_DISCRETE_ALGO: test_mpec_discrete,
    }

    runner_map[algorithm](args)
