import random
import numpy as np

class ReplayBuffer():
	""" IMPORTANT NOTE: Initializing the buffer with a dataset works different from initializing an empty buffer and storing a dataset in it.
		Make sure to utilize the variation that suits your purpose. """
	def __init__(self, mem_size, state_dim, action_dim, data = None):

		self.mem_size = mem_size
		self.mem_cntr = 0
		self.mem_full = False
		self.state_memory = np.zeros((self.mem_size, state_dim))
		self.state_memory_ = np.zeros((self.mem_size, state_dim))
		self.action_memory = np.zeros((self.mem_size, action_dim))
		self.reward_memory = np.zeros(self.mem_size)
		self.terminal_memory = np.zeros(self.mem_size, dtype = np.bool)

		if data is not None:
			self.store_batch(*data)
			self.init_cntr = self.mem_cntr
		else:
			self.init_cntr = 0

	# def store_transition(self, state, action, reward, state_, done):

	# 	index = self.mem_cntr % self.mem_size

	# 	self.state_memory[index] = state
	# 	self.state_memory_[index] = state_
	# 	self.action_memory[index] = action
	# 	self.reward_memory[index] = reward
	# 	self.terminal_memory[index] = done

	# 	self.mem_cntr += 1

	def store_batch(self, state, action, reward, state_, done = None):

		# index = self.mem_cntr % self.mem_size

		# self.state_memory[index : index + len(state)] = state
		# self.state_memory_[index : index + len(state_)] = state_
		# self.action_memory[index : index + len(action)] = action
		# self.reward_memory[index : index + len(reward)] = reward

		if self.mem_cntr + len(state) >= self.mem_size:
			self.state_memory[self.mem_cntr:] = state[:self.mem_size - self.mem_cntr]
			self.state_memory_[self.mem_cntr:] = state_[:self.mem_size - self.mem_cntr]
			self.action_memory[self.mem_cntr:] = action[:self.mem_size - self.mem_cntr]
			self.reward_memory[self.mem_cntr:] = reward[:self.mem_size - self.mem_cntr]

			self.state_memory[self.init_cntr : self.init_cntr + len(state) - self.mem_size + self.mem_cntr] = state[self.mem_size - self.mem_cntr:]
			self.state_memory_[self.init_cntr : self.init_cntr + len(state) - self.mem_size + self.mem_cntr] = state_[self.mem_size - self.mem_cntr:]
			self.action_memory[self.init_cntr : self.init_cntr + len(state) - self.mem_size + self.mem_cntr] = action[self.mem_size - self.mem_cntr:]
			self.reward_memory[self.init_cntr : self.init_cntr + len(state) - self.mem_size + self.mem_cntr] = reward[self.mem_size - self.mem_cntr:]

			if done is not None:
				self.terminal_memory[self.mem_cntr:] = 1 - done[:self.mem_size - self.mem_cntr]
				self.terminal_memory[self.init_cntr : self.init_cntr + len(state) - self.mem_size + self.mem_cntr] = 1 - done[self.mem_size - self.mem_cntr:]
			else:
				self.terminal_memory[self.mem_cntr:] = 1
				self.terminal_memory[self.init_cntr : self.init_cntr + len(state) - self.mem_size + self.mem_cntr] = 1

			self.mem_cntr = self.init_cntr + len(state) - self.mem_size + self.mem_cntr
			self.mem_full = 1

		else:
			self.state_memory[self.mem_cntr : self.mem_cntr + len(state)] = state
			self.state_memory_[self.mem_cntr : self.mem_cntr + len(state_)] = state_
			self.action_memory[self.mem_cntr : self.mem_cntr + len(action)] = action
			self.reward_memory[self.mem_cntr : self.mem_cntr + len(reward)] = reward

			if done is not None:
				self.terminal_memory[self.mem_cntr : self.mem_cntr + len(done)] = 1 - done
			else:
				self.terminal_memory[self.mem_cntr : self.mem_cntr + len(state)] = 1

			self.mem_cntr += len(state)

		""" For the specific case of main3.py """

		# self.terminal_memory[index : index + len(done)] = done


	def sample(self, batch_size):

		if self.mem_full:
			indices = np.random.choice(self.mem_size, batch_size)
		else:
			indices = np.random.choice(self.mem_cntr, batch_size)

		state_batch = self.state_memory[indices]
		state_batch_ = self.state_memory_[indices]
		action_batch = self.action_memory[indices]
		reward_batch = self.reward_memory[indices]
		terminal_batch = self.terminal_memory[indices]

		return state_batch, action_batch, reward_batch, state_batch_, terminal_batch

class PrioritizedReplay():
	""" This code is OUTDATED. Do not use it until revision. store_batch is probably bugged. """
	def __init__(self, mem_size, state_dim, action_dim, data = None, alpha = 0.6, beta_start = 0.4, beta_frames = 100000):

		self.mem_size = mem_size
		self.mem_cntr = 0
		self.mem_full = False
		self.state_memory = np.zeros((self.mem_size, state_dim))
		self.state_memory_ = np.zeros((self.mem_size, state_dim))
		self.action_memory = np.zeros((self.mem_size, action_dim))
		self.reward_memory = np.zeros(self.mem_size)
		self.terminal_memory = np.zeros(self.mem_size, dtype = np.bool)

		self.frame = 1
		self.alpha = alpha
		self.beta_start = beta_start
		self.beta_frames = beta_frames
		self.priorities = np.zeros((self.mem_size, ), dtype = np.float32)

		if data is not None:
			self.store_batch(*data)
			self.init_cntr = self.mem_cntr
		else:
			self.init_cntr = 0

	# def store_transition(self, state, action, reward, state_, done):

	# 	index = self.mem_cntr % self.mem_size

	# 	self.state_memory[index] = state
	# 	self.state_memory_[index] = state_
	# 	self.action_memory[index] = action
	# 	self.reward_memory[index] = reward
	# 	self.terminal_memory[index] = done

	# 	self.mem_cntr += 1

	def beta_by_frame(self, frame_idx):
		return min(1, self.beta_start + frame_idx * (1.0 - self.beta_start) / self.beta_frames)

	def store_batch(self, state, action, reward, state_, done = None):

		# index = self.mem_cntr % self.mem_size

		# self.state_memory[index : index + len(state)] = state
		# self.state_memory_[index : index + len(state_)] = state_
		# self.action_memory[index : index + len(action)] = action
		# self.reward_memory[index : index + len(reward)] = reward

		# I should double check this line later...
		max_prio = self.priorities.max() if not (self.state_memory == 0.0).all() else 1.0

		if self.mem_cntr + len(state) >= self.mem_size:
			self.state_memory[self.mem_cntr:] = state[:self.mem_size - self.mem_cntr]
			self.state_memory_[self.mem_cntr:] = state_[:self.mem_size - self.mem_cntr]
			self.action_memory[self.mem_cntr:] = action[:self.mem_size - self.mem_cntr]
			self.reward_memory[self.mem_cntr:] = reward[:self.mem_size - self.mem_cntr]

			self.priorities[self.mem_cntr:] = max_prio

			self.state_memory[:self.init_cntr + len(state) - self.mem_size + self.mem_cntr] = state[self.mem_size - self.mem_cntr:]
			self.state_memory_[:self.init_cntr + len(state) - self.mem_size + self.mem_cntr] = state_[self.mem_size - self.mem_cntr:]
			self.action_memory[:self.init_cntr + len(state) - self.mem_size + self.mem_cntr] = action[self.mem_size - self.mem_cntr:]
			self.reward_memory[:self.init_cntr + len(state) - self.mem_size + self.mem_cntr] = reward[self.mem_size - self.mem_cntr:]

			self.priorities[:self.init_cntr + len(state) - self.mem_size + self.mem_cntr] = max_prio

			if done is not None:
				self.terminal_memory[self.mem_cntr:] = 1 - done[:self.mem_size - self.mem_cntr]
				self.terminal_memory[:self.init_cntr + len(state) - self.mem_size + self.mem_cntr] = 1 - done[self.mem_size - self.mem_cntr:]
			else:
				self.terminal_memory[self.mem_cntr:] = 1
				self.terminal_memory[:self.init_cntr + len(state) - self.mem_size + self.mem_cntr] = 1

			self.mem_cntr = self.init_cntr + len(state) - self.mem_size + self.mem_cntr
			self.mem_full = 1

		else:
			self.state_memory[self.mem_cntr : self.mem_cntr + len(state)] = state
			self.state_memory_[self.mem_cntr : self.mem_cntr + len(state_)] = state_
			self.action_memory[self.mem_cntr : self.mem_cntr + len(action)] = action
			self.reward_memory[self.mem_cntr : self.mem_cntr + len(reward)] = reward

			self.priorities[self.mem_cntr : self.mem_cntr + len(state)] = max_prio

			if done is not None:
				self.terminal_memory[self.mem_cntr : self.mem_cntr + len(done)] = 1 - done
			else:
				self.terminal_memory[self.mem_cntr : self.mem_cntr + len(state)] = 1

			self.mem_cntr += len(state)

		""" For the specific case of main3.py """

		# self.terminal_memory[index : index + len(done)] = done


	def sample(self, batch_size):

		# if self.mem_full:
		# 	indices = np.random.choice(self.mem_size, batch_size)
		# else:
		# 	indices = np.random.choice(self.mem_cntr, batch_size)

		if self.mem_full:
			N = self.mem_size
			prios = self.priorities
		else:
			N = self.mem_cntr
			prios = self.priorities[:self.mem_cntr]

		probs = prios ** self.alpha
		P = probs / probs.sum()
		indices = np.random.choice(N, batch_size, p = P)

		state_batch = self.state_memory[indices]
		state_batch_ = self.state_memory_[indices]
		action_batch = self.action_memory[indices]
		reward_batch = self.reward_memory[indices]
		terminal_batch = self.terminal_memory[indices]

		beta = self.beta_by_frame(self.frame)
		self.frame += 1

		weights = (N * P[indices]) ** (-beta)
		weights /= weights.max()
		weights = np.array(weights, dtype = np.float32)

		return state_batch, action_batch, reward_batch, state_batch_, terminal_batch, indices, weights

	def update_priorities(self, batch_indices, batch_priorities):

		for idx, prio in zip(batch_indices, batch_priorities):
			self.priorities[idx] = abs(prio)
