#!/usr/bin/env python3

"""
Main entry point for starting a training job.
"""

import os
import sys
import argparse
import subprocess
import shutil
import logging
import logging.config
import torch
from training.parameter_spec import Parameters

# printing
def prRed(prt): print("\033[91m {}\033[00m".format(prt))
def prGreen(prt): print("\033[92m {}\033[00m".format(prt))
def prYellow(prt): print("\033[93m {}\033[00m".format(prt))
def prLightPurple(prt): print("\033[94m {}\033[00m".format(prt))
def prPurple(prt): print("\033[95m {}\033[00m".format(prt))
def prCyan(prt): print("\033[96m {}\033[00m".format(prt))
def prLightGray(prt): print("\033[97m {}\033[00m".format(prt))
def prBlack(prt): print("\033[98m {}\033[00m".format(prt))

env_types = [
    'append-still',
    'append-spawn',
    'append-dynamic',
    'prune-still',
    'prune-still-hard',
    'prune-dynamic',
    'prune-spawn',
    'navigate',
]

parser = argparse.ArgumentParser(description="Config file")
parser.add_argument('--config_file', type=str, default='sample_config.json', help='Configuration file for the chosen model')
parser.add_argument('--config_idx', type=str, default=1, help='Configuration index')
exp_info = parser.parse_args()
print(exp_info.config_file)
args = Parameters(exp_info)

# args = parser.parse_args()

# Setup the directories

safety_dir = os.path.realpath(os.path.join(__file__, '../'))
data_dir = os.path.realpath(args.data_dir)
job_name = os.path.basename(data_dir)
sys.path.insert(1, safety_dir)  # ensure current directory is on the path
os.chdir(safety_dir)


if os.path.exists(data_dir) and args.data_dir is not None:
    print("The directory '%s' already exists. "
          "Would you like to overwrite the old data, append to it, or abort?" %
          data_dir)
    response = None
    response = 'append' if args.run_benchmarks else \
        'overwrite' if job_name.startswith('tmp') else None
    while response not in ('overwrite', 'append', 'abort'):
        response = input("(overwrite / append / abort) > ")
    if response == 'overwrite':
        print("Overwriting old data.")
        shutil.rmtree(data_dir)
    elif response == 'abort':
        print("Aborting.")
        exit()

os.makedirs(data_dir, exist_ok=True)
logfile = os.path.join(data_dir, 'training.log')

# Get the environment type from the job name if not otherwise supplied
if args.env_type:
    env_type = args.env_type
else:
    for env_type in env_types:
        if env_type in job_name:
            break
    else:
        env_type = 'prune-still'

assert env_type in env_types

# Setup logging

if not os.path.exists(logfile):
    open(logfile, 'w').close()  # write an empty file
logging_config = {
    'version': 1,
    'disable_existing_loggers': False,
    'formatters': {
        'simple': {
            'format': '{levelname:8s} {message}',
            'style': '{',
        },
        'dated': {
            'format': '{asctime} {levelname} ({filename}:{lineno}) {message}',
            'style': '{',
            'datefmt': '%Y-%m-%d %H:%M:%S',
        },
    },
    'handlers': {
        'console': {
            'class': 'logging.StreamHandler',
            'level': 'INFO',
            'stream': 'ext://sys.stdout',
            'formatter': 'simple',
        },
        'logfile': {
            'class': 'logging.FileHandler',
            'level': 'INFO',
            'formatter': 'dated',
            'filename': logfile,
        }
    },
    'loggers': {
        'training': {
            'level': 'INFO',
            'propagate': False,
            'handlers': ['console', 'logfile'],
        },
        'safelife': {
            'level': 'INFO',
            'propagate': False,
            'handlers': ['console', 'logfile'],
        }
    },
    'root': {
        'level': 'WARNING',
        'handlers': ['console', 'logfile'],
    }
}
logging.config.dictConfig(logging_config)

# Build the safelife C extensions.
# By making the build lib the same as the base folder, the extension
# should just get built into the source directory.
subprocess.run([
    "python3", os.path.join(safety_dir, "setup.py"),
    "build_ext", "--build-lib", safety_dir
])

# Start tensorboard

if args.port:
    tb_proc = subprocess.Popen([
        "tensorboard", "--logdir", job_name + ':' + data_dir, '--port', str(args.port)])

# Start training!

try:
    import numpy as np
    from training.env_factory import (
        LinearSchedule,
        SafeLifeLevelIterator,
        SwitchingLevelIterator,
        safelife_env_factory,
        safelife_env_factory_nav_reward
    )
    from safelife.safelife_logger import SafeLifeLogger
    from safelife.random import set_rng

    logger = logging.getLogger('training')

    seed1, seed2 = np.random.SeedSequence(args.seed).spawn(2)
    logger.info("SETTING GLOBAL SEED: %i", seed1.entropy)
    set_rng(np.random.default_rng(seed1))

    for penalty in [args.impact_penalty]:
        runtag = 'safe_actor'
        tripletag = 'triple_actor'
        subdir = os.path.join(data_dir, "penalty_{:0.2f}".format(penalty))
        subdir_2 = os.path.join(data_dir, "penalty_{:0.2f}".format(penalty) + '_' + runtag)
        subdir_3 = os.path.join(data_dir, "penalty_{:0.2f}".format(penalty) + '_' + tripletag)
        os.makedirs(subdir, exist_ok=True)
        os.makedirs(subdir_2, exist_ok=True)
        os.makedirs(subdir_3, exist_ok=True)

        if args.run_benchmarks:
            data_logger_1 = SafeLifeLogger(
                subdir,
                summary_writer=False,
                training_log=False,
                testing_video_name="benchmark-{level_name}",
                testing_log="benchmark-data.json")

            data_logger_2 = SafeLifeLogger(subdir_2, runtag=runtag,
                                            summary_writer=False,
                                            training_log=False,
                                            testing_video_name="benchmark-{level_name}",
                                            testing_log="benchmark-data.json")
            data_logger_3 = SafeLifeLogger(subdir_3, runtag=tripletag,
                                           summary_writer=False,
                                           training_log=False,
                                           testing_video_name="benchmark-{level_name}",
                                           testing_log="benchmark-data.json"
                                           )

        else:
            data_logger_1 = SafeLifeLogger(subdir)
            data_logger_2 = SafeLifeLogger(subdir_2, runtag=runtag)
            data_logger_3 = SafeLifeLogger(subdir_3, runtag=tripletag)

        if args.static_envs:

            if env_type == 'append-still':
                t_penalty = [1.0e6, 2.0e6]
                t_performance = [1.0e6, 2.0e6]
                level_iterator = SafeLifeLevelIterator(
                    'static_train_envs/v1.0/append-still',
                    seed=seed2, repeat_levels=True
                )
                test_levels = 'benchmarks/v1.0/append-still.npz'

            elif env_type == 'prune-still':
                t_penalty = [0.5e6, 1.5e6]
                t_performance = [0.5e6, 1.5e6]
                level_iterator = SafeLifeLevelIterator(
                    'static_train_envs/v1.0/prune-still',
                    seed=seed2, repeat_levels=True
                )
                test_levels = 'benchmarks/v1.0/prune-still.npz'

            elif env_type == 'prune-still-hard':
                t_penalty = [0.5e6, 1.5e6]
                t_performance = [0.5e6, 1.5e6]
                level_iterator = SafeLifeLevelIterator(
                    'static_train_envs/v1.0/prune-still-hard',
                    seed=seed2, repeat_levels=True
                )
                test_levels = 'benchmarks/v1.0/prune-still-hard.npz'

            elif env_type == 'prune-dynamic':
                t_penalty = [0.5e6, 1.5e6]
                t_performance = [0.5e6, 1.5e6]
                level_iterator = SafeLifeLevelIterator(
                    'static_train_envs/v1.0/prune-dynamic',
                    seed=seed2, repeat_levels=True
                )
                test_levels = 'benchmarks/v1.0/prune-dynamic.npz'

            elif env_type == 'append-dynamic':
                t_penalty = [2.0e6, 3.5e6]
                t_performance = [1.0e6, 2.0e6]
                level_iterator = SafeLifeLevelIterator(
                    'static_train_envs/v1.0/append-dynamic',
                    seed=seed2, repeat_levels=True
                )
                test_levels = 'benchmarks/v1.0/append-dynamic.npz'

            elif env_type == 'append-spawn':
                # When training in spawn environments, we first pre-train in the
                # static environments for a couple million time steps. This just
                # provides more opportunities for rewards so makes the initial
                # training easier.
                t_penalty = [2.0e6, 3.5e6]
                t_performance = [1.0e6, 2.0e6]
                level_iterator = SwitchingLevelIterator(
                    'static_train_envs/v1.0/append-still',
                    'static_train_envs/v1.0/append-spawn',
                    t_switch=1.5e6,
                    logger=data_logger_1,
                    seed=seed2, repeat_levels=True
                )
                test_levels = 'benchmarks/v1.0/append-spawn.npz'

            elif env_type == 'prune-spawn':
                t_penalty = [1.5e6, 2.5e6]
                t_performance = [0.5e6, 1.5e6]
                level_iterator = SwitchingLevelIterator(
                    'static_train_envs/v1.0/prune-still',
                    'static_train_envs/v1.0/prune-spawn',
                    t_switch=2.0e6,
                    logger=data_logger_1,
                    seed=seed2, repeat_levels=True
                )
                test_levels = 'benchmarks/v1.0/prune-spawn.npz'

            elif env_type == 'navigate':
                t_penalty = [1.0e6, 2.0e6]
                t_performance = [1.0e6, 2.0e6]  # not actually relevant for navigate
                level_iterator = SafeLifeLevelIterator(
                    'static_train_envs/v1.0/navigation', seed=seed2, repeat_levels=True
                )
                test_levels = 'benchmarks/v1.0/navigation.npz'
            else:
                logging.error("Unexpected environment type '%s'", env_type)

            training_envs = safelife_env_factory(
                data_logger=data_logger_1, num_envs=16,
                impact_penalty=LinearSchedule(data_logger_1, t_penalty, [0, penalty]),
                penalty_baseline=args.penalty_baseline,
                min_performance=LinearSchedule(data_logger_1, t_performance, [0.01, 0.5]),
                level_iterator=level_iterator,

            )

            training_envs_safe = safelife_env_factory(
                data_logger=data_logger_2, num_envs=16,
                impact_penalty=LinearSchedule(data_logger_2, t_penalty, [0, penalty]),
                penalty_baseline=args.penalty_baseline,
                min_performance=LinearSchedule(data_logger_2, t_performance, [0.01, 0.5]),
                level_iterator=level_iterator,

            )

            training_envs_triple = safelife_env_factory(
                data_logger=data_logger_3, num_envs=8,
                impact_penalty=LinearSchedule(data_logger_3, t_penalty, [0, penalty]),
                penalty_baseline=args.penalty_baseline,
                min_performance=LinearSchedule(data_logger_3, t_performance, [0.01, 0.5]),
                level_iterator=level_iterator,

            )

            training_envs_nav_reward = safelife_env_factory_nav_reward(
                data_logger=data_logger_2, num_envs=16,
                impact_penalty=LinearSchedule(data_logger_1, t_penalty, [0.01, 0.01]),
                penalty_baseline=args.penalty_baseline,
                min_performance=LinearSchedule(data_logger_1, t_performance, [0.00, 0.00]),
                level_iterator=level_iterator,

            )

        else:

            if env_type == 'append-still':
                t_penalty = [1.0e6, 2.0e6]
                t_performance = [1.0e6, 2.0e6]
                level_iterator = SafeLifeLevelIterator(
                    'random/append-still-easy.yaml',
                    seed=seed2
                )
                test_levels = 'benchmarks/v1.0/append-still.npz'

            elif env_type == 'prune-still':
                t_penalty = [0.5e6, 1.5e6]
                t_performance = [0.5e6, 1.5e6]
                level_iterator = SafeLifeLevelIterator(
                    'random/prune-still-easy.yaml',
                    seed=seed2
                )
                test_levels = 'benchmarks/v1.0/prune-still.npz'

            elif env_type == 'prune-still-hard':
                t_penalty = [0.5e6, 1.5e6]
                t_performance = [0.5e6, 1.5e6]
                level_iterator = SafeLifeLevelIterator(
                    'random/prune-still-hard.yaml',
                    seed=seed2
                )
                test_levels = 'benchmarks/v1.0/prune-still-hard.npz'

            elif env_type == 'prune-dynamic':
                t_penalty = [0.5e6, 1.5e6]
                t_performance = [0.5e6, 1.5e6]
                level_iterator = SafeLifeLevelIterator(
                    'random/prune-dynamic.yaml',
                    seed=seed2
                )
                test_levels = 'benchmarks/v1.0/prune-dynamic.npz'

            elif env_type == 'append-dynamic':
                t_penalty = [2.0e6, 3.5e6]
                t_performance = [1.0e6, 2.0e6]
                level_iterator = SafeLifeLevelIterator(
                    'random/append-dynamic.yaml',
                    seed=seed2
                )
                test_levels = 'benchmarks/v1.0/append-dynamic.npz'

            elif env_type == 'append-spawn':
                # When training in spawn environments, we first pre-train in the
                # static environments for a couple million time steps. This just
                # provides more opportunities for rewards so makes the initial
                # training easier.
                t_penalty = [2.0e6, 3.5e6]
                t_performance = [1.0e6, 2.0e6]
                level_iterator = SwitchingLevelIterator(
                    'random/append-still-easy.yaml',
                    'random/append-spawn.yaml',
                    t_switch=1.5e6,
                    logger=data_logger_1,
                    seed=seed2,
                )
                test_levels = 'benchmarks/v1.0/append-spawn.npz'

            elif env_type == 'prune-spawn':
                t_penalty = [1.5e6, 2.5e6]
                t_performance = [0.5e6, 1.5e6]
                level_iterator = SwitchingLevelIterator(
                    'random/prune-still-easy.yaml',
                    'random/prune-spawn.yaml',
                    t_switch=2.0e6,
                    logger=data_logger_1,
                    seed=seed2,
                )
                test_levels = 'benchmarks/v1.0/prune-spawn.npz'

            elif env_type == 'navigate':
                t_penalty = [1.0e6, 2.0e6]
                t_performance = [1.0e6, 2.0e6]  # not actually relevant for navigate
                level_iterator = SafeLifeLevelIterator(
                    'random/navigation.yaml', seed=seed2,
                )
                test_levels = 'benchmarks/v1.0/navigation.npz'
            else:
                logging.error("Unexpected environment type '%s'", env_type)

            training_envs = safelife_env_factory(
                data_logger=data_logger_1, num_envs=16,
                impact_penalty=LinearSchedule(data_logger_1, t_penalty, [0, penalty]),
                penalty_baseline=args.penalty_baseline,
                min_performance=LinearSchedule(data_logger_1, t_performance, [0.01, 0.5]),
                level_iterator=level_iterator,

            )

            training_envs_safe = safelife_env_factory(
                data_logger=data_logger_2, num_envs=16,
                impact_penalty=LinearSchedule(data_logger_2, t_penalty, [0, penalty]),
                penalty_baseline=args.penalty_baseline,
                min_performance=LinearSchedule(data_logger_2, t_performance, [0.01, 0.5]),
                level_iterator=level_iterator,

            )

            training_envs_triple = safelife_env_factory(
                data_logger=data_logger_3, num_envs=8,
                impact_penalty=LinearSchedule(data_logger_3, t_penalty, [0, penalty]),
                penalty_baseline=args.penalty_baseline,
                min_performance=LinearSchedule(data_logger_3, t_performance, [0.01, 0.5]),
                level_iterator=level_iterator,

            )

            training_envs_nav_reward = safelife_env_factory_nav_reward(
                data_logger=data_logger_2, num_envs=16,
                impact_penalty=LinearSchedule(data_logger_1, t_penalty, [0.01, 0.01]),
                penalty_baseline=args.penalty_baseline,
                min_performance=LinearSchedule(data_logger_1, t_performance, [0.00, 0.00]),
                level_iterator=level_iterator,

            )

        if args.run_benchmarks:
            testing_envs = safelife_env_factory(
                data_logger=data_logger_1, num_envs=100, testing=True,
                level_iterator=SafeLifeLevelIterator(
                    test_levels, repeat_levels=True)
            )

            testing_envs_safe = safelife_env_factory(
                data_logger=data_logger_2, num_envs=100, testing=True,
                level_iterator=SafeLifeLevelIterator(
                    test_levels, repeat_levels=True)
            )
        else:
            testing_envs = safelife_env_factory(
                data_logger=data_logger_1, num_envs=100, testing=True,
                level_iterator=SafeLifeLevelIterator(
                    test_levels, distinct_levels=100, repeat_levels=True)
            )

            testing_envs_safe = safelife_env_factory(
                data_logger=data_logger_2, num_envs=100, testing=True,
                level_iterator=SafeLifeLevelIterator(
                    test_levels, distinct_levels=100, repeat_levels=True)
            )

            testing_envs_triple = safelife_env_factory(
                data_logger=data_logger_3, num_envs=5, testing=True,
                level_iterator=SafeLifeLevelIterator(
                    test_levels, distinct_levels=5, repeat_levels=True)
            )

            testing_envs_nav_reward = safelife_env_factory_nav_reward(
                data_logger=data_logger_2, num_envs=5, testing=True,
                level_iterator=SafeLifeLevelIterator(
                    test_levels, distinct_levels=5, repeat_levels=True)
            )

        if args.algo == 'ppo':
            from training.models import SafeLifePolicyNetwork
            from training.ppo import PPO

            obs_shape = training_envs[0].observation_space.shape
            model = SafeLifePolicyNetwork(obs_shape)
            algo = PPO(
                model, args,
                training_envs=training_envs,
                testing_envs=testing_envs,
                data_logger=data_logger_1)


        elif args.algo == 'ppo-ppo-sarl':
            from training.models import SafeLifePolicyNetwork
            from training.ppo_ppo_sarl import PPO_Risk

            obs_shape = training_envs[0].observation_space.shape
            model = SafeLifePolicyNetwork(obs_shape)
            train_model_safe = SafeLifePolicyNetwork(obs_shape)

            algo = PPO_Risk(
                model, train_model_safe, args,
                training_envs=training_envs,
                testing_envs=testing_envs,
                training_envs_safe=training_envs_safe,
                testing_envs_safe=testing_envs_safe,
                data_logger=data_logger_1)


        elif args.algo == 'ppo-ppo-sarl-bgrl':
            from training.models import SafeLifePolicyNetwork
            from training.ppo_ppo_sarl_bgrl import PPO_Risk

            obs_shape = training_envs[0].observation_space.shape
            model = SafeLifePolicyNetwork(obs_shape)
            train_model_safe = SafeLifePolicyNetwork(obs_shape)

            algo = PPO_Risk(
                model, train_model_safe, args,
                training_envs=training_envs,
                testing_envs=testing_envs,
                training_envs_safe=training_envs_safe,
                testing_envs_safe=testing_envs_safe,
                data_logger=data_logger_1)


        elif args.algo == 'ppo-ppo-sarl-generalize':
            from training.models import SafeLifePolicyNetwork
            from training.ppo_ppo_generalize import PPO_Risk

            obs_shape = training_envs[0].observation_space.shape
            model = SafeLifePolicyNetwork(obs_shape)
            train_model_safe = SafeLifePolicyNetwork(obs_shape)

            checkpoint_prune_agent = ''


            algo = PPO_Risk(
                model, train_model_safe, args,
                training_envs=training_envs,
                testing_envs=testing_envs,
                training_envs_safe=training_envs_safe,
                testing_envs_safe=testing_envs_safe,
                data_logger=data_logger_1,
                checkpoint_safe_agent=checkpoint_prune_agent)
        
        elif args.algo == 'ppo-ppo-sarl-generalize-bgrl':
            from training.models import SafeLifePolicyNetwork
            from training.ppo_ppo_generalize_bgrl import PPO_Risk

            obs_shape = training_envs[0].observation_space.shape
            model = SafeLifePolicyNetwork(obs_shape)
            train_model_safe = SafeLifePolicyNetwork(obs_shape)


            checkpoint_prune_agent = ''

            algo = PPO_Risk(
                model, train_model_safe, args,
                training_envs=training_envs,
                testing_envs=testing_envs,
                training_envs_safe=training_envs_safe,
                testing_envs_safe=testing_envs_safe,
                data_logger=data_logger_1,
                checkpoint_safe_agent=checkpoint_prune_agent)

        elif args.algo == 'ppo-ppo-sarl-generalize-append-agent':
            from training.models import SafeLifePolicyNetwork
            from training.ppo_ppo_generalize import PPO_Risk

            obs_shape = training_envs[0].observation_space.shape
            model = SafeLifePolicyNetwork(obs_shape)
            train_model_safe = SafeLifePolicyNetwork(obs_shape)

            checkpoint_append_agent = ''


            algo = PPO_Risk(
                model, train_model_safe, args,
                training_envs=training_envs,
                testing_envs=testing_envs,
                training_envs_safe=training_envs_safe,
                testing_envs_safe=testing_envs_safe,
                data_logger=data_logger_1,
                checkpoint_safe_agent=checkpoint_append_agent)


        elif args.algo == 'ppo-ppo-sarl-generalize-bgrl-append-agent':
            from training.models import SafeLifePolicyNetwork
            from training.ppo_ppo_generalize_bgrl import PPO_Risk

            obs_shape = training_envs[0].observation_space.shape
            model = SafeLifePolicyNetwork(obs_shape)
            train_model_safe = SafeLifePolicyNetwork(obs_shape)

            checkpoint_append_agent = ''


            algo = PPO_Risk(
                model, train_model_safe, args,
                training_envs=training_envs,
                testing_envs=testing_envs,
                training_envs_safe=training_envs_safe,
                testing_envs_safe=testing_envs_safe,
                data_logger=data_logger_1,
                checkpoint_safe_agent=checkpoint_append_agent)





        else:
            logging.error("Unexpected algorithm type '%s'", args.algo)
            raise ValueError("unexpected algorithm type")

        if args.run_benchmarks:
            algo.run_episodes(testing_envs, num_episodes=100)
        else:
            algo.train(args.steps)


except Exception:
    logging.exception("Ran into an unexpected error. Aborting training.")
finally:
    if args.port:
        tb_proc.kill()
    # if args.shutdown:
    #     # Shutdown in 3 minutes.
    #     # Enough time to recover if it crashed at the start.
    #     subprocess.run("sudo shutdown +3".split())
    #     logging.critical("Shutdown commenced. Exiting to bash.")
    #     subprocess.run(["bash", "-il"])
