import random
from collections import deque
from typing import NamedTuple

import torch


class Transition(NamedTuple):
    state: torch.Tensor
    act: torch.Tensor
    reward: torch.Tensor
    next_state: torch.Tensor | None
    # FIXME: merge logits and value together?
    act_logits: torch.Tensor | None   # predicted
    state_value: torch.Tensor | None  # predicted
    penalties: tuple[torch.Tensor]  # activation penalties


class Trajectory(list[Transition]):
    done_reason: str
    total_reward: float


class TrajectoryReplayBuffer:
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.memory: deque[Trajectory] = deque([], maxlen=capacity)
        self.total_seen = 0

    def append(self, trajectory: Trajectory):
        self.memory.append(trajectory)
        self.total_seen += 1

    def clear(self):
        self.memory.clear()

    def sample(self, batch_size: int) -> list[Trajectory]:
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
