""" Launch RL/IL training and evaluation. """

import sys
import signal
import os
import json
import logging

import numpy as np
import torch
import random
from six.moves import shlex_quote
from mpi4py import MPI

from config import create_parser
from trainer import Trainer
from utils.logger import logger
from utils.mpi import mpi_sync
import glob

np.set_printoptions(precision=3)
np.set_printoptions(suppress=True)


def run(parser=None):
    """ Runs Trainer. """
    if parser is None:
        parser = create_parser()

    config, unparsed = parser.parse_known_args()
    if len(unparsed):
        logger.error("Unparsed argument is detected:\n%s", unparsed)
        return

    rank = MPI.COMM_WORLD.Get_rank()
    config.rank = rank
    config.is_chef = rank == 0
    config.num_workers = MPI.COMM_WORLD.Get_size()
    set_log_path(config)

    config.seed = config.seed + rank
    if hasattr(config, "port"):
        config.port = config.port + rank * 2 # training env + evaluation env

    if config.is_chef:
        logger.warn("Run a base worker.")
        make_log_files(config)
    else:
        logger.warn("Run worker %d and disable logger.", config.rank)
        logger.setLevel(logging.CRITICAL)

    # syncronize all processes
    mpi_sync()

    def shutdown(signal, frame):
        logger.warn("Received signal %s: exiting", signal)
        sys.exit(128 + signal)

    signal.signal(signal.SIGHUP, shutdown)
    signal.signal(signal.SIGINT, shutdown)
    signal.signal(signal.SIGTERM, shutdown)

    # set global seed
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False
    
    torch.manual_seed(config.seed)
    if config.gpu is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(config.gpu) # comment out when using brain server
        assert torch.cuda.is_available()
        config.device = torch.device("cuda")
        torch.cuda.manual_seed_all(config.seed)
    else:
        config.device = torch.device("cpu")

    np.random.seed(config.seed)
    random.seed(config.seed)
    ### defaul settings to use short command
    if hasattr(config, "demo_path") and config.demo_path is None:
        set_demo_path(config)
    if hasattr(config, "encoder_type") and config.encoder_type == "cnn":
        config.ob_norm = False
        config.rollout_length = 1000 # 2000 needs too more memory
        # config.pretrained_encoder = 'vae'
        # config.pretrain_n_epochs = 5
        if hasattr(config, "task_num_in_dataset"):
            config.task_num_in_dataset = 10
    
    # build a trainer
    trainer = Trainer(config)
    if config.is_train:
        trainer.train()
        logger.info("Finish training")
    else:
        # trainer.visualize_average(config.algo)
        trainer.visulize_8directons(config.algo)
        # trainer.visulize_critic(config.algo)
        # trainer.evaluate() # TODO: for now, we just visualize the discriminator, not run evaluation
        logger.info("Finish evaluating")

def set_demo_path(config):
    env1_list = ["BaseSawyerPushForwardGoalRightUpEnv-v0", "BaseSawyerPushForwardGoalRightDownEnv-v0", "BaseSawyerPushForwardGoalLeftUpEnv-v0", "BaseSawyerPushForwardGoalLeftDownEnv-v0"]
    env2_list = ["maze2d-large-blue-v3", "maze2d-large-magenta-v3", "maze2d-large-yellow-v3", "maze2d-large-red-v3"] # fixed target position
    env3_list = ["maze2d-large-red-v0", "maze2d-large-blue-v0", "maze2d-large-magenta-v0", "maze2d-large-yellow-v0"] # randomized target position
    env4_list = ["stack-blue-magenta-v0", "stack-green-magenta-v0", "stack-blue-black-v0", "stack-white-green-v0"]
    env5_list = ['door-lock-v2', 'door-unlock-v2', 'hand-insert-v2', 'box-close-v2', 'push-wall-v2', 'reach-wall-v2', 'button-press-wall-v2', 'handle-pull-side-v2', 'button-press-topdown-v2', 'plate-slide-back-v2','window-close-v2', 'coffee-button-v2', 'drawer-open-v2', 'door-open-v2', 'shelf-place-v2', 'sweep-v2', 'lever-pull-v2', 'door-close-v2']
    env6_list = ['ReachGoal01-v0', 'ReachGoal32-v0', 'ReachGoal21-v0', 'ReachGoal12-v0']
    env7_list = ['stack-blue-magenta-black-v0', 'stack-black-white-green-v0', 'stack-magenta-green-white-v0', 'stack-green-black-magenta-v0', 'stack-white-magenta-blue-v0']
    env8_list = ["maze2d-large-red-v1", "maze2d-large-blue-v1", "maze2d-large-magenta-v1", "maze2d-large-yellow-v1"] # randomized target and distractor position
    env9_list = ["stack-green-magenta-v0", "stack-white-green-v0", "stack-blue-black-v0", "stack-magenta-blue-v0", "stack-blue-white-v0"]
    env10_list =["stack-black-magenta-v0", "stack-blue-green-v0", "stack-green-black-v0", "stack-white-blue-v0", "stack-magenta-white-v0"]
    
    # TODO: add more envs
    if config.env in env1_list:
        env_list = env1_list
        if config.algo in ["reachable_gail", "reachable_gail-v0"]:
            config.is_visual_reachability = False
    elif config.env in env2_list:
        env_list = env2_list
    elif config.env in env3_list:
        env_list = env3_list
    elif config.env in env4_list:
        env_list = env4_list

        general_list    = ["stack-blue-magenta-v0", "stack-magenta-green-v0", "stack-white-black-v0", "stack-black-blue-v0"]
        close_env_list1 = ["stack-blue-magenta-v0", "stack-black-magenta-v0", "stack-white-magenta-v0", "stack-green-magenta-v0"]
        close_env_list2 = ["stack-blue-magenta-v0", "stack-blue-black-v0", "stack-blue-green-v0", "stack-blue-white-v0"]
        far_away_list   = ["stack-blue-magenta-v0", "stack-blue-green-v0", "stack-black-green-v0", "stack-magenta-green-v0", "stack-white-green-v0",]
        if config.env_type is not None:
            if config.env_type == "close1":
                env_list = close_env_list1
            elif config.env_type == "faraway":
                env_list = far_away_list
            elif config.env_type == "close2":
                env_list = close_env_list2
            elif config.env_type == "general":
                env_list = general_list
            elif config.env_type == "task8":
                env_list = ["stack-blue-magenta-v0", "stack-magenta-green-v0", "stack-white-black-v0", "stack-black-blue-v0", "stack-green-white-v0", "stack-blue-green-v0", "stack-green-black-v0", "stack-white-blue-v0", "stack-magenta-white-v0", ] #(0,2)(2,1)(4,0)(3,4)
            elif config.env_type == "task12":
                env_list = ["stack-blue-magenta-v0", "stack-magenta-green-v0", "stack-white-black-v0", "stack-black-blue-v0", "stack-green-white-v0", "stack-blue-green-v0", "stack-green-black-v0", "stack-white-blue-v0", "stack-magenta-white-v0", "stack-magenta-blue-v0","stack-black-magenta-v0","stack-white-green-v0","stack-blue-black-v0", ] #(0,2)(2,1)(4,0)(3,4)|(1,3)(4,2)(0,1)(3,0)
            
    elif config.env in env5_list:
        env_list = ['reach-v2', 'push-v2', 'pick-place-v2', 'dial-turn-v2', 'drawer-close-v2', 'button-press-v2', 'peg-insert-side-v2', 'window-open-v2', 'sweep-into-v2', 'basketball-v2',
                    'door-close-v2', 'faucet-open-v2', 'hammer-v2', 'handle-press-side-v2', 'pick-out-of-hole-v2', 'plate-slide-v2', 'plate-slide-side-v2', 'handle-pull-v2']#, 'soccer-v2', 'stick-push-v2',]
        # get first task_num_in_dataset tasks
        if hasattr(config, "task_num_in_dataset"):
            if config.task_num_in_dataset < len(env_list):
                env_indx = np.random.choice(len(env_list), size=config.task_num_in_dataset, replace=False)
                temp_env_list = [env_list[i] for i in env_indx]
                env_list = temp_env_list

        env_list.append(config.env)
        config.is_metaworld = True
        config.traj_length = 10
        
        # config.dense_reward_scale = 0.0004
        # config.backward_relabel_threshold = 0.6
        # config.reach_reward_constant = 10.0
    elif config.env in env6_list:
        env_list = env6_list
        config.traj_length = 10
    elif config.env in env7_list:
        env_list = env7_list
        config.traj_length = 10
    elif config.env in env8_list:
        env_list = env8_list
        config.traj_length = 10
    elif config.env in env9_list:
        env_list = env9_list
    elif config.env in env10_list:
        env_list = env10_list
    
    config.num_task = len(env_list)

    # update demo folder name
    suffix = ''
    num_demo = 200
    # if config.env in env8_list:
    #     num_demo = 400
    # if config.encoder_type == 'cnn':
    #     suffix += '_img_ob'
    #     num_demo = 1000
    # if config.is_2factor:
    #     suffix += '_2factor'
    #     num_demo = 50
    if config.env in env4_list or config.env in env9_list or config.env in env10_list:
        suffix += '_no_in'

    config.target_demo_path = os.path.join("demos", "demos-{}/rollout_{}_trajs.pkl".format(config.env + suffix, num_demo))
    env_list.remove(config.env)
    config.demo_path = config.target_demo_path

    if (config.algo in ["reachable_gail-v0", "sqil", "prox", "bc"]) or (config.algo == "gail" and config.pretrain_discriminator == 2):
        np.random.shuffle(env_list)
        config.other_demo_path = os.path.join("demos", "demos-{}/rollout_{}_trajs.pkl".format(env_list[0] + suffix, num_demo))
        env_list.remove(env_list[0])
        for env_name in env_list:
            config.other_demo_path += ('#' + os.path.join("demos", "demos-{}/rollout_{}_trajs.pkl".format(env_name + suffix, num_demo)))
    
    if (config.algo in ["reachable_gail", "gail-v2"]):
        np.random.shuffle(env_list)
        for env_name in env_list:
            config.demo_path += ('#' + os.path.join("demos", "demos-{}/rollout_{}_trajs.pkl".format(env_name + suffix, num_demo)))


def set_log_path(config):
    """
    Sets paths to log directories.
    """
    config.run_name = "{}.{}.{}.{}".format(
        config.env, config.algo, config.run_prefix, config.seed
    )
    config.log_dir = os.path.join(config.log_root_dir, config.run_name)
    config.record_dir = os.path.join(config.log_dir, "video")
    config.visual_dir = os.path.join(config.log_dir, "visual")
    
    # if config.algo != "bc" and (config.is_save_ood_traj or (config.algo in ["reachable_gail", "reachable_gail-v0"] and config.eval_ood_states_reachability)):
    #     config.log_ood_dir = os.path.join(config.log_dir, "ood")
    #     config.ood_record_dir = os.path.join(config.log_ood_dir, "video")
    #     config.ood_traj_dir = os.path.join(config.log_ood_dir, "traj")
    #     config.ood_visual_dir = os.path.join(config.log_ood_dir, "visual")
    
    config.pretrain_dir = os.path.join(config.log_dir, "pretrain")



def make_log_files(config):
    """
    Sets up log directories and saves git diff and command line.
    """
    logger.info("Create log directory: %s", config.log_dir)
    os.makedirs(config.log_dir, exist_ok=config.resume or not config.is_train)

    logger.info("Create video directory: %s", config.record_dir)
    os.makedirs(config.record_dir, exist_ok=config.resume or not config.is_train)

    logger.info("Create demo directory: %s", config.visual_dir)
    os.makedirs(config.visual_dir, exist_ok=config.resume or not config.is_train)

    # if config.algo != "bc" and (config.is_save_ood_traj or (config.algo in ["reachable_gail", "reachable_gail-v0"] and config.eval_ood_states_reachability)):
    #     logger.info("Create ood directory: %s", config.log_ood_dir)
    #     os.makedirs(config.log_ood_dir, exist_ok=config.resume or not config.is_train)

    #     logger.info("Create ood video directory: %s", config.ood_record_dir)
    #     os.makedirs(config.ood_record_dir, exist_ok=config.resume or not config.is_train)

    #     logger.info("Create ood traj directory: %s", config.ood_traj_dir)
    #     os.makedirs(config.ood_traj_dir, exist_ok=config.resume or not config.is_train)

    #     logger.info("Create ood visual directory: %s", config.ood_visual_dir)
    #     os.makedirs(config.ood_visual_dir, exist_ok=config.resume or not config.is_train)
    
    if config.algo in ["reachable_gail", "reachable_gail-v0"] and config.pretrain_prox:
        logger.info("Create pretrain directory: %s", config.pretrain_dir)
        os.makedirs(config.pretrain_dir, exist_ok=config.resume or not config.is_train)

    if config.is_train:
        # log git diff
        git_path = os.path.join(config.log_dir, "git.txt")
        cmd_path = os.path.join(config.log_dir, "cmd.sh")
        cmds = [
            "echo `git rev-parse HEAD` >> {}".format(git_path),
            "git diff >> {}".format(git_path),
            "echo 'python -m rl {}' >> {}".format(
                " ".join([shlex_quote(arg) for arg in sys.argv[1:]]), cmd_path
            ),
        ]
        os.system("\n".join(cmds))

        # log config
        param_path = os.path.join(config.log_dir, "params.json")
        logger.info("Store parameters in %s", param_path)
        with open(param_path, "w") as fp:
            json.dump(config.__dict__, fp, indent=4, sort_keys=True)
