import sys, re, os, multiprocessing
import os.path as osp

import gym
from collections import defaultdict

import numpy as np

import tensorflow as tf
import joblib

from rarl_env import RarlEnv
from baselines import logger
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
from baselines.common.tf_util import get_session, get_available_gpus, save_variables
from importlib import import_module

# import `exptools` stuff
from exptools.collections import AttrDict
from exptools.logging import logger as expLogger
from exptools.logging import context as expContext
from exptools.launching.variant import load_variant
from exptools.launching.affinity import affinity_from_code
expLogger.tb_writer = None # Due to tensorflow version, disable that


try:
    from mpi4py import MPI
except ImportError:
    MPI = None

try:
    import pybullet_envs
except ImportError:
    pybullet_envs = None

try:
    import roboschool
except ImportError:
    roboschool = None

_game_envs = defaultdict(set)
for env in gym.envs.registry.all():
    # TODO: solve this with regexes
    env_type = env.entry_point.split(':')[0].split('.')[-1]
    _game_envs[env_type].add(env.id)

def get_env_type(args):
    env_id = args.env

    if args.env_type is not None:
        return args.env_type, env_id

    # Re-parse the gym registry, since we could have new envs since last time.
    for env in gym.envs.registry.all():
        env_type = env.entry_point.split(':')[0].split('.')[-1]
        _game_envs[env_type].add(env.id)  # This is a set so add is idempotent

    if env_id in _game_envs.keys():
        env_type = env_id
        env_id = [g for g in _game_envs[env_type]][0]
    else:
        env_type = None
        for g, e in _game_envs.items():
            if env_id in e:
                env_type = g
                break
        if ':' in env_id:
            env_type = re.sub(r':.*', '', env_id)
        assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())

    return env_type, env_id

def build_env(args):
    env = RarlEnv(get_session(), args.env, 
        **args.env_kwargs,
    )
    return env

def get_alg_module(alg, submodule=None):
    submodule = submodule or alg
    try:
        # first try to import the alg module from baselines
        alg_module = import_module('.'.join(['baselines', alg, submodule]))
    except ImportError:
        # then from rl_algs
        alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))

    return alg_module

def get_learn_function(alg):
    return get_alg_module(alg).learn


def get_learn_function_defaults(alg, env_type):
    try:
        alg_defaults = get_alg_module(alg, 'defaults')
        kwargs = getattr(alg_defaults, env_type)()
    except (ImportError, AttributeError):
        kwargs = {}
    return kwargs

def get_default_network(env_type):
    if env_type in {'atari', 'retro'}:
        return 'cnn'
    else:
        return 'mlp'

def train(args, extra_args):
    env_type, env_id = get_env_type(args)
    print('env_type: {}'.format(env_type))

    total_timesteps = 0

    learn = get_learn_function(args.alg)
    alg_kwargs = get_learn_function_defaults(args.alg, env_type)
    alg_kwargs.update(extra_args)

    env = build_env(args)
    if args.save_video_interval != 0:
        env = VecVideoRecorder(env, osp.join(logger.get_dir(), "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)

    if args.network:
        alg_kwargs['network'] = args.network
    else:
        if alg_kwargs.get('network') is None:
            alg_kwargs['network'] = get_default_network(env_type)

    with tf.variable_scope(args.attacker_name, reuse= False):
        attacker_model = learn(
            env=env.as_attacker_env(None, None), # just create a false env
            total_timesteps=total_timesteps,
            # epsilon=args.epsilon,
            **alg_kwargs
        )
    with tf.variable_scope(args.victim_name, reuse= False):
        victim_model = learn(
            env=env.as_victim_env(None, None), # just create a false env
            total_timesteps=total_timesteps,
            # epsilon=args.epsilon,
            **alg_kwargs
        )

    return env, attacker_model, victim_model

def contine_train(env, model_name, total_timesteps, args, extra_args):
    env_type, env_id = get_env_type(args)
    
    learn = get_learn_function(args.alg)
    alg_kwargs = get_learn_function_defaults(args.alg, env_type)
    alg_kwargs.update(extra_args)

    if args.network:
        alg_kwargs['network'] = args.network
    else:
        if alg_kwargs.get('network') is None:
            alg_kwargs['network'] = get_default_network(env_type)
    
    with tf.variable_scope(model_name, reuse= True):
        model = learn(
            env=env,
            total_timesteps=total_timesteps,
            # epsilon=args.epsilon,
            **alg_kwargs
        )

    return model


def run_experiment(args, log_dir, run_ID):
    np.set_printoptions(precision=3)

    # initialize agent and environment
    env, attacker_model, victim_model = train(args, args.algo_kwargs)

    attacker_env = env.as_attacker_env(
        victim_action_tensor= victim_model.act_model.action,
        victim_obs_ph= victim_model.act_model.X,
    )
    victim_env = env.as_victim_env(
        attacker_action_tensor= attacker_model.act_model.action,
        attacker_obs_ph= attacker_model.act_model.X,
    )

    for itr in range(args.alternate_itr):
        expLogger.log_text("RARL iteration {}".format(itr))

        # train victim
        expLogger.log_text("RARL train victim")
        logger.logkv("Switching", 1.0)
        contine_train(
            victim_env, args.victim_name, args.victim_inner_timesteps,
            args, args.algo_kwargs,
        )

        # train attacker
        expLogger.log_text("RARL train attacker")
        logger.logkv("Switching", -1.0)
        contine_train(
            attacker_env, args.attacker_name, args.attacker_inner_timesteps,
            args, args.algo_kwargs
        )

        # save snapshot will be done at baselines learn method
        """ NOTE: According to the implementation, it will store all variables (values)
        into the file, and the stored dict names can be compatible for down stream task.
        """
        if itr % args["log_kwargs"]["save_interval"] == 0 or itr == (args.dagger_itr - 1):
            snapshot_name = "snapshot-{:d}".format(itr)
            save_variables(os.path.join(
                log_dir, f"run_{run_ID}", snapshot_name
            ))
            files_dir = os.path.join(log_dir, f"run_{run_ID}")
            while len(os.listdir(files_dir)) > args["log_kwargs"]["n_file_kept"]:
                files = [f for f in os.listdir(files_dir) if ("snapshot-" in f and snapshot_name != f)]
                file_to_remove = os.path.join(
                    files_dir, files[0]
                )
                expLogger.log_text("Remove file: {}".format(file_to_remove))
                os.remove(file_to_remove)


def main(affinity_code, log_dir, run_ID, **kwargs):
    ### exptools starting
    args = load_variant(log_dir)
    gpu_idx = affinity_from_code(affinity_code)["cuda_idx"]
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_idx)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure(dir= log_dir, format_strs=["stdout", "tensorboard", "csv"], log_suffix="openai")
    else:
        rank = MPI.COMM_WORLD.Get_rank()
        logger.configure(dir= log_dir, format_strs=["stdout", "tensorboard", "csv"], log_suffix="openai{}".format(rank))

    with expContext.logger_context(log_dir, run_ID, "ARIL-{}".format(args.env), log_params= args):
        run_experiment(args, log_dir, run_ID)

    

if __name__ == '__main__':
    main(*sys.argv[1:])
