from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from smac.env.multiagentenv import MultiAgentEnv

import atexit
from operator import attrgetter
from copy import deepcopy
import numpy as np
import time
import enum
import math
from absl import logging


class K_coloring1Env(MultiAgentEnv):
	"""The StarCraft II environment for decentralised multi-agent
	micromanagement scenarios.
	"""

	def __init__(
			self,
			map_name="8m",
			step_mul=None,
			move_amount=2,
			difficulty="7",
			game_version=None,
			seed=None,
			continuing_episode=False,
			obs_all_health=True,
			obs_own_health=True,
			obs_last_action=False,
			obs_pathing_grid=False,
			obs_terrain_height=False,
			obs_instead_of_state=False,
			obs_timestep_number=False,
			state_last_action=True,
			state_timestep_number=False,
			reward_sparse=False,
			reward_only_positive=True,
			reward_death_value=10,
			reward_win=300,
			reward_defeat=0,
			reward_negative_scale=0.5,
			reward_scale=True,
			reward_scale_rate=20,
			replay_dir="",
			replay_prefix="",
			window_size_x=1920,
			window_size_y=1200,
			debug=False,
			n_agents=10,
			n_actions=10,
			episode_limit=50,
			grid_size=20,
			sight_range=3,
			random_start=True,
			full_obs=False,
			use_one_hot=False,
			is_print=False,
			explicit_credit_assignment=False,
			print_rew=False,
			print_steps=1000,
			noise1=False,
			noise2=False,
			number_fake_landmarks=None
	):
		# Map arguments
		self.n_agents = n_agents

		# Observations and state
		self.continuing_episode = continuing_episode
		self.obs_instead_of_state = obs_instead_of_state
		self.obs_last_action = obs_last_action
		self.obs_pathing_grid = obs_pathing_grid
		self.obs_terrain_height = obs_terrain_height
		self.state_last_action = state_last_action

		# Rewards args
		self.reward_sparse = reward_sparse
		self.reward_only_positive = reward_only_positive
		self.reward_negative_scale = reward_negative_scale
		self.reward_death_value = reward_death_value
		self.reward_win = reward_win
		self.reward_defeat = reward_defeat
		self.reward_scale = reward_scale
		self.reward_scale_rate = reward_scale_rate

		# Other
		self._seed = seed
		self.debug = debug
		self.window_size = (window_size_x, window_size_y)
		self.replay_dir = replay_dir
		self.replay_prefix = replay_prefix

		# Actions
		self.n_actions = n_actions

		# Map info
		# self._agent_race = map_params["a_race"]
		# self._bot_race = map_params["b_race"]
		# self.map_type = map_params["map_type"]

		self._episode_count = 0
		self._episode_steps = 0
		self._total_steps = 0
		self.last_stats = None
		self.previous_ally_units = None
		self.previous_enemy_units = None
		self.last_action = np.zeros((self.n_agents, self.n_actions))
		self._min_unit_type = 0
		self.max_distance_x = 0
		self.max_distance_y = 0

		self.battles_game = 0
		self.battles_won = 0
		self.episode_limit = episode_limit
		self.random_start = random_start
		self.timeouts = 0

		self.full_obs = full_obs
		self.sight_range = sight_range
		self.grid_size = grid_size
		self.eye_grid = np.eye(grid_size)
		self.eye_actions = np.eye(n_actions)
		self.use_one_hot = use_one_hot
		self.is_print = is_print
		self.explicit_credit_assignment = explicit_credit_assignment

		# Try to avoid leaking SC2 processes on shutdown
		# atexit.register(lambda: self.close())

		# Create Grid
		self.grid = np.zeros((self.grid_size + self.sight_range * 2, self.grid_size + self.sight_range * 2))
		self.grid[:, np.arange(self.sight_range)] = -1
		self.grid[np.arange(self.sight_range), :] = -1
		self.grid[:, np.arange(self.sight_range) - self.sight_range] = -1
		self.grid[np.arange(self.sight_range) - self.sight_range, :] = -1

		# Initial positions
		if self.random_start:
			# Agents
			self.agents = [np.random.randint(self.grid_size, size=(2)) + self.sight_range
			               for _ in range(self.n_agents)]
			for a in self.agents:
				self.grid[tuple(a)] = 1
		else:
			assert 1

		self.t_step = 0

		self.print_rew = print_rew
		self.rew_gather = []
		self.print_steps = print_steps
		self.p_step = 0

	def draw(self):
		self.draw_grid = self.grid.copy()
		for i, a in enumerate(self.agents):
			self.draw_grid[tuple(a)] = -i - 2
		for i, agent in enumerate(self.agents):
			for x in range(3, 0, -1):
				if self.sight_range >= x:
					region = self.draw_grid[agent[0] - x: agent[0] + x + 1,
					         agent[1] - x: agent[1] + x + 1]
					draw_list = region == 0
					for y in range(x + 2, 5):
						draw_list += region == y
					region[draw_list] = x + 1
		for i, agent_i in enumerate(self.agents):
			for j, agent_j in enumerate(self.agents):
				if i != j and self.check(i, j):
					if i < j:
						print('pair %c, %c' % (self.change(i + 1), self.change(j + 1)))
					self.used_agent[i] += 1

	def _print(self, actions):
		print('------------------')
		dict = {-1: '#', 0: '.', 2: '-', 3: '=', 4: '*'}
		self.used_agent = np.zeros(self.n_agents)
		self.draw()
		for i in range(self.n_agents):
			dict[-i - 2] = self.change(i + 1)
		for i in range(self.grid_size):
			res = ''
			for j in range(self.grid_size):
				res += dict[int(self.draw_grid[i + self.sight_range, j + self.sight_range])]
			print(res)
		for i, used_i in enumerate(self.used_agent):
			if used_i:
				print('agent %c : %d, action : %d' % (self.change(i + 1), used_i, actions[i]))

	def change(self, i):
		char = str(i) if i < 10 else chr(ord('A') + i - 10)
		return char

	def check(self, i, j):
		return abs(self.agents[i][0] - self.agents[j][0]) <= self.sight_range \
		       and abs(self.agents[i][1] - self.agents[j][1]) <= self.sight_range

	def step(self, actions):
		"""Returns reward, terminated, info."""
		self.t_step += 1
		info = {'battle_won': False}

		# print
		if self.is_print:
			self._print(actions)
			time.sleep(5)

		# run step

		actions = [int(a) for a in actions]
		self.last_action = np.eye(self.n_actions)[np.array(actions)]

		# Rewards and Dones
		reward = 0

		for i, agent_i in enumerate(self.agents):
			flag = 1
			for j, agent_j in enumerate(self.agents):
				if i != j and actions[i] == actions[j] and abs(agent_i[0] - agent_j[0]) <= self.sight_range \
						and abs(agent_i[1] - agent_j[1]) <= self.sight_range:
					flag = 0
					break
			if flag:
				reward += 1

		useful_reward = reward
		base_rew = 0
		for i, agent_i in enumerate(self.agents):
			flag = 1
			for j, agent_j in enumerate(self.agents):
				if i != j and abs(agent_i[0] - agent_j[0]) <= self.sight_range \
						and abs(agent_i[1] - agent_j[1]) <= self.sight_range:
					flag = 0
					break
			if flag:
				useful_reward -= 1
				base_rew += 1

		reward *= 1
		useful_reward *= 1
		base_rew *= 1
		info['reward'] = reward * self.episode_limit
		info['useful_reward'] = useful_reward * self.episode_limit
		info['base_rew'] = base_rew * self.episode_limit

		# Move agents:
		agents = self.agents.copy()
		for i in range(self.n_agents):
			new_x = self.agents[i][0]
			new_y = self.agents[i][1]

			action = np.random.randint(5)

			if action == 0:
				new_x = max(new_x - 1, self.sight_range)
			elif action == 1:
				new_y = min(new_y + 1, self.sight_range + self.grid_size - 1)
			elif action == 2:
				new_x = min(new_x + 1, self.sight_range + self.grid_size - 1)
			elif action == 3:
				new_y = max(new_y - 1, self.sight_range)

			self.agents[i] = np.array([new_x, new_y])

		for m_agent in agents:
			self.grid[tuple(m_agent)] = 0.
		for m_agent in self.agents:
			self.grid[tuple(m_agent)] = 1.

		# Terminate
		terminated = False

		if self.t_step >= self.episode_limit:
			terminated = True
			# print('Lose')
			if self.continuing_episode:
				info['episode_limit'] = True
			self.timeouts += 1

		if terminated:
			self.battles_game += 1

		if self.print_rew:
			self.p_step += 1
			self.rew_gather.append(reward)
			if self.p_step % self.print_steps == 0:
				print('steps: %d, average rew: %.3lf' % (self.p_step,
				                                         float(np.mean(self.rew_gather)) * self.episode_limit))

		return reward, terminated, info

	def get_obs(self):
		"""Returns all agent observations in a list."""
		return [self.get_obs_agent(i) for i in range(self.n_agents)]

	def get_obs_agent(self, agent_id):
		"""Returns observation for agent_id."""

		if self.full_obs:
			if self.use_one_hot:
				return np.concatenate([np.concatenate([self.eye_grid[state[0] - self.sight_range],
				                                       self.eye_grid[state[1] - self.sight_range]], axis=0)
				                       for state in self.agents], axis=0)
			else:
				return np.array(self.agents).reshape(-1)
		else:
			x = self.agents[agent_id][0]
			y = self.agents[agent_id][1]

			obs = []
			around = np.zeros((self.sight_range * 2 + 1, self.sight_range * 2 + 1))
			for i, agent_i in enumerate(self.agents):
				obs_i = around.copy()
				if self.check(agent_id, i):
					obs_i[agent_i[0] - x + self.sight_range][agent_i[1] - y + self.sight_range] = 1
				obs.append(obs_i.reshape(-1))
			obs.append(np.concatenate([self.eye_grid[x - self.sight_range],
			                           self.eye_grid[y - self.sight_range]], axis=0))
			obs = np.concatenate(obs, axis=0)

			return obs

	def get_obs_size(self):
		"""Returns the size of the observation."""
		if self.full_obs:
			if self.use_one_hot:
				return 2 * self.n_agents * self.grid_size
			else:
				return 2 * self.n_agents
		else:
			return (self.sight_range * 2 + 1) ** 2 * self.n_agents + self.grid_size * 2

	def get_state(self):
		"""Returns the global state."""
		graph = []
		for i, agent_i in enumerate(self.agents):
			for j, agent_j in enumerate(self.agents):
				if i != j and self.check(i, j):
					graph.append(1)
				else:
					graph.append(0)
		graph = np.array(graph)
		return graph

	def get_state_size(self):
		"""Returns the size of the global state."""
		return self.n_agents * self.n_agents

	def get_avail_actions(self):
		"""Returns the available actions of all agents in a list."""
		return [self.get_avail_agent_actions(i) for i in range(self.n_agents)]

	def get_avail_agent_actions(self, agent_id):
		"""Returns the available actions for agent_id."""
		return [1] * self.n_actions

	def get_total_actions(self):
		"""Returns the total number of actions an agent could ever take."""
		return self.n_actions

	def reset(self):
		"""Returns initial observations and states."""
		self.grid_size = self.grid_size
		self.grid = np.zeros((self.grid_size + self.sight_range * 2, self.grid_size + self.sight_range * 2))
		self.grid[:, np.arange(self.sight_range)] = -1
		self.grid[np.arange(self.sight_range), :] = -1
		self.grid[:, np.arange(self.sight_range) - self.sight_range] = -1
		self.grid[np.arange(self.sight_range) - self.sight_range, :] = -1

		# Initial positions
		if self.random_start:
			# Agents
			self.agents = [np.random.randint(self.grid_size, size=(2)) + self.sight_range
			               for _ in range(self.n_agents)]
			for a in self.agents:
				self.grid[tuple(a)] = 1
		else:
			assert 1

		# Others
		self.t_step = 0
		self.last_action = np.zeros(shape=(self.n_agents, self.n_actions))

		return self.get_obs()

	def render(self):
		pass

	def close(self):
		pass

	def seed(self):
		pass

	def save_replay(self):
		"""Save a replay."""
		pass

	def get_env_info(self):
		env_info = {"state_shape": self.get_state_size(),
		            "obs_shape": self.get_obs_size(),
		            "n_actions": self.get_total_actions(),
		            "n_agents": self.n_agents,
		            "episode_limit": self.episode_limit}
		return env_info

	def get_stats(self):
		stats = {
			"battles_won": self.battles_won,
			"battles_game": self.battles_game,
			"battles_draw": self.timeouts,
			"win_rate": float(self.battles_won) / self.battles_game,
			"timeouts": self.timeouts,
			"restarts": 0
		}
		return stats
