"""
SCSD v3: Cross-Block Speculative Decoding with Fixed Draft Length
Key improvements over v2:
1. Fixed draft length k (default 4) instead of per-block processing
2. Cross-block draft selection when current block has < k remaining tokens
3. Hierarchical verification tree respecting block priorities
4. Extensible design for future multi-candidate support
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer

class Colors:
    GREEN = '\033[92m'
    BLUE = '\033[94m'
    YELLOW = '\033[93m'
    CYAN = '\033[96m'
    MAGENTA = '\033[95m'
    RED = '\033[91m'
    RESET = '\033[0m'
    BOLD = '\033[1m'

@dataclass
class DraftCandidate:
    """
    Represents a candidate token for draft generation.
    Designed for future extension to support multiple token candidates per position.
    """
    position: int  # Absolute position in sequence
    block_id: int  # Which block this position belongs to
    token: int  # The proposed token (currently only top-1)
    confidence: float  # Confidence score
    # Future extension: alternative tokens and their confidences
    alternative_tokens: List[int] = field(default_factory=list)
    alternative_confidences: List[float] = field(default_factory=list)
    

@dataclass
class VerificationNode:
    """
    Node in the verification tree.
    Enhanced to support cross-block verification and future multi-candidate paths.
    """
    confirmed_positions: List[int]  # Absolute positions that are confirmed
    confirmed_blocks: Dict[int, List[int]] = field(default_factory=dict)  # {block_id: [positions]}
    children: List['VerificationNode'] = field(default_factory=list)
    verify_position: Optional[int] = None  # Position to verify at this node
    expected_token: Optional[int] = None  # Expected token at verify_position
    parent: Optional['VerificationNode'] = None
    sequence_tensor: Optional[torch.Tensor] = None  # Pre-filled sequence
    mask_positions: List[int] = field(default_factory=list)  # Positions still masked
    draft_tokens_dict: Dict[int, int] = field(default_factory=dict)  # Position -> token mapping
    
    # For future extensions
    alternative_tokens: Dict[int, List[int]] = field(default_factory=dict)  # {position: [candidates]}
    priority_score: float = 0.0  # For dynamic traversal ordering
    path_probability: float = 1.0  # Joint probability of the path
    

class TreeBuildStrategy(ABC):
    """
    Abstract base class for tree building strategies.
    Allows for different approaches to constructing verification trees.
    """
    @abstractmethod
    def build_tree(
        self, 
        candidates: List[DraftCandidate],
        confirmed_positions: List[int],
        x: torch.Tensor,
        prompt_length: int,
        mask_id: int,
        verbose: bool = False,
        all_mask_positions: List[int] = None
    ) -> VerificationNode:
        """Build verification tree from draft candidates."""
        pass
    

class GreedyStrategy(TreeBuildStrategy):
    """
    Greedy strategy: Build linear chain using highest confidence tokens.
    Single-branch tree: strictly follows block priority and confidence ordering.
    """
    def build_tree(
        self,
        candidates: List[DraftCandidate],
        confirmed_positions: List[int],
        x: torch.Tensor,
        prompt_length: int,
        mask_id: int,
        verbose: bool = False,
        all_mask_positions: List[int] = None
    ) -> VerificationNode:
        """
        Build linear verification chain with block priority and confidence ordering.
        
        Creates a single path from root to leaf, where each node represents
        accepting one more draft token in the priority order.
        """
        # Create root node - represents the starting state with NO new confirmations
        # Use all_mask_positions if provided, otherwise extract from x
        if all_mask_positions is None:
            # Extract all mask positions from x
            all_masks = (x == mask_id).nonzero(as_tuple=True)[1].tolist()
            all_mask_positions = [p for p in all_masks if p >= prompt_length]
        
        root = VerificationNode(
            confirmed_positions=confirmed_positions.copy(),  # Only previously confirmed
            mask_positions=all_mask_positions.copy()  # ALL remaining mask positions
        )
        
        # Build draft tokens dictionary
        draft_dict = {c.position: c.token for c in candidates}
        root.draft_tokens_dict = draft_dict
        
        # Build base sequence for root - should NOT include any draft tokens yet
        root_seq = x.clone()  # x already has confirmed tokens filled
        root.sequence_tensor = root_seq
        
        if len(candidates) == 0:
            return root
        
        # Sort candidates by block priority then confidence
        sorted_candidates = sorted(
            candidates,
            key=lambda c: (c.block_id, -c.confidence)
        )
        
        if verbose:
            print(f"\n    === Building Verification Tree (Linear Chain) ===")
            print(f"    Current state: {len(confirmed_positions)} positions already confirmed")
            print(f"    Draft candidates to verify (sorted by block priority):")
            for i, c in enumerate(sorted_candidates):
                print(f"      [{i}] Pos {c.position} (block {c.block_id}): "
                      f"token={c.token}, conf={c.confidence:.4f}")
            
            # Show top confidence token across all candidates
            top_conf_candidate = max(candidates, key=lambda c: c.confidence)
            if top_conf_candidate != sorted_candidates[0]:
                print(f"\n    Note: Top confidence token differs from first priority:")
                print(f"      Highest confidence: Pos {top_conf_candidate.position} "
                      f"(block {top_conf_candidate.block_id}), conf={top_conf_candidate.confidence:.4f}")
                print(f"      First by priority: Pos {sorted_candidates[0].position} "
                      f"(block {sorted_candidates[0].block_id}), conf={sorted_candidates[0].confidence:.4f}")
            
            print(f"\n    Verification strategy (Single-Branch):")
            print(f"    - Linear chain: each node has at most one child")
            print(f"    - Strictly follows priority order (no alternatives)")
            print(f"    - Verification stops at first mismatch")
        
        # Build linear chain
        self._build_linear_chain(root, sorted_candidates, x, mask_id, draft_dict)
        
        if verbose:
            print(f"\n    Tree Statistics:")
            print(f"      Total nodes: {self._count_nodes(root)} (linear chain)")
            print(f"      Chain length: {self._get_tree_depth(root)}")
            print(f"      Structure: Single path from root to leaf")
            print(f"\n    Tree Structure Visualization:")
            self._print_tree(root, indent="      ", verbose=verbose)
        
        return root
    
    def _build_linear_chain(
        self,
        parent: VerificationNode,
        sorted_candidates: List[DraftCandidate],
        x: torch.Tensor,
        mask_id: int,
        draft_dict: Dict[int, int]
    ):
        """Build a linear chain of nodes.
        
        Each node has exactly one child, following the strict priority order.
        No alternative branches are explored.
        """
        current_node = parent
        accumulated_confirmed = parent.confirmed_positions.copy()
        
        # Get all mask positions from parent (not just the candidates)
        all_mask_positions = parent.mask_positions.copy()
        
        # Build single chain following the priority order
        for i, candidate in enumerate(sorted_candidates):
            # Add this candidate to confirmed positions
            accumulated_confirmed.append(candidate.position)
            
            # Remove this position from all masks
            if candidate.position in all_mask_positions:
                all_mask_positions.remove(candidate.position)
            
            # Remaining draft positions after this one (for tree structure)
            remaining_draft_positions = [c.position for c in sorted_candidates[i+1:]]
            
            # Create the single child node
            child = VerificationNode(
                confirmed_positions=accumulated_confirmed.copy(),
                verify_position=candidate.position,
                expected_token=candidate.token,
                parent=current_node,
                mask_positions=all_mask_positions.copy(),  # All remaining masks, not just draft ones
                draft_tokens_dict=draft_dict.copy()
            )
            
            # Track confirmed blocks
            child.confirmed_blocks = current_node.confirmed_blocks.copy() if current_node.confirmed_blocks else {}
            if candidate.block_id not in child.confirmed_blocks:
                child.confirmed_blocks[candidate.block_id] = []
            child.confirmed_blocks[candidate.block_id].append(candidate.position)
            
            # Build sequence with all accumulated candidates filled in
            child_seq = x.clone()
            for pos in accumulated_confirmed:
                if pos in draft_dict:
                    child_seq[:, pos] = draft_dict[pos]
            child.sequence_tensor = child_seq
            
            # Add as the only child
            current_node.children.append(child)
            
            # Move to the child for next iteration
            current_node = child
    
    def _count_nodes(self, node: VerificationNode) -> int:
        """Count total nodes in tree."""
        count = 1
        for child in node.children:
            count += self._count_nodes(child)
        return count
    
    def _get_tree_depth(self, node: VerificationNode, depth: int = 0) -> int:
        """Get maximum depth of tree."""
        if not node.children:
            return depth
        return max(self._get_tree_depth(child, depth + 1) for child in node.children)
    
    def _get_root(self, node: VerificationNode) -> VerificationNode:
        """Get root node of tree."""
        while node.parent is not None:
            node = node.parent
        return node
    
    def _print_tree(self, node: VerificationNode, indent: str = "", is_last: bool = True, 
                    is_root: bool = True, verbose: bool = False, depth: int = 0):
        """Print tree structure in a visual format."""
        if is_root:
            num_draft = len(node.mask_positions)
            print(f"{indent}[ROOT] State: {len(node.confirmed_positions)} confirmed (from previous iterations)")
            print(f"{indent}       To verify: {num_draft} draft positions")
            if verbose and num_draft > 0 and num_draft <= 5:
                print(f"{indent}       Draft positions: {node.mask_positions}")
        else:
            # Print branch
            branch = "└── " if is_last else "├── "
            if node.verify_position is not None:
                # Show cumulative draft tokens accepted
                root_confirmed = len(self._get_root(node).confirmed_positions)
                draft_accepted = len(node.confirmed_positions) - root_confirmed
                
                print(f"{indent}{branch}[Level {depth}: +{draft_accepted} draft] "
                      f"Accept pos {node.verify_position} = {node.expected_token}")
                
                if verbose:
                    # Show remaining masks
                    remaining = len(node.mask_positions)
                    if remaining > 0:
                        print(f"{indent}{'    ' if is_last else '│   '}"
                              f"  → {remaining} positions still to verify")
        
        # Print children
        for i, child in enumerate(node.children):
            is_last_child = (i == len(node.children) - 1)
            new_indent = indent + ("    " if is_last else "│   ") if not is_root else indent
            self._print_tree(child, new_indent, is_last_child, False, verbose, depth + 1)


def add_gumbel_noise(logits, temperature):
    """Add Gumbel noise for sampling."""
    if temperature == 0:
        return logits.exp()
    noise = torch.rand_like(logits)
    gumbel_noise = (-torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def repeat_kv_cache(past_key_values, times):
    """
    Duplicate the key-value cache for multiple times.

    Args:
        past_key_values (tuple): The original key-value cache.
        times (int): The number of times to duplicate the cache.
    """
    return list(
        (k.repeat(times,1,1,1), v.repeat(times,1,1,1))
        for (k, v) in past_key_values
    )


def _generate_scsd_v3(
    input_ids,
    attention_mask,
    model,
    gen_length=128,
    block_length=None,  # Block length for block ID calculation
    temperature=0.0,
    cfg_scale=0.0,
    mask_id=126336,
    draft_length=4,  # Fixed draft length k
    tree_strategy='greedy',  # Tree building strategy
    verbose=False,
    refresh_interval=100,
    **kwargs  # For compatibility
):
    """
    SCSD v3 implementation with cross-block draft generation.
    
    Key algorithm:
    1. Generate initial draft for ALL positions (one-time)
    2. Iteratively select k positions for verification
    3. Build verification tree and verify
    4. Update draft based on verification results
    5. Repeat until all positions are confirmed
    """
    with torch.no_grad():
        refresh_interval = refresh_interval
        batch_size, prompt_length = input_ids.shape
        x = torch.full(
            (batch_size, prompt_length + gen_length),
            mask_id,
            dtype=torch.long,
            device=model.device,
        )
        x[:, :prompt_length] = input_ids
        prompt_index = x != mask_id
        
        # Initialize strategy
        if tree_strategy == 'greedy':
            strategy = GreedyStrategy()
        else:
            raise ValueError(f"Unknown tree strategy: {tree_strategy}")
        
        total_forward_passes = 0
        iteration = 0
        confirmed_positions = []  # Global confirmed positions

        #initial cache 
        past_key_values = None
        new_past_key_values = None

        # Generate initial draft for ALL positions
        if verbose:
            print(f"Generating initial draft for all {gen_length} positions...")
        
        all_draft_tokens, all_confidences, initial_forwards, past_key_values= generate_initial_full_draft(
            x=x,
            model=model,
            attention_mask=attention_mask,
            prompt_length=prompt_length,
            prompt_index=prompt_index,
            temperature=temperature,
            cfg_scale=cfg_scale,
            mask_id=mask_id,
            verbose=verbose
        )
        total_forward_passes += initial_forwards
        
        if verbose:
            print(f"Initial draft generated with 1 forward pass")
        

        # Accept the first token with highest confidence (respecting block priority)
        mask_positions = (x == mask_id).nonzero(as_tuple=True)[1].tolist()
        mask_positions = [p for p in mask_positions if p >= prompt_length]

        # init current_start_idx and current_end_idx
        block_forward = False
        cache_refresh_epoch = 0
        current_block_id = 0
        current_start_idx = prompt_length
        if block_length > 2*(draft_length+1):
            current_end_idx = current_start_idx + block_length
        else:
            current_end_idx = current_start_idx + 2*block_length
        if mask_positions:
            # Create candidates for all positions
            first_candidates = []
            for pos in mask_positions:
                rel_pos = pos - prompt_length
                block_id = rel_pos // block_length if block_length and block_length > 0 else 0
                first_candidates.append(DraftCandidate(
                    position=pos,
                    block_id=block_id,
                    token=int(all_draft_tokens[0, rel_pos].item()),
                    confidence=all_confidences[0, rel_pos].item()
                ))
            
            # Sort by block priority then confidence
            first_candidates.sort(key=lambda c: (c.block_id, -c.confidence))
            
            # Accept the first one
            first_token = first_candidates[0]
            x[:, first_token.position] = first_token.token
            confirmed_positions.append(first_token.position)
            last_confirmed_position = first_token.position
            if verbose:
                print(f"\n✨ Initial token confirmed:")
                print(f"    Pos {first_token.position} (block {first_token.block_id}): "
                      f"token={first_token.token}, conf={first_token.confidence:.4f}")
        

        # Main iteration loop
        while True:
            # Count remaining masks
            mask_positions = (x == mask_id).nonzero(as_tuple=True)[1].tolist()
            mask_positions = [p for p in mask_positions if p >= prompt_length]
            # mask position idx, list
            
            if not mask_positions:
                if verbose:
                    print(f"\n✅ Generation complete after {iteration} iterations!")
                    print(f"   Total forward passes: {total_forward_passes}")
                break

            iteration += 1

            if verbose:
                print(f"\n{'='*60}")
                print(f"ITERATION {iteration}: {len(mask_positions)} masks remaining")
                print(f"ALL_DRAFT_TOKENS LENS:{all_draft_tokens.shape[1]}")
                print(f"Candidates from {current_start_idx} to {current_end_idx}")
                print(f"{'='*60}")

            # Select up to draft_length positions for this iteration
            candidates = select_draft_candidates(
                mask_positions=mask_positions,
                all_draft_tokens=all_draft_tokens,
                all_confidences=all_confidences,
                draft_length=draft_length,
                prompt_length=prompt_length,
                block_length=block_length,
                verbose=verbose,
                current_start_idx=current_start_idx,
                current_end_idx=current_end_idx,
                gen_length = gen_length
            )
            
            token_unmask = gen_length - len(mask_positions)
            new_cache_refresh_epoch = token_unmask // refresh_interval
            if new_cache_refresh_epoch != cache_refresh_epoch:
                cache_refresh_epoch = new_cache_refresh_epoch
                block_forward = True
            if len(mask_positions) < draft_length:
                new_start_idx = mask_positions[0]
            else:
                new_start_idx = min(last_confirmed_position,mask_positions[0])
            new_candidates_length = min(2*(draft_length+1),len(mask_positions))
            new_end_idx = mask_positions[new_candidates_length-1]
            new_end_idx += 1
            
            if verbose:
                print(f"{'='*60}")
                #print(f"Current Block:{current_block_id} from {current_block_start_idx} to {current_block_end_idx}")
                #print(f"Next Block:{next_block_id} from {next_block_start_idx} to {next_block_end_idx}")
                #print(f" | current block {current_block_id} remain:{current_block_mask_remain} | next block {next_block_id} remain {next_block_mask_remain}")
                print(f"New candidates from {new_start_idx} to {new_end_idx}")
                print("Cache Refreshed!")
                print(f"{'='*60}")

            if not candidates:
                break
            
            # When remaining masks < draft_length, use standard decode
            if len(mask_positions) < draft_length:
                if verbose:
                    print(f"\n  Switching to standard decode ({len(mask_positions)} masks < draft_length {draft_length})")

                new_end_idx = mask_positions[-1]
                std_forwards = _decode_remaining_standard_v3(
                    x=x,
                    model=model,
                    attention_mask=attention_mask,
                    mask_positions=mask_positions.copy(),  # Pass a copy since we'll modify it
                    prompt_length=prompt_length,
                    prompt_index=prompt_index,
                    temperature=temperature,
                    cfg_scale=cfg_scale,
                    mask_id=mask_id,
                    all_draft_tokens=all_draft_tokens,
                    all_confidences=all_confidences,
                    block_length=block_length,
                    confirmed_positions=confirmed_positions,
                    verbose=verbose,
                    past_key_values=past_key_values,
                    current_start_idx=new_start_idx,
                    current_end_idx=new_end_idx
                )
                total_forward_passes += std_forwards
                
                # Standard decode completes all remaining masks
                break
            
            # Build verification tree and verify
            new_confirmed, confirmed_tokens, new_draft, verify_forwards, new_past_key_values = cascaded_verify_v3(
                x=x,
                model=model,
                attention_mask=attention_mask,
                candidates=candidates,
                confirmed_positions=confirmed_positions,
                all_draft_tokens=all_draft_tokens,
                all_confidences=all_confidences,
                strategy=strategy,
                prompt_length=prompt_length,
                prompt_index=prompt_index,
                temperature=temperature,
                cfg_scale=cfg_scale,
                mask_id=mask_id,
                iteration=iteration,
                block_length=block_length,
                all_mask_positions=mask_positions,  # Pass all remaining masks
                verbose=verbose,
                past_key_values=past_key_values,
                replace_start_idx=new_start_idx,
                replace_end_idx=new_end_idx,
                current_start_idx=current_start_idx,
                current_end_idx=current_end_idx,
                draft_length=draft_length,
                block_forward = block_forward
            )
            total_forward_passes += verify_forwards

            # assert new_past_key_values is not None

            # Accept newly confirmed tokens - use verified token values
            for pos in new_confirmed:
                if pos not in confirmed_positions:
                    # Use the verified token value, NOT the draft value
                    token_value = confirmed_tokens[pos]
                    x[:, pos] = token_value
                    confirmed_positions.append(pos)
                    if verbose:
                        rel_pos = pos - prompt_length
                        block_id = rel_pos // block_length if block_length else 0
                        print(f"    ✓ Confirmed: Pos {pos} (block {block_id}) = token {token_value}")
            last_confirmed_position = new_confirmed[-1] if new_confirmed else None

            if new_past_key_values is not None:
                past_key_values = new_past_key_values
            else:
                mask_positions = (x == mask_id).nonzero(as_tuple=True)[1].tolist()
                mask_positions = [p for p in mask_positions if p >= prompt_length]
                assert mask_positions is not None
            # Update draft if new predictions were made (after confirming tokens)
            if new_draft is not None:
                all_draft_tokens = new_draft['tokens']
                all_confidences = new_draft['confidences']
            if verbose:
                print(f"\n  📊 Iteration {iteration} Summary:")
                print(f"     Positions confirmed this iteration: {len(new_confirmed)}")
                print(f"     Total positions confirmed: {len(confirmed_positions)}")
                print(f"     Forward passes this iteration: {verify_forwards}")
                print(f"     Total forward passes so far: {total_forward_passes}")

            current_start_idx = new_start_idx
            current_end_idx = new_end_idx
            block_forward = False

        if verbose:
            print(f"\nTotal forward passes: {total_forward_passes}")
            
            # Print final generated tokens for comparison
            print(f"\n{'='*60}")
            print(f"FINAL GENERATED TOKENS")
            print(f"{'='*60}")
            generated = x[:, prompt_length:]
            for i in range(generated.shape[1]):
                token_id = generated[0, i].item()
                print(f"Position {i}: token_id={token_id}")
            print(f"{'='*60}\n")
        
        return x, total_forward_passes

def generate_initial_full_draft(
    x, model, attention_mask, prompt_length, prompt_index,
    temperature, cfg_scale, mask_id, verbose=False
):
    """
    Generate initial draft tokens for ALL mask positions at once.
    
    Returns:
        all_draft_tokens: Tensor of shape [batch_size, gen_length] with draft tokens
        all_confidences: Tensor of shape [batch_size, gen_length] with confidence scores
        forward_passes: Number of forward passes (always 1)
    """
    # Forward pass
    output = model(x,use_cache=True)
    logits = output.logits
    past_key_values = output.past_key_values
    # Extract generation part
    gen_logits = logits[:, prompt_length:]
    
    # Generate tokens with noise
    logits_with_noise = add_gumbel_noise(gen_logits, temperature=temperature)
    draft_tokens = torch.argmax(logits_with_noise, dim=-1)
    
    # Calculate confidence
    probs = F.softmax(gen_logits, dim=-1)
    confidence = torch.gather(probs, dim=-1, index=draft_tokens.unsqueeze(-1)).squeeze(-1)
    
    return draft_tokens, confidence, 1, past_key_values


def select_draft_candidates(
    mask_positions, all_draft_tokens, all_confidences,
    draft_length, prompt_length, block_length=None, verbose=False,
    current_start_idx=None,current_end_idx=None,
    gen_length=None,
):
    """
    Select up to draft_length candidates from remaining mask positions.
    Respects block priority: earlier blocks have higher priority.
    
    Returns:
        List of DraftCandidate objects (up to draft_length)
    """
    candidates = []
    
    for pos in mask_positions:
        # Calculate relative position and block ID
        if pos >= current_end_idx:
            break
        rel_pos = pos - current_start_idx 
        if block_length and block_length > 0:
            block_id = (pos - prompt_length) // block_length
        else:
            block_id = 0
        
        candidate = DraftCandidate(
            position=pos,
            block_id=block_id,
            token=all_draft_tokens[0, rel_pos].item(),
            confidence=all_confidences[0, rel_pos].item()
        )
        candidates.append(candidate)
    

    if verbose and candidates:
        print(f"\n  === Draft Candidate Selection ===")
        print(f"  Total mask positions available: {len(candidates)}")
        
        # Show distribution by blocks
        block_counts = {}
        for c in candidates:
            block_counts[c.block_id] = block_counts.get(c.block_id, 0) + 1
        print(f"  Block distribution: {dict(sorted(block_counts.items()))}")
        
        # Show top confidence tokens before sorting
        top_5_by_conf = sorted(candidates, key=lambda c: -c.confidence)[:5]
        print(f"\n  Top 5 by confidence (before block priority):")
        for i, c in enumerate(top_5_by_conf, 1):
            print(f"    {i}. Pos {c.position} (block {c.block_id}): "
                  f"token={c.token}, conf={c.confidence:.4f}")
    
    # Sort by block priority then confidence
    candidates.sort(key=lambda c: (c.block_id, -c.confidence))
    
    # Take only draft_length candidates
    selected = candidates[:draft_length]
    
    if verbose and selected:
        print(f"\n  Final selection ({len(selected)} candidates after block priority):")
        for i, c in enumerate(selected):
            print(f"    [{i}] Pos {c.position} (block {c.block_id}): "
                  f"token={c.token}, conf={c.confidence:.4f}")
        
        # Analysis: show if block priority changed selection
        if len(candidates) > draft_length:
            pure_conf_selection = sorted(candidates, key=lambda c: -c.confidence)[:draft_length]
            pure_conf_positions = set(c.position for c in pure_conf_selection)
            actual_positions = set(c.position for c in selected)
            diff = pure_conf_positions - actual_positions
            if diff:
                print(f"\n  Note: Block priority changed selection from pure confidence")
                print(f"    Positions selected due to block priority: {actual_positions - pure_conf_positions}")
                print(f"    Positions skipped despite higher confidence: {diff}")
    
    return selected


def cascaded_verify_v3(
    x, model, attention_mask, candidates, confirmed_positions,
    all_draft_tokens, all_confidences, strategy,
    prompt_length, prompt_index, temperature, cfg_scale, mask_id,
    iteration, block_length=None, all_mask_positions=None, verbose=False, 
    past_key_values=None,replace_start_idx=None,replace_end_idx=None,current_start_idx=None,current_end_idx=None,draft_length=None,
    block_forward=None
):
    """
    Perform cascaded verification for selected candidates.
    
    Returns:
        new_confirmed: List of newly confirmed positions
        new_draft: Updated draft if verification fails (dict with 'tokens' and 'confidences')
        forward_passes: Number of forward passes used
    """
    # Build verification tree
    tree = strategy.build_tree(
        candidates=candidates,
        confirmed_positions=confirmed_positions,
        x=x,
        prompt_length=prompt_length,
        mask_id=mask_id,
        verbose=verbose,
        all_mask_positions=all_mask_positions
    )
    
    # Traverse and verify
    longest_confirmed, confirmed_tokens, failed_info, forward_passes = traverse_and_verify_v3(
        tree=tree,
        model=model,
        attention_mask=attention_mask,
        prompt_length=prompt_length,
        prompt_index=prompt_index,
        temperature=temperature,
        cfg_scale=cfg_scale,
        mask_id=mask_id,
        candidates=candidates,
        all_draft_tokens=all_draft_tokens,
        block_length=block_length,
        verbose=verbose,
        past_key_values=past_key_values,
        replace_start_idx=replace_start_idx,
        replace_end_idx=replace_end_idx,
        current_start_idx=current_start_idx,
        current_end_idx=current_end_idx,
        draft_length=draft_length,
        block_forward=block_forward
    )
    
    # Check if we made progress
    newly_confirmed = [pos for pos in longest_confirmed if pos not in confirmed_positions]
    
    # Handle verification result - use model's predictions as new draft
    new_draft = None
    new_past_key_values = None
    if failed_info:
        # Either verification failed or leaf node succeeded
        if verbose:
            if failed_info.get('is_leaf_success'):
                print(f"\n    All drafts verified successfully!")
                print(f"    Leaf node predicted additional token at pos {failed_info['failed_pos']}")
            else:
                print(f"\n    Verification stopped at pos {failed_info['failed_pos']}")
                print(f"    Using model's prediction (token={failed_info['predicted_token']}) "
                      f"instead of draft (token={failed_info['expected_token']})")
        
        # Extract new draft from the failed node's predictions
        new_logits = failed_info['new_logits']
        gen_logits = new_logits
        
        # Generate new draft tokens
        logits_with_noise = add_gumbel_noise(gen_logits.unsqueeze(0), temperature=temperature)
        new_tokens = torch.argmax(logits_with_noise, dim=-1)
        
        probs = F.softmax(gen_logits.unsqueeze(0), dim=-1)
        new_confidence = torch.gather(probs, dim=-1, index=new_tokens.unsqueeze(-1)).squeeze(-1)
        
        new_draft = {
            'tokens': new_tokens,
            'confidences': new_confidence
        }
        new_past_key_values = failed_info['new_past_key_values']

    elif not newly_confirmed and candidates:
        # No verification succeeded at all (shouldn't happen with correct tree)
        if verbose:
            print(f"    Warning: No verification succeeded, using fallback")
        
        # Fallback: accept the highest confidence position
        best_candidate = max(candidates, key=lambda c: c.confidence)
        newly_confirmed = [best_candidate.position]
        if verbose:
            print(f"    Fallback: confirmed highest confidence position {best_candidate.position}")
        new_draft = all_draft_tokens

    return newly_confirmed, confirmed_tokens, new_draft, forward_passes, new_past_key_values


def traverse_and_verify_v3(
    tree, model, attention_mask, prompt_length, prompt_index,
    temperature, cfg_scale, mask_id, candidates, all_draft_tokens, 
    block_length=None, verbose=False, past_key_values=None, 
    replace_start_idx=None,replace_end_idx=None,current_start_idx=None,current_end_idx=None,draft_length=None,
    block_forward=None
):
    """
    Traverse the verification tree and verify draft tokens.
    
    For efficiency, collect all nodes and verify in a single batch,
    then traverse the tree based on verification results.
    
    Returns:
        longest_confirmed: List of confirmed positions
        failed_node_info: Info about the failed node (if any)
        forward_passes: Number of forward passes used
    """
    # Collect all nodes from the tree for batch processing
    all_nodes = []
    _collect_all_nodes(tree, all_nodes)
    if block_forward:
        all_nodes = all_nodes[:1]
        all_nodes[-1].children = []

    if verbose:
        print(f"\n    Collected {len(all_nodes)} nodes for batch verification")
    
    # Single batch verification of all nodes
    node_results, forward_passes = verify_all_nodes_batch(
        all_nodes, model, attention_mask, prompt_length, prompt_index,
        temperature, cfg_scale, mask_id, all_draft_tokens, block_length, 
        verbose, past_key_values,replace_start_idx, replace_end_idx,draft_length,block_forward=block_forward
    )
    
    # Traverse tree based on verification results (node_results already uses ID as key)
    longest_path, confirmed_tokens, failed_info = find_longest_valid_path(
        tree, node_results, all_draft_tokens, prompt_length, block_length, verbose, replace_start_idx, replace_end_idx,current_start_idx,current_end_idx
    )
    
    return longest_path, confirmed_tokens, failed_info, forward_passes


def _collect_all_nodes(node, all_nodes):
    """Recursively collect all nodes from the tree."""
    # Include current node (including root)
    all_nodes.append(node)
    # Then collect all children recursively
    for child in node.children:
        _collect_all_nodes(child, all_nodes)

def verify_all_nodes_batch(
    nodes, model, attention_mask, prompt_length, prompt_index,
    temperature, cfg_scale, mask_id, all_draft_tokens, 
    block_length, verbose, past_key_values, replace_start_idx,replace_end_idx, draft_length,block_forward
):
    """
    Perform forward pass for all nodes in a single batch.
    
    Returns:
        node_results: Dict mapping node ID to predictions and logits
        forward_passes: Number of forward passes (always 1)
        past_key_values: Updated past key values (if any)
    """
    if not nodes:
        return {}, 0
    
    if verbose:
        print(f"\n    === Batch Verification Results (All Nodes) ===")
        print(f"    Processing {len(nodes)} nodes in single forward pass")
        print(f"    Note: These are pre-computed results for all nodes. Actual traversal may stop early.")
    
    # Prepare batch - all nodes' sequences
    batch_size = len(nodes)
    # current_end_idx = max(nodes[0].draft_tokens_dict.keys())+1
    if not block_forward:
        replace_position = torch.zeros((1,nodes[-1].sequence_tensor.shape[1]),dtype=torch.bool).to(model.device)
        replace_position[:,replace_start_idx:replace_end_idx] = 1
    batch_x = torch.cat([node.sequence_tensor for node in nodes], dim=0)
    batch_attention_mask = attention_mask.repeat(len(nodes), 1) if attention_mask is not None else None
    if verbose and block_length and block_length > 0:
        print(f"\n    Input sequences for each node (showing relevant blocks):")
        for idx, node in enumerate(nodes):
            if node.verify_position is None:  # Root node
                print(f"      Node {idx} (root): Initial state")
                # Show first block for root
                block_start = prompt_length
                block_end = min(block_start + block_length, batch_x.shape[1])
                filled_tokens = []
                for pos in range(block_start, block_end):
                    if pos >= batch_x.shape[1]:
                        break
                    token_val = batch_x[idx, pos].item()
                    if token_val != mask_id:
                        if pos in node.confirmed_positions:
                            filled_tokens.append(f"{pos}={token_val}✓")
                        else:
                            filled_tokens.append(f"{pos}={token_val}")
                    else:
                        filled_tokens.append(f"{pos}=MASK")
                print(f"        Block 0 [{block_start}-{block_end}): {', '.join(filled_tokens)}")
            elif node.verify_position:
                # Calculate block_id for verify_position
                verify_block_id = (node.verify_position - prompt_length) // block_length
                
                # Show current block and potentially previous block if cross-block
                blocks_to_show = [verify_block_id]
                if verify_block_id > 0 and any(pos < prompt_length + verify_block_id * block_length 
                                               for pos in node.confirmed_positions 
                                               if pos >= prompt_length):
                    blocks_to_show.insert(0, verify_block_id - 1)
                
                print(f"      Node {idx} (verify pos {node.verify_position}):")
                
                for block_id in blocks_to_show:
                    block_start = prompt_length + block_id * block_length
                    block_end = min(block_start + block_length, batch_x.shape[1])
                    
                    # Collect token info for this block
                    filled_tokens = []
                    for pos in range(block_start, block_end):
                        if pos >= batch_x.shape[1]:
                            break
                        token_val = batch_x[idx, pos].item()
                        if token_val != mask_id:
                            # Check if this position is in confirmed list
                            if pos in node.confirmed_positions:
                                filled_tokens.append(f"{pos}={token_val}✓")
                            else:
                                filled_tokens.append(f"{pos}={token_val}")
                        else:
                            filled_tokens.append(f"{pos}=MASK")
                    
                    print(f"        Block {block_id} [{block_start}-{block_end}): {', '.join(filled_tokens)}")
            elif not node.mask_positions:  # Leaf node
                print(f"      Node {idx} (leaf): All draft tokens filled, ready for next prediction")
            elif node.verify_position is None:  # Root node
                print(f"      Node {idx} (root): Initial state sequence")
    
    # Single forward pass for all nodes
    if block_forward:
        output = model(batch_x,use_cache=True)
        past_key_values = output.past_key_values
        batch_logits = output.logits[:,replace_start_idx:replace_end_idx]
        # Get predictions and confidence
        batch_logits_with_noise = add_gumbel_noise(batch_logits, temperature=temperature)
        batch_predictions = torch.argmax(batch_logits_with_noise, dim=-1)
        batch_probs = F.softmax(batch_logits, dim=-1)
    else:
        output = model(batch_x[:,replace_start_idx:replace_end_idx],
                            past_key_values=repeat_kv_cache(past_key_values,batch_size),
                            replace_position=replace_position,
                            use_cache=True)
        batch_logits = output.logits
        past_key_values = output.past_key_values
        # Get predictions and confidence
        batch_logits_with_noise = add_gumbel_noise(batch_logits, temperature=temperature)
        batch_predictions = torch.argmax(batch_logits_with_noise, dim=-1)
        batch_probs = F.softmax(batch_logits, dim=-1)
    # Process results for each node - simplified to just store predictions
    node_results = {}
    for i, node in enumerate(nodes):
        # Store predictions and probabilities for each node
        node_results[id(node)] = {
            'predictions': batch_predictions[i],
            'logits': batch_logits[i],
            'probs': batch_probs[i],
            'past_key_values': [(k[i].unsqueeze(0), v[i].unsqueeze(0)) for (k, v) in past_key_values],
            'verify_position': node.verify_position,
            'expected_token': node.expected_token if hasattr(node, 'expected_token') else None,
            'mask_positions': node.mask_positions,
            'is_leaf': len(node.mask_positions) == 0,
            'is_root': node.verify_position is None
        }
    return node_results, 1 # Always 1 forward pass


def find_longest_valid_path(tree, node_results, all_draft_tokens, prompt_length, 
                            block_length, verbose, replace_start_idx, replace_end_idx,current_start_idx,current_end_idx):
    """
    Find the longest valid path in the tree using parent-centric logic.
    
    Args:
        tree: The root node
        node_results: Dict mapping node ID to predictions
        block_length: Block length for priority calculation
    
    Returns:
        longest_confirmed: List of confirmed positions
        confirmed_token_values: Dict mapping position to token value
        failed_info: Info about where verification failed (if any)
    """
    longest_confirmed = tree.confirmed_positions.copy()
    confirmed_token_values = {}  # New: store verified token values
    failed_info = None
    draft_accepted_count = 0
    
    if verbose:
        print(f"\n    === Actual Tree Traversal ===")
        print(f"    Starting from root with {len(tree.confirmed_positions)} already confirmed positions")
    
    current = tree
    
    while True:
        # Get current node's result
        if id(current) not in node_results:
            if verbose:
                print("    Warning: Current node not in results, stopping")
            break
        
        current_result = node_results[id(current)]
        
        # Check if this is a leaf node (no children)
        if not current.children:
            # Leaf node - all drafts verified, can predict one more token
            if verbose:
                print(f"\n    Reached leaf node (all drafts verified)")
                print(f"      Leaf node has {len(current_result['mask_positions'])} remaining mask positions")
                if current_result['mask_positions']:
                    print(f"      Remaining masks: {current_result['mask_positions'][:5]}")  # Show first 5
            
            # If there are still mask positions, predict the next one
            if current_result['mask_positions']:
                predictions = current_result['predictions']
                probs = current_result['probs']
                mask_positions = current_result['mask_positions']
                
                # Find best position for leaf node's extra prediction
                best_pos = None
                best_conf = -1
                best_token = None
                best_block_id = float('inf')
                
                for mask_pos in mask_positions:
                    if mask_pos >= replace_end_idx:
                        break
                    rel_pos = mask_pos - prompt_length
                    block_rel_pos = mask_pos - replace_start_idx
                    mask_block_id = rel_pos // block_length if block_length and block_length > 0 else 0
                    pred_token = predictions[block_rel_pos].item()
                    conf = probs[block_rel_pos, pred_token].item()
                    
                    if mask_block_id < best_block_id or (mask_block_id == best_block_id and conf > best_conf):
                        best_pos = mask_pos
                        best_conf = conf
                        best_token = pred_token
                        best_block_id = mask_block_id
                
                if best_pos is not None:
                    longest_confirmed.append(best_pos)
                    confirmed_token_values[best_pos] = best_token  # Save leaf prediction token
                    if verbose:
                        print(f"      → Leaf node predicts extra token: pos {best_pos} (block {best_block_id})")
                        print(f"        with token={best_token}, conf={best_conf:.4f}")
                        print(f"      ✓ Accepting leaf prediction")
                    
                    # Set up info for new draft
                    failed_info = {
                        'node': current,
                        'failed_pos': best_pos,
                        'predicted_token': best_token,
                        #'expected_token': all_draft_tokens[0, best_pos - replace_start_idx].item() if all_draft_tokens is not None else best_token,
                        'confidence': best_conf,
                        'new_predictions': predictions,
                        'new_logits': current_result['logits'],
                        'new_past_key_values': current_result['past_key_values'],
                        'is_leaf_success': True  # Mark this as successful leaf prediction
                    }
            break
        
        # If no masks left, we're done
        if not current_result['mask_positions']:
            if verbose:
                print(f"\n    No more mask positions to process")
            break
        
        # Find what position current node wants most (considering block priority)
        predictions = current_result['predictions']
        probs = current_result['probs']
        mask_positions = current_result['mask_positions']
        
        # Find best position considering block priority
        best_pos = None
        best_conf = -1
        best_token = None
        best_block_id = float('inf')
        
        for mask_pos in mask_positions:
            # Get block ID
            if mask_pos >= replace_end_idx:
                break
            rel_pos = mask_pos - prompt_length
            block_rel_pos = mask_pos - replace_start_idx
            mask_block_id = rel_pos // block_length if block_length and block_length > 0 else 0
            
            # Get predicted token and confidence
            pred_token = predictions[block_rel_pos].item()
            conf = probs[block_rel_pos, pred_token].item()
            
            # Update best if this is better (lower block ID or higher conf in same block)
            if mask_block_id < best_block_id or (mask_block_id == best_block_id and conf > best_conf):
                best_pos = mask_pos
                best_conf = conf
                best_token = pred_token
                best_block_id = mask_block_id
        
        if verbose:
            node_type = "root" if current_result['is_root'] else f"node"
            print(f"\n    Current {node_type} wants position {best_pos} (block {best_block_id}) "
                  f"with token={best_token}, conf={best_conf:.4f}")
        
        # Check if any child matches this choice
        matching_child = None
        for child in current.children:
            if child.verify_position == best_pos:
                # Found a child that verifies our desired position
                child_result = node_results[id(child)]
                
                # Check if tokens match
                if child.expected_token == best_token:
                    matching_child = child
                    if verbose:
                        print(f"      ✓ Found matching child: verify pos {best_pos} with expected token {child.expected_token}")
                else:
                    if verbose:
                        print(f"      ✗ Child at pos {best_pos} has wrong token: "
                              f"expected {child.expected_token}, parent wants {best_token}")
                break
        
        if matching_child:
            # Continue to the matching child
            if best_pos not in longest_confirmed:
                longest_confirmed.append(best_pos)
                confirmed_token_values[best_pos] = matching_child.expected_token  # Save verified token
                draft_accepted_count += 1
                if verbose:
                    print(f"      → Accepting draft token at pos {best_pos}")
                    print(f"      → Moving to child node...")
            current = matching_child
        else:
            # No matching child - accept current node's choice and stop
            if best_pos not in longest_confirmed:
                longest_confirmed.append(best_pos)
                confirmed_token_values[best_pos] = best_token  # Save model's choice
                if verbose:
                    print(f"      → No matching child for pos {best_pos}")
                    print(f"      ✓ Accepting model's choice: pos {best_pos} with token {best_token}")
                    print(f"      → Stopping traversal")
                
                # Set up failed_info for draft update
                failed_info = {
                    'node': current,
                    'failed_pos': best_pos,
                    'predicted_token': best_token,
                    #'expected_token': all_draft_tokens[0, best_pos - replace_start_idx].item() if all_draft_tokens is not None else best_token,
                    'confidence': best_conf,
                    'new_predictions': predictions,
                    'new_past_key_values': current_result['past_key_values'],
                    'new_logits': current_result['logits']
                }
            break
    
    if verbose:
        print(f"\n    === Traversal Summary ===")
        print(f"    Draft tokens accepted (as expected): {draft_accepted_count}")
        total_new = len(longest_confirmed) - len(tree.confirmed_positions)
        model_override = total_new - draft_accepted_count
        print(f"    Model override tokens: {model_override}")
        print(f"    Total new tokens confirmed: {total_new}")
        print(f"    Total positions confirmed: {len(longest_confirmed)}")
        if failed_info:
            print(f"    Stopped at position: {failed_info['failed_pos']}")
            print(f"    Reason: Model predicted different token than draft")
    
    # Include draft_accepted_count in failed_info for caller's use
    if failed_info:
        failed_info['draft_accepted_count'] = draft_accepted_count
    
    return longest_confirmed, confirmed_token_values, failed_info

def _decode_remaining_standard_v3(
    x, model, attention_mask, mask_positions, 
    prompt_length, prompt_index, temperature, cfg_scale, mask_id,
    all_draft_tokens, all_confidences, block_length, confirmed_positions, 
    verbose, past_key_values,current_start_idx, current_end_idx
):
    """
    Standard decode for remaining masks when count < draft_length.
    Decodes one token at a time, respecting block priority.
    
    Args:
        mask_positions: List of remaining mask positions (will be modified)
        confirmed_positions: List of already confirmed positions
        
    Returns:
        forward_passes: Number of forward passes used
    """
    forward_passes = 0
    replace_position = torch.zeros_like(x,dtype=torch.bool).to(x.device)
    replace_position[:,current_start_idx:current_end_idx+1] = 1

    while mask_positions:
        # Forward pass
        logits = model(x[:,current_start_idx:current_end_idx+1],
                    past_key_values=past_key_values,
                    replace_position=replace_position,
                    use_cache=True).logits
        # logits.shape= [1, (gen_length-current_start_idx),vocab_size]
        forward_passes += 1
        
        # Generate tokens and confidence for generation part
        # gen_logits = logits[:, prompt_length:]
        logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
        predictions = torch.argmax(logits_with_noise, dim=-1)
        
        # Calculate confidence
        probs = F.softmax(logits, dim=-1)
        confidence = torch.gather(probs, dim=-1, index=predictions.unsqueeze(-1)).squeeze(-1)

        # Find best position respecting block priority
        best_pos = None
        best_token = None
        best_conf = -1
        best_block_id = float('inf')
        
        for mask_pos in mask_positions:
            rel_pos = mask_pos - prompt_length
            mask_block_id = rel_pos // block_length if block_length and block_length > 0 else 0
            block_rel_pos = mask_pos - current_start_idx
            pred_token = predictions[0, block_rel_pos].item()
            conf = confidence[0,block_rel_pos].item()
            
            # Update best if this is better (lower block ID or higher conf in same block)
            if mask_block_id < best_block_id or (mask_block_id == best_block_id and conf > best_conf):
                best_pos = mask_pos
                best_token = pred_token
                best_conf = conf
                best_block_id = mask_block_id
        
        if best_pos is not None:
            # Accept the token
            x[:, best_pos] = best_token
            mask_positions.remove(best_pos)
            confirmed_positions.append(best_pos)
            
            if verbose:
                print(f"    Standard decode: Accepted pos {best_pos} (block {best_block_id}) "
                      f"with token={best_token}, conf={best_conf:.4f}")
                print(f"    Remaining masks: {len(mask_positions)}")
                print(f"    Forward passes in standard decode: {forward_passes}")
        else:
            # Should not happen
            if verbose:
                print(f"    Warning: Could not find best position, breaking")
            break
    
    if verbose:
        print(f"    Standard decode completed with {forward_passes} forward passes")
    
    return forward_passes


def ssd_with_cache(
    input_ids,
    attention_mask,
    model,
    gen_length=128,
    block_length=None,
    temperature=0.0,
    cfg_scale=0.0,
    mask_id=126336,
    draft_length=4,
    tree_strategy='greedy',
    verbose=False,
    refresh_interval=100,
    **kwargs
):
    """
    Entry point for SCSD v3 generation.
    
    Args:
        draft_length: Fixed draft length (default 4)
        tree_strategy: Tree building strategy ('greedy', etc.)
        block_length: Block length for block ID calculation (optional)
    """
    return _generate_scsd_v3(
        input_ids=input_ids,
        attention_mask=attention_mask,
        model=model,
        gen_length=gen_length,
        block_length=block_length,
        temperature=temperature,
        cfg_scale=cfg_scale,
        mask_id=mask_id,
        draft_length=draft_length,
        tree_strategy=tree_strategy,
        verbose=verbose,
        refresh_interval=refresh_interval,
        **kwargs
    )
