import sys
sys.path.append('../')
sys.path.append('../openai_baselines')
sys.path.append('../openai_baselines/baselines/common/vec_env')

import re
import multiprocessing
import os.path as osp
import gym
from collections import defaultdict
import tensorflow as tf
import numpy as np
import joblib
import os

from aril_env import AttackEnv
from baselines.common.vec_env import VecFrameStack, VecNormalize, VecEnv
# from baselines.common.vec_env.my_attack_env import AttackEnv
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
from baselines.common.tf_util import get_session, get_available_gpus, save_variables
from baselines import logger
from importlib import import_module

# import `exptools` stuff
from exptools.collections import AttrDict
from exptools.launching.variant import load_variant
from exptools.launching.affinity import affinity_from_code
from exptools.logging import logger as expLogger
from exptools.logging import context as expContext

# from baselines.common.vec_normalize import VecNormalize, MyVecNormalize

try:
    from mpi4py import MPI
except ImportError:
    MPI = None

try:
    import pybullet_envs
except ImportError:
    pybullet_envs = None

try:
    import roboschool
except ImportError:
    roboschool = None

_game_envs = defaultdict(set)
for env in gym.envs.registry.all():
    # TODO: solve this with regexes
    env_type = env.entry_point.split(':')[0].split('.')[-1]
    _game_envs[env_type].add(env.id)

# reading benchmark names directly from retro requires
# importing retro here, and for some reason that crashes tensorflow
# in ubuntu
_game_envs['retro'] = {
    'BubbleBobble-Nes',
    'SuperMarioBros-Nes',
    'TwinBee3PokoPokoDaimaou-Nes',
    'SpaceHarrier-Nes',
    'SonicTheHedgehog-Genesis',
    'Vectorman-Genesis',
    'FinalFight-Snes',
    'SpaceInvaders-Snes',
}


def train(args, extra_args):
    env_type, env_id = get_env_type(args)
    print('env_type: {}'.format(env_type))

    total_timesteps = int(args.num_timesteps)
    seed = args.seed

    learn = get_learn_function(args.alg)
    alg_kwargs = get_learn_function_defaults(args.alg, env_type)
    alg_kwargs.update(extra_args)

    env = build_env(args)
    if args.save_video_interval != 0:
        env = VecVideoRecorder(env, osp.join(logger.get_dir(), "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)

    if args.network:
        alg_kwargs['network'] = args.network
    else:
        if alg_kwargs.get('network') is None:
            alg_kwargs['network'] = get_default_network(env_type)

    print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))

    print("trainable variables")
    print(tf.trainable_variables())

    with tf.variable_scope("attacker", reuse= False):
        model = learn(
            env=env,
            seed=seed,
            total_timesteps=total_timesteps,
            **alg_kwargs
        )

    return model, env


def build_env(args):
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    seed = args.seed

    env_type, env_id = get_env_type(args)

    env = AttackEnv(get_session(), args.env,
        zero_order=args.zero_order,
        random_seed=args.seed,
        **args.env_kwargs,
    )

    return env


def get_env_type(args):
    env_id = args.env

    if args.env_type is not None:
        return args.env_type, env_id

    # Re-parse the gym registry, since we could have new envs since last time.
    for env in gym.envs.registry.all():
        env_type = env.entry_point.split(':')[0].split('.')[-1]
        _game_envs[env_type].add(env.id)  # This is a set so add is idempotent

    if env_id in _game_envs.keys():
        env_type = env_id
        env_id = [g for g in _game_envs[env_type]][0]
    else:
        env_type = None
        for g, e in _game_envs.items():
            if env_id in e:
                env_type = g
                break
        if ':' in env_id:
            env_type = re.sub(r':.*', '', env_id)
        assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())

    return env_type, env_id


def get_default_network(env_type):
    if env_type in {'atari', 'retro'}:
        return 'cnn'
    else:
        return 'mlp'

def get_alg_module(alg, submodule=None):
    submodule = submodule or alg
    try:
        # first try to import the alg module from baselines
        alg_module = import_module('.'.join(['baselines', alg, submodule]))
    except ImportError:
        # then from rl_algs
        alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))

    return alg_module


def get_learn_function(alg):
    return get_alg_module(alg).learn


def get_learn_function_defaults(alg, env_type):
    try:
        alg_defaults = get_alg_module(alg, 'defaults')
        kwargs = getattr(alg_defaults, env_type)()
    except (ImportError, AttributeError):
        kwargs = {}
    return kwargs


def main(affinity_code, log_dir, run_ID, **kwargs):
    ### exptools starting
    args = load_variant(log_dir)
    args.save_path = os.path.join(log_dir, "snapshot_final")
    gpu_idx = affinity_from_code(affinity_code)["cuda_idx"]
    os.environ["CUDA_VISIBLE_DEVICES"] = str(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[gpu_idx])

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure(dir= log_dir, format_strs=["tensorboard", "csv"], log_suffix="")
    else:
        rank = MPI.COMM_WORLD.Get_rank()
        logger.configure(dir= log_dir, format_strs=["tensorboard", "csv"], log_suffix="run{}".format(rank))

    with expContext.logger_context(log_dir, run_ID, "Attack-{}".format(args.env), log_params= args):
        run_experiment(args, log_dir, run_ID)

def run_experiment(args, log_dir, run_ID):
    tf.random.set_random_seed(args.seed)
    np.random.seed(args.seed)

    model, env = train(args, args.algo_kwargs)

    if MPI.COMM_WORLD.Get_rank() == 0:
        save_path = osp.expanduser(log_dir)
        model.save(os.path.join(save_path, "run_0", "attacker_model"))
    
    # test
    student_returns = []
    env_returns = []
    frames = []
    print('rollouts!')
    for i in range(20):
        obs = env.reset()
        done = False
        tot_att_r = 0.
        tot_env_r = 0.
        while not done:
            print('obs:', obs)
            attack_act, __, __, __ = model.step(obs)
            print('attack action:', attack_act)
            # attacked_obs = obs + attack_act
            # act = env.student_action.eval(feed_dict={env.student_obs:obs[None, :]})
            # act = env.student_action.eval(feed_dict={env.student_obs:attacked_obs})
            
            obs, r, done, __ = env.step(attack_act)
            print('reward:', r)
            if args.render:
                frames.append((env.render(mode='rgb_array')))
            tot_att_r += r
            tot_env_r += env.get_env_reward()
        student_returns.append(tot_att_r)
        env_returns.append(tot_env_r)
    print('student return mean', np.mean(student_returns))
    print('student return std', np.std(student_returns))
    print('env return mean', np.mean(env_returns))
    print('env return std', np.std(env_returns))

    # if args.render:
    #     video_dir = args.env + "_attacked.mp4"
    #     save_sequences(frames, export_dir=video_dir, fps=30)
    #     print("video saved!")

    print('trainable variables:')
    [print(v.name) for v in tf.trainable_variables()]

    if args.save_path is not None:
        save_path = osp.expanduser(args.save_path)
        save_variables(save_path)
    
               
    env.close()

    return model

if __name__ == '__main__':
    main(*sys.argv[1:])
