import tensorflow as tf
# from .dagger.utils import *
import argparse
import numpy as np
import sys
sys.path.append('..')
sys.path.append('../openai_baselines')

# from save import save_sequences
from aril_utils import *
from norm_env import *


def calculate_grad(student_obs, student_action_mean_val,
                   student_action_std_val, student_action_logstd_val,
                   student_action):

    log_probs = - 0.5 * tf.reduce_sum(tf.square((student_action - student_action_mean_val) / student_action_std_val), axis=-1) \
               - 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(student_action)[-1]) \
               - tf.reduce_sum(student_action_logstd_val, axis=-1)
    grad = tf.gradients(-log_probs, [student_obs])[0]
    sign = tf.math.sign(grad)

    return sign


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--render", type=bool, default=False)
    parser.add_argument("--num_samples", type=int, default=50)
    parser.add_argument("--env", type=str, default="Swimmer-v2")
    parser.add_argument("--epsilon", type=float, default=0.01)
    parser.add_argument('--seed', type=int, default=2222)
    parser.add_argument('--hacked', type=bool, default=False)
    parser.add_argument('--zero_order', type=bool, default=False)

    args = parser.parse_args()

    if args.env == "Swimmer-v2":
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/swimmer_student_new")
        parser.add_argument(
            "--plot_save_path", type=str, default='./plot/swimmer.png')
        parser.add_argument(
            "--model_save_path", type=str, default='./models/swimmer_attack')
    elif args.env == "Hopper-v2":
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/hopper_student_new")
        parser.add_argument(
            "--plot_save_path", type=str, default='./plot/hopper.png')
        parser.add_argument(
            "--model_save_path", type=str, default='./models/hopper_attack')
    elif args.env == "Walker2d-v2":
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/walker_student_new")
        parser.add_argument(
            "--plot_save_path", type=str, default='./plot/walker.png')
        parser.add_argument(
            "--model_save_path", type=str, default='./models/walker_attack')
    elif args.env == "HalfCheetah-v2":
        parser.add_argument(
            "--student_path",
            type=str,
            default="../dagger/student_models/cheetah_student_new")
        parser.add_argument(
            "--plot_save_path", type=str, default='./plot/cheetah.png')
        parser.add_argument(
            "--model_save_path", type=str, default='./models/cheetah_attack')

    args = parser.parse_args()
    print("Environment:", args.env)
    print('epsilon:', args.epsilon)

    # create environment
    student_env = VanillaEnv(args.env)
    attack_env = VanillaEnv(args.env)

    attack_env.seed(args.seed)
    student_env.seed(args.seed)
    tf.random.set_random_seed(args.seed)
    np.random.seed(args.seed)

    # load rms in enviroments
    obs_dim = attack_env.observation_space.shape[0]
    act_dim = attack_env.action_space.shape[0]
    # print('observation dimension:', obs_dim)
    # print('action dimension:', act_dim)
    args.obs_dim = obs_dim
    args.act_dim = act_dim

    # define placeholder and networks
    student_obs = tf.placeholder(tf.float32, shape=[None, args.obs_dim])
    student_action_mean_val = tf.placeholder(
        tf.float32, shape=[None, args.act_dim])

    # calculate student action
    student_action_mean = student_pi(student_obs, args.act_dim, 'student')
    student_action_logstd = tf.Variable(
        np.zeros(args.act_dim), name='student/logstd', dtype=tf.float32)
    student_action_std = tf.exp(student_action_logstd)
    student_action = student_action_mean + student_action_std * tf.random_normal(
        tf.shape(student_action_mean))

    # observation after attacked
    sign = calculate_grad(student_obs, student_action_mean_val,
                          student_action_std, student_action_logstd,
                          student_action)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # load weights in student policy
        variables = [
            v for v in tf.trainable_variables() if "student" in v.name
        ]
        load_variables(args.student_path, variables=variables, sess=sess)

        # reset to the same initial state
        student_state = student_env.reset()
        attack_state = attack_env.reset()

        # print(student_state)
        # print(state)

        student_rew = 0
        attack_rew = 0
        student_done, attack_done = False, False

        student_reward_buffer = np.zeros(args.num_samples)
        attack_reward_buffer = np.zeros(args.num_samples)

        student_state_mean = np.zeros((args.obs_dim))

        student_frames = dict()

        for i in range(args.num_samples):
            student_frames[i] = []
            while True:
                if not student_done:
                    student_action_val = sess.run(
                        student_action,
                        feed_dict={student_obs: [student_state]})
                    student_state, rew, student_done, __ = student_env.step(
                        student_action_val[0])
                    student_rew += rew
                    student_state_mean = student_state_mean + np.array(
                        student_state)

                if not attack_done:
                    # attacked_student_action_val_origin = sess.run(
                    #     student_action,
                    #     feed_dict={student_obs: [attack_state]})
                    cal_student_action_mean_val = sess.run(
                        student_action_mean,
                        feed_dict={student_obs: [attack_state]})
                    sign_val = sess.run(
                        sign,
                        feed_dict={
                            student_obs: [attack_state],
                            student_action_mean_val:
                            cal_student_action_mean_val
                        })

                    if args.zero_order:
                        if args.env == "Swimmer-v2":
                            pos_len = 3
                        elif args.env == 'Hopper-v2':
                            pos_len = 5
                        elif args.env == 'Walker2d-v2':
                            pos_len = 8
                        elif args.env == 'HalfCheetah-v2':
                            pos_len = 8
                        sign_val[0, pos_len:] = 0

                    attacked_obs_val = attack_state + args.epsilon * sign_val
                    if args.hacked:
                        attack_env.set_state(args.epsilon * sign_val)
                    # print(attack_env.get_obs()-attack_state)
                    attacked_student_action_val = sess.run(
                        student_action,
                        feed_dict={student_obs: attacked_obs_val})


                    attack_state, rew, attack_done, __ = attack_env.step(
                        attacked_student_action_val[0])
                    attack_rew += rew

                if args.render:
                    student_frames[i].append(
                        (attack_env.render(mode='rgb_array')))

                if attack_done and student_done:
                    student_reward_buffer[i] = student_rew
                    attack_reward_buffer[i] = attack_rew
                    student_rew = 0
                    attack_rew = 0
                    # student_state, student_obs_val = student_env.reset()
                    # state, attack_obs_val = attack_env.reset()
                    student_state, attack_state = student_env.reset(
                    ), attack_env.reset()
                    student_done, attack_done = False, False

                    break


        print('student reward mean {} std {}'.format(
            np.mean(student_reward_buffer), np.std(student_reward_buffer)))
        print('attacked student reward mean {} std {}'.format(
            np.mean(attack_reward_buffer), np.std(attack_reward_buffer)))

