"""
Two classes for Transformer Diagonal Prediction Maps with efficient inference.
The first is the `AmortizedConditioningEngine` (this should be renamed). This 
class is the main class that we should use for training.

The second class is the `InferenceEngine` which we should use for inference 
via the `.sample_sequence` method. TODO: this doesn't currently compile.
"""

from typing import Optional, Tuple

import torch
from torch.nn.attention.flex_attention import BlockMask, flex_attention

from src.models.masks import create_context_self_attention_block_mask
from src.models.utils import expand_kv_heads
from src.utils import (
    DataAttr,
    LossAttr,
    concatenate_batches,
    create_context_buffer_datapoint,
    fetch_next_query,
    fetch_next_query_batch,
)


# Pre-compile flex_attention for GPU execution only
# CPU doesn't support compiled flex_attention with score_mod
if torch.cuda.is_available() or hasattr(torch.version, 'hip'):
    flex_attention = torch.compile(flex_attention, fullgraph=True)


class AmortizedConditioningEngine(torch.nn.Module):
    """Amortized Conditioning Engine for efficient conditional neural processes.
    
    Three-section transformer architecture:
    1. Context: Observed data (all tokens attend to this)
    2. Buffer: Autoregressive tokens for efficient generation  
    3. Targets: Prediction points
    
    The attention mask allows targets to attend to context and a limited buffer window,
    enabling efficient autoregressive generation during inference.
    """

    # Configuration constants
    DEFAULT_MAX_BUFFER_SIZE: int = 32
    DEFAULT_NUM_TARGET_POINTS: int = 500
    DEFAULT_TARGETS_BLOCK_SIZE: int = 5

    def __init__(
        self,
        embedder: torch.nn.Module,
        backbone: torch.nn.Module,
        head: torch.nn.Module,
        max_buffer_size: int = DEFAULT_MAX_BUFFER_SIZE,
        num_target_points: int = DEFAULT_NUM_TARGET_POINTS,
        targets_block_size_for_buffer_attend: int = DEFAULT_TARGETS_BLOCK_SIZE,
        ar_token_init_std: float = 0.02,
        seed: Optional[int] = None,
    ) -> None:
        """Initialize ACE model.
        
        Args:
            embedder: Embedding module for inputs
            backbone: Transformer backbone
            head: Output distribution head
            max_buffer_size: Buffer size for autoregressive generation
            num_target_points: Expected number of targets during training
            targets_block_size_for_buffer_attend: Number of target chunks attending to buffer
            ar_token_init_std: Std dev for AR token initialization
            seed: Random seed for reproducibility
        """
        super().__init__()
        
        self.embedder = embedder
        self.backbone = backbone
        self.head = head
        self.max_buffer_size = max_buffer_size
        self.num_target_points = num_target_points
        self.targets_block_size_for_buffer_attend = targets_block_size_for_buffer_attend

        if seed is not None:
            torch.manual_seed(seed)
        
        self.ar_token = torch.nn.Parameter(
            torch.randn(self.max_buffer_size, self.backbone.dim_model) * ar_token_init_std
        )
        self.training_mask_cache: Optional[BlockMask] = None
    
    def validate_batch(self, batch: DataAttr) -> None:
        """Validate batch dimensions match model configuration.
        
        Call this before training, not in forward pass to avoid compilation issues.
        """
        assert batch.xb.shape[1] == self.max_buffer_size, (
            f"Buffer size {batch.xb.shape[1]} != {self.max_buffer_size}"
        )
        assert batch.xt.shape[1] <= self.num_target_points, (
            f"Targets {batch.xt.shape[1]} > {self.num_target_points}"
        )

    def embed_draft(self, buffer: DataAttr) -> torch.Tensor:
        """Add AR position tokens to buffer embeddings."""
        buffer_embeddings = self.embedder.embed_buffer(buffer)
        return buffer_embeddings + self.ar_token.unsqueeze(0)

    def forward(
        self,
        batch: DataAttr,
        block_mask: BlockMask,
    ) -> LossAttr:
        """Forward pass for training the ACE model.
        
        Args:
            batch: DataAttr with context (xc,yc), buffer (xb,yb), and targets (xt,yt)
            block_mask: Pre-computed attention mask

        Returns:
            LossAttr containing loss and predictions from the head
        """
        context = self.embedder.embed_context(batch)
        draft = self.embed_draft(batch) 
        target = self.embedder.embed_target(batch)
        
        embeddings = torch.cat([context, draft, target], dim=1)
        z, _ = self.backbone(embeddings, block_mask=block_mask)
        
        C, T = context.shape[1] + draft.shape[1], target.shape[1]
        _, zt = torch.split(z, [C, T], dim=1)
        
        return self.head(zt, batch.yt)


class InferenceEngine(torch.nn.Module):
    """Inference engine for autoregressive decoding with the ACE model.

    This class provides functionality for sampling from mode autoregressively,
    managing KV cache for efficient transformer decoding, and handling the
    interaction between context embeddings and new predictions.
    """

    # Default block sizes for self-attention masks
    DEFAULT_Q_BLOCK_SIZE: int = 32
    DEFAULT_KV_BLOCK_SIZE: int = 32

    def __init__(
        self,
        embedder: torch.nn.Module,
        backbone: torch.nn.Module,
        head: torch.nn.Module,
        ar_tokens: torch.nn.Parameter,
        max_buffer_size: int,
        q_block_size: int = DEFAULT_Q_BLOCK_SIZE,
        kv_block_size: int = DEFAULT_KV_BLOCK_SIZE,
    ):
        """Initialize the inference engine.

        Args:
            embedder: Model embedder for encoding context/buffer/target data
            backbone: Transformer backbone for processing embeddings
            head: Output head for generating predictions
            ar_tokens: Autoregressive position tokens from the trained model
            max_buffer_size: Maximum number of autoregressive buffer positions
            q_block_size: Block size for queries in flex attention
            kv_block_size: Block size for keys/values in flex attention
        """
        super().__init__()
        self.embedder = embedder
        self.backbone = backbone
        self.head = head
        self.ar_tokens = torch.nn.Parameter(ar_tokens.clone())
        self.max_buffer_size = max_buffer_size
        self.q_block_size = q_block_size
        self.kv_block_size = kv_block_size

        # Cache for score_mod functions at different offsets to avoid recompilation
        self.score_mod_cache = {}
        # Cache for self-attention masks at different context lengths
        self.selfattn_mask_cache = {}

    @classmethod
    def from_trained_model(
        cls,
        model: AmortizedConditioningEngine,
        q_block_size: int = None,
        kv_block_size: int = None,
    ) -> "InferenceEngine":
        """Create an inference engine from a trained ACE model.

        Args:
            model: Trained AmortizedConditioningEngine instance
            q_block_size: Override for query block size (defaults to model's setting)
            kv_block_size: Override for key/value block size (defaults to model's setting)

        Returns:
            InferenceEngine configured with the trained model's components
        """
        return cls(
            embedder=model.embedder,
            backbone=model.backbone,
            head=model.head,
            ar_tokens=model.ar_token,
            max_buffer_size=model.max_buffer_size,
            q_block_size=q_block_size or cls.DEFAULT_Q_BLOCK_SIZE,
            kv_block_size=kv_block_size or cls.DEFAULT_KV_BLOCK_SIZE,
        )

    def sample_sequence(self, batch: DataAttr, K: int = 4) -> DataAttr:
        """Sample a sequence of predictions autoregressively.

        This method generates predictions for all target points in batches of size K.
        Each batch is decoded autoregressively, and predictions from previous batches
        are added to the context for subsequent batches.

        Args:
            batch: Input data containing context (xc, yc) and target positions (xt)
            K: Batch size for autoregressive decoding (number of predictions per batch)

        Returns:
            DataAttr containing predicted values (yc) at target positions (xc)
        """
        T = batch.xt.shape[1]
        num_batches = T // K
        batch_size = batch.xt.shape[0]

        # Store initial context embeddings
        self.store_context_embeddings(batch)

        # Initialize KV cache for efficient transformer decoding
        max_seq = self.context_embeddings.shape[1] + T
        self.init_kv_cache(batch_size, max_seq, device=batch.xt.device)

        # Pre-allocate output tensors
        predicted_positions = torch.zeros(batch_size, T, batch.xt.shape[2], 
                                        device=batch.xt.device, dtype=batch.xt.dtype)
        predicted_values = torch.zeros(batch_size, T, 1,  # Assuming 1D output
                                     device=batch.xt.device, dtype=batch.xt.dtype)
        
        for batch_idx in range(num_batches):
            start_idx = batch_idx * K
            end_idx = min(start_idx + K, T)
            batch_K = end_idx - start_idx

            # Prefill KV cache with current context (grows as we add predictions)
            context_len = self.context_embeddings.shape[1]
            selfattention_mask = self._get_cached_selfattn_mask(
                context_len, batch.xt.device
            )
            self.prefill_kv_cache(selfattention_mask)

            # Get the next batch of K target points
            query_batch = fetch_next_query_batch(batch, start_idx, batch_K)

            # Decode the batch autoregressively
            batch_predictions = self.batch_decode(query_batch, batch_K)

            # Place predictions directly into pre-allocated tensors
            predicted_positions[:, start_idx:end_idx] = batch_predictions.xc
            predicted_values[:, start_idx:end_idx] = batch_predictions.yc

        # Return as DataAttr
        return DataAttr(xc=predicted_positions, yc=predicted_values)

    def batch_decode(self, batch: DataAttr, K: int = 4) -> DataAttr:
        """Decode K target points autoregressively without intermittent self-attention.

        Each prediction depends on all previous predictions within this batch.
        After completing the batch, the last prediction is added to the context
        for subsequent batches.

        Args:
            batch: Batch containing K target points to decode
            K: Number of points to decode in this batch
            
        Returns:
            DataAttr containing K predictions
        """
        batch_predictions = None
        previous_prediction = None

        for k in range(K):
            # Get k-th target point from the batch
            query = fetch_next_query(batch, k)

            if k == 0:
                # First prediction in batch: no previous prediction
                prediction = self.autoregressive_decode(query)
                batch_predictions = prediction
            else:
                # Subsequent predictions: condition on previous prediction
                prediction = self.autoregressive_decode(query, previous_prediction)
                batch_predictions = concatenate_batches(batch_predictions, prediction)

            previous_prediction = prediction

        # Update context with the last prediction for next batch
        self.update_context_embeddings(prediction)
        return batch_predictions

    def autoregressive_decode(
        self, query: DataAttr, previous_prediction: DataAttr = None
    ) -> DataAttr:
        """Generate a single prediction autoregressively.

        Args:
            query: Single target point to predict (contains xt position)
            previous_prediction: Previous prediction in the current batch (if any)

        Returns:
            DataAttr containing the prediction formatted for both context and buffer
        """
        # Embed the target query
        query_embedding = self.embedder.embed_target(query)

        if previous_prediction is not None:
            # Get autoregressive embedding for previous prediction
            # The AR token index corresponds to position within current batch
            num_previous = previous_prediction.xc.shape[1]
            previous_prediction_embedding = self.autoregressive_embedder_with_idx(
                previous_prediction, num_previous - 1
            )
            # Update context for transformer to attend to
            self.update_context_embeddings(previous_prediction)
            # Combine previous prediction embedding with current query
            embedding = torch.cat(
                [previous_prediction_embedding, query_embedding], dim=1
            )
        else:
            # First prediction in batch: just use query embedding
            embedding = query_embedding.clone()

        # Decode through transformer (uses KV cache for efficiency)
        z = self.transformer_decode(embedding)

        # Generate prediction from final position
        samples = self.head.sample(z[:, -1, :].unsqueeze(1), num_samples=1)
        yhat = samples.squeeze(2)

        # Format prediction for dual use as context and buffer
        prediction = create_context_buffer_datapoint(query, yhat)
        return prediction

    def store_context_embeddings(self, context: DataAttr) -> None:
        """Store initial context embeddings from the input batch."""
        self.context_embeddings = self.embedder.embed_context(context)

    def update_context_embeddings(self, new_context: DataAttr) -> None:
        """Append new predictions to the context embeddings."""
        new_embeddings = self.embedder.embed_context(new_context)
        self.context_embeddings = torch.cat(
            [self.context_embeddings, new_embeddings], dim=1
        )

    def init_kv_cache(
        self,
        B: int,
        max_seq: int,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> None:
        """Initialize key-value cache for efficient transformer decoding.

        Args:
            B: Batch size
            max_seq: Maximum sequence length to allocate
            device: Device to allocate cache on (defaults to cuda if available)
            dtype: Data type for cache (defaults to model's dtype)
        """
        # Get device and dtype from model if not provided
        if device is None or dtype is None:
            # Get from model parameters
            param = next(self.backbone.parameters())
            if device is None:
                device = param.device
            if dtype is None:
                dtype = param.dtype

        # Allocate cache for each transformer layer
        for lyr in self.backbone.layers:
            H, Dh = lyr.attn.num_heads, lyr.attn.head_dim
            lyr.k_cache = torch.zeros(B, H, max_seq, Dh, dtype=dtype, device=device)
            lyr.v_cache = torch.zeros_like(lyr.k_cache)

        # Track current sequence length and offset for causal masking
        self.backbone.seq_len = torch.zeros(B, dtype=torch.int32, device=device)
        self.backbone.offset = torch.tensor(0, dtype=torch.int64, device=device)

    @torch.no_grad()
    def prefill_kv_cache(self, mask: BlockMask) -> None:
        """Prefill the kv cache once with regular forward using fully dense (self-attention) mask"""
        _, kv_pairs = self.backbone(self.context_embeddings, mask)  # forward
        L0 = self.context_embeddings.shape[1]

        for (k, v), lyr in zip(kv_pairs, self.backbone.layers):  # k, v = [B, H, L0, Dh]
            lyr.k_cache[:, :, :L0, :] = k  # same layout
            lyr.v_cache[:, :, :L0, :] = v

        self.backbone.seq_len = torch.tensor(L0, device=self.backbone.seq_len.device)
        # Update offset in-place to maintain tensor reference
        self.backbone.offset.fill_(L0)

    def transformer_decode(
        self, new_query_embedding: torch.Tensor, keep_last: bool = False
    ) -> torch.Tensor:
        """Incremental decode function for transformer"""
        B, N, d_model = new_query_embedding.shape
        commit = N if keep_last else max(N - 1, 0)
        H = self.backbone.layers[0].attn.num_heads

        # Get cached score_mod function for current offset
        current_offset_value = int(self.backbone.offset.item())
        causal_with_offset = self._get_causal_score_mod(current_offset_value)

        for lyr in self.backbone.layers:
            Dh = lyr.attn.head_dim
            q = lyr.norm1(new_query_embedding)

            qh = lyr.attn.q_proj(q).view(B, N, H, Dh).transpose(1, 2).contiguous()  # [B,H,N,Dh]
            kh = (
                lyr.attn.k_proj(q).view(B, N, lyr.attn.num_kv_heads, Dh).transpose(1, 2).contiguous()
            )
            vh = (
                lyr.attn.v_proj(q).view(B, N, lyr.attn.num_kv_heads, Dh).transpose(1, 2).contiguous()
            )
            kh = expand_kv_heads(kh, H // lyr.attn.num_kv_heads)  # [B,H,N,Dh]
            vh = expand_kv_heads(vh, H // lyr.attn.num_kv_heads)

            past = int(self.backbone.seq_len.max().item())
            k_full = torch.cat([lyr.k_cache[:, :, :past, :], kh], dim=2)  # [B,H,T,Dh]
            v_full = torch.cat([lyr.v_cache[:, :, :past, :], vh], dim=2)

            attn = flex_attention(
                qh, k_full, v_full, score_mod=causal_with_offset
            )                                                   # [B,H,N,Dh]
            attn = attn.transpose(1, 2).reshape(B, N, d_model)  # back to [B,N,d_model]

            new_query_embedding = new_query_embedding + lyr.drop_attn(
                lyr.attn.o_proj(attn)
            )
            y = lyr.norm2(new_query_embedding)
            new_query_embedding = new_query_embedding + lyr.ff2(
                lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y)))
            )

            if commit:
                idx = self.backbone.seq_len.unsqueeze(0).unsqueeze(0) + torch.arange(
                    commit, device=new_query_embedding.device
                )

                idx_expanded_b_h_commit = idx.expand(B, H, commit)  # [B, H, commit]
                idx_unsqueezed_for_dh = idx_expanded_b_h_commit.unsqueeze(
                    -1
                )  # [B, H, commit, 1]
                indices_for_scatter = idx_unsqueezed_for_dh.expand(
                    B, H, commit, Dh
                )  # [B, H, commit, Dh]

                lyr.k_cache.scatter_(2, indices_for_scatter, kh[:, :, :commit, :])
                lyr.v_cache.scatter_(2, indices_for_scatter, vh[:, :, :commit, :])

        if commit:
            self.backbone.seq_len += commit
            self.backbone.offset += commit

        return self.backbone.norm(new_query_embedding)

    def _get_causal_score_mod(self, offset: int):
        """Get or create a cached score_mod function for the given offset."""
        if offset not in self.score_mod_cache:
            # Create a score_mod function for this specific offset value
            def make_causal_with_offset(offset_value):
                def causal_score_mod(score, b, h, q_idx, kv_idx):
                    return torch.where(
                        q_idx + offset_value >= kv_idx, score, -float("inf")
                    )

                return causal_score_mod

            self.score_mod_cache[offset] = make_causal_with_offset(offset)

        return self.score_mod_cache[offset]

    def _device_to_cache_key(self, device: torch.device) -> str:
        """Convert device to a consistent string for cache keys."""
        return str(device)

    def _get_cached_selfattn_mask(self, context_len: int, device: torch.device) -> BlockMask:
        """Get or create a cached self-attention mask for the given context length."""
        device_key = self._device_to_cache_key(device)
        cache_key = (context_len, device_key)

        if cache_key not in self.selfattn_mask_cache:
            # Create and cache the mask
            self.selfattn_mask_cache[cache_key] = (
                create_context_self_attention_block_mask(
                    current_num_context=context_len,
                    q_block_size=self.q_block_size,
                    kv_block_size=self.kv_block_size,
                    device=device,
                )
            )

        return self.selfattn_mask_cache[cache_key]

    def autoregressive_embedder_with_idx(self, x: DataAttr, idx: int) -> torch.Tensor:
        """Get autoregressive embedding for a datapoint at a specific index."""
        buffer_embeddings = self.embedder.embed_buffer(x)
        B, _, _ = buffer_embeddings.shape
        return buffer_embeddings + self.ar_tokens[idx].unsqueeze(0).expand(B, -1, -1)

    def prepare_inference_caches(self, batch: DataAttr, K: int) -> None:
        """Precompute all masks and score_mod functions needed for inference.

        This method analyzes the batch and precomputes all self-attention masks
        and score_mod functions that will be needed during sample_sequence.
        Call this before sample_sequence for optimal performance.

        Args:
            batch: Input batch to analyze for cache preparation
            K: Batch size for autoregressive decoding
        """
        # Get dimensions from batch
        initial_context_len = batch.xc.shape[1] if batch.xc is not None else 0
        T = batch.xt.shape[1]
        device = batch.xt.device

        # Calculate how many batches we'll process
        num_batches = T // K

        # Precompute self-attention masks for each batch iteration
        # At the start of each batch, context has grown by previous batch sizes
        current_context_len = initial_context_len
        for batch_idx in range(num_batches):
            # Determine batch size (might be less than K for last batch)
            start_idx = batch_idx * K
            end_idx = min(start_idx + K, T)
            batch_K = end_idx - start_idx

            # Cache mask for the context length at start of this batch
            self._get_cached_selfattn_mask(current_context_len, device)

            # After this batch, context will have grown by batch_K
            current_context_len += batch_K

        # Precompute score_mod functions for all offsets we'll encounter
        # Offsets range from initial_context_len to initial_context_len + T
        for offset in range(initial_context_len, initial_context_len + T + 1):
            self._get_causal_score_mod(offset)


class InferenceEngine2(torch.nn.Module):
    """Fast inference engine for ACE models with KV caching and torch.compile support."""
    
    DEFAULT_Q_BLOCK_SIZE: int = 32
    DEFAULT_KV_BLOCK_SIZE: int = 32
    
    def __init__(
        self,
        embedder: torch.nn.Module,
        backbone: torch.nn.Module,
        head: torch.nn.Module,
        ar_tokens: torch.nn.Parameter,
        max_buffer_size: int,
        q_block_size: int = DEFAULT_Q_BLOCK_SIZE,
        kv_block_size: int = DEFAULT_KV_BLOCK_SIZE,
    ):
        super().__init__()
        self.embedder = embedder
        self.backbone = backbone
        self.head = head
        self.ar_tokens = torch.nn.Parameter(ar_tokens.clone())
        self.max_buffer_size = max_buffer_size
        self.q_block_size = q_block_size
        self.kv_block_size = kv_block_size
        self.selfattn_mask_cache = {}
    
    @classmethod
    def from_trained_model(
        cls,
        model: AmortizedConditioningEngine,
        q_block_size: int = None,
        kv_block_size: int = None,
    ) -> "InferenceEngine2":
        """Create an inference engine from a trained ACE model."""
        engine = cls(
            embedder=model.embedder,
            backbone=model.backbone,
            head=model.head,
            ar_tokens=model.ar_token,
            max_buffer_size=model.max_buffer_size,
            q_block_size=q_block_size or cls.DEFAULT_Q_BLOCK_SIZE,
            kv_block_size=kv_block_size or cls.DEFAULT_KV_BLOCK_SIZE,
        )
        
        return engine
    
    def sample_sequence(self, batch: DataAttr, K: int = 4) -> DataAttr:
        """Sample predictions autoregressively in batches of K."""
        T = batch.xt.shape[1]
        num_batches = T // K
        batch_size = batch.xt.shape[0]
        
        self.store_context_embeddings(batch)
        
        max_seq = self.context_embeddings.shape[1] + T
        self.init_kv_cache(batch_size, max_seq, device=batch.xt.device)
        self.offset = torch.zeros([], dtype=torch.int64, device=batch.xt.device)  # Scalar tensor
        
        predicted_positions = torch.zeros(batch_size, T, batch.xt.shape[2], 
                                        device=batch.xt.device, dtype=batch.xt.dtype)
        predicted_values = torch.zeros(batch_size, T, 1,
                                     device=batch.xt.device, dtype=batch.xt.dtype)
        
        for batch_idx in range(num_batches):
            start_idx = batch_idx * K
            end_idx = min(start_idx + K, T)
            batch_K = end_idx - start_idx
            
            context_len = self.context_embeddings.shape[1]
            selfattention_mask = self._get_cached_selfattn_mask(context_len, batch.xt.device)
            self.prefill_kv_cache(selfattention_mask)
            
            query_batch = fetch_next_query_batch(batch, start_idx, batch_K)
            batch_predictions = self.batch_decode(query_batch, batch_K)
            
            predicted_positions[:, start_idx:end_idx] = batch_predictions.xc
            predicted_values[:, start_idx:end_idx] = batch_predictions.yc
        
        return DataAttr(xc=predicted_positions, yc=predicted_values)

    def evaluate_joint_loglikelihood(self, batch: DataAttr, K: int = 8) -> Tuple[DataAttr, torch.Tensor]:
        """Evaluate log-likelihood of true targets while generating predictions.
        
        Uses generated predictions as context (not true values) for realistic evaluation.
        """
        T = batch.xt.shape[1]
        batch_size = batch.xt.shape[0]
        num_batches = T // K

        self.store_context_embeddings(batch)
        
        max_seq = self.context_embeddings.shape[1] + T
        self.init_kv_cache(batch_size, max_seq, device=batch.xt.device)
        self.offset = torch.zeros([], dtype=torch.int64, device=batch.xt.device)  # Scalar tensor

        predicted_positions = torch.zeros(batch_size, T, batch.xt.shape[2], 
                                        device=batch.xt.device, dtype=batch.xt.dtype)
        predicted_values = torch.zeros(batch_size, T, 1,
                                     device=batch.xt.device, dtype=batch.xt.dtype)
        log_likelihoods = torch.zeros(batch_size, T, 1,
                                     device=batch.xt.device, dtype=batch.xt.dtype)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * K
            end_idx = min(start_idx + K, T)
            batch_K = end_idx - start_idx

            context_len = self.context_embeddings.shape[1]
            selfattention_mask = self._get_cached_selfattn_mask(context_len, batch.xt.device)
            self.prefill_kv_cache(selfattention_mask)

            query_batch = fetch_next_query_batch(batch, start_idx, batch_K)
            batch_predictions, batch_log_liks = self.batch_decode_with_evaluation(query_batch, batch_K)

            predicted_positions[:, start_idx:end_idx] = batch_predictions.xc
            predicted_values[:, start_idx:end_idx] = batch_predictions.yc
            log_likelihoods[:, start_idx:end_idx] = batch_log_liks

        return DataAttr(xc=predicted_positions, yc=predicted_values), log_likelihoods
    
    def batch_decode(self, batch: DataAttr, K: int = 4) -> DataAttr:
        """Decode K target points autoregressively - identical logic to InferenceEngine."""
        batch_predictions = None
        previous_prediction = None
        
        for k in range(K):
            query = fetch_next_query(batch, k)
            
            if k == 0:
                prediction = self.autoregressive_decode(query)
                batch_predictions = prediction
            else:
                prediction = self.autoregressive_decode(query, previous_prediction)
                batch_predictions = concatenate_batches(batch_predictions, prediction)
            
            previous_prediction = prediction
        
        # Update context with the last prediction for next batch
        self.update_context_embeddings(prediction)
        return batch_predictions
    
    def batch_decode_with_evaluation(self, batch: DataAttr, K: int = 4) -> Tuple[DataAttr, torch.Tensor]:
        """Decode K target points autoregressively while evaluating log-likelihood of true targets.
        
        Args:
            batch: DataAttr containing K target points with both positions (xt) and true values (yt)
            K: Number of points to decode in this batch
            
        Returns:
            batch_predictions: DataAttr with K generated predictions
            batch_log_likelihoods: Tensor of shape [B, K, dim_y] with log-likelihoods of true targets
        """
        batch_predictions = None
        batch_log_likelihoods = []
        previous_prediction = None
        
        for k in range(K):
            # Get k-th target point from the batch (includes true yt)
            query = fetch_next_query(batch, k)
            
            if k == 0:
                # First prediction in batch: no previous prediction
                prediction, log_likelihood = self.autoregressive_decode_with_evaluation(query)
                batch_predictions = prediction
            else:
                # Subsequent predictions: condition on previous GENERATED prediction (Option A)
                prediction, log_likelihood = self.autoregressive_decode_with_evaluation(
                    query, previous_prediction
                )
                batch_predictions = concatenate_batches(batch_predictions, prediction)
            
            batch_log_likelihoods.append(log_likelihood)
            previous_prediction = prediction
        
        # Update context with the last prediction for next batch
        self.update_context_embeddings(prediction)
        
        # Stack log-likelihoods: list of [B, 1, dim_y] -> [B, K, dim_y]
        all_log_likelihoods = torch.cat(batch_log_likelihoods, dim=1)
        
        return batch_predictions, all_log_likelihoods

    def autoregressive_decode_with_evaluation(
        self, query: DataAttr, previous_prediction: DataAttr = None
    ) -> Tuple[DataAttr, torch.Tensor]:
        """Generate prediction and evaluate log-likelihood of true target."""
        query_embedding = self.embedder.embed_target(query)
        
        if previous_prediction is not None:
            num_previous = previous_prediction.xc.shape[1]
            previous_prediction_embedding = self.autoregressive_embedder_with_idx(
                previous_prediction, num_previous - 1
            )
            self.update_context_embeddings(previous_prediction)
            embedding = torch.cat([previous_prediction_embedding, query_embedding], dim=1)
        else:
            embedding = query_embedding  # No need to clone
        
        z = self.transformer_decode(embedding)
        
        # Evaluate true target likelihood (if provided)
        if query.yt is not None:
            loss_attr = self.head(z[:, -1, :].unsqueeze(1), query.yt)
            log_likelihood = loss_attr.log_likelihood
        else:
            log_likelihood = torch.zeros(z.shape[0], 1, 1, device=z.device, dtype=z.dtype)
        
        # Sample prediction for next context
        samples = self.head.sample(z[:, -1, :].unsqueeze(1), num_samples=1)
        yhat = samples.squeeze(2)
        prediction = create_context_buffer_datapoint(query, yhat)
        
        return prediction, log_likelihood

        
    def autoregressive_decode(
        self, query: DataAttr, previous_prediction: DataAttr = None
    ) -> DataAttr:
        """Generate a single prediction autoregressively."""
        query_embedding = self.embedder.embed_target(query)
        
        if previous_prediction is not None:
            num_previous = previous_prediction.xc.shape[1]
            previous_prediction_embedding = self.autoregressive_embedder_with_idx(
                previous_prediction, num_previous - 1
            )
            self.update_context_embeddings(previous_prediction)
            embedding = torch.cat([previous_prediction_embedding, query_embedding], dim=1)
        else:
            embedding = query_embedding  # No need to clone
        
        z = self.transformer_decode(embedding)
        samples = self.head.sample(z[:, -1, :].unsqueeze(1), num_samples=1)
        yhat = samples.squeeze(2)
        
        return create_context_buffer_datapoint(query, yhat)
    
    def store_context_embeddings(self, context: DataAttr) -> None:
        """Store initial context embeddings from the input batch."""
        self.context_embeddings = self.embedder.embed_context(context)
    
    def update_context_embeddings(self, new_context: DataAttr) -> None:
        """Append new predictions to the context embeddings."""
        new_embeddings = self.embedder.embed_context(new_context)
        self.context_embeddings = torch.cat(
            [self.context_embeddings, new_embeddings], dim=1
        )
    
    def init_kv_cache(
        self,
        B: int,
        max_seq: int,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> None:
        """Initialize KV cache for transformer layers."""
        if device is None or dtype is None:
            param = next(self.backbone.parameters())
            device = device or param.device
            dtype = dtype or param.dtype
        
        for lyr in self.backbone.layers:
            H, Dh = lyr.attn.num_heads, lyr.attn.head_dim
            lyr.k_cache = torch.zeros(B, H, max_seq, Dh, dtype=dtype, device=device)
            lyr.v_cache = torch.zeros_like(lyr.k_cache)
        
        # Track as scalar tensor since all batch items have same length
        self.backbone.seq_len = torch.tensor(0, dtype=torch.int32, device=device)
    
    @torch.no_grad()
    def prefill_kv_cache(self, mask: BlockMask) -> None:
        """Prefill KV cache with context embeddings."""
        _, kv_pairs = self.backbone(self.context_embeddings, mask)
        L0 = self.context_embeddings.shape[1]
        
        for (k, v), lyr in zip(kv_pairs, self.backbone.layers):
            lyr.k_cache[:, :, :L0, :] = k
            lyr.v_cache[:, :, :L0, :] = v
        
        self.backbone.seq_len.fill_(L0)  # In-place update
        self.offset = torch.tensor(L0, dtype=torch.int64, device=self.context_embeddings.device)
    
    def transformer_decode(
        self, new_query_embedding: torch.Tensor, keep_last: bool = False
    ) -> torch.Tensor:
        """Incremental decode with KV cache - optimized for torch.compile."""
        B, N, d_model = new_query_embedding.shape
        commit = N if keep_last else max(N - 1, 0)
        H = self.backbone.layers[0].attn.num_heads
        
        # Factory function for score_mod (nano-vllm pattern)
        offset_tensor = self.offset
        def causal_score_mod(score, b, h, q_idx, kv_idx):
            return torch.where(
                q_idx + offset_tensor >= kv_idx,
                score,
                -float("inf")
            )
        
        for lyr in self.backbone.layers:
            Dh = lyr.attn.head_dim
            q = lyr.norm1(new_query_embedding)
            
            qh = lyr.attn.q_proj(q).view(B, N, H, Dh).transpose(1, 2).contiguous()
            kh = lyr.attn.k_proj(q).view(B, N, lyr.attn.num_kv_heads, Dh).transpose(1, 2).contiguous()
            vh = lyr.attn.v_proj(q).view(B, N, lyr.attn.num_kv_heads, Dh).transpose(1, 2).contiguous()
            kh = expand_kv_heads(kh, H // lyr.attn.num_kv_heads)
            vh = expand_kv_heads(vh, H // lyr.attn.num_kv_heads)
            
            # Use tensor directly for slicing (no .item() for better compilation)
            past_len = self.backbone.seq_len
            
            # In-place update of KV cache for ALL N positions
            cache_positions = torch.arange(past_len, past_len + N, device=kh.device)
            lyr.k_cache[:, :, cache_positions, :] = kh
            lyr.v_cache[:, :, cache_positions, :] = vh
            
            # Use views instead of concatenation for attention
            total_len = past_len + N
            k_full = lyr.k_cache[:, :, :total_len, :]
            v_full = lyr.v_cache[:, :, :total_len, :]
            
            attn = flex_attention(qh, k_full, v_full, score_mod=causal_score_mod)
            attn = attn.transpose(1, 2).reshape(B, N, d_model)
            
            new_query_embedding = new_query_embedding + lyr.drop_attn(lyr.attn.o_proj(attn))
            y = lyr.norm2(new_query_embedding)
            new_query_embedding = new_query_embedding + lyr.ff2(
                lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y)))
            )
        
        if commit:
            self.backbone.seq_len += commit
            self.offset.add_(commit)  # In-place for graph compatibility
        
        return self.backbone.norm(new_query_embedding)
    
    def _get_cached_selfattn_mask(self, context_len: int, device: torch.device) -> BlockMask:
        """Get or create cached self-attention mask."""
        cache_key = (context_len, str(device))
        
        if cache_key not in self.selfattn_mask_cache:
            self.selfattn_mask_cache[cache_key] = create_context_self_attention_block_mask(
                current_num_context=context_len,
                q_block_size=self.q_block_size,
                kv_block_size=self.kv_block_size,
                device=device,
            )
        
        return self.selfattn_mask_cache[cache_key]
    
    def autoregressive_embedder_with_idx(self, x: DataAttr, idx: int) -> torch.Tensor:
        """Get AR embedding for specific buffer position."""
        buffer_embeddings = self.embedder.embed_buffer(x)
        B = buffer_embeddings.shape[0]
        return buffer_embeddings + self.ar_tokens[idx].unsqueeze(0).expand(B, -1, -1)
