#endoding 'utf-8'
import numpy as np
import torch
"""
Implementation of https://github.com/sfujim/TD3 [TD3 paper]
"""
class ReplayBuffer(object):
	def __init__(self,state_dim,action_dim,
		max_size=int(1e6), 
		device: str = 'cpu',
		batch_size=256,
		max_action=1,
		normalize_actions=False,
		):
	
		max_size = int(max_size)
		self.max_size = max_size
		self.ptr = 0
		self.size = 0
		self.device = device
		self.batch_size = batch_size
		self.state = np.zeros((max_size, state_dim))
		self.action = np.zeros((max_size, action_dim))
		self.next_state = np.zeros((max_size, state_dim))
		self.reward = np.zeros((max_size, 1))
		self.not_done = np.zeros((max_size, 1))

		

		self.normalize_actions = max_action if normalize_actions else 1

	
	def add(self, state, action, next_state, reward, done):
		self.state[self.ptr] = state
		self.action[self.ptr] = action/self.normalize_actions
		self.next_state[self.ptr] = next_state
		self.reward[self.ptr] = reward
		self.not_done[self.ptr] = 1. - done
		self.ptr = (self.ptr + 1) % self.max_size
		self.size = min(self.size + 1, self.max_size)


	def sample(self,batch_size=None):
		
		self.ind = np.random.randint(0, self.size, size=self.batch_size)

		return (
			torch.tensor(self.state[self.ind], dtype=torch.float, device=self.device),
			torch.tensor(self.action[self.ind], dtype=torch.float, device=self.device),
			torch.tensor(self.next_state[self.ind], dtype=torch.float, device=self.device),
			torch.tensor(self.reward[self.ind], dtype=torch.float, device=self.device),
			torch.tensor(self.not_done[self.ind], dtype=torch.float, device=self.device)
		)
