from __future__ import annotations

from collections import deque
from typing import Deque, Tuple
import numpy as np


class NStepBuffer:
    def __init__(self, n: int, gamma: float):
        self.n = n
        self.gamma = gamma
        self.buf: Deque[Tuple[np.ndarray, np.ndarray, float, np.ndarray, float, float]] = deque()

    def push(self, obs, act, rew, next_obs, done, cost):
        self.buf.append((obs, act, float(rew), next_obs, float(done), float(cost)))
        if len(self.buf) > self.n:
            self.buf.popleft()

    def ready(self) -> bool:
        return len(self.buf) == self.n

    def pop_nstep(self):
        R = 0.0
        C = 0.0
        next_obs = None
        done_flag = 0.0
        obs, act = self.buf[0][0], self.buf[0][1]
        for i, (_, _, r, s_next, d, c) in enumerate(self.buf):
            R += (self.gamma ** i) * r
            C += (self.gamma ** i) * c
            next_obs = s_next
            done_flag = max(done_flag, d)
            if d:
                break
        self.buf.clear()
        return obs, act, np.array([R], dtype=np.float32), next_obs, np.array([done_flag], dtype=np.float32), np.array([C], dtype=np.float32)

