import numpy as np
import random
import components.sum_tree as sum_tree

class Experience(object):
	""" The class represents prioritized experience replay buffer.

	The class has functions: store samples, pick samples with
	probability in proportion to sample's priority, update
	each sample's priority, reset alpha.

	see https://arxiv.org/pdf/1511.05952.pdf .

	"""

	def __init__(self, memory_size, alpha=1):
		self.tree = sum_tree.SumTree(memory_size)
		self.memory_size = memory_size
		self.alpha = alpha

	def add(self, priority):
		self.tree.add(priority ** self.alpha)

	def select(self, batch_size):

		if self.tree.filled_size() < batch_size:
			return None

		indices = []
		priorities = []
		for _ in range(batch_size):
			r = random.random()
			priority, index = self.tree.find(r)
			priorities.append(priority)
			indices.append(index)
			self.priority_update([index], [0])  # To avoid duplicating

		self.priority_update(indices, priorities)  # Revert priorities

		return indices

	def priority_update(self, indices, priorities):
		""" The methods update samples's priority.

		Parameters
		----------
		indices :
			list of sample indices
		"""
		for i, p in zip(indices, priorities):
			self.tree.val_update(i, p ** self.alpha)
