from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random
import numpy as np
import torch
from torch.utils.data import Dataset


class ReservoirBuffer(Dataset):
    """Allows uniform sampling over a stream of data.

    This class supports the storage of arbitrary elements, such as observation
    tensors, integer actions, etc.

    See https://en.wikipedia.org/wiki/Reservoir_sampling for more details.
    """

    def __init__(self, reservoir_buffer_capacity, num_elements=4):
        self._reservoir_buffer_capacity = reservoir_buffer_capacity
        self._num_elements = num_elements
        self._data = []  # [[] for _ in range(num_elements)]
        self._add_calls = 0

    def add(self, element):
        """Potentially adds `element` to the reservoir buffer.

        Args:
          element: data to be added to the reservoir buffer.
        """
        # element = [torch.tensor(e) for e in element]
        if len(self._data) < self._reservoir_buffer_capacity:
            self._data.append(element)
        else:
            idx = np.random.randint(0, self._add_calls + 1)
            if idx < self._reservoir_buffer_capacity:
                self._data[idx] = element
        self._add_calls += 1

    # def add(self, element):
    #     """Potentially adds `element` to the reservoir buffer.

    #     Args:
    #       element: data to be added to the reservoir buffer.
    #     """

    #     if len(self._data) < self._reservoir_buffer_capacity:
    #         for i in range(self._num_elements):
    #             self._data[i].append(element[i])
    #     else:
    #         idx = np.random.randint(0, self._add_calls + 1)
    #         if idx < self._reservoir_buffer_capacity:
    #             for i in range(self._num_elements):
    #                 self._data[i][idx] = element
    #     self._add_calls += 1

    def sample(self, num_samples):
        """Returns `num_samples` uniformly sampled from the buffer.

        Args:
          num_samples: `int`, number of samples to draw.

        Returns:
          An iterable over `num_samples` random elements of the buffer.

        Raises:
          ValueError: If there are less than `num_samples` elements in the buffer
        """
        if len(self._data) < num_samples:
            raise ValueError(
                "{} elements could not be sampled from size {}".format(
                    num_samples, len(self._data)
                )
            )
        return random.sample(self._data, num_samples)

    def batch_sample(self, batch_size):
        """Returns `num_samples` uniformly sampled from the buffer.

        Args:
          num_samples: `int`, number of samples to draw.

        Returns:
          An iterable over `num_samples` random elements of the buffer.

        Raises:
          ValueError: If there are less than `num_samples` elements in the buffer
        """
        if len(self._data) < batch_size:
            batch_size = len(self._data)
            # raise ValueError("{} elements could not be sampled from size {}".format(batch_size, len(self._data)))
        # sampled_idx = random.sample(range(len(self._data)), batch_size)
        sampled_idx = np.random.choice(len(self._data), batch_size, replace=False)
        # sample = random.sample(self._data, batch_size)
        # batched_data = []
        # for k in range(len(sample[0])):
        #     batched_data.append(torch.stack([torch.tensor(s[k]) for s in sample]))

        # import pdb; pdb.set_trace()

        # import pdb; pdb.set_trace()
        return tuple(
            [self._batched_data[k][sampled_idx] for k in range(len(self._batched_data))]
        )

    def batch_data(self, device="cuda"):
        batched_data = []

        for k in range(len(self._data[0])):
            batched_data.append(
                torch.stack([torch.tensor(s[k]) for s in self._data]).to(device)
            )
        self._batched_data = batched_data

    def clear(self):
        self._data = []
        self._add_calls = 0

    def __len__(self):
        return len(self._data)

    def __iter__(self):
        return iter(self._data)

    def __getitem__(self, idx):
        return self._data[idx]

    def to(self, device):
        self._tensor_data = [[] for _ in range(len(self._data[0]))]
        for k in range(len(self._data[0])):
            self._tensor_data
        self._data = [x.to(device) for x in self._data]

    @property
    def data(self):
        return self._data

    def get_data(self):
        return self._data

    def shuffle_data(self):
        random.shuffle(self._data)

    def get_num_calls(self):
        return self._add_calls
