r"""
    Implementation stubbed with MaskDP https://github.com/FangchenLiu/MaskDP_public
    Streaming realtime buffer for online RL.

    Replay buffer should hold...
    - for now let's implement the time uniform, we'll unshape and reshape as needed
    States: T x H Neural Data
    Actions T x H (Assuming a single task dimensionality) - may be padding tokens
    Rewards: Scalar

    Removes terminals, that doesn't exist in online domain.
"""
import numpy as np
import torch
import torch.nn as nn

class StreamingReplayBuffer:
    r"""
        No specific dataset API, we'll just make explicit slices while I'm not sure how to serve it
        TODO add backbone states - since these aren't tenable to actually decode with
    """

    def __init__(
        self,
        max_timesteps: int,
        neural_state_dim: int,
        max_constraint_dim: int,
        max_action_dim: int,
        discount: float,
        traj_length: int, # ? Not sure if relevant if we're slicing
    ):
        self._max_timesteps = max_timesteps
        self._neural_state_dim = neural_state_dim
        self._max_action_dim = max_action_dim
        self._discount = discount
        self._traj_length = traj_length
        self._last_timestep = 0 # Will cap out at max_timesteps - 1, but buffer needs to fill first

        self._states = np.zeros((max_timesteps, neural_state_dim), dtype=np.float32)
        self._constraints = np.zeros((max_timesteps, 3, max_constraint_dim), dtype=np.float32)
        self._actions = np.zeros((max_timesteps, max_action_dim), dtype=np.float32)
        self._rewards = np.zeros((max_timesteps, 1), dtype=np.float32)

    def insert(
            self,
            state: np.ndarray,
            constraint: np.ndarray,
            action: np.ndarray, # This should be post-smoothing.
            reward: np.ndarray
        ):
        r"""
            TODO pad to proper dimension
        """
        if self._last_timestep == self._max_timesteps - 1:
            # roll back
            self._states = np.roll(self._states, -1, axis=0)
            self._constraints = np.roll(self._constraints, -1, axis=0)
            self._actions = np.roll(self._actions, -1, axis=0)
            self._rewards = np.roll(self._rewards, -1, axis=0)
        else:
            self._last_timestep += 1
        self._states[self._last_timestep] = state
        self._constraints[self._last_timestep] = constraint
        self._actions[self._last_timestep] = action
        self._rewards[self._last_timestep] = reward

    def reset(self):
        self._last_timestep = 0

    def sample(self, batch_size):
        r"""
            Provide `batch_size` number of transitions for IQL e.g. S A R S'
            TODO incorp traj length
        """
        rand_sample = np.random.randint(0, self._last_timestep - 1, batch_size)
        states = self._states[rand_sample]
        actions = self._actions[rand_sample]
        rewards = self._rewards[rand_sample]
        next_states = self._states[rand_sample + 1]
        return {
            "observations": states,
            "actions": actions,
            "rewards": rewards,
            "next_observations": next_states,
        }
        # My actor can't actually act with just this... surely?