from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from smac.env.multiagentenv import MultiAgentEnv

import atexit
from operator import attrgetter
from copy import deepcopy
import numpy as np
import time
import enum
import math
from absl import logging
import random


class Trivial1Env(MultiAgentEnv):
	"""The StarCraft II environment for decentralised multi-agent
	micromanagement scenarios.
	"""

	def __init__(
			self,
			map_name="8m",
			step_mul=None,
			move_amount=2,
			difficulty="7",
			game_version=None,
			seed=None,
			continuing_episode=False,
			obs_all_health=True,
			obs_own_health=True,
			obs_last_action=False,
			obs_pathing_grid=False,
			obs_terrain_height=False,
			obs_instead_of_state=False,
			obs_timestep_number=False,
			state_last_action=True,
			state_timestep_number=False,
			reward_sparse=False,
			reward_only_positive=True,
			reward_death_value=10,
			reward_win=300,
			reward_defeat=0,
			reward_negative_scale=0.5,
			reward_scale=True,
			reward_scale_rate=20,
			replay_dir="",
			replay_prefix="",
			window_size_x=1920,
			window_size_y=1200,
			debug=False,
			n_agents=10,
			n_actions=5,
			episode_limit=50,
			grid_size=20,
			prob=0.,
			random_start=True,
			full_obs=False,
			use_one_hot=False,
			is_print=False,
			explicit_credit_assignment=False,
			print_rew=False,
			print_steps=1000,
			noise1=False,
			noise2=False,
			number_fake_landmarks=None
	):
		# Map arguments
		self.n_agents = n_agents

		# Observations and state
		self.continuing_episode = continuing_episode
		self.obs_instead_of_state = obs_instead_of_state
		self.obs_last_action = obs_last_action
		self.obs_pathing_grid = obs_pathing_grid
		self.obs_terrain_height = obs_terrain_height
		self.state_last_action = state_last_action

		# Rewards args
		self.reward_sparse = reward_sparse
		self.reward_only_positive = reward_only_positive
		self.reward_negative_scale = reward_negative_scale
		self.reward_death_value = reward_death_value
		self.reward_win = reward_win
		self.reward_defeat = reward_defeat
		self.reward_scale = reward_scale
		self.reward_scale_rate = reward_scale_rate

		# Other
		self._seed = seed
		self.debug = debug
		self.window_size = (window_size_x, window_size_y)
		self.replay_dir = replay_dir
		self.replay_prefix = replay_prefix

		# Actions
		self.n_actions = n_actions

		# Map info
		# self._agent_race = map_params["a_race"]
		# self._bot_race = map_params["b_race"]
		# self.map_type = map_params["map_type"]

		self._episode_count = 0
		self._episode_steps = 0
		self._total_steps = 0
		self.last_stats = None
		self.previous_ally_units = None
		self.previous_enemy_units = None
		self.last_action = np.zeros((self.n_agents, self.n_actions))
		self._min_unit_type = 0
		self.max_distance_x = 0
		self.max_distance_y = 0

		self.battles_game = 0
		self.battles_won = 0
		self.episode_limit = episode_limit
		self.random_start = random_start
		self.timeouts = 0

		self.full_obs = full_obs
		self.prob = prob
		self.grid_size = grid_size
		self.eye_grid = np.eye(grid_size)
		self.eye_actions = np.eye(n_actions)
		self.use_one_hot = use_one_hot
		self.is_print = is_print
		self.explicit_credit_assignment = explicit_credit_assignment

		# Try to avoid leaking SC2 processes on shutdown
		# atexit.register(lambda: self.close())

		# Initial positions
		# Agents
		self.agents = [np.random.randint(self.grid_size, size=(2))
		               for _ in range(self.n_agents)]
		# Landmarks
		self.landmarks = [np.random.randint(self.grid_size, size=(2))
		                  for _ in range(self.n_agents)]

		self.t_step = 0

		self.print_rew = print_rew
		self.rew_gather = []
		self.print_steps = print_steps
		self.p_step = 0

	def step(self, actions):
		"""Returns reward, terminated, info."""
		self.t_step += 1
		info = {'battle_won': False}

		# run step

		actions = [int(a) for a in actions]
		self.last_action = np.eye(self.n_actions)[np.array(actions)]

		# Move agents:
		agents = self.agents.copy()
		for i in range(self.n_agents):
			new_x = self.agents[i][0]
			new_y = self.agents[i][1]

			action = random.randint(0, self.n_actions - 1) if random.random() < self.prob else actions[i]

			if action == 0:
				new_x = max(new_x - 1, 0)
			elif action == 1:
				new_y = min(new_y + 1, self.grid_size - 1)
			elif action == 2:
				new_x = min(new_x + 1, self.grid_size - 1)
			elif action == 3:
				new_y = max(new_y - 1, 0)

			self.agents[i] = np.array([new_x, new_y])

		# Rewards and Dones
		reward = 0
		for i, agent_i in enumerate(self.agents):
			reward += (agent_i == self.landmarks[i]).all()

		# Terminate
		terminated = False

		if self.t_step >= self.episode_limit:
			terminated = True
			# print('Lose')
			if self.continuing_episode:
				info['episode_limit'] = True
			self.timeouts += 1

		if terminated:
			self.battles_game += 1

		return reward, terminated, info

	def get_obs(self):
		"""Returns all agent observations in a list."""
		return [self.get_obs_agent(i) for i in range(self.n_agents)]

	def get_obs_agent(self, agent_id):
		"""Returns observation for agent_id."""

		obs = np.concatenate([self.eye_grid[self.agents[agent_id][0]],
		                      self.eye_grid[self.agents[agent_id][1]],
		                      self.eye_grid[self.landmarks[agent_id][0]],
		                      self.eye_grid[self.landmarks[agent_id][1]]], axis=0)

		return obs

	def get_obs_size(self):
		"""Returns the size of the observation."""
		return self.grid_size * 4

	def get_state(self):
		"""Returns the global state."""
		obs = np.concatenate([np.concatenate([self.eye_grid[self.agents[agent_id][0]],
		                      self.eye_grid[self.agents[agent_id][1]],
		                      self.eye_grid[self.landmarks[agent_id][0]],
		                      self.eye_grid[self.landmarks[agent_id][1]]], axis=0)
		                      for agent_id in range(self.n_agents)], axis=0)
		return obs

	def get_state_size(self):
		"""Returns the size of the global state."""
		return self.grid_size * 4 * self.n_agents

	def get_avail_actions(self):
		"""Returns the available actions of all agents in a list."""
		return [self.get_avail_agent_actions(i) for i in range(self.n_agents)]

	def get_avail_agent_actions(self, agent_id):
		"""Returns the available actions for agent_id."""
		return [1] * self.n_actions

	def get_total_actions(self):
		"""Returns the total number of actions an agent could ever take."""
		return self.n_actions

	def reset(self):
		"""Returns initial observations and states."""
		# Agents
		self.agents = [np.random.randint(self.grid_size, size=(2))
		               for _ in range(self.n_agents)]
		# Landmarks
		self.landmarks = [np.random.randint(self.grid_size, size=(2))
		                  for _ in range(self.n_agents)]

		# Others
		self.t_step = 0
		self.last_action = np.zeros(shape=(self.n_agents, self.n_actions))

		return self.get_obs()

	def render(self):
		pass

	def close(self):
		pass

	def seed(self):
		pass

	def save_replay(self):
		"""Save a replay."""
		pass

	def get_env_info(self):
		env_info = {"state_shape": self.get_state_size(),
		            "obs_shape": self.get_obs_size(),
		            "n_actions": self.get_total_actions(),
		            "n_agents": self.n_agents,
		            "episode_limit": self.episode_limit}
		return env_info

	def get_stats(self):
		stats = {
			"battles_won": self.battles_won,
			"battles_game": self.battles_game,
			"battles_draw": self.timeouts,
			"win_rate": float(self.battles_won) / self.battles_game,
			"timeouts": self.timeouts,
			"restarts": 0
		}
		return stats
