import numpy as np

from .base_replay_buffer import BaseReplayBuffer


class SimpleReplayBuffer(BaseReplayBuffer):
    def __init__(self, max_buf_size):
        self.max_buf_size = max_buf_size
        keys = ['state', 'action', 'next_state', 'reward', 'done', 'timeout']
        self.data = {k: self._buf_init(k, max_buf_size) for k in keys}
        self.length = 0

    def _buf_init(self, key, max_buf_size):
        raise NotImplementedError

    def _buf_add(self, buf, idx, data):
        raise NotImplementedError

    def _buf_sample(self, buf, indices):
        raise NotImplementedError

    def add_transition(self, transition):
        for key in self.data.keys():
            self._buf_add(self.data[key], self.length % self.max_buf_size, transition[key])
        self.length += 1

    def sample(self, n_samples=1, *, indices=None):
        if indices is None:
            indices = np.random.randint(len(self), size=(n_samples,), dtype=np.int64)
        batch = {k: self._buf_sample(v, indices) for k, v in self.data.items()}
        return batch

    def __len__(self):
        return min(self.length, self.max_buf_size)

    def add_transitions(self, transitions):
        if isinstance(transitions, (dict, SimpleReplayBuffer)):
            if isinstance(transitions, dict):
                lengths = [len(v) for v in transitions.values()]
                n = lengths[0]
                assert all([length == n for length in lengths])
            else:
                n = len(transitions)

            for key in self.data.keys():
                for i in range(n):
                    self._buf_add(self.data[key], (self.length + i) % self.max_buf_size, transitions[key][i])
            self.length += n
        elif isinstance(transitions, (list, tuple, np.recarray)):
            for transition in transitions:
                self.add_transition(transition)
        else:
            raise NotImplementedError

    def __getattr__(self, item: str):
        if item != 'data' and item in self.data:
            return self.data[item][:len(self)]
        raise AttributeError

    def __getitem__(self, item):
        return self.data[item][:len(self)]

    def sampling_data_loader(self, n_iters_per_epoch, batch_size):
        from torch.utils.data.dataloader import DataLoader
        buf = self

        class Loader:
            def __iter__(self):
                for _ in range(n_iters_per_epoch):
                    yield buf.sample(batch_size)

        return Loader()

    def __iadd__(self, other):
        self.add_transitions(other)
        return self


class ListReplayBuffer(SimpleReplayBuffer):
    def _buf_init(self, key, max_buf_size):
        return [None] * max_buf_size

    def _buf_add(self, buf, idx, data):
        buf[idx] = data

    def _buf_sample(self, buf, indices):
        return np.array([buf[i] for i in indices])
