from abc import ABC
from operator import itemgetter

import numpy as np


class ReplayMemory(ABC):
    def __init__(self, capacity, seed):
        self.capacity = capacity
        self.size = 0
        self.allocated = False

        self.memory = CoreMemory(self.capacity, seed)
        self.save = self.memory.save

    def __len__(self):
        return len(self.memory)

    def sample_uniform(self, batch_size=1):
        timesteps = np.array([self.memory.random_timestep() for _ in range(batch_size)])
        return self.memory.sample(timesteps)

    def sample_trajectories(self, batch_size=1, length=1):
        assert batch_size >= 1
        assert length >= 1
        timesteps = []
        for _ in range(batch_size):
            base = self.memory.random_timestep(except_last=length)
            timesteps.append(np.arange(base, base + length))
        timesteps = np.stack(timesteps)

        return self.memory.sample(timesteps)


class CoreMemory:
    def __init__(self, capacity, seed):
        self.capacity = capacity
        self.write_ptr = 0
        self.allocated = False
        self.t = 0
        self.np_random = np.random.default_rng(seed)

    def __len__(self):
        return min(self.t, self.capacity)

    def _allocate(self, array):
        array = np.array(array)
        return np.empty([self.capacity, *array.shape], dtype=array.dtype)

    def save(self, obs, action, reward, terminated, truncated, b_prob):
        if not self.allocated:
            self.observations = self._allocate(obs)
            self.actions = self._allocate(action)
            self.rewards = self._allocate(reward)
            self.terminateds = self._allocate(terminated)
            self.truncateds = self._allocate(truncated)
            self.b_probs = self._allocate(b_prob)
            self.allocated = True

        p = self.write_ptr
        self.observations[p], self.actions[p], self.rewards[p], self.terminateds[p], self.truncateds[p], self.b_probs[p] = (
            obs, action, reward, terminated, truncated, b_prob)

        self.write_ptr = (p + 1) % self.capacity
        self.t += 1

    @property
    def begin(self):
        return self.t - len(self)

    @property
    def end(self):
        return self.t - 1

    def sample(self, timesteps: np.ndarray):
        assert len(self) > 0
        assert (self.begin <= timesteps).all()
        assert (timesteps <= self.end).all()
        indices = timesteps % len(self)
        return tuple(map(itemgetter(indices), [self.observations, self.actions, self.rewards, self.terminateds, self.truncateds, self.b_probs]))

    def random_timestep(self, except_last=0):
        assert len(self) >= except_last
        return self.begin + self.np_random.integers(len(self) - except_last)
