"""
Experience Replay Buffer for FPVR-DQN
Supports both standard uniform sampling and Prioritized Experience Replay (PER).
"""
from collections import deque
import numpy as np
import torch


class ReplayBuffer:
    """Standard uniform sampling replay buffer
    
    Memory-optimized design:
    - Store only a single frame (1, H, W) instead of the full stacked observation (C, H, W)
    - Do not store next_obs explicitly; reconstruct it from (idx+1) at sampling time
    - Dynamically build the C-frame stack while handling episode boundaries
    
    Notes:
    - This implementation stores the replay buffer on CPU (NumPy) to keep the code simple and avoid using GPU memory.
    """
    
    def __init__(self, capacity, state_shape, device="cpu", n_step=None, gamma=None, *, full_mc: bool = False):
        """
        Args:
            capacity: replay buffer capacity
            state_shape: state shape (C, H, W), where C is the frame stack size (typically 4)
            device: torch device
            n_step: n-step horizon (when mixed_mc is enabled)
            gamma: discount factor (when mixed_mc is enabled)
        """
        self.capacity = int(capacity)
        self.device = device
        self.n_step = n_step
        self.gamma = gamma
        self.use_full_mc = bool(full_mc)
        self.use_n_step = (not self.use_full_mc) and (n_step is not None and gamma is not None and n_step > 0)
        if self.use_full_mc and gamma is None:
            raise ValueError("full_mc=True requires gamma to be provided")
        if self.use_n_step and self.capacity <= int(n_step) + 1:
            raise ValueError(
                f"ReplayBuffer capacity must be > n_step+1 for n-step returns. capacity={self.capacity}, n_step={n_step}"
            )
        
        c, h, w = state_shape
        self.frame_stack = c  # frame stack size (typically 4)
        self.h = h
        self.w = w
        
        # Optimization: store only a single frame (1,H,W) instead of the full stack (C,H,W).
        # next_obs is reconstructed from (idx+1) at sampling time.
        # CPU buffer: NumPy arrays.
        self.frames_buf = np.zeros((self.capacity, 1, h, w), dtype=np.uint8)  # single-frame storage
        self.action_buf = np.zeros((self.capacity,), dtype=np.int64)
        self.reward_buf = np.zeros((self.capacity,), dtype=np.float32)
        self.done_buf = np.zeros((self.capacity,), dtype=np.float32)

        # n-step fields (if enabled)
        if self.use_n_step:
            self.n_step_return_buf = np.zeros((self.capacity,), dtype=np.float32)
            # Route A: store only the index of s_{t+n} in the circular buffer, and reconstruct the stacked
            # observation on-the-fly at sampling time to avoid the huge n_step_next_obs_buf memory cost.
            self.n_step_next_idx_buf = np.zeros((self.capacity,), dtype=np.int32)
            self.n_step_done_buf = np.zeros((self.capacity,), dtype=np.float32)
            self.n_step_valid_buf = np.zeros((self.capacity,), dtype=bool)

            # Online n-step aggregation (amortized O(1) per env step when no terminal within window).
            # Each item: (buffer_idx, reward, done, next_obs_stack_uint8)
            self._nstep_queue = deque()
            self._nstep_return = 0.0  # sum_{i=0}^{k-1} gamma^i r_{t+i} for current queue (only valid when no done in queue)
            # Pending entries whose next_idx hasn't been written yet (because we store only single frames).
            # Each item: (start_idx, next_idx_to_wait_for)
            self._nstep_pending = deque()

        # Full-episode MC return-to-go (if enabled)
        if self.use_full_mc:
            self.mc_return_buf = np.zeros((self.capacity,), dtype=np.float32)
            self.mc_valid_buf = np.zeros((self.capacity,), dtype=bool)
            # Queue stores only scalars to avoid huge memory when episodes are long.
            # Each item: (buffer_idx, reward, done)
            self._mc_queue = deque()
        
        self.ptr = 0
        self.size = 0
        self.full = False
        
        # Estimate and print CPU memory usage (optimized)
        buffer_memory_mb = (
            self.capacity * 1 * h * w * 1  # frames_buf (uint8, single frame)
            + self.capacity * 8  # action_buf (int64)
            + self.capacity * 4 * 2  # reward_buf + done_buf (float32)
            # n-step extras: return(float32) + next_idx(int32) + done(float32) + valid(bool)
            + (self.capacity * 4 + self.capacity * 4 + self.capacity * 4 + self.capacity * 1 if self.use_n_step else 0)
            + (self.capacity * 4 + self.capacity * 1 if self.use_full_mc else 0)  # mc_return_buf + mc_valid_buf
        ) / (1024**2)
        print(f"[ReplayBuffer] Using CPU buffer (optimized: single frames only). Estimated memory: {buffer_memory_mb:.2f} MB")
    
    def _encode_observation(self, index):
        """
        Build a stacked (C,H,W) observation from a buffer index.
        
        Args:
            index: buffer index
            
        Returns:
            obs: [C, H, W] stacked observation
            Convention matches stack_frames: [0] is the oldest frame, [-1] is the most recent frame.
        """
        # Build a stack from history frames: [index-3, index-2, index-1, index]
        obs = np.zeros((self.frame_stack, self.h, self.w), dtype=np.uint8)
        
        # From oldest (0) to newest (frame_stack-1)
        for i in range(self.frame_stack):
            # frame index: index - (frame_stack - 1 - i)
            frame_idx = index - (self.frame_stack - 1 - i)
            
            # Handle circular buffer wrap-around and boundaries
            if not self.full:
                if frame_idx < 0:
                    # Not enough history: pad with the first frame (index 0)
                    frame_idx = 0
                elif frame_idx > index:
                    # Should not happen; clamp
                    frame_idx = index
            else:
                # Buffer is full: wrap around
                frame_idx = frame_idx % self.capacity
            
            # Check episode boundary crossing.
            # For i>0 (not the oldest frame), if done at (frame_idx-1) then frame_idx starts a new episode.
            if i > 0:
                prev_frame_idx = frame_idx - 1
                if not self.full:
                    if prev_frame_idx < 0:
                        # Already at the first frame; nothing to check
                        pass
                    else:
                        # Check done flag at prev_frame_idx
                        done_prev = self.done_buf[prev_frame_idx]
                        if done_prev > 0.5:  # done=True => crossed episode boundary
                            # frame_idx starts a new episode; pad earlier frames with current frame
                            current_frame = self.frames_buf[frame_idx, 0]
                            for j in range(i):
                                obs[j] = current_frame
                else:
                    # Buffer is full: wrap around
                    prev_frame_idx_mod = prev_frame_idx % self.capacity
                    done_prev = self.done_buf[prev_frame_idx_mod]
                    if done_prev > 0.5:  # done=True => crossed episode boundary
                        # Pad earlier frames with current frame
                        current_frame = self.frames_buf[frame_idx, 0]
                        for j in range(i):
                            obs[j] = current_frame
            
            # Fetch frame data
            obs[i] = self.frames_buf[frame_idx, 0]
        
        return obs
    
    def _has_enough_history(self, index):
        """Check whether an index has enough history to build a full frame stack."""
        if not self.full:
            # When buffer is not full, need at least (frame_stack-1) history frames
            if index < self.frame_stack - 1:
                return False
            # Cannot be the last valid index (need idx+1 to reconstruct next_obs)
            if index >= self.size - 1:
                return False
        else:
            # When buffer is full, all indices have enough history except the most recently written slot
            if index == (self.ptr - 1) % self.capacity:
                return False
        
        # If any done flag appears within the previous (frame_stack-1) frames, the stack crosses an episode boundary.
        for i in range(1, self.frame_stack):
            prev_idx = index - i
            if not self.full:
                if prev_idx < 0:
                    break  # not enough history; handled by padding
            else:
                prev_idx = prev_idx % self.capacity
            
            done_val = self.done_buf[prev_idx]
            if done_val > 0.5:  # done=True => episode boundary crossed
                return False
        
        return True
    
    def store(self, obs, action, reward, next_obs, done):
        """
        Store one transition. If n-step is enabled, update the n-step buffers as well.
        
        Args:
            obs: [C,H,W] stacked observation (only the last frame is stored)
            action: action
            reward: reward
            next_obs: [C,H,W] next stacked observation (not stored; reconstructed from idx+1)
            done: episode termination flag
        """
        # Store only the most recent frame (extract from the stacked observation)
        # Extract latest frame from stacked obs and store as CPU numpy.
        if isinstance(obs, np.ndarray):
            frame_to_store = obs[-1] if obs.ndim == 3 else obs  # [H,W] or [1,H,W]
        else:  # torch tensor
            frame_to_store = obs[-1] if obs.ndim == 3 else obs
            frame_to_store = frame_to_store.detach().cpu().numpy()

        # Normalize shapes: accept [H,W] or [1,H,W]
        if isinstance(frame_to_store, np.ndarray) and frame_to_store.ndim == 3:
            frame_to_store = frame_to_store[0]

        # Normalize scalars
        if isinstance(action, torch.Tensor):
            action = int(action.item())
        if isinstance(reward, torch.Tensor):
            reward = float(reward.item())
        if isinstance(done, torch.Tensor):
            done = float(done.item())
        
        # Store a single frame as [1,H,W] in frames_buf
        self.frames_buf[self.ptr, 0] = frame_to_store  # [H,W]
        
        self.action_buf[self.ptr] = action
        self.reward_buf[self.ptr] = reward
        self.done_buf[self.ptr] = float(done)

        # For n-step Route A: we can now mark any pending entries whose s_{t+n} points to this index as valid.
        # This ensures we never reconstruct n_step_next_obs from an index that hasn't been written yet.
        if self.use_n_step and hasattr(self, "_nstep_pending"):
            while len(self._nstep_pending) > 0:
                start_idx, wait_next_idx = self._nstep_pending[0]
                if int(wait_next_idx) != int(self.ptr):
                    break
                self._nstep_pending.popleft()
                self.n_step_valid_buf[int(start_idx)] = True
        
        # Full-episode MC: compute return-to-go only when episode ends (done=True).
        if self.use_full_mc:
            self.mc_valid_buf[self.ptr] = False

            # If the circular buffer overwrote an index still referenced by the queue, drop the queue.
            if self.full and any(int(it[0]) == int(self.ptr) for it in self._mc_queue):
                self._mc_queue.clear()

            self._mc_queue.append((int(self.ptr), float(reward), float(done)))

            if bool(done):
                # Backward discounted return-to-go for the whole episode.
                G = 0.0
                for idx, r, d in reversed(self._mc_queue):
                    G = float(r) + float(self.gamma) * (1.0 - float(d)) * G
                    self.mc_return_buf[idx] = float(G)
                    self.mc_valid_buf[idx] = True
                self._mc_queue.clear()

        # If n-step enabled, compute/store n-step target info online (fast path).
        if self.use_n_step:
            # Reset validity for the newly written slot (important when overwriting in a circular buffer).
            self.n_step_valid_buf[self.ptr] = False
            self.n_step_next_idx_buf[self.ptr] = 0

            # Normalize next_obs to a uint8 stacked observation [C,H,W].
            # This avoids needing to look ahead in the circular buffer to build s_{t+n}.
            if isinstance(next_obs, torch.Tensor):
                next_obs_np = next_obs.detach().cpu().numpy()
            else:
                next_obs_np = np.asarray(next_obs)
            if next_obs_np.ndim == 4 and next_obs_np.shape[0] == 1:
                next_obs_np = next_obs_np[0]
            if next_obs_np.ndim != 3:
                raise ValueError(f"next_obs must have shape [C,H,W] (or [1,C,H,W]), got {next_obs_np.shape}")
            if next_obs_np.dtype != np.uint8:
                # In this codebase, stacked frames should be uint8; enforce to keep memory/correctness.
                next_obs_np = next_obs_np.astype(np.uint8, copy=False)

            # If the circular buffer overwrote an index still referenced by the queue, drop the queue.
            # (This should never happen with typical large Atari buffers, but keep it correct.)
            if self.full and any(int(it[0]) == int(self.ptr) for it in self._nstep_queue):
                self._nstep_queue.clear()
                self._nstep_return = 0.0

            # Append current transition to n-step queue.
            k = len(self._nstep_queue)  # current length before append
            self._nstep_return += float(reward) * (float(self.gamma) ** k)
            self._nstep_queue.append((int(self.ptr), float(reward), float(done), next_obs_np))

            if bool(done):
                # Flush the remaining queue with truncated horizons (<= n_step), using backward recursion (O(n_step)).
                # For all these starts, terminal is encountered within the window => n_step_done=1 and we do not bootstrap.
                G = 0.0
                terminal_idx = int(self._nstep_queue[-1][0])
                # Important: even though bootstrap is masked (n_step_done=1), the training code still runs
                # q_net(n_step_next_obs) before masking. So n_step_next_idx must point to a *valid* observation.
                # Use terminal_idx itself (always written) to avoid referencing an unwritten slot.
                terminal_next_idx = int(terminal_idx)
                for idx, r, d, _no in reversed(self._nstep_queue):
                    G = float(r) + float(self.gamma) * (1.0 - float(d)) * G
                    self.n_step_return_buf[idx] = float(G)
                    self.n_step_done_buf[idx] = 1.0
                    # For done-in-window samples, bootstrap term is masked by n_step_done=1.
                    # Still provide a consistent next_idx.
                    self.n_step_next_idx_buf[idx] = int(terminal_next_idx)
                    self.n_step_valid_buf[idx] = True
                self._nstep_queue.clear()
                self._nstep_return = 0.0
            else:
                # When we have exactly n_step transitions in the queue, finalize the oldest one.
                if len(self._nstep_queue) >= int(self.n_step):
                    start_idx, r0, _d0, _no0 = self._nstep_queue[0]
                    end_idx = int(self._nstep_queue[int(self.n_step) - 1][0])
                    end_next_idx = (end_idx + 1) % self.capacity

                    self.n_step_return_buf[start_idx] = float(self._nstep_return)
                    self.n_step_done_buf[start_idx] = 0.0
                    self.n_step_next_idx_buf[start_idx] = int(end_next_idx)
                    # Route A subtlety: end_next_idx's frame is only written when we *store* the next transition.
                    # So we delay marking this entry valid until that index is written.
                    self.n_step_valid_buf[start_idx] = False
                    self._nstep_pending.append((int(start_idx), int(end_next_idx)))

                    # Pop left and update running return:
                    # R' = (R - r0) / gamma
                    self._nstep_queue.popleft()
                    self._nstep_return = (float(self._nstep_return) - float(r0)) / float(self.gamma)
        
        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)
        self.full = (self.size == self.capacity)
    
    def sample(self, batch_size):
        """
        Uniformly sample a batch.

        Optimization: dynamically build obs and next_obs frame stacks on-the-fly.
        """
        # Sample indices with enough history (except at the very beginning)
        valid_idxs = []
        attempts = 0
        max_attempts = batch_size * 10
        
        while len(valid_idxs) < batch_size and attempts < max_attempts:
            idx = np.random.randint(0, self.size)
            if self._has_enough_history(idx):
                valid_idxs.append(idx)
            attempts += 1
        
        # If valid indices are insufficient, allow some fallback indices (may be filtered later)
        while len(valid_idxs) < batch_size:
            idx = np.random.randint(0, self.size)
            valid_idxs.append(idx)
        
        idxs = np.array(valid_idxs[:batch_size], dtype=np.int32)
        
        # Dynamically build obs and next_obs
        obs_list = []
        next_obs_list = []
        
        for idx in idxs:
            # Build current state's frame stack
            obs = self._encode_observation(idx)
            
            # Build next state's frame stack starting from idx+1
            next_idx = (idx + 1) % self.capacity
            
            # Check whether we crossed an episode boundary
            done = self.done_buf[idx]
            
            if done > 0.5:  # done=True: next_obs should be the first state of the next episode
                # For a new episode, repeat the first frame at next_idx (if valid).
                # If next_idx is invalid (buffer not yet filled), repeat idx's frame as a terminal-safe fallback.
                if not self.full:
                    if next_idx < self.size:
                        # next_idx is valid: repeat next_idx frame
                        next_frame = self.frames_buf[next_idx, 0]
                        next_obs = np.stack([next_frame] * self.frame_stack, axis=0)
                    else:
                        # next_idx invalid: repeat idx frame
                        next_frame = self.frames_buf[idx, 0]
                        next_obs = np.stack([next_frame] * self.frame_stack, axis=0)
                else:
                    # Buffer full: next_idx is always valid
                    next_frame = self.frames_buf[next_idx, 0]
                    next_obs = np.stack([next_frame] * self.frame_stack, axis=0)
            else:
                # Normal case: build next_obs from next_idx
                if not self.full:
                    if next_idx < self.size:
                        next_obs = self._encode_observation(next_idx)
                    else:
                        # next_idx invalid: repeat idx frame (should not happen, guard only)
                        next_frame = self.frames_buf[idx, 0]
                        next_obs = np.stack([next_frame] * self.frame_stack, axis=0)
                else:
                    # Buffer full: next_idx is always valid
                    next_obs = self._encode_observation(next_idx)
            
            obs_list.append(obs)
            next_obs_list.append(next_obs)
        
        # Convert to tensors
        obs_batch = torch.from_numpy(np.stack(obs_list, axis=0)).to(self.device)  # [B, C, H, W]
        next_obs_batch = torch.from_numpy(np.stack(next_obs_list, axis=0)).to(self.device)  # [B, C, H, W]
        
        batch = {
            'obs': obs_batch,
            'actions': torch.from_numpy(self.action_buf[idxs]).to(self.device),
            'rewards': torch.from_numpy(self.reward_buf[idxs]).to(self.device),
            'next_obs': next_obs_batch,
            'dones': torch.from_numpy(self.done_buf[idxs]).to(self.device),
            'idxs': idxs,  # indices (used for n-step or PER updates)
        }
        
        if self.use_n_step:
            batch['n_step_returns'] = torch.from_numpy(self.n_step_return_buf[idxs]).to(self.device)
            # Reconstruct n-step next observations on-the-fly from indices.
            n_next_obs_list = []
            next_idxs = self.n_step_next_idx_buf[idxs]
            for ni in next_idxs:
                n_next_obs_list.append(self._encode_observation(int(ni)))
            batch['n_step_next_obs'] = torch.from_numpy(np.stack(n_next_obs_list, axis=0)).to(self.device)
            batch['n_step_dones'] = torch.from_numpy(self.n_step_done_buf[idxs]).to(self.device)
            batch['n_step_valid'] = torch.from_numpy(self.n_step_valid_buf[idxs]).to(self.device)

        if self.use_full_mc:
            batch['mc_returns'] = torch.from_numpy(self.mc_return_buf[idxs]).to(self.device)
            batch['mc_valid'] = torch.from_numpy(self.mc_valid_buf[idxs]).to(self.device)
        
        return batch
    
    def get_n_step_return(self, idxs, n_step=None, gamma=None):
        """
        Get n-step returns.

        If n-step is enabled at initialization, values are read from precomputed buffers.
        Otherwise, dynamic computation would be required (legacy/backward-compat).

        Args:
            idxs: starting indices (batch_size,)
            n_step: n-step horizon (legacy; ignored if precomputed)
            gamma: discount factor (legacy; ignored if precomputed)

        Returns:
            n_step_returns: discounted n-step returns (batch_size,) without Q bootstrap
            n_step_dones: done flags at step t+n (batch_size,)
            n_step_next_obs: states s_{t+n} (batch_size, C, H, W)
            n_step_valid: validity mask (no episode boundary / in-range) (batch_size,)
        """
        if self.use_n_step:
            # Read directly from precomputed buffers
            return (
                torch.from_numpy(self.n_step_return_buf[idxs]).to(self.device),
                torch.from_numpy(self.n_step_done_buf[idxs]).to(self.device),
                # Build n-step next_obs on-the-fly from indices (Route A).
                torch.from_numpy(
                    np.stack([self._encode_observation(int(i)) for i in self.n_step_next_idx_buf[idxs]], axis=0)
                ).to(self.device),
                torch.from_numpy(self.n_step_valid_buf[idxs]).to(self.device),
            )
        else:
            # Backward compatibility: dynamic computation (not implemented here).
            if n_step is None or gamma is None:
                raise ValueError("n_step and gamma must be provided if buffer was not initialized with n-step support")
            # For efficiency, enable n-step at initialization.
            raise ValueError("n-step not enabled in buffer. Please initialize ReplayBuffer with n_step and gamma parameters.")
    
    def __len__(self):
        return self.size


class PrioritizedReplayBuffer(ReplayBuffer):
    """Prioritized Experience Replay buffer (PER)."""
    
    def __init__(self, capacity, state_shape, alpha=0.6, beta=0.4, device="cpu", n_step=None, gamma=None, *, full_mc: bool = False):
        """
        Args:
            alpha: priority exponent (0=uniform, 1=full prioritization)
            beta: importance-sampling correction exponent
            n_step: n-step horizon (when mixed_mc is enabled)
            gamma: discount factor (when mixed_mc is enabled)
        """
        super().__init__(capacity, state_shape, device, n_step=n_step, gamma=gamma, full_mc=full_mc)
        
        self.alpha = alpha
        self.beta = beta
        
        # Sum tree for efficient sampling
        tree_capacity = 1
        while tree_capacity < capacity:
            tree_capacity *= 2
        
        self.tree_capacity = tree_capacity
        self.sum_tree = np.zeros(2 * tree_capacity - 1, dtype=np.float32)
        self.max_priority = 1.0
    
    def store(self, obs, action, reward, next_obs, done):
        """Store transition and set max priority for the new item."""
        idx = self.ptr
        super().store(obs, action, reward, next_obs, done)
        
        # New items get max priority
        tree_idx = idx + self.tree_capacity - 1
        self._update_tree(tree_idx, self.max_priority ** self.alpha)
    
    def _update_tree(self, tree_idx, priority):
        """Update sum-tree (optimized to reduce array accesses)."""
        change = priority - self.sum_tree[tree_idx]
        self.sum_tree[tree_idx] = priority
        
        # Propagate to the root
        while tree_idx > 0:
            tree_idx = (tree_idx - 1) // 2
            self.sum_tree[tree_idx] += change
    
    def _get_leaf(self, value):
        """Find a leaf index for a cumulative value (optimized)."""
        parent_idx = 0
        tree_size = len(self.sum_tree)
        
        # Optimization: compute left child index once per loop
        while True:
            left_idx = 2 * parent_idx + 1
            
            # If left child is out of range, parent_idx is a leaf
            if left_idx >= tree_size:
                leaf_idx = parent_idx
                break
            
            # Direct comparison
            left_priority = self.sum_tree[left_idx]
            if value <= left_priority:
                parent_idx = left_idx
            else:
                value -= left_priority
                parent_idx = left_idx + 1  # right_idx = left_idx + 1
        
        data_idx = leaf_idx - self.tree_capacity + 1
        return data_idx, leaf_idx
    
    def sample(self, batch_size):
        """
        Prioritized sampling (optimized to reduce duplicates and overhead).

        Optimization: dynamically build obs and next_obs frame stacks on-the-fly.
        """
        idxs = []
        tree_idxs = []
        priorities = []
        
        # Total priority mass
        total_priority = self.sum_tree[0]
        if total_priority <= 0:
            # If total priority is 0, fall back to uniform sampling
            candidate_idxs = list(range(self.size))
            valid_candidates = [idx for idx in candidate_idxs if self._has_enough_history(idx)]
            if len(valid_candidates) > 0:
                sample_idxs = np.random.choice(valid_candidates, size=min(batch_size, len(valid_candidates)), replace=False)
            else:
                sample_idxs = np.random.choice(candidate_idxs, size=min(batch_size, len(candidate_idxs)), replace=False)
            idxs = list(sample_idxs)
            tree_idxs = [idx + self.tree_capacity - 1 for idx in idxs]
            priorities = np.ones(len(idxs), dtype=np.float32)
        else:
            segment = total_priority / batch_size
            
            # Stratified sampling (avoid duplicates; use a set for fast membership checks)
            seen_idxs = set()
            attempts = 0
            max_attempts = batch_size * 20  # allow more attempts (may need to skip invalid indices)
            
            while len(idxs) < batch_size and attempts < max_attempts:
                a = segment * len(idxs)
                b = segment * (len(idxs) + 1)
                value = np.random.uniform(a, b)
                
                data_idx, tree_idx = self._get_leaf(value)
                
                # Ensure in-range, non-duplicate, and (preferably) enough history
                if data_idx < self.size and data_idx not in seen_idxs:
                    # Prefer indices with enough history; allow some early fallbacks
                    if self._has_enough_history(data_idx) or len(idxs) < batch_size // 2:
                        idxs.append(data_idx)
                        tree_idxs.append(tree_idx)
                        priorities.append(self.sum_tree[tree_idx])
                        seen_idxs.add(data_idx)
                
                attempts += 1
            
            # If still short, fill the remainder with random indices (prefer valid history)
            if len(idxs) < batch_size:
                remaining = batch_size - len(idxs)
                available = [i for i in range(self.size) if i not in seen_idxs]
                if len(available) > 0:
                    # Prefer indices with enough history
                    valid_available = [i for i in available if self._has_enough_history(i)]
                    if len(valid_available) >= remaining:
                        additional = np.random.choice(valid_available, size=remaining, replace=False)
                    else:
                        # If not enough valid indices, include some invalid ones
                        additional = np.random.choice(available, size=min(remaining, len(available)), replace=False)
                    
                    for idx in additional:
                        idxs.append(idx)
                        tree_idxs.append(idx + self.tree_capacity - 1)
                        priorities.append(self.sum_tree[idx + self.tree_capacity - 1])
        
        # Convert to numpy arrays
        idxs = np.array(idxs, dtype=np.int32)
        priorities = np.array(priorities, dtype=np.float32)
        
        # Importance-sampling weights (vectorized)
        if total_priority > 0:
            sampling_probs = priorities / total_priority
            weights = (self.size * sampling_probs) ** (-self.beta)
            weights /= weights.max()  # normalize
        else:
            weights = np.ones(len(idxs), dtype=np.float32)
        
        # Dynamically build obs and next_obs
        obs_list = []
        next_obs_list = []
        
        for idx in idxs:
            # Build current state's frame stack
            obs = self._encode_observation(idx)
            
            # Build next state's frame stack starting from idx+1
            next_idx = (idx + 1) % self.capacity
            
            # Check whether we crossed an episode boundary
            done = self.done_buf[idx]
            
            if done > 0.5:  # done=True: next_obs should be the first state of the next episode
                # For a new episode, repeat the frame at next_idx (if valid).
                # If next_idx is invalid (buffer not yet filled), repeat idx's frame as a terminal-safe fallback.
                if not self.full:
                    if next_idx < self.size:
                        # next_idx is valid
                        next_frame = self.frames_buf[next_idx, 0]
                        next_obs = np.stack([next_frame] * self.frame_stack, axis=0)
                    else:
                        # next_idx invalid
                        next_frame = self.frames_buf[idx, 0]
                        next_obs = np.stack([next_frame] * self.frame_stack, axis=0)
                else:
                    # Buffer full: next_idx always valid
                    next_frame = self.frames_buf[next_idx, 0]
                    next_obs = np.stack([next_frame] * self.frame_stack, axis=0)
            else:
                # Normal case: build next_obs from next_idx
                if not self.full:
                    if next_idx < self.size:
                        next_obs = self._encode_observation(next_idx)
                    else:
                        # next_idx invalid (should not happen; guard only)
                        next_frame = self.frames_buf[idx, 0]
                        next_obs = np.stack([next_frame] * self.frame_stack, axis=0)
                else:
                    # Buffer full: next_idx always valid
                    next_obs = self._encode_observation(next_idx)
            
            obs_list.append(obs)
            next_obs_list.append(next_obs)
        
        # Convert to tensors
        obs_batch = torch.from_numpy(np.stack(obs_list, axis=0)).to(self.device)  # [B, C, H, W]
        next_obs_batch = torch.from_numpy(np.stack(next_obs_list, axis=0)).to(self.device)  # [B, C, H, W]
        
        batch = {
            'obs': obs_batch,
            'actions': torch.from_numpy(self.action_buf[idxs]).to(self.device),
            'rewards': torch.from_numpy(self.reward_buf[idxs]).to(self.device),
            'next_obs': next_obs_batch,
            'dones': torch.from_numpy(self.done_buf[idxs]).to(self.device),
            'idxs': idxs,
            'tree_idxs': tree_idxs,
            'weights': torch.from_numpy(weights).float().to(self.device),
        }
        
        if self.use_n_step:
            batch['n_step_returns'] = torch.from_numpy(self.n_step_return_buf[idxs]).to(self.device)
            n_next_obs_list = []
            next_idxs = self.n_step_next_idx_buf[idxs]
            for ni in next_idxs:
                n_next_obs_list.append(self._encode_observation(int(ni)))
            batch['n_step_next_obs'] = torch.from_numpy(np.stack(n_next_obs_list, axis=0)).to(self.device)
            batch['n_step_dones'] = torch.from_numpy(self.n_step_done_buf[idxs]).to(self.device)
            batch['n_step_valid'] = torch.from_numpy(self.n_step_valid_buf[idxs]).to(self.device)
        
        return batch
    
    def update_priorities(self, tree_idxs, priorities):
        """Update priorities (batch-optimized)."""
        # Batch processing: compute priorities first, then update the tree
        priorities = np.asarray(priorities, dtype=np.float32)
        priorities = np.maximum(priorities, 1e-6)  # avoid zero priorities
        priorities_alpha = priorities ** self.alpha
        
        # Update max_priority
        self.max_priority = max(self.max_priority, float(priorities.max()))
        
        # Update tree nodes (still per-node, but reduced compute overhead)
        for tree_idx, priority_alpha in zip(tree_idxs, priorities_alpha):
            self._update_tree(tree_idx, priority_alpha)

