import random
from collections import namedtuple, deque
import torch
import numpy as np

class ReplayBuffer:
  def __init__(self, num_agents, memory_size, batch_size):
    self.pg_frames = 0
    self.total_frames = 0
    self.buffer = [Replay(memory_size, batch_size) for _ in range(num_agents)]

  def sample(self, agent_id):
    return self.buffer[agent_id].sample()
  
  def add(self, experience):
    for i in range(len(experience)):
      self.buffer[i].add(experience[i])
      self.pg_frames += 1
      self.total_frames += 1

  def __len__(self):
    return len(self.buffer[0].memory)



class Replay(object):
  '''
  Replay buffer to store experiences
  '''
  def __init__(self, capacity, batch_size):
    self.capacity = int(capacity)
    self.batch_size  = batch_size
    self.experience  = namedtuple("Experience", field_names=["state", "action", "next_state", "reward", "done"])
    self.memory = []
    self.pos = 0

  def add(self, experience):
    '''
    Add experience(s) into memory
    '''
    for exp in experience:
      data = (torch.Tensor(exp[0]), torch.Tensor(exp[1]), torch.Tensor(exp[2]),\
              torch.Tensor(exp[3]), torch.Tensor(exp[4]))
      if self.pos >= len(self.memory):
        self.memory.append(self.experience(*data))
      else:
        self.memory[self.pos] = self.experience(*data)
      self.pos = (self.pos + 1) % self.capacity

    
  def sample(self):
    if self.empty():
      return None
    sampled_indices = random.sample(range(len(self.memory)), self.batch_size) #[np.random.randint(0, len(self.memory)) for _ in range(self.batch_size)]
    states, actions, next_states, rewards, dones = [], [], [], [], []
    for idx in sampled_indices:
      state, action, next_state, reward, done = self.memory[idx]
      states.append(state)
      actions.append(action)
      next_states.append(next_state)
      rewards.append(reward)
      dones.append(done)
    return torch.stack(states), torch.stack(actions), torch.stack(next_states), torch.stack(rewards), torch.stack(dones) 
   

  def __len__(self):
    return len(self.memory)

  def size(self):
    return self.memory_capacity

  def empty(self):
    return len(self.memory) == 0

  def shuffle(self):
    np.random.shuffle(self.memory)