import tensorflow as tf
# from .dagger.utils import *
import argparse
import numpy as np
import sys
sys.path.append('..')
sys.path.append('../openai_baselines')
sys.path.append('/home/jianingq/research_tool/visualization/')

from save import save_sequences
from aril_utils import *
from norm_env import *
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


def calculate_grad(student_obs, student_action_mean_val,
                   student_action_std_val, student_action_logstd_val,
                   student_action):

    log_probs = - 0.5 * tf.reduce_sum(tf.square((student_action - student_action_mean_val) / student_action_std_val), axis=-1) \
               - 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(student_action)[-1]) \
               - tf.reduce_sum(student_action_logstd_val, axis=-1)
    grad = tf.gradients(-log_probs, [student_obs])[0]
    sign = tf.math.sign(grad)

    return sign


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--render", type=bool, default=False)
    parser.add_argument("--num_samples", type=int, default=50)
    parser.add_argument("--env", type=str, default="Swimmer-v2")
    parser.add_argument('--seed', type=int, default=2222)
    parser.add_argument('--hacked', type=bool, default=False)
    parser.add_argument('--zero_order', type=bool, default=False)

    args = parser.parse_args()

    if args.env == "Swimmer-v2":
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/swimmer_student_new")
        if args.hacked:
            epsilon_min = 0
            epsilon_max = 0.1
        else:
            epsilon_min = 0
            epsilon_max = 0.4
    elif args.env == "Hopper-v2":
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/hopper_student_new")
        if args.hacked:
            epsilon_min = 0
            epsilon_max = 0.01
        else:
            epsilon_min = 0
            epsilon_max = 0.06
    elif args.env == "Walker2d-v2":
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/walker_student_new")
        if args.hacked:
            epsilon_min = 0
            epsilon_max = 0.005
        else:
            epsilon_min = 0
            epsilon_max = 0.06
    elif args.env == "HalfCheetah-v2":
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/cheetah_student_new")
        if args.hacked:
            epsilon_min = 0
            epsilon_max = 0.05
        else:
            epsilon_min = 0
            epsilon_max = 0.2

    args = parser.parse_args()
    print("Environment:", args.env)

    # create environment
    student_env = VanillaEnv(args.env)
    attack_env = VanillaEnv(args.env)

    attack_env.seed(args.seed)
    student_env.seed(args.seed)
    tf.random.set_random_seed(args.seed)
    np.random.seed(args.seed)

    obs_dim = attack_env.observation_space.shape[0]
    act_dim = attack_env.action_space.shape[0]
    args.obs_dim = obs_dim
    args.act_dim = act_dim

    # define placeholder and networks
    student_obs = tf.placeholder(tf.float32, shape=[None, args.obs_dim])
    student_action_mean_val = tf.placeholder(
        tf.float32, shape=[None, args.act_dim])

    # calculate student action
    student_action_mean = student_pi(student_obs, args.act_dim, 'student')
    student_action_logstd = tf.Variable(
        np.zeros(args.act_dim), name='student/logstd', dtype=tf.float32)
    student_action_std = tf.exp(student_action_logstd)
    student_action = student_action_mean + student_action_std * tf.random_normal(
        tf.shape(student_action_mean))

    # observation after attacked
    sign = calculate_grad(student_obs, student_action_mean_val,
                          student_action_std, student_action_logstd,
                          student_action)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # load weights in student policy
        variables = [
            v for v in tf.trainable_variables() if "student" in v.name
        ]
        load_variables(args.student_path, variables=variables, sess=sess)

        attack_mean = []
        attack_std = []
        student_mean = []
        student_std = []
        x = np.linspace(epsilon_min, epsilon_max, num=15)

        for e in x:

            student_state = student_env.reset()
            attack_state = attack_env.reset()

            student_rew = 0
            attack_rew = 0
            student_done, attack_done = False, False

            student_reward_buffer = np.zeros(args.num_samples)
            attack_reward_buffer = np.zeros(args.num_samples)

            student_state_mean = np.zeros((args.obs_dim))

            for i in range(args.num_samples):
                while True:
                    if not student_done:
                        student_action_val = sess.run(
                            student_action,
                            feed_dict={student_obs: [student_state]})
                        student_state, rew, student_done, __ = student_env.step(
                            student_action_val[0])
                        student_rew += rew
                        student_state_mean = student_state_mean + np.array(
                            student_state)

                    if not attack_done:
                        cal_student_action_mean_val = sess.run(
                            student_action_mean,
                            feed_dict={student_obs: [attack_state]})
                        sign_val = sess.run(
                            sign,
                            feed_dict={
                                student_obs: [attack_state],
                                student_action_mean_val:
                                cal_student_action_mean_val
                            })

                        if args.zero_order:
                            if args.env == "Swimmer-v2":
                                pos_len = 3
                            elif args.env == 'Hopper-v2':
                                pos_len = 5
                            elif args.env == 'Walker2d-v2':
                                pos_len = 8
                            elif args.env == 'HalfCheetah-v2':
                                pos_len = 8
                            sign_val[0, pos_len:] = 0

                        attacked_obs_val = attack_state + e * sign_val
                        if args.hacked:
                            attack_env.set_state(e * sign_val)
                        attacked_student_action_val = sess.run(
                            student_action,
                            feed_dict={student_obs: attacked_obs_val})

                        attack_state, rew, attack_done, __ = attack_env.step(
                            attacked_student_action_val[0])
                        attack_rew += rew

                    if attack_done and student_done:
                        student_reward_buffer[i] = student_rew
                        attack_reward_buffer[i] = attack_rew
                        student_rew = 0
                        attack_rew = 0
                        student_state, attack_state = student_env.reset(
                        ), attack_env.reset()
                        student_done, attack_done = False, False

                        break
            attack_mean.append(np.mean(attack_reward_buffer))
            attack_std.append(np.std(attack_reward_buffer))
            student_mean.append(np.mean(student_reward_buffer))
            student_std.append(np.std(student_reward_buffer))

        attack_mean = np.array(attack_mean)
        attack_std = np.array(attack_std)
        student_mean = np.array(student_mean)
        student_std = np.array(student_std)
        fig = plt.figure(figsize=(17, 10))
        plt.plot(x, attack_mean, 'k', color='#014f00')
        plt.fill_between(
            x,
            attack_mean - attack_std,
            attack_mean + attack_std,
            alpha=0.5,
            edgecolor='#08ff8b',
            facecolor='#08ff8b')
        plt.plot(x, student_mean, 'k')
        plt.plot(x, student_mean * 0.8, 'k')
        plt.plot(x, student_mean * 0.7, 'k')
        plt.legend(
            ['FGSM_RL_{}'.format(args.hacked), 'No Attack', '0.8', '0.7'])
        plt.xlabel('Epsilon')
        plt.ylabel('Reward: {}'.format(args.env))
        fig.savefig('./plot/{}_{}_rl_seed_7777.png'.format(args.env, args.hacked))
