import gym
import numpy as np
import random

from absl import logging

from smac.env.multiagentenv import MultiAgentEnv


class PushBallEnv(MultiAgentEnv):
	def __init__(
			self,
			agent_num=2,
			action_num=5,
			ball_num=1,
			dim_num=2,
			size=15,
			episode_length=300,
			debug=False,
			state_last_action=True,
			seed=None):
		"""Create Pushball environment."""
		self.debug = debug

		self.seed = random.randint(0, 9999)
		np.random.seed(self.seed)

		if self.debug:
			logging.debug("Creating Pushball environment...")

		# Game arguments
		self.n_agent = agent_num
		self.n_action = action_num
		self.n_ball = ball_num

		self.n_dim = dim_num
		self.size = size
		self.episode_length = episode_length
		self.map = np.zeros([self.size, self.size])

		# Reward
		self.reward_wall = [1000, 1000, 1000, 1000]

		# Something else
		self.state_n = [np.array([1, 1]) for _ in range(self.n_agent)]
		self.ball_n = self.random_start_n()
		self.eye = np.eye(self.size)

		self.win = 0
		self.t_step = 0
		self.battles_game = 1
		self.battles_won = 0
		self.episode_steps = 0
		self.episode_limit = 300
		self.continuing_episode = True

		# Used by OpenAI baselines
		self.action_space = gym.spaces.Discrete(self.n_action)
		self.observation_space = gym.spaces.Box(low=-1, high=1, shape=[
			self.size * 2 * self.n_agent +
			self.size * 2 * self.n_ball
		])

		if self.debug:
			logging.debug("The Pushball environment is available.")

	def random_start(self):
		return np.array(np.random.randint(4, 12, [self.n_dim]))

	def random_start_n(self):
		ball_list = []
		for i in range(self.n_ball):
			while True:
				new_ball = self.random_start()
				flag = True
				for j in range(i):
					if (new_ball == ball_list[j]).all():
						flag = False
				for j, t_state in enumerate(self.state_n):
					if (new_ball == t_state).all():
						flag = False
				if flag:
					break
			ball_list.append(new_ball)
		return ball_list

	def get_reward(self):
		reward = 0
		win_count = np.zeros(4)
		for i, ball in enumerate(self.ball_n):
			pre_reward = reward
			if ball[0] == 0:
				reward += self.reward_wall[0]
				win_count[0] += 1
			if ball[1] == 0:
				reward += self.reward_wall[1]
				win_count[1] += 1
			if ball[0] == self.size - 1:
				reward += self.reward_wall[2]
				win_count[2] += 1
			if ball[1] == self.size - 1:
				reward += self.reward_wall[3]
				win_count[3] += 1
			if reward > pre_reward:
				while True:
					new_ball = self.random_start()
					flag = True
					for j, t_ball in enumerate(self.ball_n):
						if (new_ball == t_ball).all():
							flag = False
					for j, t_state in enumerate(self.state_n):
						if (new_ball == t_state).all():
							flag = False
					if flag:
						break
				self.ball_n[i] = new_ball

		return reward, win_count

	def is_terminated(self):
		if self.t_step >= self.episode_length:
			return True

		return False

	def step(self, action_n):
		"""Returns reward, terminated, info."""
		self.t_step += 1

		ball_count = [np.zeros((5)) for i in range(self.n_ball)]

		for i, action in enumerate(action_n):
			new_row = -1
			new_column = -1

			if action == 0:
				new_row = max(self.state_n[i][0] - 1, 1)
				new_column = self.state_n[i][1]
			elif action == 1:
				new_row = self.state_n[i][0]
				new_column = min(self.state_n[i][1] + 1, self.size - 2)
			elif action == 2:
				new_row = min(self.state_n[i][0] + 1, self.size - 2)
				new_column = self.state_n[i][1]
			elif action == 3:
				new_row = self.state_n[i][0]
				new_column = max(self.state_n[i][1] - 1, 1)
			elif action == 4:
				new_row = self.state_n[i][0]
				new_column = self.state_n[i][1]

			for j, ball in enumerate(self.ball_n):
				if (self.state_n[i] != ball).any() and new_row == ball[0] and new_column == ball[1]:
					ball_count[j][action] += 1
					assert (action < 5)

		new_ball_n = []
		# Move Ball
		for i, ball in enumerate(self.ball_n):

			move_x = 0
			if ball_count[i][0] - ball_count[i][2] >= 2:
				move_x = -1
			if ball_count[i][2] - ball_count[i][0] >= 2:
				move_x = 1

			move_y = 0
			if ball_count[i][3] - ball_count[i][1] >= 2:
				move_y = -1
			if ball_count[i][1] - ball_count[i][3] >= 2:
				move_y = 1

			new_ball = np.array([ball[0] + move_x, ball[1] + move_y])

			flag = True
			for j, t_ball in enumerate(self.ball_n):
				if (new_ball == t_ball).all():
					flag = False
			for j, t_state in enumerate(self.state_n):
				if (new_ball == t_state).all():
					flag = False

			if flag:
				new_ball_n.append(new_ball)
			else:
				new_ball_n.append(ball)
		self.ball_n = new_ball_n

		for i, action in enumerate(action_n):
			new_row = -1
			new_column = -1

			if action == 0:
				new_row = max(self.state_n[i][0] - 1, 1)
				new_column = self.state_n[i][1]
			elif action == 1:
				new_row = self.state_n[i][0]
				new_column = min(self.state_n[i][1] + 1, self.size - 2)
			elif action == 2:
				new_row = min(self.state_n[i][0] + 1, self.size - 2)
				new_column = self.state_n[i][1]
			elif action == 3:
				new_row = self.state_n[i][0]
				new_column = max(self.state_n[i][1] - 1, 1)
			elif action == 4:
				new_row = self.state_n[i][0]
				new_column = self.state_n[i][1]

			flag = False
			for j, ball in enumerate(self.ball_n):
				if (self.state_n[i] != ball).any() and new_row == ball[0] and new_column == ball[1]:
					assert (action < 5)
					flag = True

			if flag:
				new_row = self.state_n[i][0]
				new_column = self.state_n[i][1]

			self.state_n[i] = np.array([new_row, new_column])

		ball_info = np.concatenate([[ball[0], ball[1]] for ball in self.ball_n], axis=None)
		info_state_n = []
		for i, state in enumerate(self.state_n):
			full_state = np.concatenate([state, ball_info], axis=0)
			info_state_n.append(full_state)

		reward, reward_info = self.get_reward()
		terminated = self.is_terminated()

		self.win += reward > 0

		info = {
			'ball': self.n_ball,
			"battle_won": False,
			'state': info_state_n,
			'ext': 0
		}

		if terminated:
			info['ext'] += self.win * 1000

		return reward, terminated, info

	def get_obs(self):
		"""Returns all agent observations in a list."""
		agent_obs = [self.get_obs_agent(i) for i in range(self.n_agent)]
		return agent_obs

	def get_obs_agent(self, agent_id):
		"""Returns observation for agent_id."""
		ball = np.concatenate([
			np.concatenate([self.eye[ball[0]], self.eye[ball[1]]], axis=0)
			for ball in self.ball_n
		], axis=0)
		observation = np.concatenate([
			self.eye[self.state_n[0][0]], self.eye[self.state_n[0][1]],
			self.eye[self.state_n[1][0]], self.eye[self.state_n[1][1]],
			ball
		]).copy()

		return observation

	def get_obs_size(self):
		"""Returns the size of the observation."""
		return self.size * 2 * self.n_agent + self.size * 2 * self.n_ball

	def get_state(self):
		"""Returns the global state."""
		return np.array([0])

	def get_state_size(self):
		"""Returns the size of the global state."""
		return 1

	def get_avail_actions(self):
		"""Returns the available actions of all agents in a list."""
		avail_actions = []
		for agent_id in range(self.n_agent):
			avail_agent = self.get_avail_agent_actions(agent_id)
			avail_actions.append(avail_agent)
		return avail_actions

	def get_avail_agent_actions(self, agent_id):
		"""Returns the available actions for agent_id."""
		return [1 for _ in range(self.n_action)]

	def get_total_actions(self):
		"""Returns the total number of actions an agent could ever take."""
		return self.n_action

	def reset(self):
		self.win = 0
		self.t_step = 0
		self.state_n = [np.array([1, 1]) for _ in range(self.n_agent)]
		self.ball_n = self.random_start_n()

		return self.get_obs(), self.get_state()

	def render(self):
		"""Not implemented."""
		raise NotImplementedError

	def close(self):
		"""Close Pass environment."""
		self.reset()

	def seed(self):
		"""Not implemented."""
		raise NotImplementedError

	def save_replay(self):
		"""Save a replay."""
		raise NotImplementedError

	def get_env_info(self):
		env_info = {"state_shape": self.get_state_size(),
					"obs_shape": self.get_obs_size(),
					"n_actions": self.get_total_actions(),
					"n_agents": self.n_agent,
					"episode_limit": self.episode_limit}
		return env_info

	def get_stats(self):
		stats = {
			"battles_won": self.battles_won,
			"battles_game": self.battles_game,
			"win_rate": self.battles_won / self.battles_game
		}
		return stats
