import random
import numpy as np

class ReplayBuffer:
    def __init__(self, capacity: int, sequence_length: int):
        """
        Args:
            capacity (int): Maximum number of episodes to store.
            sequence_length (int): Fixed length of sequences to sample.
        """
        self.capacity = capacity            # Maximum number of episodes
        self.sequence_length = sequence_length  # Length of sequence to sample
        self.buffer = []                    # List of episodes (each episode is a list of transitions)
        self.current_episode = []           # Temporary storage for the current episode

    def add(self, state: np.ndarray, action, reward: float, next_state: np.ndarray, done: bool):
        """
        Add a transition to the buffer. If the transition ends an episode (done==True),
        the episode is stored and the current_episode buffer is reset.
        """
        self.current_episode.append((state, action, reward, next_state, done))
        if done:
            # An episode is finished; store it.
            self.buffer.append(self.current_episode)
            # If we've exceeded capacity, remove the oldest episode.
            if len(self.buffer) > self.capacity:
                self.buffer.pop(0)
            self.current_episode = []  # Reset for next episode

    def sample(self, batch_size: int):
        """
        Sample a batch of sequences from the buffer.
        
        Returns:
            List[List[Tuple]]: A batch (list) of sequences. Each sequence is a list of transitions.
        """
        if len(self.buffer) < batch_size:
            raise ValueError("Not enough episodes in the buffer to sample the requested batch size.")

        sampled_sequences = []
        for _ in range(batch_size):
            episode = random.choice(self.buffer)
            if len(episode) >= self.sequence_length:
                # Sample a random contiguous sequence from the episode.
                start_idx = random.randint(0, len(episode) - self.sequence_length)
                seq = episode[start_idx : start_idx + self.sequence_length]
            else:
                # If the episode is shorter than sequence_length, pad the sequence.
                seq = episode.copy()
                # Assume state is a numpy array; get its shape for padding.
                state_shape = episode[0][0].shape if episode else (1,)
                pad_length = self.sequence_length - len(seq)
                for _ in range(pad_length):
                    # You can adjust the dummy values below as needed.
                    pad_transition = (
                        np.zeros(state_shape),  # state
                        0,                      # action (or any default value)
                        0.0,                    # reward
                        np.zeros(state_shape),  # next_state
                        True                    # done flag (or True to indicate padding/end)
                    )
                    seq.append(pad_transition)
            sampled_sequences.append(seq)
        return sampled_sequences

    def clear(self):
        """Clear the entire buffer (all stored episodes)."""
        self.buffer = []
        self.current_episode = []

    def __len__(self):
        """Return the number of complete episodes stored."""
        return len(self.buffer)
