import sys
sys.path.append('../')
sys.path.append('../openai_baselines')
sys.path.append('../openai_baselines/baselines/common/vec_env')

import re
import multiprocessing
import os.path as osp
import gym
from collections import defaultdict
import tensorflow as tf
import numpy as np
import joblib
import os

from aril_env import AttackEnv
from baselines.common.vec_env import VecFrameStack, VecNormalize, VecEnv
# from baselines.common.vec_env.my_attack_env import AttackEnv
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
from baselines.common.tf_util import get_session, get_available_gpus, save_variables
from baselines import logger
from importlib import import_module

try:
    from mpi4py import MPI
except ImportError:
    MPI = None

try:
    import pybullet_envs
except ImportError:
    pybullet_envs = None

try:
    import roboschool
except ImportError:
    roboschool = None

_game_envs = defaultdict(set)
for env in gym.envs.registry.all():
    # TODO: solve this with regexes
    env_type = env.entry_point.split(':')[0].split('.')[-1]
    _game_envs[env_type].add(env.id)

# reading benchmark names directly from retro requires
# importing retro here, and for some reason that crashes tensorflow
# in ubuntu
_game_envs['retro'] = {
    'BubbleBobble-Nes',
    'SuperMarioBros-Nes',
    'TwinBee3PokoPokoDaimaou-Nes',
    'SpaceHarrier-Nes',
    'SonicTheHedgehog-Genesis',
    'Vectorman-Genesis',
    'FinalFight-Snes',
    'SpaceInvaders-Snes',
}


def train(args, extra_args):
    env_type, env_id = get_env_type(args)
    print('env_type: {}'.format(env_type))

    total_timesteps = int(args.num_timesteps)
    seed = args.seed

    learn = get_learn_function(args.alg)
    alg_kwargs = get_learn_function_defaults(args.alg, env_type)
    alg_kwargs.update(extra_args)

    env = build_env(args)
    if args.save_video_interval != 0:
        env = VecVideoRecorder(env, osp.join(logger.get_dir(), "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)

    if args.network:
        alg_kwargs['network'] = args.network
    else:
        if alg_kwargs.get('network') is None:
            alg_kwargs['network'] = get_default_network(env_type)

    # print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))

    print("trainable variables")
    print(tf.trainable_variables())

    # trainable_collections = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
    # remove_list = []
    # for var in trainable_collections:
    #     if "student" in var.name or "ppo2" in var.name:
    #         remove_list.append(var)
    
    # for var in remove_list:
    #     trainable_collections.remove(var)

    # print("trainable variables")
    # print(tf.trainable_variables())

    model, policy, holders = learn(
        env=env,
        seed=seed,
        total_timesteps=total_timesteps,
        epsilon=args.epsilon,
        **alg_kwargs
    )

    return model, policy, holders, env


def my_train(env, policy, holders, args, extra_args):
    env_type, env_id = get_env_type(args)

    total_timesteps = int(args.num_timesteps)
    seed = args.seed

    learn = get_learn_function(args.alg)
    alg_kwargs = get_learn_function_defaults(args.alg, env_type)
    alg_kwargs.update(extra_args)

    if args.network:
        alg_kwargs['network'] = args.network
    else:
        if alg_kwargs.get('network') is None:
            alg_kwargs['network'] = get_default_network(env_type)

    # print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))

    model, policy, holders = learn(
        env=env,
        seed=seed,
        policy=policy,
        holders=holders,
        total_timesteps=total_timesteps,
        epsilon=args.epsilon,
        **alg_kwargs
    )

    return model, policy, holders


def build_env(args):
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    seed = args.seed

    env_type, env_id = get_env_type(args)

    if args.env == "Swimmer-v2":
        expert_path = os.path.abspath("openai_baselines/models/swimmer_20M_ppo2_norm")
    elif args.env == "Hopper-v2":
        expert_path = os.path.abspath("openai_baselines/models/hopper_20M_ppo2_norm")
    elif args.env == "Walker2d-v2":
        expert_path = os.path.abspath("openai_baselines/models/walker_20M_ppo2_norm")
    elif args.env == "HalfCheetah-v2":
        expert_path = os.path.abspath("openai_baselines/models/cheetah_20M_ppo2_norm")
    env = AttackEnv(get_session(), args.env, expert_path, None, hacked=args.hacked, epsilon=args.epsilon, zero_order=args.zero_order, random_seed=args.seed)

    return env


def get_env_type(args):
    env_id = args.env

    if args.env_type is not None:
        return args.env_type, env_id

    # Re-parse the gym registry, since we could have new envs since last time.
    for env in gym.envs.registry.all():
        env_type = env.entry_point.split(':')[0].split('.')[-1]
        _game_envs[env_type].add(env.id)  # This is a set so add is idempotent

    if env_id in _game_envs.keys():
        env_type = env_id
        env_id = [g for g in _game_envs[env_type]][0]
    else:
        env_type = None
        for g, e in _game_envs.items():
            if env_id in e:
                env_type = g
                break
        if ':' in env_id:
            env_type = re.sub(r':.*', '', env_id)
        assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())

    return env_type, env_id


def get_default_network(env_type):
    if env_type in {'atari', 'retro'}:
        return 'cnn'
    else:
        return 'mlp'

def get_alg_module(alg, submodule=None):
    submodule = submodule or alg
    try:
        # first try to import the alg module from baselines
        alg_module = import_module('.'.join(['baselines', alg, submodule]))
    except ImportError:
        # then from rl_algs
        alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))

    return alg_module


def get_learn_function(alg):
    return get_alg_module(alg).learn


def get_learn_function_defaults(alg, env_type):
    try:
        alg_defaults = get_alg_module(alg, 'defaults')
        kwargs = getattr(alg_defaults, env_type)()
    except (ImportError, AttributeError):
        kwargs = {}
    return kwargs



def parse_cmdline_kwargs(args):
    '''
    convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible
    '''
    def parse(v):

        assert isinstance(v, str)
        try:
            return eval(v)
        except (NameError, SyntaxError):
            return v

    return {k: parse(v) for k,v in parse_unknown_args(args).items()}



def main(args):
    # configure logger, disable logging in child MPI processes (with rank > 0)
    tf.reset_default_graph()

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    extra_args = parse_cmdline_kwargs(unknown_args)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure()
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    np.set_printoptions(precision=3)

    args.buffer_size = int(5e6)
    args.dagger_itr = int(700)
    # num of iterations for training TRPO attack
    args.num_timesteps = int(1e6)
    # num of student training iteration in supervised learning manner 
    args.student_steps = 1000
    args.batch_size = 64
    # num of rollouts for expert labeling 
    args.num_rollouts = 40
    
    if args.env == "Swimmer-v2":
        args.obs_dim = 8
        args.act_dim = 2
        args.attack_dim = 3 if args.zero_order else args.obs_dim
        expert_path = os.path.abspath("openai_baselines/models/swimmer_20M_ppo2_norm")
    elif args.env == "Hopper-v2":
        args.obs_dim = 11
        args.act_dim = 3
        args.attack_dim = 5 if args.zero_order else args.obs_dim
        expert_path = os.path.abspath("openai_baselines/models/hopper_20M_ppo2_norm")
    elif args.env == "Walker2d-v2":
        args.obs_dim = 17
        args.act_dim = 6
        args.attack_dim = 8 if args.zero_order else args.obs_dim
        expert_path = os.path.abspath("openai_baselines/models/walker_20M_ppo2_norm")
    elif args.env == "HalfCheetah-v2":
        args.obs_dim = 17
        args.act_dim = 6
        args.attack_dim = 8 if args.zero_order else args.obs_dim
        expert_path = os.path.abspath("openai_baselines/models/cheetah_20M_ppo2_norm")
    
    # initialize trpo agent and environment
    args.num_timesteps = 0
    model, policy, holders, env = train(args, extra_args)
    args.num_timesteps = int(1e6)
    
    buffer_state = np.empty((args.buffer_size, args.obs_dim))
    buffer_action = np.empty((args.buffer_size, args.act_dim))

    student_returns = np.zeros((args.dagger_itr, 2))
    attack_returns = np.zeros((args.dagger_itr, 2))
    attacked_student_returns = np.zeros((args.dagger_itr, 2))

    # initializing DAgger buffer
    print("---------------- initializing buffer -----------------")
    obs = env.reset()
    episode_rew = 0.0
    episode_rews = []
    cur_buffer_idx = 0
    while cur_buffer_idx < args.buffer_size:
        act = env.get_expert_action(obs)
        obs, rew, done, __ = env.expert_step(act)
        buffer_state[cur_buffer_idx, :] = obs
        buffer_action[cur_buffer_idx, :] = act
        cur_buffer_idx += 1
        episode_rew += rew
        if done:
            episode_rews.append(episode_rew)
            obs = env.reset()
            episode_rew = 0.0

    cur_buffer_idx = 0
    print("Expert reward mean {} std {}".format(
        np.mean(episode_rews), np.std(episode_rews)
    ))

    # initialize student
    print("---------------- initializing student ----------------")
    for step in range(2000):
        batch_i = np.random.randint(0, args.buffer_size, size=args.batch_size)
        env.train_step.run(feed_dict={env.student_obs: buffer_state[batch_i, :], env.teacher_y: buffer_action[batch_i, :]})
    print('obj value is ', env.loss_l2.eval(feed_dict={env.student_obs: buffer_state, env.teacher_y: buffer_action}))  

    # aril training in DAgger manner
    for i in range(args.dagger_itr):
        print("DAGGER ITERATION", i)
        # test student performance (without attack)
        print("------------------ testing student -------------------")
        student_test_return = []
        act = np.zeros(args.attack_dim)
        for j in range(10):
            obs = env.reset()
            done = False
            tot_env_r = 0.
            while not done:
                obs, rew, done, __ = env.step([act])
                tot_env_r += env.get_env_reward()
            student_test_return.append(tot_env_r)
        print("student reward (without attack) mean {} std {}".format(
            np.mean(student_test_return), np.std(student_test_return)
        ))
        student_returns[i, 0] = np.mean(student_test_return)
        student_returns[i, 1] = np.std(student_test_return)

        # train attack (RL step)
        print("---------------- training attack pi ----------------")
        # args.num_timesteps = 0 if i < 300 else 100
        model, policy, holders = my_train(env, policy, holders, args, extra_args)
        
        # generating trajectory
        print("---------------- generating rollouts ----------------")
        student_env_returns = []
        student_att_returns = []
        student_states = []
        student_attacked_states = []
        for j in range(args.num_rollouts):
            # generating expert trajectory
            if np.random.random() < 0.5:
                obs = env.reset()
                done = False
                while not done:
                    attack_act, __, __, __ = model.step(obs)
                    attack_act = attack_act[0]
                    if args.zero_order:
                        attack_act = np.pad(attack_act, (0, args.obs_dim - args.attack_dim), 'constant')
                    attacked_obs = obs + attack_act
                    student_states.append(obs)
                    student_attacked_states.append(attacked_obs)

                    act = env.get_expert_action(obs)
                    obs, rew, done, __ = env.expert_step(act)
            # generating attacked trajectory
            else:
                obs = env.reset()
                done = False
                tot_att_r = 0.
                tot_env_r = 0.
                while not done:
                    attack_act, __, __, __ = model.step(obs)
                    attack_act = attack_act[0]
                    if args.zero_order:
                        attack_act_pad = np.pad(attack_act, (0, args.obs_dim - args.attack_dim), 'constant')
                        attacked_obs = obs + attack_act_pad
                    else:
                        attacked_obs = obs + attack_act
                    student_states.append(obs)
                    student_attacked_states.append(attacked_obs)
                    obs, rew, done, __ = env.step([attack_act])
                    tot_env_r += env.get_env_reward()
                    tot_att_r += rew
                student_env_returns.append(tot_env_r)
                student_att_returns.append(tot_att_r)
        print("environment return mean {} std {}".format(np.mean(student_env_returns), np.std(student_env_returns)))
        print("attack return mean {} std {}".format(np.mean(student_att_returns), np.std(student_att_returns)))
        attacked_student_returns[i, 0] = np.mean(student_env_returns)
        attacked_student_returns[i, 1] = np.std(student_env_returns)
        attack_returns[i, 0] = np.mean(student_att_returns)
        attack_returns[i, 1] = np.std(student_att_returns)
    
        # expert labeling & data aggregation
        print("------------------ expert labeling ------------------")
        for j in range(len(student_states)):
            obs = student_states[i]
            attacked_obs = student_attacked_states[i]
            expert_act = env.get_expert_action(obs)

            buffer_state[cur_buffer_idx, :] = attacked_obs
            buffer_action[cur_buffer_idx, :] = expert_act

            cur_buffer_idx += 1
            if cur_buffer_idx >= args.buffer_size:
                cur_buffer_idx = 0

        # train student
        print("----------------- training student -----------------")
        for step in range(args.student_steps):
            batch_i = np.random.randint(0, args.buffer_size, size=args.batch_size)
            env.train_step.run(feed_dict={env.student_obs: buffer_state[batch_i, :], env.teacher_y: buffer_action[batch_i, :]})
        print('obj value is ', env.loss_l2.eval(feed_dict={env.student_obs: buffer_state, env.teacher_y: buffer_action}))

    if args.save_path is not None:
        save_path = osp.expanduser(args.save_path)
        save_variables(save_path)

    env.close()

    data_path = './data' + args.env + '_data_' + args.save_path[-1]
    np.savez(data_path, student_returns, attack_returns, attacked_student_returns)

    return model

if __name__ == '__main__':
    # configuration for remote attach and debug
    import ptvsd
    import sys
    ip_address = ('0.0.0.0', 6789)
    print("Process: " + " ".join(sys.argv[:]))
    print("Is waiting for attach at address: %s:%d" % ip_address, flush= True)
    # Allow other computers to attach to ptvsd at this IP address and port.
    ptvsd.enable_attach(address=ip_address, redirect_output= True)
    # Pause the program until a remote debugger is attached
    ptvsd.wait_for_attach()
    print("Process attached, start running into experiment...", flush= True)
    ptvsd.break_into_debugger()
    main(sys.argv)
