import sys
import os
sys.path.append(os.path.abspath('./'))

from env.runAtari import rollout_Atari_main
from env.runClassic import rollout_Classic_main
from env.runMPE import rollout_MPE_main


def main(rollout):
    if 'tag' in rollout:
        rollout_MPE_main(env_name = "tag",
            agent_weight_path = "./env/tag/weight/30000_4p.pt", 
            all_episodes = [50, 100, 250], 
            max_steps = 26)
    
    if 'push' in rollout:
        rollout_MPE_main(env_name = "push",
            agent_weight_path = "./env/push/weight/30000.pt", 
            all_episodes = [50, 100, 250], 
            max_steps = 26)
    
    if 'spread' in rollout:
        rollout_MPE_main(env_name = "spread",
            agent_weight_path = "./env/spread/weight/30000.pt", 
            all_episodes = [50, 100, 250], 
            max_steps = 26)

    if 'connect4' in rollout:
        rollout_Classic_main(env_name = "connect4",
            agent_weight_path = "./env/connect4/weight/lesson4_agent.pt", 
            all_episodes = [50, 100, 250], 
            opponent_difficulty = "strong", # ["random", "weak", "strong", "self"]
            max_steps = 42)
    
    if 'holdem' in rollout:
        rollout_Classic_main(env_name = "holdem",
            agent_weight_path = "./env/holdem/result/model_4500.pth", 
            all_episodes = [50, 100, 250])
    
    if 'box' in rollout:
        rollout_Atari_main(env_name = "box",
            agent_weight_path = "./env/box/weight/20000.pth", 
            all_episodes = [10, 50, 100], 
            max_steps = 128)
    
    if 'tennis' in rollout:
        rollout_Atari_main(env_name = "tennis",
            agent_weight_path = "./env/tennis/weight/20000.pth", 
            all_episodes = [10, 50, 100], 
            max_steps = 128)

    
    
if __name__ == "__main__":
    """
    Enter the task name to get its dataset.
    """
    rollout = ['tag', 'push', 'spread', 'connect4', 'holdem', 'box', 'tennis'] # 'tag', 'push', 'spread', 'connect4', 'holdem', 'box', 'tennis'
    main(rollout)