import sys
sys.path.append('../')
sys.path.append('../openai_baselines')
sys.path.append('../openai_baselines/baselines/common/vec_env')

import re
from shutil import copyfile
import multiprocessing
import os.path as osp
import gym
from collections import defaultdict
import tensorflow as tf
import numpy as np
import joblib
import os

from aril_env import AttackEnv
from baselines.common.vec_env import VecFrameStack, VecNormalize, VecEnv
# from baselines.common.vec_env.my_attack_env import AttackEnv
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
from baselines.common.tf_util import get_session, get_available_gpus, save_variables
from baselines import logger
from importlib import import_module

from exptools.logging import logger as expLogger
from exptools.logging import context as expContext

try:
    from mpi4py import MPI
except ImportError:
    MPI = None

try:
    import pybullet_envs
except ImportError:
    pybullet_envs = None

try:
    import roboschool
except ImportError:
    roboschool = None

_game_envs = defaultdict(set)
for env in gym.envs.registry.all():
    # TODO: solve this with regexes
    env_type = env.entry_point.split(':')[0].split('.')[-1]
    _game_envs[env_type].add(env.id)

# reading benchmark names directly from retro requires
# importing retro here, and for some reason that crashes tensorflow
# in ubuntu
_game_envs['retro'] = {
    'BubbleBobble-Nes',
    'SuperMarioBros-Nes',
    'TwinBee3PokoPokoDaimaou-Nes',
    'SpaceHarrier-Nes',
    'SonicTheHedgehog-Genesis',
    'Vectorman-Genesis',
    'FinalFight-Snes',
    'SpaceInvaders-Snes',
}


def train(args, extra_args):
    env_type, env_id = get_env_type(args)
    print('env_type: {}'.format(env_type))

    total_timesteps = int(args.num_timesteps)
    seed = args.seed

    learn = get_learn_function(args.alg)
    alg_kwargs = get_learn_function_defaults(args.alg, env_type)
    alg_kwargs.update(extra_args)

    env = build_env(args)
    if args.save_video_interval != 0:
        env = VecVideoRecorder(env, osp.join(logger.get_dir(), "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)

    if args.network:
        alg_kwargs['network'] = args.network
    else:
        if alg_kwargs.get('network') is None:
            alg_kwargs['network'] = get_default_network(env_type)

    # print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))

    print("trainable variables")
    print(tf.trainable_variables())

    # trainable_collections = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
    # remove_list = []
    # for var in trainable_collections:
    #     if "student" in var.name or "ppo2" in var.name:
    #         remove_list.append(var)
    
    # for var in remove_list:
    #     trainable_collections.remove(var)

    # print("trainable variables")
    # print(tf.trainable_variables())

    with tf.variable_scope("attacker", reuse= False):
        model = learn(
            env=env,
            seed=seed,
            total_timesteps=total_timesteps,
            # epsilon=args.epsilon,
            **alg_kwargs
        )

    return model, env


def my_train(env, args, extra_args):
    env_type, env_id = get_env_type(args)

    total_timesteps = int(args.num_timesteps)
    seed = args.seed

    learn = get_learn_function(args.alg)
    alg_kwargs = get_learn_function_defaults(args.alg, env_type)
    alg_kwargs.update(extra_args)

    if args.network:
        alg_kwargs['network'] = args.network
    else:
        if alg_kwargs.get('network') is None:
            alg_kwargs['network'] = get_default_network(env_type)

    # print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))

    with tf.variable_scope("attacker", reuse= True):
        model = learn(
            env=env,
            seed=seed,
            total_timesteps=total_timesteps,
            # epsilon=args.epsilon,
            **alg_kwargs
        )

    return model


def build_env(args):
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    seed = args.seed

    env_type, env_id = get_env_type(args)
    
    env = AttackEnv(get_session(), args.env,
        zero_order=args.zero_order,
        random_seed=args.seed,
        **args.env_kwargs,
    )

    return env


def get_env_type(args):
    env_id = args.env

    if args.env_type is not None:
        return args.env_type, env_id

    # Re-parse the gym registry, since we could have new envs since last time.
    for env in gym.envs.registry.all():
        env_type = env.entry_point.split(':')[0].split('.')[-1]
        _game_envs[env_type].add(env.id)  # This is a set so add is idempotent

    if env_id in _game_envs.keys():
        env_type = env_id
        env_id = [g for g in _game_envs[env_type]][0]
    else:
        env_type = None
        for g, e in _game_envs.items():
            if env_id in e:
                env_type = g
                break
        if ':' in env_id:
            env_type = re.sub(r':.*', '', env_id)
        assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())

    return env_type, env_id


def get_default_network(env_type):
    if env_type in {'atari', 'retro'}:
        return 'cnn'
    else:
        return 'mlp'

def get_alg_module(alg, submodule=None):
    submodule = submodule or alg
    try:
        # first try to import the alg module from baselines
        alg_module = import_module('.'.join(['baselines', alg, submodule]))
    except ImportError:
        # then from rl_algs
        alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))

    return alg_module


def get_learn_function(alg):
    return get_alg_module(alg).learn


def get_learn_function_defaults(alg, env_type):
    try:
        alg_defaults = get_alg_module(alg, 'defaults')
        kwargs = getattr(alg_defaults, env_type)()
    except (ImportError, AttributeError):
        kwargs = {}
    return kwargs

# import `exptools` stuff
from exptools.collections import AttrDict
from exptools.launching.variant import load_variant
from exptools.launching.affinity import affinity_from_code

def main(affinity_code, log_dir, run_ID, **kwargs):
    ### exptools starting
    args = load_variant(log_dir)
    args.save_path = os.path.join(log_dir, "snapshot_final")
    try:
        gpu_idx = affinity_from_code(affinity_code)["cuda_idx"]
        os.environ["CUDA_VISIBLE_DEVICES"] = str(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[gpu_idx])
    except Exception as e:
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
        print("Assuming no GPU: ", str(e))

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure(dir= log_dir, format_strs=["tensorboard", "csv"], log_suffix="openai")
    else:
        rank = MPI.COMM_WORLD.Get_rank()
        logger.configure(dir= log_dir, format_strs=["tensorboard", "csv"], log_suffix="openai{}".format(rank))

    with expContext.logger_context(log_dir, run_ID, "ARIL-{}".format(args.env), log_params= args):
        run_experiment(args, log_dir, run_ID)

def run_experiment(args, log_dir, run_ID):
    extra_args = args.algo_kwargs
    np.set_printoptions(precision=3)

    if not hasattr(args, "buffer_size"):
        args.buffer_size = int(5e6)
    if not hasattr(args, "dagger_itr"):
        args.dagger_itr = int(700)
    if not hasattr(args, "num_timesteps"):
        # num of iterations for training TRPO attack
        args.num_timesteps = int(1e6)
    if not hasattr(args, "student_steps"):
        # num of student training iteration in supervised learning manner 
        args.student_steps = 1000
    if not hasattr(args, "batch_size"):
        args.batch_size = 64
    if not hasattr(args, "num_rollouts"):
        # num of rollouts for expert labeling 
        args.num_rollouts = 40
    
    if args.env == "Swimmer-v2":
        args.obs_dim = 8
        args.act_dim = 2
        args.attack_dim = 3 if args.zero_order else args.obs_dim
    elif args.env == "Hopper-v2":
        args.obs_dim = 11
        args.act_dim = 3
        args.attack_dim = 5 if args.zero_order else args.obs_dim
    elif args.env == "Walker2d-v2":
        args.obs_dim = 17
        args.act_dim = 6
        args.attack_dim = 8 if args.zero_order else args.obs_dim
    elif args.env == "HalfCheetah-v2":
        args.obs_dim = 17
        args.act_dim = 6
        args.attack_dim = 8 if args.zero_order else args.obs_dim
    elif args.env == "Ant-v2":
        args.obs_dim = 111
        args.act_dim = 8
        args.attack_dim = 13 if args.zero_order else args.obs_dim

    # copy pretrained expert into current log path
    expLogger.log_text("Copying expert sanpshot from {}".format(args["env_kwargs"]["expert_path"]))
    expert_file = os.path.join(log_dir, "run_{}".format(run_ID), "expert_snapshot")
    copyfile(
        args["env_kwargs"]["expert_path"],
        expert_file,
    )
    args["env_kwargs"]["expert_path"] = expert_file
    
    # initialize trpo agent and environment
    tmp_num_timesteps = args.num_timesteps
    args.num_timesteps = 0
    model, env = train(args, extra_args)
    args.num_timesteps = tmp_num_timesteps
    
    buffer_state = np.empty((args.buffer_size, args.obs_dim))
    buffer_action = np.empty((args.buffer_size, args.act_dim))

    student_returns = np.zeros((args.dagger_itr, 2))
    attack_returns = np.zeros((args.dagger_itr, 2))
    attacked_student_returns = np.zeros((args.dagger_itr, 2))

    # initializing DAgger buffer
    print("---------------- initializing buffer -----------------")
    obs = env.reset()
    episode_rew = 0.0
    episode_rews = []
    cur_buffer_idx = 0
    while cur_buffer_idx < args.buffer_size:
        act = env.get_expert_action(obs)
        obs, rew, done, __ = env.expert_step(act)
        buffer_state[cur_buffer_idx, :] = obs
        buffer_action[cur_buffer_idx, :] = act
        cur_buffer_idx += 1
        episode_rew += rew
        if done:
            episode_rews.append(episode_rew)
            obs = env.reset()
            episode_rew = 0.0

    cur_buffer_idx = 0
    # print("Expert reward mean {} std {}".format(
    #     np.mean(episode_rews), np.std(episode_rews)
    # ))
    expLogger.log_scalar_batch("expert_reward", episode_rews, 0)

    # initialize student if no `student_path` provided (else, student is assumed trained)
    if args["env_kwargs"]["student_path"] is None:
        print("---------------- initializing student ----------------")
        for step in range(args.pretrain_iterations):
            batch_i = np.random.randint(0, args.buffer_size, size=args.batch_size)
            env.train_step.run(feed_dict={env.student_obs: buffer_state[batch_i, :], env.teacher_y: buffer_action[batch_i, :]})
    student_loss = env.loss_l2.eval(feed_dict={env.student_obs: buffer_state, env.teacher_y: buffer_action})
    expLogger.log_scalar("student_init_loss", student_loss, 0)

    # aril training in DAgger manner
    for i in range(args.dagger_itr):
        print("DAGGER ITERATION", i)
        # test student performance (without attack)
        print("------------------ testing student without attack -------------------")
        student_test_return = []
        act = np.zeros(args.attack_dim)
        for j in range(10):
            obs = env.reset()
            done = False
            tot_env_r = 0.
            while not done:
                obs, rew, done, __ = env.step([act])
                tot_env_r += env.get_env_reward()
            student_test_return.append(tot_env_r)
        print("student reward (without attack) mean {} std {}".format(
            np.mean(student_test_return), np.std(student_test_return)
        ))
        expLogger.log_scalar_batch("student_reward_wo_attack", student_test_return, i)
        student_returns[i, 0] = np.mean(student_test_return)
        student_returns[i, 1] = np.std(student_test_return)

        if args.train_attacker and i > 0:
            # train attack (RL step)
            print("---------------- training attack pi ----------------")
            # args.num_timesteps = 0 if i < 300 else 100
            model = my_train(env, args, extra_args)

        if hasattr(env, "backup_student_action"):
            print("------------------ test backup_student under attack -----------")
            bk_student_returns = []
            for _ in range(10):
                obs = env.reset()
                done = False
                tot_env_r = 0.
                while not done:
                    if args.train_attacker:
                        attack_act = model.step(obs)[0][0]
                        attack_act = np.clip(attack_act, -env.epsilon, env.epsilon)
                        if args["env_kwargs"]["epsilon"] == 0.0:
                            assert np.linalg.norm(attack_act) == 0.0, str(attack_act)
                    else:
                        attack_act = np.zeros(args.attack_dim)
                    if args.zero_order:
                        attack_act = np.pad(attack_act, (0, args.obs_dim - args.attack_dim), 'constant')
                    attacked_obs = obs + attack_act
                    bk_student_action_val = env.sess.run(
                        env.backup_student_action,
                        feed_dict= {env.backup_student_obs: attacked_obs},
                    )
                    if env.hacked: env.set_state(attack_act)
                    obs, env_r, done, _ = env.attack_env.step(bk_student_action_val)
                    obs = [obs]
                    tot_env_r += env_r
                bk_student_returns.append(tot_env_r)
            expLogger.log_scalar_batch("backup_return_w_attack", bk_student_returns, i)
        
        # generating trajectory
        print("---------------- generating rollouts ----------------")
        student_env_returns = []
        student_att_returns = []
        student_states = []
        student_attacked_states = []
        expert_returns = []
        for j in range(args.num_rollouts):
            if np.random.random() < 0.5:
                # generating expert trajectory
                obs = env.reset()
                done = False
                expert_return = 0.0
                while not done:
                    if args.train_attacker:
                        attack_act, __, __, __ = model.step(obs)
                        attack_act = attack_act[0]
                        attack_act = np.clip(attack_act, -env.epsilon, env.epsilon)
                    else:
                        attack_act = np.zeros(args.attack_dim)
                    if args.zero_order:
                        attack_act = np.pad(attack_act, (0, args.obs_dim - args.attack_dim), 'constant')
                    attacked_obs = obs + attack_act
                    student_states.append(obs)
                    student_attacked_states.append(attacked_obs)

                    act = env.get_expert_action(obs)
                    obs, rew, done, __ = env.expert_step(act)
                    expert_return += rew[0]
                expert_returns.append(expert_return)
            else:
                # generating attacked trajectory
                obs = env.reset()
                done = False
                tot_att_r = 0.
                tot_env_r = 0.
                while not done:
                    if args.train_attacker:
                        attack_act, __, __, __ = model.step(obs)
                        attack_act = attack_act[0]
                        attack_act = np.clip(attack_act, -env.epsilon, env.epsilon)
                    else:
                        attack_act = np.zeros(args.attack_dim)
                    if args.zero_order:
                        attack_act_pad = np.pad(attack_act, (0, args.obs_dim - args.attack_dim), 'constant')
                        attacked_obs = obs + attack_act_pad
                    else:
                        attacked_obs = obs + attack_act
                    student_states.append(obs)
                    student_attacked_states.append(attacked_obs)
                    obs, rew, done, __ = env.step([attack_act])
                    tot_env_r += env.get_env_reward()
                    tot_att_r += rew
                student_env_returns.append(tot_env_r)
                student_att_returns.append(tot_att_r)
        print("environment return mean {} std {}".format(np.mean(student_env_returns), np.std(student_env_returns)))
        print("attack return mean {} std {}".format(np.mean(student_att_returns), np.std(student_att_returns)))
        expLogger.log_scalar_batch("attacked_env_return", student_env_returns, i)
        expLogger.log_scalar_batch("attack_return", student_att_returns, i)
        expLogger.log_scalar_batch("expert_returns", expert_returns, i)
        attacked_student_returns[i, 0] = np.mean(student_env_returns)
        attacked_student_returns[i, 1] = np.std(student_env_returns)
        attack_returns[i, 0] = np.mean(student_att_returns)
        attack_returns[i, 1] = np.std(student_att_returns)
    
        if args.train_student and i > 0:
            # expert labeling & data aggregation
            expert_act_diff = list()
            print("------------------ expert labeling ------------------")
            for j in range(len(student_states)):
                obs = student_states[j]
                attacked_obs = student_attacked_states[j]
                expert_act = env.get_expert_action(obs)
                attacked_expert_act = env.get_expert_action(attacked_obs) # to verify the assumption

                buffer_state[cur_buffer_idx, :] = attacked_obs
                buffer_action[cur_buffer_idx, :] = expert_act

                cur_buffer_idx += 1
                if cur_buffer_idx >= args.buffer_size:
                    cur_buffer_idx = 0

                # count the different of expert action after being attacked
                expert_act_diff.append(np.linalg.norm((expert_act - attacked_expert_act)))
            expLogger.log_scalar_batch("expert_act_diff", expert_act_diff, i)

            # train student
            print("----------------- training student -----------------")
            for step in range(args.student_steps):
                batch_i = np.random.randint(0, args.buffer_size, size=args.batch_size)
                env.train_step.run(feed_dict={env.student_obs: buffer_state[batch_i, :], env.teacher_y: buffer_action[batch_i, :]})
            student_loss = env.loss_l2.eval(feed_dict={env.student_obs: buffer_state, env.teacher_y: buffer_action})
            expLogger.log_scalar("student_loss", student_loss, i)

            print("---------------- testing student after defense -------------")
            student_test_return = []
            attack_test_return = []
            for _ in range(10):
                obs = env.reset()
                done = False
                tot_env_r = 0.
                tot_att_r = 0.
                while not done:
                    if args.train_attacker:
                        attack_act = model.step(obs)[0][0]
                        attack_act = np.clip(attack_act, -env.epsilon, env.epsilon)
                    else:
                        attack_act = np.zeros(args.attack_dim)
                    obs, rew, done, _ = env.step([attack_act])
                    tot_env_r += env.get_env_reward()
                    tot_att_r += rew
                student_test_return.append(tot_env_r)
                attack_test_return.append(tot_att_r)
            expLogger.log_scalar_batch("defensed_env_return", student_test_return, i)
            expLogger.log_scalar_batch("defensed_attack_return", attack_test_return, i)

        expLogger.dump_data()

        # save snapshot if needed
        if i % args["log_kwargs"]["save_interval"] == 0 or i == (args.dagger_itr - 1):
            snapshot_name = "snapshot-{:d}".format(i)
            save_variables(os.path.join(
                log_dir, f"run_{run_ID}", snapshot_name
            ))
            files_dir = os.path.join(log_dir, f"run_{run_ID}")
            if len(os.listdir(files_dir)) > args["log_kwargs"]["n_file_kept"]:
                files = [f for f in os.listdir(files_dir) if ("snapshot-" in f and snapshot_name != f)]
                for f in files:
                    if int(f[9:]) <= (i - (args["log_kwargs"]["save_interval"] * args["log_kwargs"]["n_file_kept"])):
                        file_to_remove = os.path.join(
                            files_dir, f
                        )
                        expLogger.log_text("Remove file: {}".format(file_to_remove))
                        os.remove(file_to_remove)

    env.close()

    data_path = './data' + args.env + '_data_' + args.save_path[-1]
    np.savez(data_path, student_returns, attack_returns, attacked_student_returns)

    return model

if __name__ == '__main__':
    main(*sys.argv[1:])
