import numpy as np
import random
from collections import namedtuple, deque
import torch


class ReplayBuffer:
    """Fixed-size buffer to store experience tuples."""

    def __init__(self, buffer_size, batch_size, seed, device):
        """Initialize a ReplayBuffer object.
        Params
        ======
            buffer_size (int): maximum size of buffer
            batch_size (int): size of each training batch
            seed (int): random seed
            device (string): GPU or CPU
        """

        self.memory = deque(maxlen=buffer_size)
        self.batch_size = batch_size
        self.experience = namedtuple("Experience",
                                     field_names=["state", "needs_weight", "action", "delta_utility", "reward",
                                                  "next_state", "done"])
        self.seed = random.seed(seed)
        self.device = device

    def add(self, state, needs_weight, action, delta_utility, reward, next_state, done):
        """Add a new experience to memory."""
        e = self.experience(state, needs_weight, action, delta_utility, reward, next_state, done)
        self.memory.append(e)

    def sample(self):
        """Randomly sample a batch of experiences from memory."""
        experiences = random.sample(self.memory, k=self.batch_size)

        states = torch.from_numpy(np.array([e.state for e in experiences if e is not None])).float().to(self.device)
        needs_weights = torch.from_numpy(np.array([e.state for e in experiences if e is not None])).float().to(
            self.device)
        actions = torch.from_numpy(np.array([e.action for e in experiences if e is not None])).long().to(self.device)
        delta_utilities = torch.from_numpy(np.array([e.state for e in experiences if e is not None])).float().to(
            self.device)
        rewards = torch.from_numpy(np.array([e.reward for e in experiences if e is not None])).float().to(self.device)
        next_states = torch.from_numpy(np.array([e.next_state for e in experiences if e is not None])).float().to(
            self.device)
        dones = torch.from_numpy(np.array([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
            self.device)

        return (states, needs_weights, actions, delta_utilities, rewards, next_states, dones)

    def __len__(self):
        """Return the current size of internal memory."""
        return len(self.memory)
