import sys, os
import numpy as np
from copy import copy

import tensorflow as tf
import joblib

from baselines.common.vec_env import VecEnv

import gym
from gym.spaces import Box

from norm_env import *
from aril_env import MyBox

class RarlEnv(VecEnv):
	""" This is a environment wrapper that make a doppelganger for both attacker and agent to train
	The setting will be a bit different from the original RARL setting.
	NOTE: use `as_attacker_env` and `as_victim_env` for each training, check them out
	"""
	def __init__(self,
			sess, env: str,
			attacker_space= 1e-4, # the action space for attacker somewhere call it `epsilon`
			random_seed= 666,
			be_attacked_prob= 0.5, # to make the victim robust, you have to decide the probability to attacking victim when the environment is for victim
			):
		# intialize attributes
		self.sess = sess
		self.env = gym.make(env)
		self.attacker_space = attacker_space
		self.random_seed = random_seed
		self.be_attacked_prob = be_attacked_prob
		self.num_envs= 1 # for VecEnv api
		self.mode = None # choose between "for_attacker" or "for_victim"

		# run initialization code
		self.attacker_observation_space = self.env.observation_space
		self.attacker_action_space = MyBox(
			epsilon = self.attacker_space,
			shape= self.env.observation_space.shape,
		)
		self.victim_observation_space = self.env.observation_space
		self.victim_action_space = self.env.action_space

		# reset environment
		self.state = self.env.reset()
		self.eplen = 0 # count when it is done
		self.ret = 0 # count the total env reward of an epsoide

	def seed(self, random_seed):
		self.env.seed(random_seed)

	def as_attacker_env(self, victim_action_tensor, victim_obs_ph):
		""" return a shallow copy of self, which makes the attributes the same object
		(as long as you don't change them by assignment)
		"""
		return_ = copy(self)
		return_.mode = "for_attacker"
		return_.victim_action_tensor = victim_action_tensor
		return_.victim_obs_ph = victim_obs_ph
		return return_
	def as_victim_env(self, attacker_action_tensor, attacker_obs_ph):
		return_ = copy(self)
		return_.mode = "for_victim"
		return_.attacker_action_tensor = attacker_action_tensor
		return_.attacker_obs_ph = attacker_obs_ph
		return return_

	def attack_victim(self, state):
		""" randomly choose whether to attack the state, which gives to the victim later.
		The state can be attacked and might also stay the same.
		"""
		assert self.mode == "for_victim", "Wrong state under current environment, please use `as_victim_env`"
		if np.random.uniform() < self.be_attacked_prob:
			if not (isinstance(self.attacker_action_tensor, tf.Tensor) and isinstance(self.attacker_obs_ph, tf.Tensor)):
				print("WARNING: you did not provide attacker, but asked to attack victim observation")
				return state
			attacker_act = self.sess.run(
				self.attacker_action_tensor,
				feed_dict= {self.attacker_obs_ph: np.expand_dims(state, 0)}
			)
			attacker_act = np.clip(
				attacker_act,
				self.attacker_action_space.low,
				self.attacker_action_space.high,
			)
			return state + attacker_act
		else:
			return state

	def _match_mode_behavior(self, for_attacker, for_victim, *args, **kwargs):
		""" Due to the code snippet reuse, we use a common interface to match the environment mode
		"""
		assert self.mode, "Error using environment, you should set the environment for agent or for attacker"
		if self.mode == "for_attacker":
			return for_attacker(*args, **kwargs)
		elif self.mode == "for_victim":
			return for_victim(*args, **kwargs)
		else:
			raise ValueError("Wrong env.mode for RarlEnv instance: {}".format(self.id()))
	
	def step_async(self, action):
		def for_attacker(action):
			action = np.clip(
				action,
				self.attacker_action_space.low,
				self.attacker_action_space.high,
			)
			attacked_state = self.state + action
			victim_attacked_action = self.sess.run(
				self.victim_action_tensor,
				feed_dict= {self.victim_obs_ph: attacked_state}
			)
			self.state, self.reward, self.done, self.info = self.env.step(victim_attacked_action)
			self.attacker_reward = -self.reward
			if self.info is None: self.info = dict()
			if self.done:
				self.info['episode'] = dict(
					r= self.ret,
					l= self.eplen
				)
				self.reset()
			else:
				self.ret += self.reward; self.eplen += 1

		def for_victim(action):
			self.state, self.reward, self.done, self.info = self.env.step(action)
			self.attacked_state = self.attack_victim(self.state)
			if self.info is None: self.info = dict()
			if self.done:
				self.info['episode'] = dict(
					r= self.ret,
					l= self.eplen
				)
				self.reset()
			else:
				self.ret += self.reward; self.eplen += 1

		return self._match_mode_behavior(for_attacker, for_victim,
			action,
		)
		
	def step_wait(self):
		def for_attacker():
			return np.expand_dims(self.state, 0), np.expand_dims(self.attacker_reward, 0), np.expand_dims(self.done, 0), [self.info]

		def for_victim():
			return np.expand_dims(self.attacked_state, 0), np.expand_dims(self.reward, 0), np.expand_dims(self.done, 0), [self.info]

		return self._match_mode_behavior(for_attacker, for_victim)

	@property
	def action_space(self):
		def for_attacker():
			return self.attacker_action_space
		def for_victim():
			return self.victim_action_space
		return self._match_mode_behavior(for_attacker, for_victim)
	@property
	def observation_space(self):
		def for_attacker():
			return self.attacker_observation_space
		def for_victim():
			return self.victim_observation_space
		return self._match_mode_behavior(for_attacker, for_victim)

	def reset(self):
		self.state = self.env.reset()
		self.eplen = 0 # count when it is done
		self.ret = 0 # count the total reward of an epsoide
		if self.mode == "for_victim":
			return self.attack_victim(self.state)
		else:
			return self.state

	def render(self, mode= None):
		return self.env.render(mode)
