"""Environment for training vehicles to reduce congestion in a merge."""

from flow.envs.multiagent.base import MultiEnv
from flow.core.rewards import desired_velocity, local_desired_velocity
from gym.spaces.box import Box
import numpy as np


ADDITIONAL_ENV_PARAMS = {
	# maximum acceleration for autonomous vehicles, in m/s^2
	"max_accel": 3,
	# maximum deceleration for autonomous vehicles, in m/s^2
	"max_decel": 3,
	# desired velocity for all vehicles in the network, in m/s
	"target_velocity": 25,
}


class MultiAgentMergePOEnv(MultiEnv):
	"""Partially observable multi-agent merge environment.

	This environment is used to train autonomous vehicles to attenuate the
	formation and propagation of waves in an open merge network.

	Required from env_params:

	* max_accel: maximum acceleration for autonomous vehicles, in m/s^2
	* max_decel: maximum deceleration for autonomous vehicles, in m/s^2
	* target_velocity: desired velocity for all vehicles in the network, in m/s

	States
		The observation consists of the speeds and bumper-to-bumper headways of
		the vehicles immediately preceding and following autonomous vehicle, as
		well as the ego speed of the autonomous vehicles.

		In order to maintain a fixed observation size, when the number of AVs
		in the network is less than "num_rl", the extra entries are filled in
		with zeros. Conversely, if the number of autonomous vehicles is greater
		than "num_rl", the observations from the additional vehicles are not
		included in the state space.

	Actions
		The action space consists of a vector of bounded accelerations for each
		autonomous vehicle $i$. In order to ensure safety, these actions are
		bounded by failsafes provided by the simulator at every time step.

		In order to account for variability in the number of autonomous
		vehicles, if n_AV < "num_rl" the additional actions provided by the
		agent are not assigned to any vehicle. Moreover, if n_AV > "num_rl",
		the additional vehicles are not provided with actions from the learning
		agent, and instead act as human-driven vehicles as well.

	Rewards
		The reward function encourages proximity of the system-level velocity
		to a desired velocity, while slightly penalizing small time headways
		among autonomous vehicles.

	Termination
		A rollout is terminated if the time horizon is reached or if two
		vehicles collide into one another.
	"""

	def __init__(self, env_params, sim_params, network, simulator='traci'):
		for p in ADDITIONAL_ENV_PARAMS.keys():
			if p not in env_params.additional_params:
				raise KeyError(
					'Environment parameter "{}" not supplied'.format(p))

		# used for visualization: the vehicles behind and after RL vehicles
		# (ie the observed vehicles) will have a different color
		self.leader = []
		self.follower = []
		self.local_reward = env_params.additional_params["local"]

		super().__init__(env_params, sim_params, network, simulator)

	@property
	def action_space(self):
		"""See class definition."""
		return Box(
			low=-abs(self.env_params.additional_params["max_decel"]),
			high=self.env_params.additional_params["max_accel"],
			shape=(1,),
			dtype=np.float32)

	@property
	def observation_space(self):
		"""See class definition."""
		return Box(low=-1000, high=1000, shape=(5,), dtype=np.float32)

	def _apply_rl_actions(self, rl_actions):
		"""See class definition."""
		for rl_id in enumerate(self.k.vehicle.get_rl_ids()):
			if rl_id not in rl_actions.keys():
				# the vehicle just entered, so ignore
				continue
			self.k.vehicle.apply_acceleration(rl_id, rl_actions[rl_id])

	def get_state(self, rl_id=None, **kwargs):
		"""See class definition."""
		observation = {}
		self.leader = []
		self.follower = []

		# normalizing constants
		max_speed = self.k.network.max_speed()
		max_length = self.k.network.length()

		for rl_id in self.k.vehicle.get_rl_ids():
			this_speed = self.k.vehicle.get_speed(rl_id)
			lead_id = self.k.vehicle.get_leader(rl_id)
			follower = self.k.vehicle.get_follower(rl_id)

			if lead_id in ["", None]:
				# in case leader is not visible
				lead_speed = max_speed
				lead_head = max_length
			else:
				self.leader.append(lead_id)
				lead_speed = self.k.vehicle.get_speed(lead_id)
				lead_head = self.k.vehicle.get_x_by_id(lead_id) \
					- self.k.vehicle.get_x_by_id(rl_id) \
					- self.k.vehicle.get_length(rl_id)

			if follower in ["", None]:
				# in case follower is not visible
				follow_speed = 0
				follow_head = max_length
			else:
				self.follower.append(follower)
				follow_speed = self.k.vehicle.get_speed(follower)
				follow_head = self.k.vehicle.get_headway(follower)

			observation[rl_id] = np.array([
				this_speed / max_speed,
				(lead_speed - this_speed) / max_speed,
				lead_head / max_length,
				(this_speed - follow_speed) / max_speed,
				follow_head / max_length
			])

		return observation

	def compute_reward(self, rl_actions, **kwargs):
		"""See class definition."""
		if self.env_params.evaluate:
			return np.mean(self.k.vehicle.get_speed(self.k.vehicle.get_ids()))
		else:
			# return a reward of 0 if a collision occurred
			if kwargs["fail"]:
				return {rl_id : 0 for rl_id in self.k.vehicle.get_rl_ids()}
			
			if self.local_reward == "local":
				rew = dict()
				for rl_id in self.k.vehicle.get_rl_ids():
					cost1 = local_desired_velocity(self, [rl_id], fail=kwargs["fail"])
					cost2 = 0
					t_min = 1
					lead_id = self.k.vehicle.get_leader(rl_id)
					if lead_id not in ["", None] \
							and self.k.vehicle.get_speed(rl_id) > 0:
						t_headway = max(
							self.k.vehicle.get_headway(rl_id) /
							self.k.vehicle.get_speed(rl_id), 0)
						cost2 += min((t_headway - t_min) / t_min, 0)
					# weights for cost1 and cost2, respectively
					eta1, eta2 = 1.00, 0.10
					reward = max(eta1 * cost1 + eta2 * cost2, 0)
					rew[rl_id] = reward
				return rew

			if self.local_reward == "partial_first":
				vip = self.k.vehicle.get_rl_ids()[:3]
				cost1 = local_desired_velocity(self, vip, fail=kwargs["fail"])
			elif self.local_reward == "partial_last":
				vip = self.k.vehicle.get_rl_ids()[-3:]
				cost1 = local_desired_velocity(self, vip, fail=kwargs["fail"])
			else:
				# reward high system-level velocities
				cost1 = desired_velocity(self, fail=kwargs["fail"])

			# penalize small time headways
			cost2 = 0
			t_min = 1  # smallest acceptable time headway
			for rl_id in self.k.vehicle.get_rl_ids():
				lead_id = self.k.vehicle.get_leader(rl_id)
				if lead_id not in ["", None] \
						and self.k.vehicle.get_speed(rl_id) > 0:
					t_headway = max(
						self.k.vehicle.get_headway(rl_id) /
						self.k.vehicle.get_speed(rl_id), 0)
					cost2 += min((t_headway - t_min) / t_min, 0)

			# weights for cost1 and cost2, respectively
			eta1, eta2 = 1.00, 0.10

			reward = max(eta1 * cost1 + eta2 * cost2, 0)
			return {key: reward for key in self.k.vehicle.get_rl_ids()}

	def additional_command(self):
		"""See parent class.

		This method defines which vehicles are observed for visualization
		purposes.
		"""
		for veh_id in self.leader + self.follower:
			self.k.vehicle.set_observed(veh_id)

	def reset(self, new_inflow_rate=None):
		"""See parent class.

		In addition, a few variables that are specific to this class are
		emptied before they are used by the new rollout.
		"""
		self.leader = []
		self.follower = []
		return super().reset()
