import gym
from gym import spaces
import pygame
from ple import PLE_multiplayer as PLEm
from ple.games.pong import Pong
import numpy as np
import warnings
import math

class statePreproc():
	def __init__(self, s1, s2):
		self.player1_state = s1.copy()
		self.player2_state = s2.copy()
		self.shape = s1.shape

	def copy(self):
		return statePreproc(self.player1_state, self.player2_state)


def nv_state_preprocessor(state):
	"""
		This preprocesses our state from PLE. We rescale the values to be between
		0,1 and -1,1.
	"""
	# taken by inspection of source code. Better way is on its way!
	max_values = np.array([200., 11.0, 200., 256.0, 200., 150.0, 150.0, 30.])
	state1 = np.array(list(state['player1_state'].values())) / max_values
	state2 = np.array(list(state['player2_state'].values())) / max_values

	return statePreproc(state1, state2)


def reward_return2sender(reward, state, game_h):
	# Clip reward to remove default +1 reward for win
	reward = np.clip(reward, -1., 0.)

	# Check if the ball has bounced off the top or bottom of the screen
	up_bounce = (state[4]*game_h - np.round(0.03 * game_h) - 1 <= 0.)
	down_bounce = (np.round(0.03 * game_h) + state[4]*game_h + 1 >= game_h)
	# Bounces are harder for the degault AI to handle so we don't want it to bounce too much in defensive play
	# Penalize any bounce while returning the ball (checked with state[5]) -> bounces on receive are ignored
	# Otherwise just get a small bonus for every frame of the episode
	if (up_bounce or down_bounce) and state[5] > 0 :
		reward -= 1.
	else:
		reward += 0.1

	return reward

def reward_devensive(reward, state, last_state, game_h, done):
	reward = reward_return2sender(reward, state, game_h)

	# Check if p1 or p2 hit the ball
	p1hit, p2hit = False, False
	if (not done) and (np.sign(state[5]) != last_state[5]):
		if state[5] - last_state[5] > 0:
			p1hit = True
		elif state[5] - last_state[5] < 0:
			p2hit = True

	# Points for hits to tell the agent that hits are good (and misses are not)
	if p1hit:
		reward += 0.5
	if p2hit:
		reward += 0.1

	return reward



class PongGym(gym.Env):
	def __init__(self, display_screen=False, player1_speed_ratio=0.5, player2_speed_ratio=0.5, max_score=1, flip=False, **kwargs):
		"""
		Initialization of gym-like pong environment
		- Key inputs:
			-- display_screen : (bool) render every frame?
			-- playerX_speed_ratio : speed setting for players between 0 and 1
			-- max_score : max score per episode of a game
			-- flip : (bool) control player 2 instead? <for debug purposes for future multi-player interaction>
		"""
		self.game_w, self.game_h = 256, 200
		game = Pong(width=self.game_w, height=self.game_h, 
					player1_speed_ratio=player1_speed_ratio, player2_speed_ratio=player2_speed_ratio, MAX_SCORE=max_score,
					p2_enabled=flip)
		self.env = PLEm(game, fps=60, state_preprocessor=nv_state_preprocessor, display_screen=display_screen)

		self.valid_acts = self.env.getActionSet() # p1up, p1down, p1rot_in, p1rot_out, p2up, p2down, p2rot_in, p2rot_out, no-op
		# obs_dim = self.env.getGameStateDims()
		self.obs_low = np.array([0., -1., 0., 0., 0., -1., -1., -30.])
		self.obs_high = np.array([200., 11.0, 200., 256.0, 200., 150.0, 150.0, 30.])

		self.action_space = spaces.Discrete(5) # up, down, rotate in, rotate out, no-op
		self.observation_space = spaces.Box(self.obs_low, self.obs_high)

		self.flip = flip

	def step(self, p1_action, p2_action=None):
		if self.flip:
			p1_, p2_ = p2_action, p1_action
			p1_action, p2_action = p1_, p2_

		# Check action validity
		if p1_action is not None:
			assert self.action_space.contains(p1_action), "%r (%s) invalid" % (p1_action, type(p1_action))
		if p2_action is not None:
			assert self.action_space.contains(p2_action), "%r (%s) invalid" % (p2_action, type(p2_action))

		rewards = self.act(p1_action, p2_action)
		next_state = self.env.getGameState()
		done = self.env.game_over()

		if self.flip:
			return next_state.player2_state, rewards['player2'], done, {'p1state': next_state.player1_state, 'p1reward':rewards['player1']}
		else:
			return next_state.player1_state, rewards['player1'], done, {'p2state': next_state.player2_state, 'p2reward':rewards['player2']}

	def act(self, p1action, p2action):
		if (p1action == 4) or (p1action is None):
			p1act = self.valid_acts[-1]
		else:
			p1act = self.valid_acts[p1action]
		
		if p2action is not None:
			p2action += 4
			p2act = self.valid_acts[p2action]
		else:
			p2act = self.valid_acts[-1]

		actions = [p1act, p2act]
		
		return self.env.act(actions)

	def reset(self):
		self.env.reset_game()
		obs = self.env.getGameState()
		return obs.player1_state

	def getP2state(self):
		obs = self.env.getGameState()
		return obs.player2_state		

	def render(self, mode, **kwargs):
		warnings.warn("Set `display_screen` for displaying game. Render function not used")
		pass

	def close(self):
		pygame.quit()
		return

	def seed(self, seed):
		warnings.warn('Seed is not implemented for PongGym')
		pass

class PongGym_Basic(gym.Wrapper):
	"""
	Goal : (basic) Only paddle position control, no rotation.
	"""
	def __init__(self, pongGym):
		super().__init__(pongGym)
		self.action_space = spaces.Discrete(3) # up, down, no-op
		self.obs_low = np.array([0., -1., 0., 0., 0., -1., -1.])
		self.obs_high = np.array([200., 11.0, 200., 256.0, 200., 150.0, 150.0])
		self.observation_space = spaces.Box(self.obs_low, self.obs_high)

	def step(self, action):
		if action == 2:
			action = 4
		next_state, reward, done, info = self.env.step(action)
		reward = np.clip(reward, -1., 1.)
		return next_state[:-1], reward, done, info

	def reset(self):
		obs = self.env.reset()
		return obs[:-1]

class PongGym_ReturnToOpponent(gym.Wrapper):
	"""
	Goal : (defensive) Keep the game going as long as possible
	"""
	def __init__(self, pongGym):
		super().__init__(pongGym)

	def step(self, action):
		next_state, reward, done, info = self.env.step(action)

		reward = reward_return2sender(reward, next_state, self.game_h)
		
		return next_state, reward, done, info

	def reset(self):
		return self.env.reset()

class PongGym_Defensive(gym.Wrapper):
	"""
	Goal : (defensive) Have as many paddle hits as possible to extend game time and keep game going as long as possible
	"""
	def __init__(self, pongGym):
		super().__init__(pongGym)
		self.last_state = None

	def step(self, action):
		next_state, reward, done, info = self.env.step(action)
		
		reward = reward_devensive(reward, next_state, self.last_state, self.game_h, done)

		self.last_state = next_state.copy()
		return next_state, reward, done, info

	def reset(self):
		obs = self.env.reset()
		self.last_state = obs.copy()
		return obs


class PongGym_BinaryStyles(gym.Wrapper):
	"""
	Goal : Learn to operate both defensively and aggressively (aggressive behavior theoretically comes from basic game)
	"""
	def __init__(self, pongGym):
		super().__init__(pongGym)
		self.setting = np.random.randint(0,2)
		self.obs_low = np.concatenate(([0, 0], self.obs_low))
		self.obs_high = np.concatenate(([1, 1], self.obs_high))
		self.observation_space = spaces.Box(self.obs_low, self.obs_high)
		self.last_state = None

	def step(self, action):
		next_state, reward, done, info = self.env.step(action)

		offensive_reward = np.clip(reward, -1., 1.)

		defensive_reward = reward_devensive(reward, next_state, self.last_state, self.game_h, done)
		
		if self.setting > 0:
			rew = offensive_reward
		else:
			rew = defensive_reward

		self.last_state = next_state.copy()
		next_state = np.concatenate(([self.setting, 1 - self.setting], next_state))

		info['RewardBreakdown'] = np.array([offensive_reward, defensive_reward])
		
		return next_state, rew, done, info

	def reset(self):
		obs = self.env.reset()
		self.last_state = obs.copy()
		self.setting = np.random.randint(0,2)
		return np.concatenate(([self.setting, 1 - self.setting], obs))


class PongGym_LinearStyles(gym.Wrapper):
	"""
	Goal : Learn to operate both defensively and aggressively with interpolation
	"""
	def __init__(self, pongGym):
		super().__init__(pongGym)
		self.setting = np.random.rand()
		self.obs_low = np.concatenate(([0, 0], self.obs_low))
		self.obs_high = np.concatenate(([1, 1], self.obs_high))
		self.observation_space = spaces.Box(self.obs_low, self.obs_high)
		self.last_state = None

	def step(self, action):
		next_state, reward, done, info = self.env.step(action)
		
		offensive_reward = np.clip(reward, -1., 1.)

		defensive_reward = reward_devensive(reward, next_state, self.last_state, self.game_h, done)

		rew = (1 - self.setting) * offensive_reward + (self.setting) * defensive_reward

		self.last_state = next_state.copy()
		next_state = np.concatenate(([self.setting, 1 - self.setting], next_state))

		info['RewardBreakdown'] = np.array([offensive_reward, defensive_reward])
		
		return next_state, rew, done, info

	def reset(self):
		obs = self.env.reset()
		self.last_state = obs.copy()
		self.setting = np.random.rand()
		return np.concatenate(([self.setting, 1 - self.setting], obs))

class PongGym_SetStyle(gym.Wrapper):
	"""
	For environments with extra 'setting' state, externally set 'setting' 
	"""
	def __init__(self, pongGym, setting):
		super().__init__(pongGym)
		self.setting = setting
		self.obs_low = np.concatenate(([0, 0], self.obs_low))
		self.obs_high = np.concatenate(([1, 1], self.obs_high))
		self.observation_space = spaces.Box(self.obs_low, self.obs_high)
		self.last_state = None

	def step(self, action):
		next_state, reward, done, info = self.env.step(action)
		
		offensive_reward = np.clip(reward, -1., 1.)

		defensive_reward = reward_devensive(reward, next_state, self.last_state, self.game_h, done)

		rew = (1 - (self.setting)) * offensive_reward + (self.setting) * defensive_reward
		rew = defensive_reward

		self.last_state = next_state.copy()
		next_state = np.concatenate(([self.setting, 1 - self.setting], next_state))
		
		return next_state, rew, done, info

	def reset(self):
		obs = self.env.reset()
		self.last_state = obs.copy()
		return np.concatenate(([self.setting, 1 - self.setting], obs))

