import gym
import numpy as np
import copy
import random

from absl import logging

from smac.env.multiagentenv import MultiAgentEnv


class PassEnv(MultiAgentEnv):
	def __init__(
			self,
			agent_num=2,
			action_num=4,
			dim_num=2,
			size=30,
			door_open_interval=0,
			episode_length=300,
			fix_start=True,
			simple_env=False,
			debug=False,
			state_last_action=True,
			seed=None):
		"""Create a PassEnv environment."""
		self.debug = debug

		self.seed = random.randint(0, 9999)
		np.random.seed(self.seed)

		if self.debug:
			logging.debug("Creating Pass environment...")

		# Game arguments
		self.n_agent = agent_num
		self.n_action = action_num

		self.n_dim = dim_num
		self.size = size
		self.fix_start = fix_start
		self.simple_env = simple_env
		self.episode_length = episode_length

		# Map
		self.map = np.zeros([self.size, self.size])
		self.map[:, self.size // 2] = -1
		self.map[int(self.size * 0.8), int(self.size * 0.2)] = 1
		self.map[int(self.size * 0.2), int(self.size * 0.8)] = 1

		self.landmarks = [
			np.array([int(self.size * 0.8), int(self.size * 0.2)]),
			np.array([int(self.size * 0.2), int(self.size * 0.8)])]

		# Door
		self.door_open_interval = door_open_interval
		self.door_open = False
		self.door_open_step_count = 0

		# Something else
		self.state_n = [np.array([0, 0]) for _ in range(self.n_agent)]
		self.eye = np.eye(self.size)

		self.t_step = 0
		self.battles_game = 1
		self.battles_won = 0
		self.episode_steps = 0
		self.episode_limit = 300
		self.continuing_episode = True

		# Used by OpenAI baselines
		self.action_space = gym.spaces.Discrete(self.n_action)
		self.observation_space = gym.spaces.Box(low=-1, high=1,
		                                        shape=[self.size * 4])

		if self.debug:
			logging.debug("The Pass environment is available.")

	def get_reward(self):
		count = 0

		for i, state in enumerate(self.state_n):
			if state[1] > self.size // 2:
				count += 1

		reward = [(count >= 2) * 1000, (count >= 2) * 1000]

		return reward[0]

	def is_terminated(self):
		count = 0

		for state in self.state_n:
			if state[1] > self.size // 2:
				count += 1

		if count >= 2 or self.t_step >= self.episode_length:
			return True

		return False

	def step(self, action_n, obs_d=False):
		"""Returns reward, terminated, info."""
		self.t_step += 1
		for i, action in enumerate(action_n):
			new_row = -1
			new_column = -1

			if action == 0:
				new_row = max(self.state_n[i][0] - 1, 0)
				new_column = self.state_n[i][1]
			elif action == 1:
				new_row = self.state_n[i][0]
				new_column = min(self.state_n[i][1] + 1, self.size - 1)
			elif action == 2:
				new_row = min(self.state_n[i][0] + 1, self.size - 1)
				new_column = self.state_n[i][1]
			elif action == 3:
				new_row = self.state_n[i][0]
				new_column = max(self.state_n[i][1] - 1, 0)

			if self.map[new_row][new_column] != -1:
				self.state_n[i] = np.array([new_row, new_column])

		if self.door_open:
			if self.door_open_step_count >= self.door_open_interval:
				self.door_open = False
				self.map[int(self.size * 0.45):int(self.size * 0.55), self.size // 2] = -1
				self.door_open_step_count = 0
			else:
				self.door_open_step_count += 1

		if not self.door_open:
			for landmark_id, landmark in enumerate(self.landmarks):
				for i, state in enumerate(self.state_n):
					if (landmark == state).all():
						if self.simple_env and landmark_id != i:
							continue
						self.door_open = True
						self.map[int(self.size * 0.45):int(self.size * 0.55), self.size // 2] = 0
						self.door_open_step_count = 0
						break

		reward = self.get_reward()
		terminated = self.is_terminated()

		info = {
			'door': self.door_open,
			'battle_won': False,
			'state': copy.deepcopy(self.state_n),
			'win': (reward > 0)
		}

		return reward, terminated, info

	def get_obs(self):
		"""Returns all agent observations in a list."""
		agent_obs = [self.get_obs_agent(i) for i in range(self.n_agent)]
		return agent_obs

	def get_obs_agent(self, agent_id):
		"""Returns observation for agent_id."""
		observation = np.concatenate([
			self.eye[self.state_n[0][0]], self.eye[self.state_n[0][1]],
			self.eye[self.state_n[1][0]], self.eye[self.state_n[1][1]]]
		).copy()

		return observation

	def get_obs_size(self):
		"""Returns the size of the observation."""
		return self.size * 4

	def get_state(self):
		"""Returns the global state."""
		return np.array([0])

	def get_state_size(self):
		"""Returns the size of the global state."""
		return 1

	def get_avail_actions(self):
		"""Returns the available actions of all agents in a list."""
		avail_actions = []
		for agent_id in range(self.n_agent):
			avail_agent = self.get_avail_agent_actions(agent_id)
			avail_actions.append(avail_agent)
		return avail_actions

	def get_avail_agent_actions(self, agent_id):
		"""Returns the available actions for agent_id."""
		return [1 for _ in range(self.n_action)]

	def get_total_actions(self):
		"""Returns the total number of actions an agent could ever take."""
		return self.n_action

	def reset(self):
		self.t_step = 0
		self.door_open = False
		self.door_open_step_count = 0

		if self.fix_start is True:
			self.state_n = [np.array([0, 0]) for _ in range(self.n_agent)]
		else:
			for i in range(self.n_agent):
				self.state_n[i][1] = np.random.randint(self.size // 2)
				self.state_n[i][0] = np.random.randint(self.size)

		self.map[int(self.size * 0.45):int(self.size * 0.55), self.size // 2] = -1

		return self.get_obs(), self.get_state()

	def render(self):
		"""Not implemented."""
		raise NotImplementedError

	def close(self):
		"""Close Pass environment."""
		self.reset()

	def seed(self):
		"""Not implemented."""
		raise NotImplementedError

	def save_replay(self):
		"""Save a replay."""
		raise NotImplementedError

	def get_env_info(self):
		env_info = {"state_shape": self.get_state_size(),
		            "obs_shape": self.get_obs_size(),
		            "n_actions": self.get_total_actions(),
		            "n_agents": self.n_agent,
		            "episode_limit": self.episode_limit}
		return env_info

	def get_stats(self):
		stats = {
			"battles_won": self.battles_won,
			"battles_game": self.battles_game,
			"win_rate": self.battles_won / self.battles_game
		}
		return stats
