import numpy as np
from collections import deque


class TemporaryBuffer:
    def __init__(self, delayed_steps):
        self.d = delayed_steps
        self.states  = deque(maxlen=delayed_steps + 2)
        self.actions = deque(maxlen=2 * delayed_steps + 1)

    def clear(self):
        self.states.clear()
        self.actions.clear()

    def get_augmented_state(self, last_observed_state, first_action_idx):
        aug_state = np.concatenate([last_observed_state, self.actions[first_action_idx]])
        for i in range(first_action_idx + 1, first_action_idx + self.d):
            aug_state = np.concatenate([aug_state, self.actions[i]])
        return aug_state

    def get_tuple(self):
        assert len(self.states) == self.d + 2 and len(self.actions) == 2 * self.d + 1

        aug_s = self.get_augmented_state(self.states[0], 0)
        s = self.states[-2]
        a = self.actions[self.d]

        next_aug_s = self.get_augmented_state(self.states[1], 1)
        next_s = self.states[-1]

        self.states.popleft()
        self.actions.popleft()

        return aug_s, s, a, next_aug_s, next_s
