import tensorflow as tf
# from .dagger.utils import *
import argparse
import numpy as np
import sys
sys.path.append('..')
sys.path.append('../openai_baselines')
sys.path.append('/home/jianingq/research_tool/visualization/')

from aril_utils import *
from norm_env import *


def calculate_grad(student_obs, expert_action, student_action):
    # student_action = student_pi(student_obs, args.act_dim, 'student')
    # expert_action = pi(expert_obs, args.act_dim, 'ppo2_model/pi')
    # expert_action += expert_pi_std * np.random.normal(size=args.act_dim)
    loss_l2 = tf.reduce_mean(tf.square(student_action - expert_action))
    grad = tf.gradients(loss_l2, [student_obs])[0]
    sign = tf.math.sign(grad)

    return sign


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--render", type=bool, default=False)
    parser.add_argument("--num_samples", type=int, default=50)
    parser.add_argument("--env", type=str, default="Swimmer-v2")
    parser.add_argument("--epsilon", type=float, default=0.01)
    parser.add_argument('--seed', type=int, default=2222)
    parser.add_argument('--hacked', type=bool, default=False)
    parser.add_argument('--zero_order', type=bool, default=False)

    args = parser.parse_args()

    if args.env == "Swimmer-v2":
        parser.add_argument(
            "--expert_path",
            type=str,
            default="../openai_baselines/models/swimmer_20M_ppo2_norm")
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/swimmer_student_new")
        parser.add_argument(
            "--plot_save_path", type=str, default='./plot/swimmer.png')
        parser.add_argument(
            "--model_save_path", type=str, default='./models/swimmer_attack')
    elif args.env == "Hopper-v2":
        parser.add_argument(
            "--expert_path",
            type=str,
            default="../openai_baselines/models/hopper_20M_ppo2_norm")
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/hopper_student_new")
        parser.add_argument(
            "--plot_save_path", type=str, default='./plot/hopper.png')
        parser.add_argument(
            "--model_save_path", type=str, default='./models/hopper_attack')
    elif args.env == "Walker2d-v2":
        parser.add_argument(
            "--expert_path",
            type=str,
            default="../openai_baselines/models/walker_20M_ppo2_norm")
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/walker_student_new")
        parser.add_argument(
            "--plot_save_path", type=str, default='./plot/walker.png')
        parser.add_argument(
            "--model_save_path", type=str, default='./models/walker_attack')
    elif args.env == "HalfCheetah-v2":
        parser.add_argument(
            "--expert_path",
            type=str,
            default="../openai_baselines/models/cheetah_20M_ppo2_norm")
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/cheetah_student_new")
        parser.add_argument(
            "--plot_save_path", type=str, default='./plot/cheetah.png')
        parser.add_argument(
            "--model_save_path", type=str, default='./models/cheetah_attack')

    args = parser.parse_args()
    print("Environment:", args.env)
    print('epsilon:', args.epsilon)

    # create environment
    # attack_env = NormalizeEnv(args.env, update=False)
    # student_env = NormalizeEnv(args.env, update=False)
    expert_env = NormalizeEnv(args.env, update=False)
    student_env = VanillaEnv(args.env)
    attack_env = VanillaEnv(args.env)

    attack_env.seed(args.seed)
    student_env.seed(args.seed)
    expert_env.seed(args.seed)
    tf.random.set_random_seed(args.seed)
    np.random.seed(args.seed)

    # load rms in enviroments
    # npz_path = args.student_path + '_rms.npz'
    # attack_env.obs_rms.load_npz(npz_path)
    # student_env.obs_rms.load_npz(npz_path)
    expert_env.obs_rms.load(args.expert_path)

    obs_dim = attack_env.observation_space.shape[0]
    act_dim = attack_env.action_space.shape[0]
    # print('observation dimension:', obs_dim)
    # print('action dimension:', act_dim)
    args.obs_dim = obs_dim
    args.act_dim = act_dim

    # define placeholder and networks
    expert_obs = tf.placeholder(tf.float32, shape=[None, args.obs_dim])
    student_obs = tf.placeholder(tf.float32, shape=[None, args.obs_dim])
    expert_action = tf.placeholder(tf.float32, shape=[None, args.act_dim])

    # load expert std
    para = joblib.load(args.expert_path)
    expert_pi_logstd = para['ppo2_model/pi/logstd:0']
    expert_pi_std = np.exp(expert_pi_logstd)

    # define student and expert action
    # student_mean, student_std = student_pi(student_obs, args.act_dim,
    #                                        'student')
    # student_action = student_mean + student_std * tf.random_normal(
    #     tf.shape(student_mean))

    student_action_mean = student_pi(student_obs, args.act_dim, 'student')
    student_action_logstd = tf.Variable(
        np.zeros(args.act_dim), name='student/logstd', dtype=tf.float32)
    student_action_std = tf.exp(student_action_logstd)
    student_action = student_action_mean + student_action_std * tf.random_normal(
        tf.shape(student_action_mean))

    cal_expert_action = pi(expert_obs, args.act_dim, 'ppo2_model/pi')
    cal_expert_action += expert_pi_std * np.random.normal(size=args.act_dim)

    # observation after attacked
    sign = calculate_grad(student_obs, expert_action, student_action)
    # attacked_obs = attack(args, student_obs, expert_action, student_action)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # load weights in student policy
        variables = [
            v for v in tf.trainable_variables() if "student" in v.name
        ]
        load_variables(args.student_path, variables=variables, sess=sess)

        # load weights in expert policy
        variables = [
            v for v in tf.trainable_variables() if "ppo2_model/pi/" in v.name
        ]
        load_variables(args.expert_path, variables=variables, sess=sess)

        # reset to the same initial state?
        # student_state, student_obs_val = student_env.reset()
        student_state = student_env.reset()
        # attack_state, attack_obs_val = attack_env.reset()
        attack_state = attack_env.reset()

        # print(student_state)
        # print(state)

        student_rew = 0
        attack_rew = 0
        student_steps = 0
        attacked_steps = 0
        student_done, attack_done = False, False

        student_reward_buffer = np.zeros(args.num_samples)
        attack_reward_buffer = np.zeros(args.num_samples)
        action_dist_buffer = np.zeros(args.num_samples)

        student_state_mean = np.zeros((args.obs_dim))

        student_frames_bf_a = dict()
        student_frames_af_a = dict()
        expert_frames = []

        for i in range(args.num_samples):
            action_dist = 0.0
            student_frames_bf_a[i] = []
            student_frames_af_a[i] = []
            while True:
                if not student_done:
                    student_steps += 1
                    # student_action_val = sess.run(student_action, feed_dict={student_obs: [student_obs_val]})
                    # __, student_obs_val, rew, student_done, __ = student_env.step(student_action_val[0])
                    student_action_val = sess.run(
                        student_action,
                        feed_dict={student_obs: [student_state]})
                    student_state, rew, student_done, __ = student_env.step(
                        student_action_val[0])
                    student_rew += rew
                    student_state_mean = student_state_mean + np.array(
                        student_state)

                if not attack_done:
                    attacked_steps += 1
                    # attacked_obs_val = sess.run(
                    #     attacked_obs,
                    #     feed_dict={
                    #         student_obs: [attack_obs_val],
                    #         expert_obs: [expert_env.get_obs(state)]
                    #     })

                    # cal_expert_action_val = sess.run(cal_expert_action, feed_dict={expert_obs: [expert_env.get_obs(state)]})
                    # sign_val = sess.run(sign, feed_dict={student_obs: [attack_obs_val], expert_action: cal_expert_action_val})

                    cal_expert_action_val = sess.run(
                        cal_expert_action,
                        feed_dict={
                            expert_obs: [expert_env.get_obs(attack_state)]
                        })
                    sign_val = sess.run(
                        sign,
                        feed_dict={
                            student_obs: [attack_state],
                            expert_action: cal_expert_action_val
                        })

                    if args.zero_order:
                        if args.env == "Swimmer-v2":
                            pos_len = 3
                        elif args.env == 'Hopper-v2':
                            pos_len = 5
                        elif args.env == 'Walker2d-v2':
                            pos_len = 8
                        elif args.env == 'HalfCheetah-v2':
                            pos_len = 8
                        sign_val[0, pos_len:] = 0

                    attacked_obs_val = attack_state + args.epsilon * sign_val
                    student_frames_bf_a[i].append(
                        attack_env.render(mode='rgb_array'))
                    if args.hacked:
                        attack_env.set_state(args.epsilon * sign_val)
                    student_frames_af_a[i].append(
                        attack_env.render(mode='rgb_array'))
                    # print(attack_env.get_obs()-attack_state)
                    attacked_student_action_val = sess.run(
                        student_action,
                        feed_dict={student_obs: attacked_obs_val})
                    # print(attack_obs_val)
                    # print(attacked_obs_val)

                    action_dist += np.mean(
                        np.square(cal_expert_action_val -
                                  attacked_student_action_val))

                    attack_state, rew, attack_done, __ = attack_env.step(
                        attacked_student_action_val[0])
                    attack_rew += rew

                if attack_done and student_done:
                    student_reward_buffer[i] = student_rew
                    attack_reward_buffer[i] = attack_rew
                    action_dist_buffer[i] = action_dist
                    student_rew = 0
                    attack_rew = 0
                    student_steps = 0
                    attacked_steps = 0
                    # student_state, student_obs_val = student_env.reset()
                    # state, attack_obs_val = attack_env.reset()
                    student_state, attack_state = student_env.reset(
                    ), attack_env.reset()
                    student_done, attack_done = False, False

                    # print('action distance:', action_dist)
                    action_dist = 0.0

                    break

        print('student reward mean {} std {}'.format(
            np.mean(student_reward_buffer), np.std(student_reward_buffer)))
        print('attacked student reward mean {} std {}'.format(
            np.mean(attack_reward_buffer), np.std(attack_reward_buffer)))
        print('attack action reward mean {} std {}'.format(
            np.mean(action_dist_buffer), np.std(action_dist_buffer)))

