import sys
sys.path.append('../')
sys.path.append('../openai_baselines')
sys.path.append('../openai_baselines/baselines/common/vec_env')

import re
from shutil import copyfile
import multiprocessing
import threading
import os.path as osp
import gym
from collections import defaultdict
import tensorflow as tf
import numpy as np
import joblib
import os

from aril_env import AttackEnv
from baselines.common.vec_env import VecFrameStack, VecNormalize, VecEnv
# from baselines.common.vec_env.my_attack_env import AttackEnv
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
from baselines.common.tf_util import get_session, get_available_gpus, save_variables
from baselines import logger
from importlib import import_module

from exptools.logging import context as expContext

try:
    from mpi4py import MPI
except ImportError:
    MPI = None

try:
    import pybullet_envs
except ImportError:
    pybullet_envs = None

try:
    import roboschool
except ImportError:
    roboschool = None

_game_envs = defaultdict(set)
for env in gym.envs.registry.all():
    # TODO: solve this with regexes
    env_type = env.entry_point.split(':')[0].split('.')[-1]
    _game_envs[env_type].add(env.id)

# reading benchmark names directly from retro requires
# importing retro here, and for some reason that crashes tensorflow
# in ubuntu
_game_envs['retro'] = {
    'BubbleBobble-Nes',
    'SuperMarioBros-Nes',
    'TwinBee3PokoPokoDaimaou-Nes',
    'SpaceHarrier-Nes',
    'SonicTheHedgehog-Genesis',
    'Vectorman-Genesis',
    'FinalFight-Snes',
    'SpaceInvaders-Snes',
}


def train(args, extra_args):
    env_type, env_id = get_env_type(args)
    print('env_type: {}'.format(env_type))

    total_timesteps = int(args.num_timesteps)
    seed = args.seed

    learn = get_learn_function(args.alg)
    alg_kwargs = get_learn_function_defaults(args.alg, env_type)
    alg_kwargs.update(extra_args)

    env = build_env(args)

    if args.network:
        alg_kwargs['network'] = args.network
    else:
        if alg_kwargs.get('network') is None:
            alg_kwargs['network'] = get_default_network(env_type)

    # print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))

    print("trainable variables")
    print(tf.trainable_variables())

    # trainable_collections = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
    # remove_list = []
    # for var in trainable_collections:
    #     if "student" in var.name or "ppo2" in var.name:
    #         remove_list.append(var)
    
    # for var in remove_list:
    #     trainable_collections.remove(var)

    # print("trainable variables")
    # print(tf.trainable_variables())

    with tf.variable_scope("attacker", reuse= False):
        model = learn(
            env=env,
            seed=seed,
            total_timesteps=0,
            # epsilon=args.epsilon,
            **alg_kwargs
        )

    return model, env

def build_env(args):
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    seed = args.seed

    env_type, env_id = get_env_type(args)
    
    args.env_kwargs["build_backup_student"] = True
    env = AttackEnv(get_session(), args.env,
        zero_order=args.zero_order,
        random_seed=args.seed,
        **args.env_kwargs,
    )

    return env


def get_env_type(args):
    env_id = args.env

    if args.env_type is not None:
        return args.env_type, env_id

    # Re-parse the gym registry, since we could have new envs since last time.
    for env in gym.envs.registry.all():
        env_type = env.entry_point.split(':')[0].split('.')[-1]
        _game_envs[env_type].add(env.id)  # This is a set so add is idempotent

    if env_id in _game_envs.keys():
        env_type = env_id
        env_id = [g for g in _game_envs[env_type]][0]
    else:
        env_type = None
        for g, e in _game_envs.items():
            if env_id in e:
                env_type = g
                break
        if ':' in env_id:
            env_type = re.sub(r':.*', '', env_id)
        assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())

    return env_type, env_id


def get_default_network(env_type):
    if env_type in {'atari', 'retro'}:
        return 'cnn'
    else:
        return 'mlp'

def get_alg_module(alg, submodule=None):
    submodule = submodule or alg
    try:
        # first try to import the alg module from baselines
        alg_module = import_module('.'.join(['baselines', alg, submodule]))
    except ImportError:
        # then from rl_algs
        alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))

    return alg_module


def get_learn_function(alg):
    return get_alg_module(alg).learn


def get_learn_function_defaults(alg, env_type):
    try:
        alg_defaults = get_alg_module(alg, 'defaults')
        kwargs = getattr(alg_defaults, env_type)()
    except (ImportError, AttributeError):
        kwargs = {}
    return kwargs

# import `exptools` stuff
from exptools.collections import AttrDict
from exptools.launching.variant import load_variant
from exptools.launching.affinity import affinity_from_code

def main(log_dir, student_filename, expert_filename):
    ### exptools starting
    args = load_variant(log_dir)

    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

    args["env_kwargs"]["student_path"] = student_filename
    args["env_kwargs"]["expert_path"] = expert_filename

    run_experiment(args)

def run_experiment(args):
    args.seed = None
    extra_args = args.algo_kwargs
    np.set_printoptions(precision=3)

    if not hasattr(args, "buffer_size"):
        args.buffer_size = int(5e6)
    if not hasattr(args, "dagger_itr"):
        args.dagger_itr = int(700)
    if not hasattr(args, "num_timesteps"):
        # num of iterations for training TRPO attack
        args.num_timesteps = int(1e6)
    if not hasattr(args, "student_steps"):
        # num of student training iteration in supervised learning manner 
        args.student_steps = 1000
    if not hasattr(args, "batch_size"):
        args.batch_size = 64
    if not hasattr(args, "num_rollouts"):
        # num of rollouts for expert labeling 
        args.num_rollouts = 40
    
    if args.env == "Swimmer-v2":
        args.obs_dim = 8
        args.act_dim = 2
        args.attack_dim = 3 if args.zero_order else args.obs_dim
    elif args.env == "Hopper-v2":
        args.obs_dim = 11
        args.act_dim = 3
        args.attack_dim = 5 if args.zero_order else args.obs_dim
    elif args.env == "Walker2d-v2":
        args.obs_dim = 17
        args.act_dim = 6
        args.attack_dim = 8 if args.zero_order else args.obs_dim
    elif args.env == "HalfCheetah-v2":
        args.obs_dim = 17
        args.act_dim = 6
        args.attack_dim = 8 if args.zero_order else args.obs_dim
    elif args.env == "Ant-v2":
        args.obs_dim = 111
        args.act_dim = 8
        args.attack_dim = 13 if args.zero_order else args.obs_dim
    
    # initialize trpo agent and environment
    args.num_timesteps = 0
    model, env = train(args, extra_args)

    # Verifying Expert
    verify_expert_thread = None
    if args["env_kwargs"]["expert_path"]:
        def verify_expert():
            print("---------------- Verifying Expert -----------------")
            episode_rews = []
            for episode_i in range(int(2e2)):
                episode_rew = 0.0
                done = False
                obs = env.reset()
                while not done:
                    act = env.get_expert_action(obs)
                    obs, rew, done, __ = env.expert_step(act)
                    episode_rew += rew
                episode_rews.append(episode_rew)
                obs = env.reset()
                episode_rew = 0.0
            print("Expert reward mean {} std {} min {} max {}".format(
                np.mean(episode_rews), np.std(episode_rews), np.min(episode_rews), np.max(episode_rews)
            ))
        verify_expert_thread = threading.Thread(target= verify_expert)
        verify_expert_thread.start()

    # Verifying Student without attack
    verify_student_thread = None
    if args["env_kwargs"]["student_path"]:
        def verify_student():
            print("---------------- Verifying Student -----------------")
            episode_rews = []
            act = np.zeros(args.attack_dim)
            for episode_i in range(int(2e2)):
                episode_rew = 0.0
                done = False
                obs = env.reset()
                while not done:
                    obs, rew, done, __ = env.step([act])
                    episode_rew += env.get_env_reward()
                episode_rews.append(episode_rew)
            print("student reward (without attack) mean {} std {} min {} max {}".format(
                np.mean(episode_rews), np.std(episode_rews), np.min(episode_rews), np.max(episode_rews)
            ))
        verify_student_thread = threading.Thread(target= verify_student)
        verify_student_thread.start()

        if False:
            print("--------------- Verifying Backup Student")
            episode_rews = []
            act = np.zeros(args.attack_dim)
            for episode_i in range(100):
                episode_rew = 0.0
                done = False
                obs = env.reset()
                while not done:
                    bk_student_action_val = env.sess.run(
                        env.backup_student_action,
                        feed_dict= {env.backup_student_obs: obs},
                    )
                    obs, rew, done, _ = env.attack_env.step(bk_student_action_val)
                    obs = [obs]
                    episode_rew += env.get_env_reward()
                episode_rews.append(episode_rew)
            print("backup student reward (without attack) mean {} std {} min {} max {}".format(
                np.mean(episode_rews), np.std(episode_rews), np.min(episode_rews), np.max(episode_rews)
            ))

    
    # # Verifying Attacker and Student
    # att_episode_rews = []
    # stu_episode_rews = []
    # act = np.zeros(args.attack_dim)
    # for episode_i in range(100):
    #     att_episode_rew = 0.0
    #     stu_episode_rew = 0.0
    #     done = False
    #     obs = env.reset()
    #     while not done:
    #         obs, rew, done, __ = env.step([act])
    #         episode_rew += env.get_env_reward()
    #     episode_rews.append(episode_rew)
    # print("student reward (without attack) mean {} std {}".format(
    #     np.mean(episode_rews), np.std(episode_rews)
    # ))

    verify_expert_thread.join()
    verify_student_thread.join()

if __name__ == '__main__':
    if False:
        # configuration for remote attach and debug
        import ptvsd
        import sys
        ip_address = ('0.0.0.0', 6789)
        print("Process: " + " ".join(sys.argv[:]))
        print("Is waiting for attach at address: %s:%d" % ip_address, flush= True)
        # Allow other computers to attach to ptvsd at this IP address and port.
        ptvsd.enable_attach(address=ip_address, redirect_output= True)
        # Pause the program until a remote debugger is attached
        ptvsd.wait_for_attach()
        print("Process attached, start running into experiment...", flush= True)
        ptvsd.break_into_debugger()

    log_dir = "./data/slurm/aril_experiment/20210220/tAttackerFalsetStudentTrue/Ant-v2/seed-2048/run_0"
    main(
        log_dir,
        os.path.join(log_dir, "snapshot-475"),
        # "openai_baselines/models/expert_swimmer_20M_ppo2_norm",
        os.path.join(log_dir, "expert_snapshot"),
    )
