from typing import Dict, List, Tuple, Optional
import numpy as np
import torch

class ReplayBuffer:
    def __init__(self, max_size: int = 500_000):
        self.max_size = max_size
        self.buffer: List[Optional[Dict]] = []
        self.position = 0

    def push(self, traj1_key: str, traj2_key: str, true_label: torch.Tensor) -> None:
        if len(self.buffer) < self.max_size:
            self.buffer.append(None)
        self.buffer[self.position] = {
            "traj1": traj1_key,
            "traj2": traj2_key,
            "true_label": true_label,
        }
        self.position = (self.position + 1) % self.max_size

    def sample(self, batch_size: int) -> List[Dict]:
        indices = np.random.choice(len(self.buffer), batch_size)
        return [self.buffer[idx] for idx in indices]

    def __len__(self) -> int:
        return len(self.buffer)

