import sys
sys.path.append('../')
sys.path.append('../openai_baselines')
sys.path.append('../openai_baselines/baselines/common/vec_env')

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
from baselines.common.vec_env import VecEnv
import gym
import joblib
import os
from gym.spaces import Box
from aril_utils import *
from norm_env import *

from exptools.logging import logger

class MyBox(Box):
    def __init__(self, epsilon, shape=None, dtype=np.float32):
        super(MyBox, self).__init__(-epsilon, epsilon, shape, dtype)


class AttackEnv(VecEnv):
    def __init__(self,
                 sess,
                 env,
                 expert_path,
                 student_path,
                 hacked,
                 epsilon,
                 zero_order,
                 mode='rl',
                 optimizer_kwargs= None,
                 random_seed=6666,
                 expert_panelty= 0, # The factor to determine if you want to add expert action difference as panelty
                 build_backup_student= False, # if True, it will build a `backup_student` used for un-trained student
                 ):
        # tf.random.set_random_seed(args.seed)
        # np.random.seed(args.seed)

        # set attack and expert environment
        self.env_name = env
        self.attack_env = gym.make(env)
        self.expert_env = NormalizeEnv(env, update=False, random_seed=random_seed)
        # self.expert_env.obs_rms.load(expert_path)
        self.set_seed(random_seed)
        self.mode = mode
        print('mode', self.mode)
        self.num_envs= 1
        self.expert_panelty = expert_panelty

        self.state = np.empty(self.attack_env.observation_space.shape[0])
        self.e_s = np.empty(self.attack_env.observation_space.shape[0])
        self.e_o = np.empty(self.attack_env.observation_space.shape[0])

        # attack reward
        self.rew = 0
        # env reward
        self.env_rew = 0
        self.done = False
        self.expert_done = False
        self.info = None
        self.first_reset = True
        self.attack_return = 0.0
        self.expert_return = 0.0
        self.student_return = 0.0

        # get values of args
        self.sess = sess
        self.hacked = hacked
        self.epsilon = epsilon
        self.zero_order = zero_order
        self.student_path = student_path
        self.expert_path = expert_path

        if zero_order:
            if self.env_name == "Swimmer-v2":
                self.attack_dim = 3
            elif self.env_name == "Hopper-v2":
                self.attack_dim = 5
            elif self.env_name == "Walker2d-v2":
                self.attack_dim = 8
            elif self.env_name == "HalfCheetah-v2":
                self.attack_dim = 8
            elif self.env_name == "Ant-v2":
                self.attack_dim = 13
            else:
                raise RuntimeError("Wrong env selected")
        else:
            if self.env_name == "Swimmer-v2":
                self.attack_dim = 8
            elif self.env_name == "Hopper-v2":
                self.attack_dim = 11
            elif self.env_name == "Walker2d-v2":
                self.attack_dim = 17
            elif self.env_name == "HalfCheetah-v2":
                self.attack_dim = 17
            elif self.env_name == "Ant-v2":
                self.attack_dim = 111
            else:
                raise RuntimeError("Wrong env selected")

        self.observation_space = self.attack_env.observation_space
        self.action_space = MyBox(epsilon, [self.attack_dim])
        self._max_episode_steps = self.attack_env._max_episode_steps

        obs_dim = self.attack_env.observation_space.shape[0]
        act_dim = self.attack_env.action_space.shape[0]
        self.obs_dim = obs_dim
        self.act_dim = act_dim

        # set student and expert policy network
        self.student_obs = tf.placeholder(tf.float32, shape=[None, obs_dim])
        self.expert_obs = tf.placeholder(tf.float32, shape=[None, obs_dim])
        self.teacher_y = tf.placeholder(tf.float32, shape=[None, act_dim])

        # load expert policy
        para = joblib.load(expert_path)
        expert_pi_logstd = para['expert/ppo2_model/pi/logstd:0']
        expert_pi_std = np.exp(expert_pi_logstd)
        self.expert_action = pi(self.expert_obs, act_dim, 'expert/ppo2_model/pi')
        # self.expert_action += tfd.Normal(loc= np.zeros(act_dim, dtype= np.float32), scale= expert_pi_std).sample()
        expert_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope= "expert")
        load_variables(expert_path,
            variables= expert_variables,
            sess= self.sess,
        )

        # set student policy
        student_action_mean = student_pi(self.student_obs, act_dim, 'student')
        student_action_logstd = tf.Variable(
            np.zeros(act_dim), name='student/logstd', dtype=tf.float32)
        student_action_std = tf.exp(student_action_logstd)
        self.student_action = student_action_mean + student_action_std * tf.random_normal(
            tf.shape(student_action_mean))

        if build_backup_student:
            # build backup student
            self.backup_student_obs = tf.placeholder(tf.float32, shape=[None, obs_dim])
            student_action_mean = student_pi(self.backup_student_obs, act_dim, 'backup_student')
            student_action_logstd = tf.Variable(
                np.zeros(act_dim), name='backup_student/logstd', dtype=tf.float32)
            student_action_std = tf.exp(student_action_logstd)
            self.backup_student_action = student_action_mean + student_action_std * tf.random_normal(
                tf.shape(student_action_mean))

        # student the loss and optim for DAgger
        self.loss_l2 = tf.reduce_mean(
            tf.square(self.student_action - self.teacher_y))
        if optimizer_kwargs is None:
            self.train_step = tf.train.AdamOptimizer().minimize(self.loss_l2)
        else:
            self.train_step = tf.train.AdamOptimizer(**optimizer_kwargs).minimize(self.loss_l2)

    def set_seed(self, random_seed):
        self.attack_env.seed(random_seed)
        self.expert_env.seed(random_seed)

    def reset(self):
        if self.first_reset:
            self.first_reset = False

            if self.student_path is not None:
                variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope= "student")
                load_variables(
                    self.student_path, variables=variables, sess=self.sess)
                if hasattr(self, "backup_student_action"):
                    copy_variables(
                        from_variables= tf.trainable_variables(scope= "student"),
                        to_variables= tf.trainable_variables(scope= "backup_student"),
                        sess= self.sess,
                    )

            variables = [
                v for v in tf.trainable_variables()
                if "expert/ppo2_model/pi/" in v.name
            ]
            load_variables(
                self.expert_path, variables=variables, sess=self.sess)

        self.e_s, self.e_o = self.expert_env.reset()
        self.state = self.attack_env.reset()
        self.eplen = 0.0
        self.expert_return = 0.0
        self.student_return = 0.0
        self.attack_return = 0.0

        return np.expand_dims(self.state, 0)

    def step_async(self, action):
        # print("action", action)
        action = np.clip(action[0], -self.epsilon, self.epsilon)
        expert_action_val = self.sess.run(
            self.expert_action,
            feed_dict={self.expert_obs: [self.expert_env.get_obs(self.state)]})
        if self.zero_order:
            action = np.pad(action, (0, self.obs_dim - self.attack_dim),
                            'constant')
        attacked_state = self.state + action
        student_action_val = self.sess.run(
            self.student_action,
            feed_dict={self.student_obs: [attacked_state]})
        expert_att_action_val = self.sess.run(
            self.expert_action,
            feed_dict={self.expert_obs: [attacked_state]})
        self.rew = np.sum(np.square(student_action_val - expert_action_val) - self.expert_panelty * np.square(expert_action_val - expert_att_action_val))

        if self.hacked:
            self.set_state(action)

        self.state, self.env_rew, self.done, self.info = self.attack_env.step(
            student_action_val)
        if self.info is None: self.info = dict()

        if not self.done:
            true_expert_action_val = self.sess.run(
                self.expert_action, feed_dict={self.expert_obs: [self.e_o]})
            self.e_s, self.e_o, e_r, self.expert_done, __ = self.expert_env.step(
                true_expert_action_val)
            self.expert_return += e_r
            self.student_return += self.env_rew
            self.attack_return += self.rew
        else:
            # print('expert return:', self.expert_return)
            # print('student return:', self.student_return)
            # print('attack return:', self.attack_return)
            self.info['episode'] = dict(
                r= self.attack_return,
                l= self.student_return
            )
            self.expert_done = False
            self.reset()

    def step_wait(self):
        # return self.state, self.rew, self.done, self.info
        if self.mode == 'rl':
            return np.expand_dims(self.state, 0), np.expand_dims(-self.env_rew, 0), np.expand_dims(self.done, 0), [self.info]
        elif self.mode == 'il':
            return np.expand_dims(self.state, 0), np.expand_dims(self.rew, 0), np.expand_dims(self.done, 0), [self.info]

    def get_env_reward(self):
        return np.expand_dims(self.env_rew, 0)

    def render(self, mode=None):
        return self.attack_env.render(mode)

    def expert_step(self, action):
        self.e_s, self.e_o, rew, done, info = self.expert_env.step(action)
        return np.expand_dims(self.e_s, 0), np.expand_dims(rew, 0), np.expand_dims(done, 0), [info]

    def get_expert_action(self, state):
        expert_obs_np = self.expert_env.get_obs(state)
        expert_action = self.sess.run(
            self.expert_action,
            feed_dict={self.expert_obs: expert_obs_np})
        return expert_action

    def set_state(self, delta):
        position = self.attack_env.sim.data.qpos.copy()
        velocity = self.attack_env.sim.data.qvel.copy()
        # print(delta)
        # print(position)
        # print(velocity)
        # delta = delta[0]
        if self.env_name == 'Swimmer-v2':
            new_pos = np.concatenate((position.flat[:2],
                                      position.flat[2:] + delta[:3]))
            new_vel = velocity.flat + delta[3:]
        elif self.env_name == 'Hopper-v2':
            new_pos = np.concatenate((position.flat[:1],
                                      position.flat[1:] + delta[:5]))
            new_vel = velocity.flat + delta[5:]
        elif self.env_name == 'Walker2d-v2':
            new_pos = np.concatenate((position.flat[:1],
                                      position.flat[1:] + delta[:8]))
            new_vel = velocity.flat + delta[8:]
        elif self.env_name == 'HalfCheetah-v2':
            new_pos = np.concatenate((position.flat[:1],
                                      position.flat[1:] + delta[:8]))
            new_vel = velocity.flat + delta[8:]
        elif self.env_name == "Ant-v2":
            # It seems only Ant's observation has something more than
            # qpos, qvel.
            new_pos = np.concatenate((position.flat[:2],
                                      position.flat[2:] + delta[:13]))
            new_vel = velocity.flat + delta[13:27]
        self.attack_env.set_state(new_pos, new_vel)

