from collections import namedtuple
import random
import torch
import threading
Transition = namedtuple(
    'Transition', ('state', 'action', 'reward', 'state_next', 'done')
)

class ReplayBuffer:
    def __init__(self, capacity):
        self._capacity = capacity
        self._memorery = {}
        self._position = 0
        # self.next_batch = None

        self.next_batch_process = None
        self.next_batch_size = None
        self.next_batch_device = None
        self.next_batch = None

    def add(self, *args):
        """
        Saves a transition.
        """
        if self.next_batch_process is not None:
            # Don't add to the buffer when sampling from it.
            self.next_batch_process.join()

        self._memorery[self._position] = Transition(*args)

        self._position = (self._position + 1) % self._capacity

    def _prepare_sample(self, batch_size, device=None):
        self.next_batch_size = batch_size
        self.next_batch_device = device

        batch = random.sample(list(self._memorery.values()), batch_size)

        self.next_batch = [torch.stack(tensors).to(device) for tensors in zip(*batch)]
        self.next_batch_ready = True

    def launch_sample(self, *args):
        self.next_batch_process = threading.Thread(target=self._prepare_sample, args=args)
        self.next_batch_process.start()

    def sample(self, batch_size, device=None):
        """
        Samples a batch of Transitions, with the tensors already stacked
        and transfered to the specified device.
        Return a list of tensors in the order specified in Transition.
        """
        if self.next_batch_process is not None:
            self.next_batch_process.join()
        else:
            self.launch_sample(batch_size, device)
            self.sample(batch_size, device)

        if self.next_batch_size==batch_size and self.next_batch_device==device:
            next_batch = self.next_batch
            self.launch_sample(batch_size, device)
            return next_batch
        else:
            self.launch_sample(batch_size, device)
            self.sample(batch_size, device)

    def __len__(self):
        return len(self._memorery)