import tensorflow as tf
# from .dagger.utils import *
import argparse
import numpy as np
import sys
sys.path.append('..')
sys.path.append('../openai_baselines')
sys.path.append('/home/jianingq/research_tool/visualization/')

from save import save_sequences
from aril_utils import *
from norm_env import *
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


def calculate_grad(student_obs, expert_action, student_action):
    loss_l2 = tf.reduce_mean(tf.square(student_action - expert_action))
    grad = tf.gradients(loss_l2, [student_obs])[0]
    sign = tf.math.sign(grad)

    return sign


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--render", type=bool, default=False)
    parser.add_argument("--num_samples", type=int, default=50)
    parser.add_argument("--env", type=str, default="Swimmer-v2")
    parser.add_argument('--seed', type=int, default=1111)
    parser.add_argument('--hacked', type=bool, default=False)
    parser.add_argument('--zero_order', type=bool, default=False)

    args = parser.parse_args()

    if args.env == "Swimmer-v2":
        parser.add_argument(
            "--expert_path",
            type=str,
            default="../openai_baselines/models/swimmer_20M_ppo2_norm")
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/swimmer_student_new")
        if args.hacked:
            epsilon_min = 0
            epsilon_max = 0.02
        else:
            epsilon_min = 0
            epsilon_max = 0.4
    elif args.env == "Hopper-v2":
        parser.add_argument(
            "--expert_path",
            type=str,
            default="../openai_baselines/models/hopper_20M_ppo2_norm")
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/hopper_student_new")
        if args.hacked:
            epsilon_min = 0
            epsilon_max = 0.01
        else:
            epsilon_min = 0
            epsilon_max = 0.06
    elif args.env == "Walker2d-v2":
        parser.add_argument(
            "--expert_path",
            type=str,
            default="../openai_baselines/models/walker_20M_ppo2_norm")
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/walker_student_new")
        if args.hacked:
            epsilon_min = 0
            epsilon_max = 0.001
        else:
            epsilon_min = 0
            epsilon_max = 0.01
    elif args.env == "HalfCheetah-v2":
        parser.add_argument(
            "--expert_path",
            type=str,
            default="../openai_baselines/models/cheetah_20M_ppo2_norm")
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/cheetah_student_new")
        if args.hacked:
            epsilon_min = 0
            epsilon_max = 0.05
        else:
            epsilon_min = 0
            epsilon_max = 0.2

    args = parser.parse_args()
    print("Environment:", args.env)

    # create environment
    expert_env = NormalizeEnv(args.env, update=False)
    student_env = VanillaEnv(args.env)
    attack_env = VanillaEnv(args.env)

    attack_env.seed(args.seed)
    student_env.seed(args.seed)
    expert_env.seed(args.seed)
    tf.random.set_random_seed(args.seed)
    np.random.seed(args.seed)

    expert_env.obs_rms.load(args.expert_path)

    obs_dim = attack_env.observation_space.shape[0]
    act_dim = attack_env.action_space.shape[0]
    args.obs_dim = obs_dim
    args.act_dim = act_dim

    # define placeholder and networks
    expert_obs = tf.placeholder(tf.float32, shape=[None, args.obs_dim])
    student_obs = tf.placeholder(tf.float32, shape=[None, args.obs_dim])
    expert_action = tf.placeholder(tf.float32, shape=[None, args.act_dim])

    # load expert std
    para = joblib.load(args.expert_path)
    expert_pi_logstd = para['ppo2_model/pi/logstd:0']
    expert_pi_std = np.exp(expert_pi_logstd)

    student_action_mean = student_pi(student_obs, args.act_dim, 'student')
    student_action_logstd = tf.Variable(
        np.zeros(args.act_dim), name='student/logstd', dtype=tf.float32)
    student_action_std = tf.exp(student_action_logstd)
    student_action = student_action_mean + student_action_std * tf.random_normal(
        tf.shape(student_action_mean))

    cal_expert_action = pi(expert_obs, args.act_dim, 'ppo2_model/pi')
    cal_expert_action += expert_pi_std * np.random.normal(size=args.act_dim)

    # observation after attacked
    sign = calculate_grad(student_obs, expert_action, student_action)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # load weights in student policy
        variables = [
            v for v in tf.trainable_variables() if "student" in v.name
        ]
        load_variables(args.student_path, variables=variables, sess=sess)

        # load weights in expert policy
        variables = [
            v for v in tf.trainable_variables() if "ppo2_model/pi/" in v.name
        ]
        load_variables(args.expert_path, variables=variables, sess=sess)

        attack_mean = []
        attack_std = []
        student_mean = []
        student_std = []
        x = np.linspace(epsilon_min, epsilon_max, num=20)

        for e in x:

            student_state = student_env.reset()
            attack_state = attack_env.reset()

            student_rew = 0
            attack_rew = 0
            student_steps = 0
            attacked_steps = 0
            student_done, attack_done = False, False

            student_reward_buffer = np.zeros(args.num_samples)
            attack_reward_buffer = np.zeros(args.num_samples)
            action_dist_buffer = np.zeros(args.num_samples)

            student_state_mean = np.zeros((args.obs_dim))

            for i in range(args.num_samples):
                action_dist = 0.0
                while True:
                    if not student_done:
                        student_steps += 1
                        student_action_val = sess.run(
                            student_action,
                            feed_dict={student_obs: [student_state]})
                        student_state, rew, student_done, __ = student_env.step(
                            student_action_val[0])
                        student_rew += rew
                        student_state_mean = student_state_mean + np.array(
                            student_state)

                    if not attack_done:
                        attacked_steps += 1
                        cal_expert_action_val = sess.run(
                            cal_expert_action,
                            feed_dict={
                                expert_obs: [expert_env.get_obs(attack_state)]
                            })
                        sign_val = sess.run(
                            sign,
                            feed_dict={
                                student_obs: [attack_state],
                                expert_action: cal_expert_action_val
                            })

                        if args.zero_order:
                            if args.env == "Swimmer-v2":
                                pos_len = 3
                            elif args.env == 'Hopper-v2':
                                pos_len = 5
                            elif args.env == 'Walker2d-v2':
                                pos_len = 8
                            elif args.env == 'HalfCheetah-v2':
                                pos_len = 8
                            sign_val[0, pos_len:] = 0

                        attacked_obs_val = attack_state + e * sign_val
                        if args.hacked:
                            attack_env.set_state(e * sign_val)
                        attacked_student_action_val = sess.run(
                            student_action,
                            feed_dict={student_obs: attacked_obs_val})

                        action_dist += np.mean(
                            np.square(cal_expert_action_val -
                                      attacked_student_action_val))

                        attack_state, rew, attack_done, __ = attack_env.step(
                            attacked_student_action_val[0])
                        attack_rew += rew

                    if attack_done and student_done:
                        student_reward_buffer[i] = student_rew
                        attack_reward_buffer[i] = attack_rew
                        action_dist_buffer[i] = action_dist
                        student_rew = 0
                        attack_rew = 0
                        student_steps = 0
                        attacked_steps = 0
                        student_state, attack_state = student_env.reset(
                        ), attack_env.reset()
                        student_done, attack_done = False, False
                        action_dist = 0.0

                        break
            attack_mean.append(np.mean(attack_reward_buffer))
            attack_std.append(np.std(attack_reward_buffer))
            student_mean.append(np.mean(student_reward_buffer))
            student_std.append(np.std(student_reward_buffer))

        attack_mean = np.array(attack_mean)
        attack_std = np.array(attack_std)
        student_mean = np.array(student_mean)
        student_std = np.array(student_std)
        fig = plt.figure(figsize=(17, 10))
        plt.plot(x, attack_mean, 'k', color='#014f00')
        plt.fill_between(
            x,
            attack_mean - attack_std,
            attack_mean + attack_std,
            alpha=0.5,
            edgecolor='#08ff8b',
            facecolor='#08ff8b')
        plt.plot(x, student_mean, 'k')
        plt.plot(x, student_mean * 0.8, 'k')
        plt.plot(x, student_mean * 0.7, 'k')
        plt.legend(['FGSM_IL_{}'.format(args.hacked), 'No Attack', '0.8', '0.7'])
        plt.xlabel('Epsilon')
        plt.ylabel('Reward: {}'.format(args.env))
        fig.savefig('./plot/{}_{}_il.png'.format(args.env, args.hacked))
