from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from smac.env.multiagentenv import MultiAgentEnv

import atexit
from operator import attrgetter
from copy import deepcopy
import numpy as np
import enum
import math
from absl import logging


class Listener1Env(MultiAgentEnv):
	"""The StarCraft II environment for decentralised multi-agent
	micromanagement scenarios.
	"""

	def __init__(
			self,
			map_name="8m",
			step_mul=None,
			move_amount=2,
			difficulty="7",
			game_version=None,
			seed=None,
			continuing_episode=False,
			obs_all_health=True,
			obs_own_health=True,
			obs_last_action=False,
			obs_pathing_grid=False,
			obs_terrain_height=False,
			obs_instead_of_state=False,
			obs_timestep_number=False,
			state_last_action=True,
			state_timestep_number=False,
			reward_sparse=False,
			reward_only_positive=True,
			reward_death_value=10,
			reward_win=300,
			reward_defeat=0,
			reward_negative_scale=0.5,
			reward_scale=True,
			reward_scale_rate=20,
			replay_dir="",
			replay_prefix="",
			window_size_x=1920,
			window_size_y=1200,
			debug=False,
			n_agents=10,
			n_actions=10,
			episode_limit=50,
			grid_size=20,
			sight_range=1,
			random_start=True,
			noise1=False,
			noise2=False,
			number_fake_landmarks=None
	):
		# Map arguments
		self.n_agents = n_agents

		# Observations and state
		self.continuing_episode = continuing_episode
		self.obs_instead_of_state = obs_instead_of_state
		self.obs_last_action = obs_last_action
		self.obs_pathing_grid = obs_pathing_grid
		self.obs_terrain_height = obs_terrain_height
		self.state_last_action = state_last_action

		# Rewards args
		self.reward_sparse = reward_sparse
		self.reward_only_positive = reward_only_positive
		self.reward_negative_scale = reward_negative_scale
		self.reward_death_value = reward_death_value
		self.reward_win = reward_win
		self.reward_defeat = reward_defeat
		self.reward_scale = reward_scale
		self.reward_scale_rate = reward_scale_rate

		# Other
		self._seed = seed
		self.debug = debug
		self.window_size = (window_size_x, window_size_y)
		self.replay_dir = replay_dir
		self.replay_prefix = replay_prefix

		# Actions
		self.n_actions = n_actions

		# Map info
		# self._agent_race = map_params["a_race"]
		# self._bot_race = map_params["b_race"]
		# self.map_type = map_params["map_type"]

		self._episode_count = 0
		self._episode_steps = 0
		self._total_steps = 0
		self.last_stats = None
		self.previous_ally_units = None
		self.previous_enemy_units = None
		self.last_action = np.zeros((self.n_agents, self.n_actions))
		self._min_unit_type = 0
		self.max_distance_x = 0
		self.max_distance_y = 0

		self.battles_game = 0
		self.battles_won = 0
		self.episode_limit = episode_limit
		self.sight_range = sight_range
		self.timeouts = 0
		self.t_step = 0

		# next step
		self.perm_now = np.random.permutation(self.n_agents)
		self.action_now = np.array([np.random.randint(self.n_actions) for _ in range(self.n_agents)])

		# Try to avoid leaking SC2 processes on shutdown
		# atexit.register(lambda: self.close())

	def step(self, actions):
		"""Returns reward, terminated, info."""
		self.t_step += 1
		info = {}

		actions = [int(a) for a in actions]

		reward = 0

		for i in range(self.n_agents):
			reward += self.action_now[i] + actions[i] + 1 == self.n_actions
		reward = 100. * reward / self.n_agents

		self.perm_now = np.random.permutation(self.n_agents)
		self.action_now = np.array([np.random.randint(self.n_actions) for _ in range(self.n_agents)])

		terminated = False

		if self.t_step >= self.episode_limit:
			terminated = True
			# print('Lose')
			if self.continuing_episode:
				info['episode_limit'] = True
			self.timeouts += 1

		if terminated:
			self.battles_game += 1

		return reward, terminated, info

	def get_index(self, index):
		if np.random.randint(3) == 0 and index > 0:
			index -= 1
		elif np.random.randint(3) == 0 and index + 1 < self.n_agents:
			index += 1
		return index

	def get_obs(self):
		"""Returns all agent observations in a list."""
		return [self.get_obs_agent(i, self.perm_now[i]) for i in range(self.n_agents)]

	def get_obs_agent(self, agent_id, perm_i):
		"""Returns observation for agent_id."""
		index = self.get_index(perm_i)
		obs = [index, self.action_now[perm_i]]
		return np.array(obs)

	def get_obs_size(self):
		"""Returns the size of the observation."""
		return 2

	def get_state(self):
		"""Returns the global state."""
		return self.action_now

	def get_state_size(self):
		"""Returns the size of the global state."""
		return self.n_agents

	def get_avail_actions(self):
		"""Returns the available actions of all agents in a list."""
		return [self.get_avail_agent_actions(i) for i in range(self.n_agents)]

	def get_avail_agent_actions(self, agent_id):
		"""Returns the available actions for agent_id."""
		return [1] * self.n_actions

	def get_total_actions(self):
		"""Returns the total number of actions an agent could ever take."""
		return self.n_actions

	def reset(self):
		"""Returns initial observations and states."""
		self.perm_now = np.random.permutation(self.n_agents)
		self.action_now = np.array([np.random.randint(self.n_actions) for _ in range(self.n_agents)])

		self.t_step = 0
		return self.get_obs(), self.get_state()

	def render(self):
		pass

	def close(self):
		pass

	def seed(self):
		pass

	def save_replay(self):
		"""Save a replay."""
		pass

	def get_env_info(self):
		env_info = {"state_shape": self.get_state_size(),
		            "obs_shape": self.get_obs_size(),
		            "n_actions": self.get_total_actions(),
		            "n_agents": self.n_agents,
		            "episode_limit": self.episode_limit}
		return env_info

	def get_stats(self):
		stats = {
			"battles_won": self.battles_won,
			"battles_game": self.battles_game,
			"win_rate": self.battles_won / self.battles_game,
			"timeouts": self.timeouts,
			"restarts": 0
		}
		return stats
