import gym
import numpy as np
import random

from absl import logging

from smac.env.multiagentenv import MultiAgentEnv


class IslandEnv(MultiAgentEnv):
	def __init__(
			self,
			agent_num=2,
			wolf_num=1,
			action_num=6,
			landmark_num=9,
			dim_num=2,
			size=10,
			attack_range=1,
			episode_length=300,
			partial_obs=True,
			wolf_recover_time=1000,
			agent_max_power=11,
			wolf_max_power=9,
			debug=False,
			state_last_action=True,
			seed=None):
		"""Create a IslandEnv environment."""
		self.debug = debug

		self.seed = random.randint(0, 9999)
		np.random.seed(self.seed)

		if self.debug:
			logging.debug("Creating Island environment...")

		# Game arguments
		self.n_agent = agent_num
		self.n_wolf = wolf_num
		self.n_action = action_num
		self.n_landmark = landmark_num

		self.n_dim = dim_num
		self.size = size
		self.attack_range = attack_range
		self.episode_length = episode_length
		self.is_full_obs = not partial_obs

		# Landmarks
		if self.n_landmark == 2:
			self.landmarks = np.array(
				[[int(self.size * 0.8), int(self.size * 0.2)],
				 [int(self.size // 2), int(self.size * 0.8)]])
		elif self.n_landmark == 3:
			self.landmarks = np.array(
				[[int(self.size * 0.8), int(self.size * 0.2)],
				 [int(self.size // 2), int(self.size * 0.8)],
				 [int(self.size // 2), int(self.size // 2)]])
		elif self.n_landmark == 4:
			self.landmarks = np.array(
				[[int(self.size * 0.8), int(self.size * 0.2)],
				 [int(self.size // 2), int(self.size * 0.8)],
				 [int(self.size * 0.8), int(self.size * 0.8)],
				 [int(self.size // 2), int(self.size * 0.2)]])
		elif self.n_landmark == 9:
			self.landmarks = []
			location_list = [int(self.size * 0.2), int(self.size // 2),
							 int(self.size * 0.8)]
			for x in location_list:
				for y in location_list:
					self.landmarks.append([x, y])
			self.landmarks = np.array(self.landmarks)

		self.landmark_visited = [False for _ in range(self.n_landmark)]

		# Movable agent
		self.state_n = [np.array([0, 0]) for _ in range(self.n_agent)]
		self.agent_alive_n = [True for _ in range(self.n_agent)]
		self.done_n = [False for _ in range(self.n_agent)]

		# Movable wolf
		self.wolf_n = [self.random_wolf() for _ in range(self.n_wolf)]
		self.wolf_alive_n = [True for _ in range(self.n_wolf)]
		self.pre_wolf_alive_n = [True for _ in range(self.n_wolf)]

		self.eye = np.eye(self.size)
		self.flag = np.eye(2)

		self.wolf_recover_time = wolf_recover_time

		# Power
		self.agent_max_power = agent_max_power
		self.agent_power_eye = np.eye(self.agent_max_power)
		self.agent_power = np.array(
			[self.agent_max_power - 1 for _ in range(self.n_agent)])
		self.agent_power_zeros_like = np.zeros_like(self.agent_power)

		self.wolf_max_power = wolf_max_power
		self.wolf_power_eye = np.eye(self.wolf_max_power)
		self.wolf_power = np.array(
			[self.wolf_max_power - 1 for _ in range(self.n_wolf)])
		self.died_wolf_number = 0

		# Something else
		self.t_landmarks = 0
		self.t_kill = 0
		self.t_step = 0
		self.agent_score = [[0. for _ in range(self.n_agent)] for _ in
							range(self.n_wolf)]
		self.battles_game = 1
		self.battles_won = 0
		self.episode_steps = 0
		self.episode_limit = 300
		self.continuing_episode = True

		# OpenAI baselines
		self.action_space = gym.spaces.Discrete(self.n_action)
		self.observation_space = gym.spaces.Box(low=-1, high=1,
												shape=[(self.size * 2 + self.agent_max_power) * self.n_agent +
													   (self.size * 2) * self.n_wolf +
													   self.n_landmark +
													   self.n_wolf * int(self.is_full_obs)])

		if self.debug:
			logging.debug("The Island environment is available.")

	def random_wolf(self):
		while True:
			wolf = np.random.randint(0, self.size, [self.n_dim])
			flag = False
			for i, state in enumerate(self.state_n):
				if self.agent_alive_n[i] and \
						wolf[0] - 1 <= state[0] <= wolf[0] + 1 and \
						wolf[1] - 1 <= state[1] <= wolf[1] + 1:
					flag = True
					break
			if not flag:
				break
		return wolf

	def get_reward(self):
		reward = np.array([0. for _ in range(self.n_agent)])

		num_kill_info = 0

		for w_i, wolf in enumerate(self.wolf_n):
			if self.pre_wolf_alive_n[w_i] and not self.wolf_alive_n[w_i]:
				self.pre_wolf_alive_n[w_i] = False
				reward += 300.
				num_kill_info += 1

		landmark_num = 0
		time_length = []
		for agent_index in range(self.n_agent):
			if self.agent_alive_n[agent_index]:
				time_length.append(1.)
				for i, landmark in enumerate(self.landmarks):
					if not self.landmark_visited[i] and (landmark == self.state_n[agent_index]).all():
						landmark_num += 1
						reward += 10.
						self.landmark_visited[i] = True
			else:
				time_length.append(0.)

		info = {}
		info['kill'] = num_kill_info
		info['landmark'] = landmark_num
		info['death'] = self.done_n
		info['time_length'] = time_length

		return reward[0], info

	def is_terminated(self):
		if self.t_step >= self.episode_length or (self.agent_power == self.agent_power_zeros_like).any():
			return True

		return False

	def step(self, action_n):
		"""Returns reward, terminated, info."""
		if self.t_step % self.wolf_recover_time == 0:
			for i in range(self.n_wolf):
				if self.wolf_alive_n[i]:
					self.wolf_power[i] = min(self.wolf_power[i] + 1, self.wolf_max_power - 1)

		self.t_step += 1

		for w_i, wolf in enumerate(self.wolf_n):
			if self.wolf_alive_n[w_i]:
				t_s = 0
				for i, action in enumerate(action_n):
					if self.agent_alive_n[i]:
						if action == 5:
							if wolf[0] - self.attack_range <= self.state_n[i][0] <= wolf[0] + self.attack_range and \
									wolf[1] - self.attack_range <= self.state_n[i][1] <= wolf[1] + self.attack_range:
								t_s += 1
				t_harm = t_s
				a_score = 1
				if t_s > 1:
					t_harm += t_s
					a_score = 1. * t_harm / t_s
				self.wolf_power[w_i] = max(self.wolf_power[w_i] - t_harm, 0)

				for i, action in enumerate(action_n):
					if self.agent_alive_n[i]:
						if action == 5:
							if wolf[0] - self.attack_range <= self.state_n[i][0] <= wolf[0] + self.attack_range and \
									wolf[1] - self.attack_range <= self.state_n[i][1] <= wolf[1] + self.attack_range:
								self.agent_score[w_i][i] += a_score

		for i, action in enumerate(action_n):
			if self.agent_alive_n[i]:
				new_row = -1
				new_column = -1
				if action == 0:
					new_row = max(self.state_n[i][0] - 1, 0)
					new_column = self.state_n[i][1]
				elif action == 1:
					new_row = self.state_n[i][0]
					new_column = min(self.state_n[i][1] + 1, self.size - 1)
				elif action == 2:
					new_row = min(self.state_n[i][0] + 1, self.size - 1)
					new_column = self.state_n[i][1]
				elif action == 3:
					new_row = self.state_n[i][0]
					new_column = max(self.state_n[i][1] - 1, 0)
				elif action == 4 or action == 5:
					new_row = self.state_n[i][0]
					new_column = self.state_n[i][1]
				self.state_n[i] = np.array([new_row, new_column])

		for w_i, wolf in enumerate(self.wolf_n):
			if self.wolf_alive_n[w_i] and self.wolf_power[w_i] == 0:
				self.wolf_alive_n[w_i] = False
				self.died_wolf_number += 1

		for w_i, wolf in enumerate(self.wolf_n):
			if self.wolf_alive_n[w_i]:
				t_sum = 0
				for i in range(self.n_agent):
					if self.agent_alive_n[i] and \
							wolf[0] - self.attack_range <= self.state_n[i][0] <= wolf[0] + self.attack_range and \
							wolf[1] - self.attack_range <= self.state_n[i][1] <= wolf[1] + self.attack_range:
						t_sum += 1
				if t_sum > 0:
					t_sum = 2 // t_sum
				for i in range(self.n_agent):
					if self.agent_alive_n[i] and \
							wolf[0] - self.attack_range <= self.state_n[i][0] <= wolf[0] + self.attack_range and \
							wolf[1] - self.attack_range <= self.state_n[i][1] <= wolf[1] + self.attack_range:
						self.agent_power[i] = max(self.agent_power[i] - t_sum, 0)

		for i, state in enumerate(self.state_n):
			if self.agent_alive_n[i] and self.agent_power[i] == 0:
				self.agent_alive_n[i] = False
				self.done_n[i] = True
			else:
				self.done_n[i] = False

		# Move Wolf
		for i, wolf in enumerate(self.wolf_n):
			if self.wolf_alive_n[i]:
				new_row = new_column = 0
				action = np.random.randint(5)

				if action == 0:
					new_row = max(wolf[0] - 1, 0)
					new_column = wolf[1]
				elif action == 1:
					new_row = wolf[0]
					new_column = min(wolf[1] + 1, self.size - 1)
				elif action == 2:
					new_row = min(wolf[0] + 1, self.size - 1)
					new_column = wolf[1]
				elif action == 3:
					new_row = wolf[0]
					new_column = max(wolf[1] - 1, 0)
				elif action == 4:
					new_row = wolf[0]
					new_column = wolf[1]

				self.wolf_n[i] = np.array([new_row, new_column])

		wolf_info = np.concatenate([[wolf[0], wolf[1]]
		                            for i, wolf in enumerate(self.wolf_n)], axis=0)
		state_n = []
		for i, state in enumerate(self.state_n):
			full_state = np.concatenate([state, [self.agent_power[i]], wolf_info], axis=0)
			state_n.append(full_state)

		reward, reward_info = self.get_reward()
		self.t_landmarks += reward_info['landmark']
		self.t_kill += reward_info['kill']

		terminated = self.is_terminated()

		info = {
			"wolf": self.n_wolf,
			"battle_won": False,
			'state': state_n,
			'kill': 0,
			'landmark': 0,
			'ext': 0
		}

		if terminated:
			info['kill'] += self.t_kill
			info['landmark'] += self.t_landmarks
			info['ext'] += self.t_kill * 300 + self.t_landmarks * 10

		return reward, terminated, info

	def get_obs(self):
		"""Returns all agent observations in a list."""
		agents_obs = [self.get_obs_agent(i) for i in range(self.n_agent)]
		return agents_obs

	def get_obs_agent(self, agent_id):
		"""Returns observation for agent_id."""
		wolf_state = np.concatenate([
			np.concatenate([
				self.eye[wolf[0]], self.eye[wolf[1]]
			], axis=0) for i, wolf in enumerate(self.wolf_n)
		], axis=0)

		agent_state = np.concatenate([np.concatenate([
			self.eye[state[0]],
			self.eye[state[1]],
			self.agent_power_eye[self.agent_power[i]]
			], axis=0) for i, state in enumerate(self.state_n)
		], axis=0)

		observation = [agent_state, wolf_state, self.landmark_visited]
		if self.is_full_obs:
			observation.append(self.wolf_alive_n)
		observation = np.concatenate(observation).copy()

		return observation

	def get_obs_size(self):
		"""Returns the size of the observation."""
		return (self.size * 2 + self.agent_max_power) * self.n_agent + \
			   (self.size * 2) * self.n_wolf + \
			   self.n_landmark + \
			   self.n_wolf * int(self.is_full_obs)

	def get_state(self):
		"""Returns the global state."""
		return np.array([0])

	def get_state_size(self):
		"""Returns the size of the global state."""
		return 1

	def get_avail_actions(self):
		"""Returns the available actions of all agents in a list."""
		avail_actions = []
		for agent_id in range(self.n_agent):
			avail_agent = self.get_avail_agent_actions(agent_id)
			avail_actions.append(avail_agent)
		return avail_actions

	def get_avail_agent_actions(self, agent_id):
		"""Returns the available actions for agent_id."""
		return [1 for _ in range(self.n_action)]

	def get_total_actions(self):
		"""Returns the total number of actions an agent could ever take."""
		return self.n_action

	def reset(self):
		"""Returns initial observations and states."""
		self.landmark_visited = [False for _ in range(self.n_landmark)]

		self.state_n = [np.array([0, 0]) for _ in range(self.n_agent)]
		self.agent_alive_n = [True for _ in range(self.n_agent)]
		self.done_n = [False for _ in range(self.n_agent)]

		self.wolf_n = [self.random_wolf() for _ in range(self.n_wolf)]
		self.wolf_alive_n = [True for _ in range(self.n_wolf)]
		self.pre_wolf_alive_n = [True for _ in range(self.n_wolf)]

		self.agent_power = np.array([self.agent_max_power - 1 for _ in range(self.n_agent)])
		self.wolf_power = np.array([self.wolf_max_power - 1 for _ in range(self.n_wolf)])
		self.died_wolf_number = 0

		self.t_landmarks = 0
		self.t_kill = 0
		self.t_step = 0
		self.agent_score = [[0. for _ in range(self.n_agent)] for _ in range(self.n_wolf)]

		return self.get_obs(), self.get_state()

	def render(self):
		"""Not implemented."""
		raise NotImplementedError

	def close(self):
		"""Close Island environment."""
		self.reset()

	def seed(self):
		"""Not implemented."""
		raise NotImplementedError

	def save_replay(self):
		"""Save a replay."""
		raise NotImplementedError

	def get_env_info(self):
		env_info = {"state_shape": self.get_state_size(),
					"obs_shape": self.get_obs_size(),
					"n_actions": self.get_total_actions(),
					"n_agents": self.n_agent,
					"episode_limit": self.episode_limit}
		return env_info

	def get_stats(self):
		stats = {
			"battles_won": self.battles_won,
			"battles_game": self.battles_game,
			"win_rate": self.battles_won / self.battles_game
		}
		return stats
