# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Set

import torch
from safetensors.torch import load_file
from torch.nn.modules.module import _IncompatibleKeys

from cosmos1.models.autoregressive.configs.base.model import ModelConfig
from cosmos1.models.autoregressive.configs.base.tokenizer import TokenizerConfig
from cosmos1.models.autoregressive.modules.mm_projector import MultimodalProjector
from cosmos1.models.autoregressive.networks.transformer import Transformer
from cosmos1.models.autoregressive.networks.vit import VisionTransformer, get_vit_config
from cosmos1.models.autoregressive.sjd.multi_token_utils import (
    find_first_mismatch,
    get_multi_token_initialization,
    rollback_kv_cache,
)
from cosmos1.models.autoregressive.sjd.sjd_config import SJDConfig
from cosmos1.models.autoregressive.sjd.speculative_sampler import CosmosSpeculativeSampler
from cosmos1.models.autoregressive.tokenizer.tokenizer import DiscreteMultimodalTokenizer, update_vocab_size
from cosmos1.models.autoregressive.utils.checkpoint import (
    get_partial_state_dict,
    process_state_dict,
    substrings_to_ignore,
)
from cosmos1.models.autoregressive.utils.sampling import (
    decode_n_tokens,
    decode_one_token,
    prefill,
    sample_top_p,
    sample_top_p_all_positions,
)
from cosmos1.utils import log, misc


def _sample_from_probs_with_global_rand(
    prob: torch.Tensor,
    random_table: torch.Tensor,
    batch_idx: int,
    prob_idx: int,
    table_idx: int,
    vocab_size: int,
):
    rand_for_vocab = random_table[batch_idx, table_idx, :vocab_size]

    eps = 1e-9
    gumbels = -torch.log(-torch.log(rand_for_vocab + eps) + eps)  # [vocab_size]

    log_probs = prob.log()

    z = log_probs + gumbels
        
    token_id = z.argmax().item()

    return token_id

class AutoRegressiveModel(torch.nn.Module):
    """
    A class to build and use a AutoRegressiveModel model for text generation.

    Methods:
        build: Build a AutoRegressiveModel instance by initializing and loading a model checkpoint.
        generate: Generate text sequences based on provided prompts using the language generation model.
    """

    def __init__(
        self,
        model: Transformer = None,
        tokenizer: DiscreteMultimodalTokenizer = None,
        config: ModelConfig = None,
        vision_encoder: VisionTransformer = None,
        mm_projector: MultimodalProjector = None,
    ):
        """
        Initialize the AutoRegressiveModel instance with a model and tokenizer.

        Args:
            model (Transformer): The Transformer model for text generation.
            tokenizer (Tokenizer): The tokenizer for encoding and decoding text.
            config (Config): The configuration for the AutoRegressiveModel model.
            vision_encoder (VisionTransformer): The vision encoder for the AutoRegressiveModel model.
            mm_projector (MultimodalProjector): The multi-modal projector for the AutoRegressiveModel model.
        """
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.config = config

        self.vision_encoder = vision_encoder
        self.mm_projector = mm_projector

    @property
    def precision(self):
        return self.model.precision

    def get_num_params(
        self,
    ) -> int:
        """
        Return the number of parameters in the model.
        """
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def load_ar_model(
        self,
        tokenizer_config,
    ):
        """
        Load the AR model.
        """
        model_config = self.config
        ckpt_path = model_config.ckpt_path
        with misc.timer(f"loading checkpoint from {ckpt_path}"):
            if ckpt_path.endswith("safetensors"):
                # Load with safetensors API
                checkpoint = load_file(ckpt_path, device="cpu")
            else:
                # The pytorch version
                checkpoint = torch.load(
                    ckpt_path,
                    map_location="cpu",
                    mmap=True,  # load the checkpoint in memory-mapped mode
                    weights_only=True,
                )
        llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint
        orig_precision = torch.get_default_dtype()
        precision = getattr(torch, model_config.precision)
        torch.set_default_dtype(precision)
        log.debug(f"Setting torch default dtype to {precision}")

        model = Transformer(
            params=model_config,
            tokenizer_config=tokenizer_config,
        )
        log.debug(
            f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size}"
        )
        vocab_size = update_vocab_size(
            existing_vocab_size=0,
            to_be_added_vocab_size=tokenizer_config.video_tokenizer.vocab_size,
            training_type=tokenizer_config.training_type,
            add_special_tokens=False,
        )
        log.debug(
            f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size}  vocab_size {vocab_size}"
        )
        # Perform vocab expansion
        if vocab_size > model.vocab_size:
            log.debug(f"Expanding vocab size to {vocab_size}")
            # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer,
            expand_output_layer = not (tokenizer_config.training_type == "text_to_video")
            model.expand_vocab(
                vocab_size,
                init_method="gaussian",
                expand_output_layer=expand_output_layer,
            )
        # Remove the "model." prefix in the state_dict
        llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
        with misc.timer("loading state_dict into model"):
            missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True)
        # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
        missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
        assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"

        self.model = model.to(precision).to("cuda")
        torch.set_default_dtype(orig_precision)  # Reset the default dtype to the original value

    def load_tokenizer(self, tokenizer_config):
        """
        Load the tokenizer.
        """
        self.tokenizer = DiscreteMultimodalTokenizer(tokenizer_config)

    @staticmethod
    def build(
        model_config: ModelConfig = ModelConfig(),
        tokenizer_config: TokenizerConfig = None,
    ) -> "AutoRegressiveModel":
        """
        Build a AutoRegressiveModel instance by initializing and loading a model checkpoint.

        Args:
            model_config (ModelConfig, optional): The model configuration for the AutoRegressiveModel instance. Defaults to ModelConfig().
            tokenizer_config (TokenizerConfig, optional): The tokenizer configuration for the AutoRegressiveModel instance. Defaults to None.
            download_rank_sync (bool, optional): Whether to download the checkpoint in a rank-synchronized manner. Defaults to True.
        Returns:
            AutoRegressiveModel: An instance of the AutoRegressiveModel class with the loaded model and tokenizer.

        Raises:
            AssertionError: If there are no checkpoint files in the specified directory.

        Note:
            This method sets the device to CUDA and loads the pre-trained model and tokenizer.
        """
        # Initialize model configuration parameters
        config_params = {}

        # Load checkpoint and model parameters

        if model_config.ckpt_path is None:
            # If ckpt_path is not provided, we assume the model checkpoint is saved in the ckpt_dir
            ckpt_dir = model_config.ckpt_dir

            # We prioritize safetensors version over the pytorch version, since the former is
            # much faster for checkpoint loading.
            checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors"))
            if len(checkpoints) == 0:
                checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))

            assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
            assert (
                len(checkpoints) == 1
            ), f"multiple checkpoint files found in {ckpt_dir} (currently only one is supported)"
            ckpt_path = str(checkpoints[0])  # Assuming single checkpoint for non-parallel case

            if os.path.exists(Path(ckpt_dir) / "config.json"):
                with open(Path(ckpt_dir) / "config.json", "r") as f:
                    config_params = json.loads(f.read())
            else:
                log.info(
                    f"No params.json found in the checkpoint directory ({ckpt_dir}). " f"Using default model config."
                )

        else:
            # If ckpt_path is provided, we load the model from the specified path,
            # and use the default model configuration
            ckpt_path = model_config.ckpt_path

        for key, value in config_params.items():
            if hasattr(model_config, key):
                # Override the default model configuration with the parameters from the checkpoint
                setattr(model_config, key, value)

        with misc.timer(f"loading checkpoint from {ckpt_path}"):
            if ckpt_path.endswith("safetensors"):
                # Load with safetensors API
                checkpoint = load_file(ckpt_path, device="cpu")
            else:
                # The pytorch version
                checkpoint = torch.load(
                    ckpt_path,
                    map_location="cpu",
                    mmap=True,  # load the checkpoint in memory-mapped mode
                    weights_only=True,
                )
        llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint

        if model_config.vision_encoder is not None:
            # Take the LLM weights (starting with "model.") from the VLM checkpoint
            llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.")
        if model_config.vision_encoder is not None:
            # For vanilla VLM ckpt before fine-tuning, `checkpoint['model']` only contains LLM weights, and `checkpoint['vision_encoder']`
            #   and `checkpoint['mm_projector']` are both for those weights
            # For fine-tuned VLM ckpt, `checkpoint['model']` contains all LLM, mm_projector and vision_encoder weights
            if "vision_encoder" in checkpoint:
                log.debug("Using pretrained vision_encoder")
                vit_checkpoint = checkpoint["vision_encoder"]
            else:
                log.debug("Using fine-tuned vision_encoder")
                vit_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="vision_encoder.")
                vit_checkpoint = process_state_dict(vit_checkpoint, prefix_to_remove="vision_encoder.")
            if "mm_projector" in checkpoint:
                log.debug("Using pretrained mm_projector")
                projector_checkpoint = checkpoint["mm_projector"]
            else:
                log.debug("Using fine-tuned mm_projector")
                projector_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="mm_projector.")
                projector_checkpoint = process_state_dict(projector_checkpoint, prefix_to_remove="mm_projector.")
            assert (
                len(vit_checkpoint) > 0 and len(projector_checkpoint) > 0
            ), "vit_checkpoint and projector_checkpoint cannot be empty. We do not support random initialization for vision_encoder and mm_projector."

        tokenizer = DiscreteMultimodalTokenizer(tokenizer_config)
        orig_precision = torch.get_default_dtype()
        precision = getattr(torch, model_config.precision)
        torch.set_default_dtype(precision)
        log.debug(f"Setting torch default dtype to {precision}")

        model = Transformer(
            params=model_config,
            tokenizer_config=tokenizer_config,
        )
        model_kwargs = {}

        if model_config.vision_encoder is not None:
            assert model_config.mm_projector is not None, "mm_projector must be provided if vision_encoder is provided."
            vit_config = get_vit_config(model_config.vision_encoder)
            vision_encoder = VisionTransformer.build(
                vit_config,
            )

            mm_projector = MultimodalProjector(
                mm_projector_type=model_config.mm_projector, in_dim=vit_config["dim"], out_dim=model_config["dim"]
            )
            model_kwargs.update({"vision_encoder": vision_encoder, "mm_projector": mm_projector})

        # Perform vocab expansion
        if tokenizer.vocab_size > model.vocab_size:
            log.debug(f"Expanding vocab size to {tokenizer.vocab_size}")
            # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer,
            expand_output_layer = not (tokenizer.training_type == "text_to_video")
            model.expand_vocab(
                tokenizer.vocab_size,
                init_method="gaussian",
                expand_output_layer=expand_output_layer,
            )

        # Remove the "model." prefix in the state_dict
        llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
        with misc.timer("loading state_dict into model"):
            missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True)
        # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
        missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
        assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"

        if model_config.vision_encoder is not None:
            vision_encoder.load_state_dict(vit_checkpoint)
            mm_projector.load_state_dict(projector_checkpoint)
            if model_config.vision_encoder_in_channels != 3:
                vision_encoder.expand_in_channels(model_config.vision_encoder_in_channels)

        model = model.to(precision)  # ensure model parameters are in the correct precision
        log.debug(f"Model config: {model_config}")

        model_class = AutoRegressiveModel

        torch.set_default_dtype(orig_precision)  # Reset the default dtype to the original value

        return model_class(model, tokenizer, model_config, **model_kwargs)

    @torch.no_grad()
    def generate(
        self,
        prompt_tokens: List[List[int]] | torch.Tensor,
        max_gen_len: int,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        num_gen_seq: int = 1,
        logprobs: bool = False,
        echo: bool = False,
        seed: int = None,
        context: Optional[torch.Tensor] = None,
        context_mask: Optional[torch.Tensor] = None,
        sjd_config: Optional[SJDConfig] = None,
        compile_sampling: bool = True,
        compile_prefill: bool = False,
        verbose: bool = True,
        stop_tokens: Optional[Set[int]] = None,
        images: Optional[torch.Tensor] = None,
    ):
        """
        Autoregressive generation built upon the gpt-fast implementation (https://github.com/pytorch-labs/gpt-fast).

        Args:
            prompt_tokens (List[List[int]] | torch.Tensor): A single prompt of shape (1, seq_len).
            max_gen_len (int): Maximum length of the generated text sequence.
            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
            top_k (int, optional): Top-k value for top-k sampling. Defaults to None.
            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None.
            num_gen_seq (int, optional): Number of outputs to generate given the same prompt. Defaults to 1. When temperature == 0, num_gen_seq must be 1 because the generation is deterministic.
            echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
            logit_clipping_range (list, optional): Range of logits to clip. Defaults to [].
            seed (int, optional): Random seed for reproducibility. Defaults to None.
            compile_sampling (bool, optional): Flag indicating whether to compile the decoding function. Defaults to True.
            compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False.
            verbose (bool, optional): Flag indicating whether to print the the time. Defaults to False.
        """

        assert top_k is None or top_p is None, f"Only one of top_k ({top_k} or top_p ({top_p} should be specified."
        if temperature == 0:
            top_p, top_k = None, None
            log.debug("Setting top_p and top_k to None because temperature is 0")
        if top_p is not None:
            log.debug(f"Using top-p sampling with p={top_p} and temperature={temperature}")
        elif top_k is not None:
            log.debug(f"Using top-k sampling with k={top_k} and temperature={temperature}")
        else:
            log.debug("Not applying top-k or top-p sampling. Will use top-k sampling with k=None")

        orig_precision = torch.get_default_dtype()
        torch.set_default_dtype(self.precision)

        torch._inductor.config.coordinate_descent_tuning = True
        torch._inductor.config.triton.unique_kernel_names = True
        # Experimental features to reduce compilation times, will be on by default in future
        torch._inductor.config.fx_graph_cache = True

        if seed is not None:
            misc.set_random_seed(seed)

        assert not logprobs, "logprobs are not supported for fast_generate yet"
        # Examine if the function prefil and decode_one_token functions are compiled yet. If not, compile them based on the flags
        if compile_sampling and not getattr(self, "inference_decode_compiled", False):
            self.decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
            self.inference_decode_compiled = True
            log.info("Compiled AR sampling function. Note: the first run will be slower due to compilation")
        if compile_prefill and not getattr(self, "inference_prefill_compiled", False):
            self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
            self.inference_prefill_compiled = True
            log.info("Compiled prefill function. Note: the first run will be slower due to compilation")

        if not hasattr(self, "decode_one_token"):
            self.decode_one_token = decode_one_token
        if not hasattr(self, "prefill"):
            self.prefill = prefill

        # Determine if we should use SJD
        use_sjd = (
            sjd_config is not None and sjd_config.enable_sjd # and sjd_config.max_num_new_tokens > 1
        )
        if use_sjd:
            log.debug(f"Using SJD with config: {sjd_config}")
            # Initialize the speculative sampler
            generator = torch.Generator(device="cuda").manual_seed(seed) if seed is not None else None
            speculative_sampler = CosmosSpeculativeSampler(generator=generator)
            # Initialize buffers for reusing leftover advanced tokens from the previous step
            next_draft_tokens = None
            next_draft_probs = None

        # Initialization and Assertions
        if isinstance(self.model.params, list):
            # During training, model.params is a list
            log.debug(
                f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}"
            )
            params = self.config
        else:
            params = self.model.params
        if isinstance(prompt_tokens, list):
            prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cuda")
        if prompt_tokens.ndim == 1:
            prompt_tokens = prompt_tokens.view(1, -1)
        else:
            assert prompt_tokens.ndim == 2, f"prompt_tokens has shape {prompt_tokens.shape}"
        batch_size, prompt_len = prompt_tokens.shape
        total_len = min(params.max_seq_len, max_gen_len + prompt_len)
        if max_gen_len + prompt_len > params.max_seq_len:
            log.warning(
                f"max_gen_len + prompt_len={max_gen_len + prompt_len} exceeds max_seq_len={params.max_seq_len}, truncate max_gen_len to {params.max_seq_len - prompt_len}"
            )
            max_gen_len = params.max_seq_len - prompt_len

        if context_mask is not None:
            context_mask = context_mask.to(dtype=torch.bool)
            if context_mask.ndim == 2:
                assert (
                    context_mask.shape[0] == batch_size
                ), f"batch_size mismatch: {context_mask.shape[0]} != {batch_size}"
                # Unsqueeze it to make it of shape [batch_size, 1, 1, context_seq_len]
                context_mask = context_mask.view(batch_size, 1, 1, -1)

        if num_gen_seq > 1:
            assert (
                batch_size == 1
            ), f"num_gen_seq > 1 is only supported for a single prompt, got {len(prompt_tokens)} prompts"
            log.debug(f"Generating {num_gen_seq} sequences with the same prompt")
            assert (
                num_gen_seq <= params.max_batch_size
            ), f"num_gen_seq={num_gen_seq} exceeds max_batch_size={params.max_batch_size}"
            # repeat the prompt tokens for num_gen_seq times
            prompt_tokens = prompt_tokens.repeat(num_gen_seq, 1)
            assert prompt_tokens.shape == (
                num_gen_seq,
                prompt_len,
            ), f"prompt_tokens must be of shape (num_gen_seq, seq_len), got {prompt_tokens.shape}"
            batch_size = len(prompt_tokens)

        # create an empty tensor of the expected final shape and fill in the current tokens
        empty = torch.zeros(batch_size, total_len, dtype=prompt_tokens.dtype, device=prompt_tokens.device)
        empty[:, :prompt_len] = prompt_tokens
        seq = empty
        input_pos = torch.arange(0, prompt_len, device="cuda")

        if verbose:
            prefill_start = time.time()
            sjd_iterations = 0
            standard_iterations = 0

        if images is not None:
            images = images.to(device=prompt_tokens.device, dtype=torch.bfloat16)
            prompt_token_embeddings = self.embed_vision_language_features(prompt_tokens, images)
        else:
            prompt_token_embeddings = None

        if context is not None:
            context = context.to(device=prompt_tokens.device, dtype=self.precision)

        # Prefill stage
        # The prefill function now needs to handle the kv_cache directly on the model
        next_token = self.prefill(
            self.model,
            input_pos=input_pos,
            tokens=prompt_tokens if prompt_token_embeddings is None else None,
            token_embeddings=prompt_token_embeddings,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            context=context,
            context_mask=context_mask,
        )
        if verbose:
            prefill_time = time.time() - prefill_start

        seq[:, [prompt_len]] = next_token.to(dtype=seq.dtype)
        current_pos = prompt_len + 1

        stop_tokens_tensor = torch.tensor(list(stop_tokens or []), dtype=torch.long, device="cuda")

        if verbose:
            decode_start = time.time()
            if use_sjd:
                sjd_step1_time = 0.0  # Prepare draft tokens
                sjd_step2_time = 0.0  # Forward pass to get advanced logits
                sjd_step3_time = 0.0  # Vectorized sampling of advanced tokens
                sjd_step4_time = 0.0  # Speculative sampler acceptance
                sjd_step5_time = 0.0  # KV-cache rollback
                sjd_step6_time = 0.0  # Update sequence and store leftovers

        random_table = torch.rand([1, sjd_config.max_num_new_tokens, 64000], device="cuda")

        # Decode stage
        while current_pos < total_len:
            if use_sjd:
                # 1. Prepare draft tokens by reusing leftovers and filling with new random ones.
                sjd_iterations += 1
                remaining_len = total_len - current_pos
                num_draft_tokens = min(sjd_config.max_num_new_tokens, remaining_len)

                if verbose:
                    _t1 = time.time()
                draft_tokens_parts = []
                draft_probs_parts = []
                num_reused_tokens = 0

                if next_draft_tokens is not None and next_draft_tokens.shape[1] > 0:
                    num_to_reuse = min(next_draft_tokens.shape[1], num_draft_tokens)
                    draft_tokens_parts.append(next_draft_tokens[:, :num_to_reuse])
                    draft_probs_parts.append(next_draft_probs[:, :num_to_reuse])
                    num_reused_tokens = num_to_reuse
                
                num_new_random_tokens = num_draft_tokens - num_reused_tokens
                if num_new_random_tokens > 0:
                    new_draft_tokens, new_draft_probs = get_multi_token_initialization(
                        input_ids=seq[:, current_pos - 1:current_pos],
                        num_tokens=num_new_random_tokens,
                        vocab_size=64000,
                        init_scheme=sjd_config.multi_token_init_scheme,
                        device="cuda"
                    )
                    draft_tokens_parts.append(new_draft_tokens)
                    draft_probs_parts.append(new_draft_probs)

                draft_tokens = torch.cat(draft_tokens_parts, dim=1)
                draft_probs = torch.cat(draft_probs_parts, dim=1)
                if verbose:
                    sjd_step1_time += time.time() - _t1

                # 2. generate advanced logits with one forward pass
                if verbose:
                    _t2 = time.time()
                draft_tokens = torch.cat([seq[:, current_pos - 1:current_pos], draft_tokens[:, :-1]], dim=1)
                draft_probs = torch.cat([torch.zeros_like(draft_probs[:, :1]), draft_probs[:, :-1]], dim=1)
                draft_input_pos = torch.arange(current_pos - 1, current_pos + num_draft_tokens - 1, device="cuda")
                advanced_logits = self.model(
                    tokens=draft_tokens,
                    input_pos=draft_input_pos,
                    context=context,
                    context_mask=context_mask,
                )
                if verbose:
                    sjd_step2_time += time.time() - _t2

                # 3. vectorized sampling of advanced tokens
                if verbose:
                    _t3 = time.time()
                if temperature > 0:
                    if top_p is not None:
                        # Vectorized top-p sampling for all positions at once
                        advanced_tokens, advanced_probs = sample_top_p_all_positions(
                            advanced_logits,
                            temperature=temperature,
                            top_p=top_p,
                            return_probs=True,
                        )
                    else:
                        # Fallback to multinomial sampling over all positions without top-p
                        probs = torch.softmax(advanced_logits / temperature, dim=-1)
                        advanced_tokens = torch.multinomial(
                            probs.view(-1, probs.shape[-1]), num_samples=1
                        ).view(probs.shape[:-1])
                        advanced_probs = probs
                else:  # Greedy decoding not supported in SJD path
                    assert False, "Greedy decoding is not supported for SJD"
                if verbose:
                    sjd_step3_time += time.time() - _t3

                # 4. Use speculative sampler to get accepted tokens
                if verbose:
                    _t4 = time.time()
                mismatch_positions, accepted_tokens, _ = speculative_sampler(
                    draft_tokens=draft_tokens,
                    advanced_tokens=advanced_tokens,
                    draft_prob=draft_probs,
                    advanced_prob=advanced_probs,
                    maximal_coupling=sjd_config.maximal_coupling
                )
                num_accepted = mismatch_positions[0] # Assuming batch size of 1 for now
                if verbose:
                    sjd_step4_time += time.time() - _t4

                # 5. Rollback unused KV cache
                if verbose:
                    _t5 = time.time()
                num_to_rollback = num_draft_tokens - num_accepted
                if num_to_rollback > 0:
                    rollback_kv_cache(self.model, num_to_rollback)
                if verbose:
                    sjd_step5_time += time.time() - _t5
                
                # 6. Update sequence and position; store leftovers for next iteration
                if verbose:
                    _t6 = time.time()
                if num_accepted > 0:
                    accepted_end = current_pos + num_accepted
                    seq[:, current_pos:accepted_end] = accepted_tokens[:, :num_accepted]
                    current_pos = accepted_end
                    
                    # update random table
                    random_table[:,:-num_accepted,:] = random_table[:,num_accepted:,:].clone()
                    random_table[:,sjd_config.max_num_new_tokens-num_accepted:,:] = torch.rand([1,num_accepted,64000], device="cuda")
                
                # Store the unused advanced tokens for the next iteration's draft
                if advanced_tokens.shape[1] > num_accepted:
                    next_draft_tokens = accepted_tokens[:, num_accepted:]
                    next_draft_probs = advanced_probs[:, num_accepted:]
                else:
                    next_draft_tokens = None
                    next_draft_probs = None
                if verbose:
                    sjd_step6_time += time.time() - _t6

                # If no tokens were accepted or we have space, generate one more token conventionally
                if current_pos < total_len and (num_accepted == 0):
                    assert False, "num_accepted cannot be 0"
            else:
                # Standard auto-regressive decoding path
                standard_iterations += 1
                input_pos = torch.tensor([current_pos - 1], device=prompt_tokens.device)
                next_token_tensor, _ = self.decode_one_token(
                    self.model,
                    seq[:, input_pos],
                    input_pos,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                )
                seq[:, current_pos] = next_token_tensor.squeeze(-1)
                current_pos += 1

            if stop_tokens and torch.isin(seq[:, current_pos - 1], stop_tokens_tensor).any():
                break

        if verbose:
            decode_time = time.time() - decode_start
            generated_len = current_pos - prompt_len
            decode_throughput = generated_len / decode_time if decode_time > 0 else 0
            log.info(f"[Prefill] Time: {prefill_time:.2f}s; Throughput: {prompt_len / prefill_time:.2f} tokens/s")
            log.info(f"[Decode] Time: {decode_time:.2f}s; Throughput: {decode_throughput:.2f} tokens/s")
            if use_sjd:
                log.info(f"  SJD iterations: {sjd_iterations}; Standard fallbacks: {standard_iterations}")
                if sjd_iterations > 0:
                    log.info(
                        f"  [SJD] Step1 Prepare: total {sjd_step1_time:.3f}s; avg {sjd_step1_time/sjd_iterations:.3f}s/iter"
                    )
                    log.info(
                        f"  [SJD] Step2 Forward: total {sjd_step2_time:.3f}s; avg {sjd_step2_time/sjd_iterations:.3f}s/iter"
                    )
                    log.info(
                        f"  [SJD] Step3 Sample: total {sjd_step3_time:.3f}s; avg {sjd_step3_time/sjd_iterations:.3f}s/iter"
                    )
                    log.info(
                        f"  [SJD] Step4 Speculate: total {sjd_step4_time:.3f}s; avg {sjd_step4_time/sjd_iterations:.3f}s/iter"
                    )
                    log.info(
                        f"  [SJD] Step5 Rollback: total {sjd_step5_time:.3f}s; avg {sjd_step5_time/sjd_iterations:.3f}s/iter"
                    )
                    log.info(
                        f"  [SJD] Step6 Update+Store: total {sjd_step6_time:.3f}s; avg {sjd_step6_time/sjd_iterations:.3f}s/iter"
                    )

        final_seq = seq[:, :current_pos]
        if not echo:
            final_seq = final_seq[:, prompt_len:]

        torch.set_default_dtype(orig_precision)  # Reset the default dtype to the original value

        return final_seq, None

    def embed_vision_language_features(self, input_ids: torch.Tensor, images: torch.tensor) -> torch.Tensor:
        """
        Embed vision and language features into a combined representation.

        Args:
            input_ids (torch.Tensor): Input token IDs.
            images (torch.tensor): Input images.

        Returns:
            torch.Tensor: Combined vision-language features.

        Raises:
            AssertionError: If vision encoder or mm projector is not initialized,
                            or if dimensions mismatch.
        """
        # Ensure vision encoder and mm projector are initialized
        assert self.vision_encoder is not None
        assert self.mm_projector is not None

        # Get image token ID and validate it
        image_token_id = self.vision_encoder.image_token_id
        assert isinstance(image_token_id, int) and image_token_id >= 0, f"Invalid image_token_id: {image_token_id}"

        # Identify text and image locations in the input
        text_locations = input_ids != image_token_id
        image_locations = input_ids == image_token_id

        # Process text features
        text_features = self.model.tok_embeddings(input_ids[text_locations])

        # Process image features
        images = images.to(device=text_features.device, dtype=text_features.dtype)
        vit_outputs = self.vision_encoder(images)
        image_features = self.mm_projector(vit_outputs)

        # Get dimensions
        B, seq_len = input_ids.shape
        N_total = B * seq_len
        N_txt, D_txt = text_features.shape
        N_img, N_patch, D_img = image_features.shape

        # Reshape image features
        image_features = image_features.reshape(N_img * N_patch, D_img)

        # Validate dimensions
        assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}"
        assert (
            N_total == N_txt + N_img * N_patch
        ), f"seq_len {seq_len} should be equal to N_txt + N_img*N_Patch {(N_txt, N_img * N_patch, image_locations.sum().item())}"

        # Combine text and image features
        combined_features = torch.empty(
            (B, seq_len, D_txt),
            dtype=text_features.dtype,
            device=text_features.device,
        )
        combined_features[text_locations, :] = text_features
        combined_features[image_locations, :] = image_features

        return combined_features

    def state_dict(self, *args, **kwargs):
        """
        Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8).
        """
        state_dict = super().state_dict(*args, **kwargs)
        return process_state_dict(state_dict)

    def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False):
        """
        Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by
        TransformerEngine for FP8).
        """
        state_dict = process_state_dict(state_dict)
        missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign)
        actual_missing_keys = []
        for key in missing_keys:
            if not any(substring in key for substring in substrings_to_ignore):
                actual_missing_keys.append(key)
        if strict:
            if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0:
                raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}")
        return _IncompatibleKeys(actual_missing_keys, unexpected_keys)
