"""
SCSD v3: Cross-Block Speculative Decoding with Fixed Draft Length
Key improvements over v2:
1. Fixed draft length k (default 4) instead of per-block processing
2. Cross-block draft selection when current block has < k remaining tokens
3. Hierarchical verification tree respecting block priorities
4. Extensible design for future multi-candidate support
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import List, Tuple, Optional, Dict, Set, Union
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from transformers import AutoTokenizer

@dataclass
class DraftCandidate:
    """
    Represents a candidate token for draft generation.
    Designed for future extension to support multiple token candidates per position.
    """
    position: int  # Absolute position in sequence
    block_id: int  # Which block this position belongs to
    token: int  # The proposed token (currently only top-1)
    confidence: float  # Confidence score
    # Future extension: alternative tokens and their confidences
    alternative_tokens: List[int] = field(default_factory=list)
    alternative_confidences: List[float] = field(default_factory=list)
    

@dataclass
class VerificationNode:
    """
    Node in the verification tree.
    Enhanced to support cross-block verification and future multi-candidate paths.
    """
    confirmed_positions: List[int]  # Absolute positions that are confirmed
    confirmed_blocks: Dict[int, List[int]] = field(default_factory=dict)  # {block_id: [positions]}
    children: List['VerificationNode'] = field(default_factory=list)
    verify_position: Optional[int] = None  # Position to verify at this node
    expected_token: Optional[int] = None  # Expected token at verify_position
    parent: Optional['VerificationNode'] = None
    sequence_tensor: Optional[torch.Tensor] = None  # Pre-filled sequence
    mask_positions: List[int] = field(default_factory=list)  # Positions still masked
    draft_tokens_dict: Dict[int, int] = field(default_factory=dict)  # Position -> token mapping
    
    # For future extensions
    alternative_tokens: Dict[int, List[int]] = field(default_factory=dict)  # {position: [candidates]}
    priority_score: float = 0.0  # For dynamic traversal ordering
    path_probability: float = 1.0  # Joint probability of the path
    

class TreeBuildStrategy(ABC):
    """
    Abstract base class for tree building strategies.
    Allows for different approaches to constructing verification trees.
    """
    @abstractmethod
    def build_tree(
        self, 
        candidates: List[DraftCandidate],
        confirmed_positions: List[int],
        x: torch.Tensor,
        prompt_length: int,
        mask_id: int,
        verbose: bool = False,
        all_mask_positions: List[int] = None
    ) -> VerificationNode:
        """Build verification tree from draft candidates."""
        pass
    

class GreedyStrategy(TreeBuildStrategy):
    """
    Greedy strategy: Build linear chain using highest confidence tokens.
    Single-branch tree: strictly follows block priority and confidence ordering.
    """
    def build_tree(
        self,
        candidates: List[DraftCandidate],
        confirmed_positions: List[int],
        x: torch.Tensor,
        prompt_length: int,
        mask_id: int,
        verbose: bool = False,
        all_mask_positions: List[int] = None
    ) -> VerificationNode:
        """
        Build linear verification chain with block priority and confidence ordering.
        
        Creates a single path from root to leaf, where each node represents
        accepting one more draft token in the priority order.
        """
        # Create root node - represents the starting state with NO new confirmations
        # Use all_mask_positions if provided, otherwise extract from x
        if all_mask_positions is None:
            # Extract all mask positions from x
            all_masks = (x == mask_id).nonzero(as_tuple=True)[1].tolist()
            all_mask_positions = [p for p in all_masks if p >= prompt_length]
        
        root = VerificationNode( 
            confirmed_positions=confirmed_positions.copy(),  # Only previously confirmed
            mask_positions=all_mask_positions.copy()  # ALL remaining mask positions
        )
        
        # Build draft tokens dictionary
        draft_dict = {c.position: c.token for c in candidates}
        root.draft_tokens_dict = draft_dict
        
        # Build base sequence for root - should NOT include any draft tokens yet
        root_seq = x.clone()  # x already has confirmed tokens filled
        root.sequence_tensor = root_seq
        
        if len(candidates) == 0:
            return root
        
        # Sort candidates by block priority then confidence
        sorted_candidates = sorted(
            candidates,
            key=lambda c: (c.block_id, -c.confidence)
        )
        
        if verbose:
            print(f"\n    === Building Verification Tree (Linear Chain) ===")
            print(f"    Current state: {len(confirmed_positions)} positions already confirmed")
            print(f"    Draft candidates to verify (sorted by block priority):")
            for i, c in enumerate(sorted_candidates):
                print(f"      [{i}] Pos {c.position} (block {c.block_id}): "
                      f"token={c.token}, conf={c.confidence:.4f}")
            
            # Show top confidence token across all candidates
            top_conf_candidate = max(candidates, key=lambda c: c.confidence)
            if top_conf_candidate != sorted_candidates[0]:
                print(f"\n    Note: Top confidence token differs from first priority:")
                print(f"      Highest confidence: Pos {top_conf_candidate.position} "
                      f"(block {top_conf_candidate.block_id}), conf={top_conf_candidate.confidence:.4f}")
                print(f"      First by priority: Pos {sorted_candidates[0].position} "
                      f"(block {sorted_candidates[0].block_id}), conf={sorted_candidates[0].confidence:.4f}")
            
            print(f"\n    Verification strategy (Single-Branch):")
            print(f"    - Linear chain: each node has at most one child")
            print(f"    - Strictly follows priority order (no alternatives)")
            print(f"    - Verification stops at first mismatch")
        
        # Build linear chain
        self._build_linear_chain(root, sorted_candidates, x, mask_id, draft_dict)
        
        if verbose:
            print(f"\n    Tree Statistics:")
            print(f"      Total nodes: {self._count_nodes(root)} (linear chain)")
            print(f"      Chain length: {self._get_tree_depth(root)}")
            print(f"      Structure: Single path from root to leaf")
            print(f"\n    Tree Structure Visualization:")
            self._print_tree(root, indent="      ", verbose=verbose)
        
        return root
    
    def _build_linear_chain(
        self,
        parent: VerificationNode,
        sorted_candidates: List[DraftCandidate],
        x: torch.Tensor,
        mask_id: int,
        draft_dict: Dict[int, int]
    ):
        """Build a linear chain of nodes.
        
        Each node has exactly one child, following the strict priority order.
        No alternative branches are explored.
        """
        current_node = parent
        accumulated_confirmed = parent.confirmed_positions.copy()
        
        # Get all mask positions from parent (not just the candidates)
        all_mask_positions = parent.mask_positions.copy()
        
        # Build single chain following the priority order
        for i, candidate in enumerate(sorted_candidates):
            # Add this candidate to confirmed positions
            accumulated_confirmed.append(candidate.position)
            
            # Remove this position from all masks
            if candidate.position in all_mask_positions:
                all_mask_positions.remove(candidate.position)
            
            # Remaining draft positions after this one (for tree structure)
            remaining_draft_positions = [c.position for c in sorted_candidates[i+1:]]
            
            # Create the single child node
            child = VerificationNode(
                confirmed_positions=accumulated_confirmed.copy(),
                verify_position=candidate.position,
                expected_token=candidate.token,
                parent=current_node,
                mask_positions=all_mask_positions.copy(),  # All remaining masks, not just draft ones
                draft_tokens_dict=draft_dict.copy()
            )
            
            # Track confirmed blocks
            child.confirmed_blocks = current_node.confirmed_blocks.copy() if current_node.confirmed_blocks else {}
            if candidate.block_id not in child.confirmed_blocks:
                child.confirmed_blocks[candidate.block_id] = []
            child.confirmed_blocks[candidate.block_id].append(candidate.position)
            
            # Build sequence with all accumulated candidates filled in
            child_seq = x.clone()
            for pos in accumulated_confirmed:
                if pos in draft_dict:
                    child_seq[:, pos] = draft_dict[pos]
            child.sequence_tensor = child_seq
            
            # Add as the only child
            current_node.children.append(child)
            
            # Move to the child for next iteration
            current_node = child
    
    def _count_nodes(self, node: VerificationNode) -> int:
        """Count total nodes in tree."""
        count = 1
        for child in node.children:
            count += self._count_nodes(child)
        return count
    
    def _get_tree_depth(self, node: VerificationNode, depth: int = 0) -> int:
        """Get maximum depth of tree."""
        if not node.children:
            return depth
        return max(self._get_tree_depth(child, depth + 1) for child in node.children)
    
    def _get_root(self, node: VerificationNode) -> VerificationNode:
        """Get root node of tree."""
        while node.parent is not None:
            node = node.parent
        return node
    
    def _print_tree(self, node: VerificationNode, indent: str = "", is_last: bool = True, 
                    is_root: bool = True, verbose: bool = False, depth: int = 0):
        """Print tree structure in a visual format."""
        if is_root:
            num_draft = len(node.mask_positions)
            print(f"{indent}[ROOT] State: {len(node.confirmed_positions)} confirmed (from previous iterations)")
            print(f"{indent}       To verify: {num_draft} draft positions")
            if verbose and num_draft > 0 and num_draft <= 5:
                print(f"{indent}       Draft positions: {node.mask_positions}")
        else:
            # Print branch
            branch = "└── " if is_last else "├── "
            if node.verify_position is not None:
                # Show cumulative draft tokens accepted
                root_confirmed = len(self._get_root(node).confirmed_positions)
                draft_accepted = len(node.confirmed_positions) - root_confirmed
                
                print(f"{indent}{branch}[Level {depth}: +{draft_accepted} draft] "
                      f"Accept pos {node.verify_position} = {node.expected_token}")
                
                if verbose:
                    # Show remaining masks
                    remaining = len(node.mask_positions)
                    if remaining > 0:
                        print(f"{indent}{'    ' if is_last else '│   '}"
                              f"  → {remaining} positions still to verify")
        
        # Print children
        for i, child in enumerate(node.children):
            is_last_child = (i == len(node.children) - 1)
            new_indent = indent + ("    " if is_last else "│   ") if not is_root else indent
            self._print_tree(child, new_indent, is_last_child, False, verbose, depth + 1)


class GreedyWithAlternativesStrategy(TreeBuildStrategy):
    """
    Greedy strategy with alternative branches.
    Builds a main greedy chain, but at each node adds alternative branches
    for the next best tokens (based on confidence).
    Alternative branches are leaf nodes (not expanded further).
    """
    def __init__(self, num_alternatives: int = 1):
        """
        Args:
            num_alternatives: Total number of alternatives to consider at each position.
                             1 = greedy only (no alternatives)
                             2 = greedy + 1 alternative
                             3 = greedy + 2 alternatives, etc.
        """
        self.num_alternatives = max(1, num_alternatives)
    
    def build_tree(
        self,
        candidates: List[DraftCandidate],
        confirmed_positions: List[int],
        x: torch.Tensor,
        prompt_length: int,
        mask_id: int,
        verbose: bool = False,
        all_mask_positions: List[int] = None
    ) -> VerificationNode:
        """
        Build verification tree with main chain and alternative branches.
        
        Creates a tree where:
        - Main chain follows greedy (highest confidence) path
        - At each main chain node, add alternative branches for next-best tokens
        - Alternative branches are leaf nodes (not expanded)
        """
        # Create root node - same as GreedyStrategy
        if all_mask_positions is None:
            all_masks = (x == mask_id).nonzero(as_tuple=True)[1].tolist()
            all_mask_positions = [p for p in all_masks if p >= prompt_length]
        
        root = VerificationNode(
            confirmed_positions=confirmed_positions.copy(),
            mask_positions=all_mask_positions.copy()
        )
        
        # Build draft tokens dictionary
        draft_dict = {c.position: c.token for c in candidates}
        root.draft_tokens_dict = draft_dict
        
        # Build base sequence for root
        root_seq = x.clone()
        root.sequence_tensor = root_seq
        
        if len(candidates) == 0:
            return root
        
        # Sort candidates by block priority then confidence
        sorted_candidates = sorted(
            candidates,
            key=lambda c: (c.block_id, -c.confidence)
        )
        
        if verbose:
            print(f"\n    === Building Verification Tree (Greedy with Alternatives) ===")
            print(f"    Number of alternatives per position: {self.num_alternatives}")
            print(f"    Current state: {len(confirmed_positions)} positions already confirmed")
            print(f"    Draft candidates to verify (sorted by block priority):")
            for i, c in enumerate(sorted_candidates):
                print(f"      [{i}] Pos {c.position} (block {c.block_id}): "
                      f"token={c.token}, conf={c.confidence:.4f}")
            
            if self.num_alternatives > 1:
                print(f"\n    Verification strategy (Multi-Branch):")
                print(f"    - Main chain: greedy path with highest confidence")
                print(f"    - Alternative branches: {self.num_alternatives - 1} next-best tokens at each node")
                print(f"    - Alternatives are leaf nodes (not expanded further)")
            else:
                print(f"\n    Verification strategy (Single-Branch):")
                print(f"    - Linear chain: each node has at most one child")
                print(f"    - No alternative branches (num_alternatives=1)")
        
        # Build tree with alternatives
        self._build_tree_with_alternatives(root, sorted_candidates, x, mask_id, draft_dict, verbose)
        
        if verbose:
            print(f"\n    Tree Statistics:")
            total_nodes = self._count_nodes(root)
            print(f"      Total nodes: {total_nodes}")
            print(f"      Main chain length: {self._get_main_chain_depth(root)}")
            if self.num_alternatives > 1:
                alt_count = self._count_alternative_nodes(root)
                print(f"      Alternative branches: {alt_count}")
                print(f"      Structure: Main chain with alternative branches")
            else:
                print(f"      Structure: Single linear path")
            print(f"\n    Tree Structure Visualization:")
            self._print_tree(root, indent="      ", verbose=verbose)
        
        return root
    
    def _build_tree_with_alternatives(
        self,
        parent: VerificationNode,
        sorted_candidates: List[DraftCandidate],
        x: torch.Tensor,
        mask_id: int,
        draft_dict: Dict[int, int],
        verbose: bool
    ):
        """Build tree with main chain and alternative branches."""
        current_node = parent
        accumulated_confirmed = parent.confirmed_positions.copy()
        all_mask_positions = parent.mask_positions.copy()
        
        # Build main chain with alternatives at each step
        for i, candidate in enumerate(sorted_candidates):
            # Get candidates at the same position for alternatives
            # We need to find other candidates that could fill the same position
            position_candidates = self._get_position_alternatives(
                candidate.position, 
                sorted_candidates[i:],  # Only look at remaining candidates
                self.num_alternatives
            )
            
            # Add main candidate to confirmed positions
            accumulated_confirmed.append(candidate.position)
            
            # Remove this position from all masks
            if candidate.position in all_mask_positions:
                all_mask_positions.remove(candidate.position)
            
            # Create main chain child (greedy choice)
            main_child = VerificationNode(
                confirmed_positions=accumulated_confirmed.copy(),
                verify_position=candidate.position,
                expected_token=candidate.token,
                parent=current_node,
                mask_positions=all_mask_positions.copy(),
                draft_tokens_dict=draft_dict.copy()
            )
            
            # Track confirmed blocks
            main_child.confirmed_blocks = current_node.confirmed_blocks.copy() if current_node.confirmed_blocks else {}
            if candidate.block_id not in main_child.confirmed_blocks:
                main_child.confirmed_blocks[candidate.block_id] = []
            main_child.confirmed_blocks[candidate.block_id].append(candidate.position)
            
            # Build sequence with accumulated candidates
            child_seq = x.clone()
            for pos in accumulated_confirmed:
                if pos in draft_dict:
                    child_seq[:, pos] = draft_dict[pos]
            main_child.sequence_tensor = child_seq
            
            # Add main child
            current_node.children.append(main_child)
            
            # Add alternative branches (if num_alternatives > 1)
            if self.num_alternatives > 1 and len(position_candidates) > 1:
                # Skip the first one (it's the main candidate)
                for alt_idx, alt_candidate in enumerate(position_candidates[1:], 1):
                    if alt_idx >= self.num_alternatives:
                        break
                    
                    # Alternative branch - these are leaf nodes
                    # They have the same confirmed positions as parent
                    alt_confirmed = current_node.confirmed_positions.copy()
                    alt_confirmed.append(alt_candidate.position)
                    
                    alt_child = VerificationNode(
                        confirmed_positions=alt_confirmed.copy(),
                        verify_position=alt_candidate.position,
                        expected_token=alt_candidate.token,
                        parent=current_node,
                        mask_positions=all_mask_positions.copy(),
                        draft_tokens_dict=draft_dict.copy()
                    )
                    
                    # Track confirmed blocks for alternative
                    alt_child.confirmed_blocks = current_node.confirmed_blocks.copy() if current_node.confirmed_blocks else {}
                    if alt_candidate.block_id not in alt_child.confirmed_blocks:
                        alt_child.confirmed_blocks[alt_candidate.block_id] = []
                    alt_child.confirmed_blocks[alt_candidate.block_id].append(alt_candidate.position)
                    
                    # Build sequence for alternative (parent's confirmed + this alternative)
                    alt_seq = x.clone()
                    for pos in current_node.confirmed_positions:
                        if pos in draft_dict:
                            alt_seq[:, pos] = draft_dict[pos]
                    # Add the alternative token
                    alt_seq[:, alt_candidate.position] = alt_candidate.token
                    alt_child.sequence_tensor = alt_seq
                    
                    # Mark as alternative (for visualization)
                    alt_child.priority_score = -alt_idx  # Negative to indicate alternative
                    
                    # Add alternative child
                    current_node.children.append(alt_child)
                    
                    if verbose:
                        print(f"      Added alternative branch at pos {alt_candidate.position}: "
                              f"token={alt_candidate.token}, conf={alt_candidate.confidence:.4f}")
            
            # Move to main child for next iteration
            current_node = main_child
    
    def _get_position_alternatives(
        self, 
        position: int, 
        remaining_candidates: List[DraftCandidate],
        max_alternatives: int
    ) -> List[DraftCandidate]:
        """
        Get alternative candidates for a specific position.
        
        Since we're working with a fixed draft where each position has one token,
        we need to generate alternatives based on confidence scores.
        For now, return just the main candidate.
        
        In a more sophisticated version, we could:
        1. Store multiple token candidates per position in DraftCandidate
        2. Use the alternative_tokens field in DraftCandidate
        """
        # For now, just return the first candidate at this position
        # In future, this could return multiple token options if available
        position_candidates = [c for c in remaining_candidates if c.position == position]
        
        # If we have alternative tokens stored in the candidate, use them
        if position_candidates and position_candidates[0].alternative_tokens:
            main_candidate = position_candidates[0]
            alternatives = [main_candidate]  # Main candidate first
            
            # Create additional candidates from alternative tokens
            for i, (alt_token, alt_conf) in enumerate(zip(
                main_candidate.alternative_tokens[:max_alternatives-1],
                main_candidate.alternative_confidences[:max_alternatives-1]
            )):
                alt_candidate = DraftCandidate(
                    position=main_candidate.position,
                    block_id=main_candidate.block_id,
                    token=alt_token,
                    confidence=alt_conf,
                    alternative_tokens=[],
                    alternative_confidences=[]
                )
                alternatives.append(alt_candidate)
            
            return alternatives
        
        # Default: just return what we have
        return position_candidates[:max_alternatives] if position_candidates else []
    
    def _count_nodes(self, node: VerificationNode) -> int:
        """Count total nodes in tree."""
        count = 1
        for child in node.children:
            count += self._count_nodes(child)
        return count
    
    def _get_main_chain_depth(self, node: VerificationNode, depth: int = 0) -> int:
        """Get depth of main chain (following first child)."""
        if not node.children:
            return depth
        # Follow the first child (main chain)
        return self._get_main_chain_depth(node.children[0], depth + 1)
    
    def _count_alternative_nodes(self, node: VerificationNode) -> int:
        """Count alternative branch nodes."""
        count = 0
        # Alternative nodes have negative priority_score
        if hasattr(node, 'priority_score') and node.priority_score < 0:
            count = 1
        for child in node.children:
            count += self._count_alternative_nodes(child)
        return count
    
    def _print_tree(self, node: VerificationNode, indent: str = "", is_last: bool = True, 
                    is_root: bool = True, verbose: bool = False, depth: int = 0):
        """Print tree structure in a visual format."""
        if is_root:
            num_draft = len(node.mask_positions)
            print(f"{indent}[ROOT] State: {len(node.confirmed_positions)} confirmed (from previous iterations)")
            print(f"{indent}       To verify: {num_draft} draft positions")
            if verbose and num_draft > 0 and num_draft <= 5:
                print(f"{indent}       Draft positions: {node.mask_positions}")
        else:
            # Check if this is an alternative branch
            is_alternative = hasattr(node, 'priority_score') and node.priority_score < 0
            
            # Print branch
            branch = "└── " if is_last else "├── "
            if node.verify_position is not None:
                # Show cumulative draft tokens accepted
                root = self._get_root(node)
                root_confirmed = len(root.confirmed_positions)
                draft_accepted = len(node.confirmed_positions) - root_confirmed
                
                node_type = "[ALT]" if is_alternative else f"[Level {depth}: +{draft_accepted} draft]"
                print(f"{indent}{branch}{node_type} "
                      f"Accept pos {node.verify_position} = {node.expected_token}")
                
                if verbose:
                    # Show remaining masks
                    remaining = len(node.mask_positions)
                    if remaining > 0 and not is_alternative:  # Don't show for alternatives
                        print(f"{indent}{'    ' if is_last else '│   '}"
                              f"  → {remaining} positions still to verify")
        
        # Print children (main chain first, then alternatives)
        main_children = [c for c in node.children if not hasattr(c, 'priority_score') or c.priority_score >= 0]
        alt_children = [c for c in node.children if hasattr(c, 'priority_score') and c.priority_score < 0]
        
        # Use id comparison to avoid tensor comparison issues
        main_children_ids = [id(c) for c in main_children]
        
        all_children = main_children + alt_children
        for i, child in enumerate(all_children):
            is_last_child = (i == len(all_children) - 1)
            new_indent = indent + ("    " if is_last else "│   ") if not is_root else indent
            # Only increment depth for main chain nodes (use id comparison)
            new_depth = depth + 1 if id(child) in main_children_ids else depth
            self._print_tree(child, new_indent, is_last_child, False, verbose, new_depth)
    
    def _get_root(self, node: VerificationNode) -> VerificationNode:
        """Get root node of tree."""
        while node.parent is not None:
            node = node.parent
        return node


def add_gumbel_noise(logits, temperature):
    """Add Gumbel noise for sampling."""
    if temperature == 0:
        return logits.exp()
    noise = torch.rand_like(logits)
    gumbel_noise = (-torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def _generate_scsd_v3(
    input_ids,
    attention_mask,
    model,
    gen_length=128,
    block_length=None,  # Block length for block ID calculation
    temperature=0.0,
    cfg_scale=0.0,
    mask_id=126336,
    draft_length=4,  # Fixed draft length k
    tree_strategy='greedy',  # Tree building strategy
    num_alternatives=1,  # Number of alternatives per position (for greedy_with_alternatives)
    verbose=False,
    **kwargs  # For compatibility
):
    """
    SCSD v3 implementation with cross-block draft generation.
    
    Key algorithm:
    1. Generate initial draft for ALL positions (one-time)
    2. Iteratively select k positions for verification
    3. Build verification tree and verify
    4. Update draft based on verification results
    5. Repeat until all positions are confirmed
    """
    with torch.no_grad():
        batch_size, prompt_length = input_ids.shape
        x = torch.full(
            (batch_size, prompt_length + gen_length),
            mask_id,
            dtype=torch.long,
            device=model.device,
        )
        x[:, :prompt_length] = input_ids
        prompt_index = x != mask_id
        # Initialize strategy
        if tree_strategy == 'greedy':
            strategy = GreedyStrategy()
        elif tree_strategy == 'greedy_with_alternatives':
            strategy = GreedyWithAlternativesStrategy(num_alternatives=num_alternatives)
        else:
            raise ValueError(f"Unknown tree strategy: {tree_strategy}")
        
        total_forward_passes = 0
        iteration = 0
        confirmed_positions = []  # Global confirmed positions
        
        # Generate initial draft for ALL positions
        if verbose:
            print(f"Generating initial draft for all {gen_length} positions...")
        
        # Check if we need alternatives
        need_alternatives = (tree_strategy == 'greedy_with_alternatives' and num_alternatives > 1)
        
        draft_result = generate_initial_full_draft(
            x=x,
            model=model,
            attention_mask=attention_mask,
            prompt_length=prompt_length,
            prompt_index=prompt_index,
            temperature=temperature,
            cfg_scale=cfg_scale,
            mask_id=mask_id,
            verbose=verbose,
            return_alternatives=need_alternatives,
            num_alternatives=num_alternatives
        )
        
        # Unpack results based on whether alternatives were requested
        if need_alternatives:
            all_draft_tokens, all_confidences, initial_forwards, alternatives_dict = draft_result
        else:
            all_draft_tokens, all_confidences, initial_forwards = draft_result
            alternatives_dict = None
        
        total_forward_passes += initial_forwards
        
        if verbose:
            print(f"Initial draft generated with 1 forward pass")
        
        # Accept the first token with highest confidence (respecting block priority)
        mask_positions = (x == mask_id).nonzero(as_tuple=True)[1].tolist()
        mask_positions = [p for p in mask_positions if p >= prompt_length]
        
        if mask_positions:
            # Create candidates for all positions
            first_candidates = []
            for pos in mask_positions:
                rel_pos = pos - prompt_length
                block_id = rel_pos // block_length if block_length and block_length > 0 else 0
                candidate = DraftCandidate(
                    position=pos,
                    block_id=block_id,
                    token=all_draft_tokens[0, rel_pos].item(),
                    confidence=all_confidences[0, rel_pos].item()
                )
                
                # Add alternatives if available
                if alternatives_dict is not None:
                    alt_tokens = alternatives_dict['tokens'][0, rel_pos]
                    alt_probs = alternatives_dict['probs'][0, rel_pos]
                    
                    # Filter out the draft token from alternatives to avoid duplicates
                    draft_token = all_draft_tokens[0, rel_pos].item()
                    alternative_tokens = []
                    alternative_confidences = []
                    
                    for alt_tok, alt_prob in zip(alt_tokens.tolist(), alt_probs.tolist()):
                        if alt_tok != draft_token and len(alternative_tokens) < num_alternatives - 1:
                            alternative_tokens.append(alt_tok)
                            alternative_confidences.append(alt_prob)
                    
                    if alternative_tokens:
                        candidate.alternative_tokens = alternative_tokens
                        candidate.alternative_confidences = alternative_confidences
                
                first_candidates.append(candidate)
            
            # Sort by block priority then confidence
            first_candidates.sort(key=lambda c: (c.block_id, -c.confidence))
            
            # Accept the first one
            first_token = first_candidates[0]
            x[:, first_token.position] = first_token.token
            confirmed_positions.append(first_token.position)
            
            if verbose:
                print(f"\n✨ Initial token confirmed:")
                print(f"    Pos {first_token.position} (block {first_token.block_id}): "
                      f"token={first_token.token}, conf={first_token.confidence:.4f}")
        
        # Main iteration loop
        while True:
            # Count remaining masks
            mask_positions = (x == mask_id).nonzero(as_tuple=True)[1].tolist()
            mask_positions = [p for p in mask_positions if p >= prompt_length]
            
            if not mask_positions:
                if verbose:
                    print(f"\n✅ Generation complete after {iteration} iterations!")
                    print(f"   Total forward passes: {total_forward_passes}")
                break
            
            iteration += 1
            if verbose:
                print(f"\n{'='*60}")
                print(f"ITERATION {iteration}: {len(mask_positions)} masks remaining")
                print(f"{'='*60}")
            
            # Select up to draft_length positions for this iteration
            candidates = select_draft_candidates(
                mask_positions=mask_positions,
                all_draft_tokens=all_draft_tokens,
                all_confidences=all_confidences,
                draft_length=draft_length,
                prompt_length=prompt_length,
                block_length=block_length,
                verbose=verbose,
                alternatives_dict=alternatives_dict,
                num_alternatives=num_alternatives
            )
            
            if not candidates:
                break
            
            # When remaining masks < draft_length, use standard decode
            if len(mask_positions) < draft_length:
                if verbose:
                    print(f"\n  Switching to standard decode ({len(mask_positions)} masks < draft_length {draft_length})")
                
                std_forwards = _decode_remaining_standard_v3(
                    x=x,
                    model=model,
                    attention_mask=attention_mask,
                    mask_positions=mask_positions.copy(),  # Pass a copy since we'll modify it
                    prompt_length=prompt_length,
                    prompt_index=prompt_index,
                    temperature=temperature,
                    cfg_scale=cfg_scale,
                    mask_id=mask_id,
                    all_draft_tokens=all_draft_tokens,
                    all_confidences=all_confidences,
                    block_length=block_length,
                    confirmed_positions=confirmed_positions,
                    verbose=verbose
                )
                total_forward_passes += std_forwards
                
                # Standard decode completes all remaining masks
                break
            
            # Build verification tree and verify
            new_confirmed, confirmed_tokens, new_draft, verify_forwards = cascaded_verify_v3(
                x=x,
                model=model,
                attention_mask=attention_mask,
                candidates=candidates,
                confirmed_positions=confirmed_positions,
                all_draft_tokens=all_draft_tokens,
                all_confidences=all_confidences,
                strategy=strategy,
                prompt_length=prompt_length,
                prompt_index=prompt_index,
                temperature=temperature,
                cfg_scale=cfg_scale,
                mask_id=mask_id,
                iteration=iteration,
                block_length=block_length,
                all_mask_positions=mask_positions,  # Pass all remaining masks
                verbose=verbose,
                need_alternatives=need_alternatives,
                num_alternatives=num_alternatives
            )
            total_forward_passes += verify_forwards
            
            # Accept newly confirmed tokens - use verified token values
            for pos in new_confirmed:
                if pos not in confirmed_positions:
                    # Use the verified token value, NOT the draft value
                    token_value = confirmed_tokens[pos]
                    x[:, pos] = token_value
                    confirmed_positions.append(pos)
                    if verbose:
                        rel_pos = pos - prompt_length
                        block_id = rel_pos // block_length if block_length else 0
                        print(f"    ✓ Confirmed: Pos {pos} (block {block_id}) = token {token_value}")
            
            # Update draft if new predictions were made (after confirming tokens)
            if new_draft is not None:
                all_draft_tokens = new_draft['tokens']
                all_confidences = new_draft['confidences']
                # Update alternatives if present
                if 'alternatives' in new_draft:
                    alternatives_dict = new_draft['alternatives']
                else:
                    alternatives_dict = None
            
            # import sys
            # print(f"\033[91m all_draft_tokens len:\033[0m {all_draft_tokens.shape}\033[0m'")
            # print(f"\033[91m all_draft_tokens{all_draft_tokens}\033[0m'")
            # sys.exit()

            if verbose:
                print(f"\n  📊 Iteration {iteration} Summary:")
                print(f"     Positions confirmed this iteration: {len(new_confirmed)}")
                print(f"     Total positions confirmed: {len(confirmed_positions)}")
                print(f"     Forward passes this iteration: {verify_forwards}")
                print(f"     Total forward passes so far: {total_forward_passes}")
        
        if verbose:
            print(f"\nTotal forward passes: {total_forward_passes}")
            
            # Print final generated tokens for comparison
            print(f"\n{'='*60}")
            print(f"FINAL GENERATED TOKENS")
            print(f"{'='*60}")
            generated = x[:, prompt_length:]
            for i in range(generated.shape[1]):
                token_id = generated[0, i].item()
                print(f"Position {i}: token_id={token_id}")
            print(f"{'='*60}\n")
        
        return x, total_forward_passes


def generate_initial_full_draft(
    x, model, attention_mask, prompt_length, prompt_index,
    temperature, cfg_scale, mask_id, verbose=False, return_alternatives=False, num_alternatives=1
):
    """
    Generate initial draft tokens for ALL mask positions at once.
    
    Args:
        return_alternatives: If True, also return top-k alternatives for each position
        num_alternatives: Number of alternatives to return per position
    
    Returns:
        all_draft_tokens: Tensor of shape [batch_size, gen_length] with draft tokens
        all_confidences: Tensor of shape [batch_size, gen_length] with confidence scores
        forward_passes: Number of forward passes (always 1)
        alternatives: Optional dict with 'tokens' and 'confidences' for alternatives
    """
    # Forward pass
    if cfg_scale > 0.0:
        cfg_x = x.clone()
        cfg_x[prompt_index] = mask_id
        logits = model(x, attention_mask=attention_mask).logits
        cfg_logits = model(cfg_x, attention_mask=attention_mask).logits
        cfg_residual = logits - cfg_logits
        logits = (logits - cfg_residual) + (cfg_scale + 1) * cfg_residual
    else:
        logits = model(x, attention_mask=attention_mask).logits
    
    # Extract generation part
    gen_logits = logits[:, prompt_length:]
    
    # Generate tokens with noise
    logits_with_noise = add_gumbel_noise(gen_logits, temperature=temperature)
    draft_tokens = torch.argmax(logits_with_noise, dim=-1)
    
    # Calculate confidence
    probs = F.softmax(gen_logits, dim=-1)
    confidence = torch.gather(probs, dim=-1, index=draft_tokens.unsqueeze(-1)).squeeze(-1)
    
    # Get alternatives if requested
    if return_alternatives and num_alternatives > 1:
        # Get top-k tokens and their probabilities
        top_k_probs, top_k_tokens = torch.topk(probs, k=min(num_alternatives, probs.shape[-1]), dim=-1)
        return draft_tokens, confidence, 1, {'tokens': top_k_tokens, 'probs': top_k_probs}
    
    return draft_tokens, confidence, 1


def select_draft_candidates(
    mask_positions, all_draft_tokens, all_confidences,
    draft_length, prompt_length, block_length=None, verbose=False,
    alternatives_dict=None, num_alternatives=1
):
    """
    Select up to draft_length candidates from remaining mask positions.
    Respects block priority: earlier blocks have higher priority.
    
    Args:
        alternatives_dict: Optional dict with 'tokens' and 'probs' for alternatives
    
    Returns:
        List of DraftCandidate objects (up to draft_length)
    """
    candidates = []
    
    for pos in mask_positions:
        # Calculate relative position and block ID
        rel_pos = pos - prompt_length
        if block_length and block_length > 0:
            block_id = rel_pos // block_length
        else:
            block_id = 0
        
        candidate = DraftCandidate(
            position=pos,
            block_id=block_id,
            token=all_draft_tokens[0, rel_pos].item(),
            confidence=all_confidences[0, rel_pos].item()
        )
        
        # Add alternatives if available
        if alternatives_dict is not None:
            alt_tokens = alternatives_dict['tokens'][0, rel_pos]  # Shape: [num_alternatives]
            alt_probs = alternatives_dict['probs'][0, rel_pos]    # Shape: [num_alternatives]
            
            # Filter out the draft token from alternatives to avoid duplicates
            draft_token = all_draft_tokens[0, rel_pos].item()
            alternative_tokens = []
            alternative_confidences = []
            
            for alt_tok, alt_prob in zip(alt_tokens.tolist(), alt_probs.tolist()):
                if alt_tok != draft_token and len(alternative_tokens) < num_alternatives - 1:
                    alternative_tokens.append(alt_tok)
                    alternative_confidences.append(alt_prob)
            
            if alternative_tokens:
                candidate.alternative_tokens = alternative_tokens
                candidate.alternative_confidences = alternative_confidences
        
        candidates.append(candidate)
    
    if verbose and candidates:
        print(f"\n  === Draft Candidate Selection ===")
        print(f"  Total mask positions available: {len(candidates)}")
        
        # Show distribution by blocks
        block_counts = {}
        for c in candidates:
            block_counts[c.block_id] = block_counts.get(c.block_id, 0) + 1
        print(f"  Block distribution: {dict(sorted(block_counts.items()))}")
        
        # Show top confidence tokens before sorting
        top_5_by_conf = sorted(candidates, key=lambda c: -c.confidence)[:5]
        print(f"\n  Top 5 by confidence (before block priority):")
        for i, c in enumerate(top_5_by_conf, 1):
            print(f"    {i}. Pos {c.position} (block {c.block_id}): "
                  f"token={c.token}, conf={c.confidence:.4f}")
    
    # Sort by block priority then confidence
    candidates.sort(key=lambda c: (c.block_id, -c.confidence))
    
    # Take only draft_length candidates
    selected = candidates[:draft_length]
    
    if verbose and selected:
        print(f"\n  Final selection ({len(selected)} candidates after block priority):")
        for i, c in enumerate(selected):
            print(f"    [{i}] Pos {c.position} (block {c.block_id}): "
                  f"token={c.token}, conf={c.confidence:.4f}")
        
        # Analysis: show if block priority changed selection
        if len(candidates) > draft_length:
            pure_conf_selection = sorted(candidates, key=lambda c: -c.confidence)[:draft_length]
            pure_conf_positions = set(c.position for c in pure_conf_selection)
            actual_positions = set(c.position for c in selected)
            diff = pure_conf_positions - actual_positions
            if diff:
                print(f"\n  Note: Block priority changed selection from pure confidence")
                print(f"    Positions selected due to block priority: {actual_positions - pure_conf_positions}")
                print(f"    Positions skipped despite higher confidence: {diff}")
    
    return selected


def cascaded_verify_v3(
    x, model, attention_mask, candidates, confirmed_positions,
    all_draft_tokens, all_confidences, strategy,
    prompt_length, prompt_index, temperature, cfg_scale, mask_id,
    iteration, block_length=None, all_mask_positions=None, verbose=False,
    need_alternatives=False, num_alternatives=1
):
    """
    Perform cascaded verification for selected candidates.
    
    Returns:
        new_confirmed: List of newly confirmed positions
        new_draft: Updated draft if verification fails (dict with 'tokens' and 'confidences')
        forward_passes: Number of forward passes used
    """
    # Build verification tree
    tree = strategy.build_tree(
        candidates=candidates,
        confirmed_positions=confirmed_positions,
        x=x,
        prompt_length=prompt_length,
        mask_id=mask_id,
        verbose=verbose,
        all_mask_positions=all_mask_positions
    )
    
    # Traverse and verify
    longest_confirmed, confirmed_tokens, failed_info, forward_passes = traverse_and_verify_v3(
        tree=tree,
        model=model,
        attention_mask=attention_mask,
        prompt_length=prompt_length,
        prompt_index=prompt_index,
        temperature=temperature,
        cfg_scale=cfg_scale,
        mask_id=mask_id,
        candidates=candidates,
        all_draft_tokens=all_draft_tokens,
        block_length=block_length,
        verbose=verbose
    )
    
    # Check if we made progress
    newly_confirmed = [pos for pos in longest_confirmed if pos not in confirmed_positions]
    
    # Handle verification result - use model's predictions as new draft
    new_draft = None
    if failed_info:
        # Either verification failed or leaf node succeeded
        if verbose:
            if failed_info.get('is_leaf_success'):
                print(f"\n    All drafts verified successfully!")
                print(f"    Leaf node predicted additional token at pos {failed_info['failed_pos']}")
            else:
                print(f"\n    Verification stopped at pos {failed_info['failed_pos']}")
                print(f"    Using model's prediction (token={failed_info['predicted_token']}) "
                      f"instead of draft (token={failed_info['expected_token']})")
        
        # Extract new draft from the failed node's predictions
        new_logits = failed_info['new_logits']
        gen_logits = new_logits[prompt_length:]
        
        # Generate new draft tokens
        logits_with_noise = add_gumbel_noise(gen_logits.unsqueeze(0), temperature=temperature)
        new_tokens = torch.argmax(logits_with_noise, dim=-1)
        
        probs = F.softmax(gen_logits.unsqueeze(0), dim=-1)
        new_confidence = torch.gather(probs, dim=-1, index=new_tokens.unsqueeze(-1)).squeeze(-1)
        
        # Generate alternatives if needed
        if need_alternatives and num_alternatives > 1:
            # Get top-k tokens and their probabilities
            top_k_probs, top_k_tokens = torch.topk(probs, k=min(num_alternatives, probs.shape[-1]), dim=-1)
            new_draft = {
                'tokens': new_tokens,
                'confidences': new_confidence,
                'alternatives': {'tokens': top_k_tokens, 'probs': top_k_probs}
            }
        else:
            new_draft = {
                'tokens': new_tokens,
                'confidences': new_confidence
            }
    elif not newly_confirmed and candidates:
        # No verification succeeded at all (shouldn't happen with correct tree)
        if verbose:
            print(f"    Warning: No verification succeeded, using fallback")
        
        # Fallback: accept the highest confidence position
        best_candidate = max(candidates, key=lambda c: c.confidence)
        newly_confirmed = [best_candidate.position]
        
        if verbose:
            print(f"    Fallback: confirmed highest confidence position {best_candidate.position}")
    
    return newly_confirmed, confirmed_tokens, new_draft, forward_passes


def traverse_and_verify_v3(
    tree, model, attention_mask, prompt_length, prompt_index,
    temperature, cfg_scale, mask_id, candidates, all_draft_tokens, 
    block_length=None, verbose=False
):
    """
    Traverse the verification tree and verify draft tokens.
    
    For efficiency, collect all nodes and verify in a single batch,
    then traverse the tree based on verification results.
    
    Returns:
        longest_confirmed: List of confirmed positions
        failed_node_info: Info about the failed node (if any)
        forward_passes: Number of forward passes used
    """
    # Collect all nodes from the tree for batch processing
    all_nodes = []
    _collect_all_nodes(tree, all_nodes)
    
    if verbose:
        print(f"\n    Collected {len(all_nodes)} nodes for batch verification")
    
    # Single batch verification of all nodes
    node_results, forward_passes = verify_all_nodes_batch(
        all_nodes, model, attention_mask, prompt_length, prompt_index,
        temperature, cfg_scale, mask_id, all_draft_tokens, block_length, verbose
    )
    
    # Traverse tree based on verification results (node_results already uses ID as key)
    longest_path, confirmed_tokens, failed_info = find_longest_valid_path(
        tree, node_results, all_draft_tokens, prompt_length, block_length, verbose
    )
    
    return longest_path, confirmed_tokens, failed_info, forward_passes


def _collect_all_nodes(node, all_nodes):
    """Recursively collect all nodes from the tree."""
    # Include current node (including root)
    all_nodes.append(node)
    # Then collect all children recursively
    for child in node.children:
        _collect_all_nodes(child, all_nodes)


def verify_all_nodes_batch(
    nodes, model, attention_mask, prompt_length, prompt_index,
    temperature, cfg_scale, mask_id, all_draft_tokens, block_length, verbose
):
    """
    Perform forward pass for all nodes in a single batch.
    
    Returns:
        node_results: Dict mapping node ID to predictions and logits
        forward_passes: Number of forward passes (always 1)
    """
    if not nodes:
        return {}, 0
    
    if verbose:
        print(f"\n    === Batch Verification Results (All Nodes) ===")
        print(f"    Processing {len(nodes)} nodes in single forward pass")
        print(f"    Note: These are pre-computed results for all nodes. Actual traversal may stop early.")
    
    # Prepare batch - all nodes' sequences
    batch_x = torch.cat([node.sequence_tensor for node in nodes], dim=0)
    batch_attention_mask = attention_mask.repeat(len(nodes), 1) if attention_mask is not None else None
    
    if verbose and block_length and block_length > 0:
        print(f"\n    Input sequences for each node (showing relevant blocks):")
        for idx, node in enumerate(nodes):
            if node.verify_position is None:  # Root node
                print(f"      Node {idx} (root): Initial state")
                # Show first block for root
                block_start = prompt_length
                block_end = min(block_start + block_length, batch_x.shape[1])
                filled_tokens = []
                for pos in range(block_start, block_end):
                    if pos >= batch_x.shape[1]:
                        break
                    token_val = batch_x[idx, pos].item()
                    if token_val != mask_id:
                        if pos in node.confirmed_positions:
                            filled_tokens.append(f"{pos}={token_val}✓")
                        else:
                            filled_tokens.append(f"{pos}={token_val}")
                    else:
                        filled_tokens.append(f"{pos}=MASK")
                print(f"        Block 0 [{block_start}-{block_end}): {', '.join(filled_tokens)}")
            elif node.verify_position:
                # Calculate block_id for verify_position
                verify_block_id = (node.verify_position - prompt_length) // block_length
                
                # Show current block and potentially previous block if cross-block
                blocks_to_show = [verify_block_id]
                if verify_block_id > 0 and any(pos < prompt_length + verify_block_id * block_length 
                                               for pos in node.confirmed_positions 
                                               if pos >= prompt_length):
                    blocks_to_show.insert(0, verify_block_id - 1)
                
                print(f"      Node {idx} (verify pos {node.verify_position}):")
                
                for block_id in blocks_to_show:
                    block_start = prompt_length + block_id * block_length
                    block_end = min(block_start + block_length, batch_x.shape[1])
                    
                    # Collect token info for this block
                    filled_tokens = []
                    for pos in range(block_start, block_end):
                        if pos >= batch_x.shape[1]:
                            break
                        token_val = batch_x[idx, pos].item()
                        if token_val != mask_id:
                            # Check if this position is in confirmed list
                            if pos in node.confirmed_positions:
                                filled_tokens.append(f"{pos}={token_val}✓")
                            else:
                                filled_tokens.append(f"{pos}={token_val}")
                        else:
                            filled_tokens.append(f"{pos}=MASK")
                    
                    print(f"        Block {block_id} [{block_start}-{block_end}): {', '.join(filled_tokens)}")
            elif not node.mask_positions:  # Leaf node
                print(f"      Node {idx} (leaf): All draft tokens filled, ready for next prediction")
            elif node.verify_position is None:  # Root node
                print(f"      Node {idx} (root): Initial state sequence")
    
    # Single forward pass for all nodes
    if cfg_scale > 0.0:
        batch_prompt_index = prompt_index.repeat(len(nodes), 1)
        cfg_batch_x = batch_x.clone()
        cfg_batch_x[batch_prompt_index] = mask_id
        batch_logits = model(batch_x, attention_mask=batch_attention_mask).logits
        cfg_batch_logits = model(cfg_batch_x, attention_mask=batch_attention_mask).logits
        cfg_residual = batch_logits - cfg_batch_logits
        batch_logits = (batch_logits - cfg_residual) + (cfg_scale + 1) * cfg_residual
    else:
        batch_logits = model(batch_x, attention_mask=batch_attention_mask).logits
    
    # Get predictions and confidence
    batch_logits_with_noise = add_gumbel_noise(batch_logits, temperature=temperature)
    batch_predictions = torch.argmax(batch_logits_with_noise, dim=-1)
    batch_probs = F.softmax(batch_logits, dim=-1)
    
    # Process results for each node - simplified to just store predictions
    node_results = {}
    for i, node in enumerate(nodes):
        # Store predictions and probabilities for each node
        node_results[id(node)] = {
            'predictions': batch_predictions[i],
            'logits': batch_logits[i],
            'probs': batch_probs[i],
            'verify_position': node.verify_position,
            'expected_token': node.expected_token if hasattr(node, 'expected_token') else None,
            'mask_positions': node.mask_positions,
            'is_leaf': len(node.mask_positions) == 0,
            'is_root': node.verify_position is None
        }
    
    return node_results, 1  # Always 1 forward pass


def find_longest_valid_path(tree, node_results, all_draft_tokens, prompt_length, block_length, verbose):
    """
    Find the longest valid path in the tree using parent-centric logic.
    
    Args:
        tree: The root node
        node_results: Dict mapping node ID to predictions
        block_length: Block length for priority calculation
    
    Returns:
        longest_confirmed: List of confirmed positions
        confirmed_token_values: Dict mapping position to token value
        failed_info: Info about where verification failed (if any)
    """
    longest_confirmed = tree.confirmed_positions.copy()
    confirmed_token_values = {}  # New: store verified token values
    failed_info = None
    draft_accepted_count = 0
    
    if verbose:
        print(f"\n    === Actual Tree Traversal ===")
        print(f"    Starting from root with {len(tree.confirmed_positions)} already confirmed positions")
    
    current = tree
    
    while True:
        # Get current node's result
        if id(current) not in node_results:
            if verbose:
                print(f"    Warning: Current node not in results, stopping")
            break
        
        current_result = node_results[id(current)]
        
        # Check if this is a leaf node (no children)
        if not current.children:
            # Leaf node - all drafts verified, can predict one more token
            if verbose:
                print(f"\n    Reached leaf node (all drafts verified)")
                print(f"      Leaf node has {len(current_result['mask_positions'])} remaining mask positions")
                if current_result['mask_positions']:
                    print(f"      Remaining masks: {current_result['mask_positions'][:5]}")  # Show first 5
            
            # If there are still mask positions, predict the next one
            if current_result['mask_positions']:
                predictions = current_result['predictions']
                probs = current_result['probs']
                mask_positions = current_result['mask_positions']
                
                # Find best position for leaf node's extra prediction
                best_pos = None
                best_conf = -1
                best_token = None
                best_block_id = float('inf')
                
                for mask_pos in mask_positions:
                    rel_pos = mask_pos - prompt_length
                    mask_block_id = rel_pos // block_length if block_length and block_length > 0 else 0
                    pred_token = predictions[mask_pos].item()
                    conf = probs[mask_pos, pred_token].item()
                    
                    if mask_block_id < best_block_id or (mask_block_id == best_block_id and conf > best_conf):
                        best_pos = mask_pos
                        best_conf = conf
                        best_token = pred_token
                        best_block_id = mask_block_id
                
                if best_pos is not None:
                    longest_confirmed.append(best_pos)
                    confirmed_token_values[best_pos] = best_token  # Save leaf prediction token
                    if verbose:
                        print(f"      → Leaf node predicts extra token: pos {best_pos} (block {best_block_id})")
                        print(f"        with token={best_token}, conf={best_conf:.4f}")
                        print(f"      ✓ Accepting leaf prediction")
                    
                    # Set up info for new draft
                    failed_info = {
                        'node': current,
                        'failed_pos': best_pos,
                        'predicted_token': best_token,
                        'expected_token': all_draft_tokens[0, best_pos - prompt_length].item() if all_draft_tokens is not None else best_token,
                        'confidence': best_conf,
                        'new_predictions': predictions,
                        'new_logits': current_result['logits'],
                        'is_leaf_success': True  # Mark this as successful leaf prediction
                    }
            break
        
        # If no masks left, we're done
        if not current_result['mask_positions']:
            if verbose:
                print(f"\n    No more mask positions to process")
            break
        
        # Find what position current node wants most (considering block priority)
        predictions = current_result['predictions']
        probs = current_result['probs']
        mask_positions = current_result['mask_positions']
        
        # Find best position considering block priority
        best_pos = None
        best_conf = -1
        best_token = None
        best_block_id = float('inf')
        
        for mask_pos in mask_positions:
            # Get block ID
            rel_pos = mask_pos - prompt_length
            mask_block_id = rel_pos // block_length if block_length and block_length > 0 else 0
            
            # Get predicted token and confidence
            pred_token = predictions[mask_pos].item()
            conf = probs[mask_pos, pred_token].item()
            
            # Update best if this is better (lower block ID or higher conf in same block)
            if mask_block_id < best_block_id or (mask_block_id == best_block_id and conf > best_conf):
                best_pos = mask_pos
                best_conf = conf
                best_token = pred_token
                best_block_id = mask_block_id
        
        if verbose:
            node_type = "root" if current_result['is_root'] else f"node"
            print(f"\n    Current {node_type} wants position {best_pos} (block {best_block_id}) "
                  f"with token={best_token}, conf={best_conf:.4f}")
        
        # Check if any child matches this choice
        matching_child = None
        for child in current.children:
            if child.verify_position == best_pos:
                # Found a child that verifies our desired position
                child_result = node_results[id(child)]
                
                # Check if tokens match
                if child.expected_token == best_token:
                    matching_child = child
                    if verbose:
                        print(f"      ✓ Found matching child: verify pos {best_pos} with expected token {child.expected_token}")
                else:
                    if verbose:
                        print(f"      ✗ Child at pos {best_pos} has wrong token: "
                              f"expected {child.expected_token}, parent wants {best_token}")
                break
        
        if matching_child:
            # Continue to the matching child
            if best_pos not in longest_confirmed:
                longest_confirmed.append(best_pos)
                confirmed_token_values[best_pos] = matching_child.expected_token  # Save verified token
                draft_accepted_count += 1
                if verbose:
                    print(f"      → Accepting draft token at pos {best_pos}")
                    print(f"      → Moving to child node...")
            current = matching_child
        else:
            # No matching child - accept current node's choice and stop
            if best_pos not in longest_confirmed:
                longest_confirmed.append(best_pos)
                confirmed_token_values[best_pos] = best_token  # Save model's choice
                if verbose:
                    print(f"      → No matching child for pos {best_pos}")
                    print(f"      ✓ Accepting model's choice: pos {best_pos} with token {best_token}")
                    print(f"      → Stopping traversal")
                
                # Set up failed_info for draft update
                failed_info = {
                    'node': current,
                    'failed_pos': best_pos,
                    'predicted_token': best_token,
                    'expected_token': all_draft_tokens[0, best_pos - prompt_length].item() if all_draft_tokens is not None else best_token,
                    'confidence': best_conf,
                    'new_predictions': predictions,
                    'new_logits': current_result['logits']
                }
            break
    
    if verbose:
        print(f"\n    === Traversal Summary ===")
        print(f"    Draft tokens accepted (as expected): {draft_accepted_count}")
        total_new = len(longest_confirmed) - len(tree.confirmed_positions)
        model_override = total_new - draft_accepted_count
        print(f"    Model override tokens: {model_override}")
        print(f"    Total new tokens confirmed: {total_new}")
        print(f"    Total positions confirmed: {len(longest_confirmed)}")
        if failed_info:
            print(f"    Stopped at position: {failed_info['failed_pos']}")
            print(f"    Reason: Model predicted different token than draft")
    
    # Include draft_accepted_count in failed_info for caller's use
    if failed_info:
        failed_info['draft_accepted_count'] = draft_accepted_count
    
    return longest_confirmed, confirmed_token_values, failed_info


def _decode_remaining_standard_v3(
    x, model, attention_mask, mask_positions, 
    prompt_length, prompt_index, temperature, cfg_scale, mask_id,
    all_draft_tokens, all_confidences, block_length, confirmed_positions, verbose
):
    """
    Standard decode for remaining masks when count < draft_length.
    Decodes one token at a time, respecting block priority.
    
    Args:
        mask_positions: List of remaining mask positions (will be modified)
        confirmed_positions: List of already confirmed positions
        
    Returns:
        forward_passes: Number of forward passes used
    """
    forward_passes = 0
    
    while mask_positions:
        # Forward pass
        if cfg_scale > 0.0:
            cfg_x = x.clone()
            cfg_x[prompt_index] = mask_id
            logits = model(x, attention_mask=attention_mask).logits
            cfg_logits = model(cfg_x, attention_mask=attention_mask).logits
            cfg_residual = logits - cfg_logits
            logits = (logits - cfg_residual) + (cfg_scale + 1) * cfg_residual
        else:
            logits = model(x, attention_mask=attention_mask).logits
        
        forward_passes += 1
        
        # Generate tokens and confidence for generation part
        gen_logits = logits[:, prompt_length:]
        logits_with_noise = add_gumbel_noise(gen_logits, temperature=temperature)
        predictions = torch.argmax(logits_with_noise, dim=-1)
        
        # Calculate confidence
        probs = F.softmax(gen_logits, dim=-1)
        confidence = torch.gather(probs, dim=-1, index=predictions.unsqueeze(-1)).squeeze(-1)
        
        # Find best position respecting block priority
        best_pos = None
        best_token = None
        best_conf = -1
        best_block_id = float('inf')
        
        for mask_pos in mask_positions:
            rel_pos = mask_pos - prompt_length
            mask_block_id = rel_pos // block_length if block_length and block_length > 0 else 0
            
            pred_token = predictions[0, rel_pos].item()
            conf = confidence[0, rel_pos].item()
            
            # Update best if this is better (lower block ID or higher conf in same block)
            if mask_block_id < best_block_id or (mask_block_id == best_block_id and conf > best_conf):
                best_pos = mask_pos
                best_token = pred_token
                best_conf = conf
                best_block_id = mask_block_id
        
        if best_pos is not None:
            # Accept the token
            x[:, best_pos] = best_token
            mask_positions.remove(best_pos)
            confirmed_positions.append(best_pos)
            
            if verbose:
                print(f"    Standard decode: Accepted pos {best_pos} (block {best_block_id}) "
                      f"with token={best_token}, conf={best_conf:.4f}")
                print(f"    Remaining masks: {len(mask_positions)}")
                print(f"    Forward passes in standard decode: {forward_passes}")
        else:
            # Should not happen
            if verbose:
                print(f"    Warning: Could not find best position, breaking")
            break
    
    if verbose:
        print(f"    Standard decode completed with {forward_passes} forward passes")
    
    return forward_passes


def ssd_without_cache(
    input_ids,
    attention_mask,
    model,
    gen_length=128,
    block_length=None,
    temperature=0.0,
    cfg_scale=0.0,
    mask_id=126336,
    draft_length=4,
    tree_strategy='greedy',
    num_alternatives=1,
    verbose=False,
    **kwargs
):
    """
    Entry point for SCSD v3 generation.
    
    Args:
        draft_length: Fixed draft length (default 4)
        tree_strategy: Tree building strategy ('greedy', 'greedy_with_alternatives')
        num_alternatives: Number of alternatives per position (for greedy_with_alternatives)
        block_length: Block length for block ID calculation (optional)
    """
    return _generate_scsd_v3(
        input_ids=input_ids,
        attention_mask=attention_mask,
        model=model,
        gen_length=gen_length,
        block_length=block_length,
        temperature=temperature,
        cfg_scale=cfg_scale,
        mask_id=mask_id,
        draft_length=draft_length,
        tree_strategy=tree_strategy,
        num_alternatives=num_alternatives,
        verbose=verbose,
        **kwargs
    )