from typing import Tuple, List
from pathlib import Path

import torch
import flashinfer
from transformers import PreTrainedTokenizer
from flashinfer.sampling import chain_speculative_sampling

from spec_benchmark.Engine.models.base import LoRAConfig
from spec_benchmark.Engine.models.mtp_model import Transformer, MTPTransformer
from spec_benchmark.Engine.utils import PageManager, register_custom_attn_op, sample, get_sampling_probs
from spec_benchmark.profiler import backend_bucket_timer


def print_on_rank0(message):
    if torch.distributed.get_rank() == 0:
        print(message)


class MTPLMBackend:
    def __init__(
        self,
        dtype: torch.dtype = torch.bfloat16,
        device: str = "cuda:0",
        draft_length: List[int] = [4],
        tokenizer: PreTrainedTokenizer = None,
    ) -> None:
        self.dtype = dtype
        self.device = device
        self.draft_lengths = draft_length
        self.cnt = 0
        self.tokenizer = tokenizer

        self.prefill_forward = lambda model, x, position_ids, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen:\
                model(x, None, position_ids, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen, attn_type="prefill")
        self.draft_forward = lambda model, x, gate_mask, position_ids, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen:\
                model(x, gate_mask, position_ids, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen, attn_type="draft")
        self.draft_and_verify_forward = lambda model, x, gate_mask, position_ids, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen:\
                model(x, gate_mask, position_ids, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen, attn_type="draft_and_verify")


    def load_model(
        self,
        model_name: str,
        target_checkpoint: Path,
        lora_checkpoint: Path,
        lora_config: LoRAConfig,
        use_tp: bool = False,
        rank_group = None,
        group = None,
    ):
        with torch.device('meta'):
            base_model = Transformer.from_name(model_name, lora_config)
            model = MTPTransformer(base_model)

            target_checkpoint = torch.load(str(target_checkpoint), mmap=True, weights_only=True)
            model.base_model.load_state_dict(target_checkpoint, assign=True, strict=False)

            lora_checkpoint = torch.load(str(lora_checkpoint), mmap=True, weights_only=True)
            # LoRA adapter load
            model.base_model.load_state_dict(lora_checkpoint, assign=True, strict=False)
            # Sampler load
            model.load_state_dict(lora_checkpoint, assign=True, strict=False)

            # Append <mask> token and resize token embeddings
            if '<mask>' not in self.tokenizer.get_vocab():
                self.tokenizer.add_special_tokens({'additional_special_tokens': ['<mask>']})
                print(f"Added <mask> token with ID: {self.tokenizer.convert_tokens_to_ids('<mask>')}")
            
            self.mask_token_id = self.tokenizer.convert_tokens_to_ids('<mask>')
            model.base_model.resize_token_embeddings(len(self.tokenizer))

            if use_tp:
                from spec_benchmark.Engine.utils import apply_tp
                print("Applying tensor parallel to model ...")
                apply_tp(model.base_model, rank_group, group=group)

            model = model.to(device=self.device, dtype=self.dtype)
            self.model = model.eval()
        

    def setup_caches(self, max_batch_size=1, max_seq_length=2048, page_size=16, prefill_chunk_size=128):
        self.batch_size = max_batch_size
        self.page_size = page_size
        self.prefill_chunk_size = prefill_chunk_size
        if len(self.draft_lengths) == 1:
            self.draft_and_verify_len = (self.draft_lengths[0] + 1) ** 2
        else:
            self.draft_and_verify_len = (self.draft_lengths[0] + 1) * (self.draft_lengths[1] + 1)
        self.max_cache_len = max_seq_length + self.draft_and_verify_len + 1
        self.common_attn_masks, self.common_position_ids = self._setup_common_attn_masks_and_position_ids(self.draft_lengths)

        self.cachelens = torch.zeros(max_batch_size, dtype=torch.int32, device=self.device)
        self.max_num_pages = max_batch_size * ((self.max_cache_len + self.page_size - 1) // self.page_size)
        self.max_num_pages_per_request = self.max_num_pages // max_batch_size
        self.num_pages_per_request = torch.zeros(max_batch_size, device=self.device, dtype=torch.int32)
        
        self.qo_indptr = torch.arange(max_batch_size+1, dtype=torch.int32, device=self.device)
        self.paged_kv_indptr = torch.arange(max_batch_size+1, dtype=torch.int32, device=self.device)
        self.paged_kv_indices = torch.empty(self.max_num_pages, dtype=torch.int32, device=self.device)
        self.paged_kv_last_page_len = torch.zeros((max_batch_size), dtype=torch.int32, device=self.device)
        self.page_manager = PageManager(max_batch_size, self.max_num_pages_per_request, self.device)

        self.prefill_buffer = torch.empty(3 * 128 * 1024 * 1024, dtype=torch.uint8, device=self.device)
        self.draft_buffer = torch.empty(3 * 128 * 1024 * 1024, dtype=torch.uint8, device=self.device)
        self.draft_and_verify_buffer = torch.empty(3 * 128 * 1024 * 1024, dtype=torch.uint8, device=self.device)
        self.custom_mask = torch.empty(max_batch_size * self.draft_and_verify_len * self.max_cache_len // 8 + 1, dtype=torch.uint8, device=self.device)
        self.mask_indptr = torch.arange(max_batch_size+1, dtype=torch.int32, device=self.device)
        self.attn_wrappers = {
            "prefill": flashinfer.BatchPrefillWithPagedKVCacheWrapper(self.prefill_buffer, "NHD", use_cuda_graph=True,
                                                                      qo_indptr_buf=self.qo_indptr * prefill_chunk_size, 
                                                                      paged_kv_indptr_buf=self.paged_kv_indptr, 
                                                                      paged_kv_indices_buf=self.paged_kv_indices, 
                                                                      paged_kv_last_page_len_buf=self.paged_kv_last_page_len),
            "draft": flashinfer.BatchPrefillWithPagedKVCacheWrapper(self.draft_buffer, "NHD", use_cuda_graph=True,
                                                                      qo_indptr_buf=self.qo_indptr * self.draft_lengths[0],
                                                                      paged_kv_indptr_buf=self.paged_kv_indptr,
                                                                      paged_kv_indices_buf=self.paged_kv_indices,
                                                                      paged_kv_last_page_len_buf=self.paged_kv_last_page_len),
            "draft_and_verify": flashinfer.BatchPrefillWithPagedKVCacheWrapper(self.draft_and_verify_buffer, "NHD", use_cuda_graph=True,
                                                                      qo_indptr_buf=self.qo_indptr * self.draft_and_verify_len, 
                                                                      paged_kv_indptr_buf=self.paged_kv_indptr, 
                                                                      paged_kv_indices_buf=self.paged_kv_indices, 
                                                                      paged_kv_last_page_len_buf=self.paged_kv_last_page_len,
                                                                      custom_mask_buf=self.custom_mask,
                                                                      mask_indptr_buf=self.mask_indptr),
        }
        register_custom_attn_op("mylib::attn_prefill", self.attn_wrappers["prefill"])
        register_custom_attn_op("mylib::attn_draft", self.attn_wrappers["draft"])
        register_custom_attn_op("mylib::attn_draft_and_verify", self.attn_wrappers["draft_and_verify"])
        
        with torch.device(self.device):
            self.model.base_model.setup_caches(num_pages=self.max_num_pages, page_size=self.page_size, use_position_ids=True)


    def _setup_common_attn_masks_and_position_ids(self, draft_lengths) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Build non-causal MTP attention mask (bool) and position_ids.

        Sequence (length S = k*(k+1)) consists of k groups:
        group g in [0..k-1]: [x_g, m1, m2, ..., mk]
        index i:
            g(i) = i // (k+1)      # group id
            r(i) = i %  (k+1)      # 0 -> regular, 1..k -> mask id

        Mask rule:
        - row r(i)=0 (regular): allow columns with r(j)=0 and j <= i
        - row r(i)>0 (mask):    allow columns with r(j)=0 and j <= i  (all prev regulars)
                                + same-group masks with 1 <= r(j) <= r(i)
        position_ids:
        pos[i] = g(i) + r(i)
        """
        assert all(draft_length >= 1 for draft_length in draft_lengths)

        if len(draft_lengths) == 1:
            draft_length = draft_lengths[0]
            S = (draft_length + 1) ** 2
            idx = torch.arange(S, device=self.device)

            # Closed-form position ids
            g = torch.div(idx, (draft_length + 1), rounding_mode='floor')
            r = idx % (draft_length + 1)
            position_ids = (g + r).to(torch.long)

            # Vectorized boolean mask
            I, J = idx[:, None], idx[None, :]
            g_i, g_j = g[:, None], g[None, :]
            r_i, r_j = r[:, None], r[None, :]

            allowed_regulars = (r_j == 0) & (J <= I)
            same_group_masks = (g_i == g_j) & (r_j > 0) & (r_j <= r_i)
            attn_mask = allowed_regulars | same_group_masks

            attn_mask_list = [attn_mask]
            position_ids_list = [position_ids]

        else:
            S = (draft_lengths[0] + 1) * (draft_lengths[1] + 1)
            idx = torch.arange(S, device=self.device)

            attn_mask_list, position_ids_list = [], []
            for draft_length in draft_lengths:
                # Closed-form position ids
                g = torch.div(idx, (draft_length + 1), rounding_mode='floor')
                r = idx % (draft_length + 1)
                position_ids = (g + r).to(torch.long)

                # Vectorized boolean mask
                I, J = idx[:, None], idx[None, :]
                g_i, g_j = g[:, None], g[None, :]
                r_i, r_j = r[:, None], r[None, :]

                allowed_regulars = (r_j == 0) & (J <= I)
                same_group_masks = (g_i == g_j) & (r_j > 0) & (r_j <= r_i)
                attn_mask = allowed_regulars | same_group_masks

                attn_mask_list.append(attn_mask)
                position_ids_list.append(position_ids)

        return attn_mask_list, position_ids_list


    def setup_sampling_params(self, temperature=0.0, top_k=0, top_p=0.95):
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.greedy = self.temperature == 0.0

    
    def compile(self):
        import torch._dynamo.config
        import torch._inductor.config
        torch._inductor.config.coordinate_descent_tuning = True
        torch._inductor.config.triton.unique_kernel_names = True
        torch._inductor.config.fx_graph_cache = True
        torch._functorch.config.enable_autograd_cache = True

        self.prefill_forward = torch.compile(self.prefill_forward)
        self.draft_forward = torch.compile(self.draft_forward)
        self.draft_and_verify_forward = torch.compile(self.draft_and_verify_forward)


    @torch.no_grad()
    def interleave_mask_tokens(self, input_ids: torch.LongTensor):
        if len(self.draft_lengths) > 1:
            # Flip the draft length
            self.cnt = (self.cnt + 1) % 2
            
        B, L = input_ids.shape
        D = self.draft_lengths[self.cnt]

        # Output length = L * (D + 1)
        out_len = L * (D + 1)

        # 1) Allocate and fill with mask_token_id
        out_ids = torch.empty((B, out_len), dtype=input_ids.dtype, device=self.device)
        out_ids.fill_(self.mask_token_id)

        # View as [B, L, D+1] and write tokens at slot 0 of each block
        view_ids = out_ids.view(B, L, D + 1)
        view_ids[:, :, 0] = input_ids  # tokens at the first position of each (D+1)-block

        # 2) gate_mask: 1 for masks, 0 for tokens
        gate_mask = torch.ones((B, out_len), dtype=self.dtype, device=self.device)
        view_gate = gate_mask.view(B, L, D + 1)
        view_gate[:, :, 0] = 0  # token slots

        return out_ids, gate_mask[..., None]
        

    def encode(self, input_ids: torch.LongTensor, query_lens: torch.Tensor):
        self.clear_kv()
        logits = None
        bsz, seq_len = input_ids.shape
        assert seq_len % self.prefill_chunk_size == 0, f"The sequence length must be divisible by the prefill chunk size, but got seq_len={seq_len} and prefill_chunk_size={self.prefill_chunk_size}"

        last_logits = None # For lazy initialization
        last_recorded = torch.zeros(bsz, dtype=torch.bool, device=self.device)

        chunk_size = self.prefill_chunk_size
        num_chunks = seq_len // chunk_size

        for i in range(num_chunks):
            chunk_input_ids = input_ids[:, i*chunk_size:(i+1)*chunk_size]
            chunk_query_lens = query_lens - (i * chunk_size)
            chunk_query_lens = torch.clamp(chunk_query_lens, min=0, max=chunk_size)
            
            # if every query in chunk only has pad tokens, skip the chunk
            if torch.all(chunk_query_lens == 0): continue
            
            # Target Prefill
            position_ids = (self.cachelens[:, None] + torch.arange(chunk_size, device=self.cachelens.device)[None, :]).flatten()
            self.pre_encode(dec_len=chunk_size)
            with torch.inference_mode():
                logits, _ = self.prefill_forward(
                    model=self.model,
                    x=chunk_input_ids,
                    position_ids=position_ids,
                    kv_append_indptr=self.qo_indptr*chunk_size,
                    kv_page_indices=self.paged_kv_indices,
                    kv_page_indptr=self.paged_kv_indptr,
                    kv_page_lastlen=self.paged_kv_last_page_len,
                )

            # Lazy initialization since the hidden_dim is not known until the first chunk is processed (for TP)
            if last_logits is None:
                last_logits = torch.full((bsz, logits.shape[-1]), float('nan'), device=self.device, dtype=self.dtype)

            # Grab the last token's logits and hidden states for each sequence in the chunk
            target_indices_in_chunk = chunk_query_lens - 1
            finishes_in_this_chunk = (query_lens > i*chunk_size) & (query_lens <= (i+1)*chunk_size)
            target_sequences_mask = finishes_in_this_chunk & (~last_recorded)
            
            target_batch_indices = torch.where(target_sequences_mask)[0]
            if target_batch_indices.numel() > 0:
                indices_in_chunk_to_grab = target_indices_in_chunk[target_batch_indices]
                last_logits[target_batch_indices] = logits[target_batch_indices, indices_in_chunk_to_grab, :]
                last_recorded[target_batch_indices] = True
            
            exists_padding = (chunk_query_lens < chunk_size)
            if exists_padding.any():
                self.delete_kv(chunk_size - chunk_query_lens)
        
        assert not torch.isnan(last_logits).any(), "Found NaN in last_logits."
        return sample(last_logits, top_p=self.top_p, top_k=self.top_k, temperature=self.temperature) # [bsz, 1]
        
    
    def pre_encode(self, dec_len):
        self.insert_kv(dec_len)
        self.attn_wrappers["prefill"].plan(
            qo_indptr=self.qo_indptr*dec_len,
            paged_kv_indptr=self.paged_kv_indptr,
            paged_kv_indices=self.paged_kv_indices,
            paged_kv_last_page_len=self.paged_kv_last_page_len,
            num_qo_heads=self.model.base_model.config.n_head,
            num_kv_heads=self.model.base_model.config.n_local_heads,
            head_dim_qk=self.model.base_model.config.head_dim,
            page_size=self.page_size,
            q_data_type=self.dtype,
            causal=True,
        )


    def draft_and_verify(self, input_ids: torch.LongTensor, gate_mask: torch.Tensor):
        """
        Regular routine for self-speculative decoding with MTP.

        Assume that k draft tokens are generated in the previous step (x_0, ..., x_k) where k is the draft length.
        Then the input sequence is [x_0, m_1, ..., m_k, x_1, m_1, ..., m_k, ..., x_k, m_1, ..., m_k] where m_i is the <mask> token.
        And the corresponding gate_mask is [0, 1, ..., 1, 0, 1, ..., 1, ..., 0, 1, ..., 1] where the 0's for x_0, x_1, ..., x_k and 1's for m_1, ..., m_k.
        
        Note that in this case, we use non-causal attention mask and custom position_ids.
        The attention mask follows the rules:
            1. Every regular token (x_i) only attends to the previous regular tokens,
            2. Every mask token (m_i) only attends to the previous regular tokens and previous mask tokens within the same block of mask tokens (m_1, ..., m_k).
        
        For example, when k = 2, the input sequence would be [x_0, m_1, m_2, x_1, m_1, m_2] and the gate_mask would be [0, 1, 1, 0, 1, 1].
        Then, the corresponding attention mask would be:
            [[1,0,0,0,0,0],
             [1,1,0,0,0,0],
             [1,1,1,0,0,0],
             [1,0,0,1,0,0],
             [1,0,0,1,1,0],
             [1,0,0,1,1,1]]
        , and the corresponding position_ids would be [0, 1, 2, 1, 2, 3] which can be derived from the attention mask.
        """

        bsz, dec_len = input_ids.shape
        assert dec_len == self.draft_and_verify_len, f"The input sequence length must be equal to the draft_and_verify length, but got dec_len={dec_len} and draft_and_verify_len={self.draft_and_verify_len}"
        
        # prepare position_ids
        position_ids = (self.cachelens[:, None] + self.common_position_ids[self.cnt][None, :]).flatten()

        # model forward for draft_and_verify
        self.pre_draft_and_verify(bsz, dec_len)
        with torch.inference_mode():
            logits, hidden_states = self.draft_and_verify_forward(
                model=self.model, 
                x=input_ids,
                gate_mask=gate_mask,
                position_ids=position_ids,
                kv_append_indptr=self.qo_indptr*dec_len,
                kv_page_indices=self.paged_kv_indices,
                kv_page_indptr=self.paged_kv_indptr,
                kv_page_lastlen=self.paged_kv_last_page_len,
            )

        return logits, hidden_states


    def pre_draft_and_verify(self, bsz, dec_len):
        mask_arr = []
        for i in range(bsz):
            ones_mask = torch.ones((dec_len, self.cachelens[i]), device=self.device)
            mask_i = torch.cat((ones_mask, self.common_attn_masks[self.cnt]), dim=-1)
            mask_arr.append(mask_i.flatten())

        attn_mask = torch.cat(mask_arr, dim=0)
        attn_mask = attn_mask.contiguous().to(device=self.device, dtype=torch.bool)

        self.insert_kv(dec_len)
        self.attn_wrappers["draft_and_verify"].plan(
            qo_indptr=self.qo_indptr*dec_len,
            paged_kv_indptr=self.paged_kv_indptr,
            paged_kv_indices=self.paged_kv_indices,
            paged_kv_last_page_len=self.paged_kv_last_page_len,
            num_qo_heads=self.model.base_model.config.n_head, 
            num_kv_heads=self.model.base_model.config.n_local_heads, 
            head_dim_qk=self.model.base_model.config.head_dim, 
            page_size=self.page_size, 
            q_data_type=self.dtype, 
            causal=False,
            custom_mask=attn_mask
        )


    def draft(self, input_ids: torch.LongTensor, gate_mask: torch.Tensor):
        """
        First draft after prefill / Fallback draft
        
        Assume that a single token is generated in the previous step (x_0).
        Then the input sequence is [x_0, m_1, ..., m_k] where m_i is the <mask> token, and k is the draft length.
        And the corresponding gate_mask is [0, 1, ..., 1] where the first 0 is for x_0.
        
        Note that in this case, we use causal attention mask and normal position_ids.
        """
        dec_len = input_ids.shape[1]
        assert dec_len == self.draft_lengths[self.cnt] + 1, f"The input sequence length must be equal to the draft length + 1, but got dec_len={dec_len} and draft_length={self.draft_lengths[self.cnt]}"

        # prepare position_ids
        position_ids = (self.cachelens[:, None] + torch.arange(dec_len, device=self.cachelens.device)[None, :]).flatten()

        # model forward for draft
        self.pre_draft(dec_len)
        with torch.inference_mode():
            logits, hidden_states = self.draft_forward(
                model=self.model, 
                x=input_ids,
                gate_mask=gate_mask,
                position_ids=position_ids,
                kv_append_indptr=self.qo_indptr*dec_len,
                kv_page_indices=self.paged_kv_indices,
                kv_page_indptr=self.paged_kv_indptr,
                kv_page_lastlen=self.paged_kv_last_page_len,
            ) # [bsz, dec_len, vocab_size], [bsz, dec_len, hidden_size]
        
        # delete the KV cache entries for the draft(mask) tokens
        self.delete_kv(self.draft_lengths[self.cnt])

        return logits, hidden_states


    def pre_draft(self, dec_len):
        self.insert_kv(dec_len)
        self.attn_wrappers["draft"].plan(
            qo_indptr=self.qo_indptr*dec_len,
            paged_kv_indptr=self.paged_kv_indptr,
            paged_kv_indices=self.paged_kv_indices,
            paged_kv_last_page_len=self.paged_kv_last_page_len,
            num_qo_heads=self.model.base_model.config.n_head,
            num_kv_heads=self.model.base_model.config.n_local_heads,
            head_dim_qk=self.model.base_model.config.head_dim,
            page_size=self.page_size,
            q_data_type=self.dtype,
            causal=True,
        )


    def sampler_draft(self, next_tokens, draft_hidden_states):
        """
        Draft with sampler.
        Assume that the next_tokens is one generated from the prefill or last accepted token from the evaluate posterior.
        And the draft_hidden_states is obtained by the model forward for draft(mask) tokens proceeding the one that gives the next_tokens.
        
        In other words, in the previous step, the input sequence was either [x_0, m_1, ..., m_k] (prefill) or [x_0, m_1, ..., m_k, ..., x_k, m_1, ..., m_k] (draft_and_verify).
        For the former, the next_tokens is x_1 and the draft_hidden_states is model(m_1, ..., m_k).
        For the latter, for example, when [x_0, x_1, x_2] is accepted, the next_tokens is x_3 and the draft_hidden_states is model(x_2, m_1, ..., m_k)[1:] (drop the first one).

        This function generates the actual draft tokens by sampling from the sampler.

        Args:
            next_tokens (torch.Tensor): The next tokens to be generated. Shape: [bsz, 1]
            draft_hidden_states (torch.Tensor): The hidden states of the draft tokens proceeding the one that gives the next_tokens. Shape: [bsz, draft_length, hidden_size]

        Returns:
            draft_tokens (torch.Tensor): The actual draft tokens. Shape: [bsz, draft_length]
        """

        # placeholder for draft tokens (actually, 1 regular token + draft_length draft tokens)
        draft_length = self.draft_lengths[self.cnt]
        tokens_buffer = torch.zeros((self.batch_size, 1 + draft_length), dtype=torch.long, device=self.device) # [bsz, 1 + draft_length]
        tokens_buffer[:, :1] = next_tokens # [bsz, 1]
        
        # model forward for sampler
        with torch.inference_mode():
            for j in range(draft_length):
                tokens_buffer[:, j+1:j+2] = self.model.sampler_forward(tokens_buffer[:, j:j+1], draft_hidden_states[:, j:j+1, :]) # [bsz, 1, vocab_size]

        return tokens_buffer[:, 1:]


    def evaluate_posterior(self, draft_tokens, target_preds, eot_token):
        """
        Evaluate the posterior for the draft tokens.

        Args:
            draft_tokens (torch.Tensor): The tokens that are generated by the sampler. Shape: [bsz, draft_length]
            target_preds (torch.Tensor)
                - (Greedy) The predicted tokens by the target model. Shape: [bsz, draft_length + 1]
                - (Sampling) The predicted logits by the target model. Shape: [bsz, draft_length + 1]
            eot_token (int): The end-of-text token.

        Returns:
            bonus_tokens (torch.Tensor): The bonus tokens. Shape: [bsz, 1]
            accept_nums (torch.Tensor): The number of accepted tokens. Shape: [bsz, 1]
            eot_accepted (torch.Tensor): The flag of whether the EOT tokens are accepted. Shape: [bsz, 1]
        """
        bsz, draft_length = draft_tokens.shape
        if self.greedy:
            eot_condition = (draft_tokens == eot_token)  # [B, draft_length]
            accept_flags_matrix = target_preds[:, :draft_length] == draft_tokens # [bsz, draft_length]
            accept_nums = accept_flags_matrix.sum(dim=1, keepdim=True) + 1  # [bsz, 1]
            
            bonus_tokens = target_preds.gather(1, accept_nums - 1) # [bsz, 1]
            eot_accepted = (eot_condition & accept_flags_matrix).any(dim=1, keepdim=True) # [bsz, 1]
            return bonus_tokens, accept_nums[:, 0], eot_accepted
        else:
            vocab_size = target_preds.shape[-1]

            target_probs = get_sampling_probs(target_preds, top_p=self.top_p, top_k=self.top_k, temperature=self.temperature) # [bsz, draft_length+1, V]
            draft_probs = torch.zeros((bsz * draft_length, vocab_size), dtype=target_preds.dtype, device=target_preds.device)
            draft_probs.scatter_(1, draft_tokens.reshape(-1, 1), 1.0)
            draft_probs = draft_probs.reshape(bsz, draft_length, vocab_size)

            output_tokens, _, emitted_nums = chain_speculative_sampling(draft_probs, draft_tokens, target_probs)
            accept_nums = emitted_nums + 1 # [bsz, 1]
            
            last_valid_idx = ((output_tokens != -1).to(torch.long) * torch.arange(output_tokens.size(1), device=output_tokens.device)).argmax(dim=1, keepdim=True)
            bonus_tokens = output_tokens.gather(1, last_valid_idx) # [bsz, 1]
            eot_accepted = (eot_token == output_tokens).any(dim=1, keepdim=True) # [bsz, 1]
            return bonus_tokens, accept_nums, eot_accepted


    def collate_accepted_kv_cache(self, accept_nums: torch.Tensor, prev_cachelens: torch.Tensor):
        """
        Collate the accepted KV cache entries based on the accepted numbers.
        Assume that k(k+1) KV entries were inserted to the KV cache from the previous step (draft_and_verify) where k is the draft length.
        Here, we need to 
            1. collate the accepted KV cache entries based on the accepted numbers.
            2. delete the KV cache entries that are (1) not accepted, (2) from the mask tokens.
        
        For example, suppose that the k=3, accepted_nums=[2, 3], prev_cachelens=[10, 15].
        Then, the number of newly inserted KV entries from the previous step is 12.
        In this case, we need to compute the following indices:
            1. save indices : [[10, 14], [15, 19, 23]]
            2. delete indices : [[11, 12, 13, 15, 16, 17, 18, 19, 20, 21], [16, 17, 18, 20, 21, 22, 24, 25, 26]]
        The accepted KV cache entries from save indices will be appended to the back-front of the KV cache.

        Args:
            accept_nums (torch.Tensor): The number of accepted tokens. Shape: [bsz]
            prev_cachelens (torch.Tensor): The number of KV cache entries from the previous step. Shape: [bsz]

        Returns:
            None
        """ 
        assert accept_nums.dim() == 1 and prev_cachelens.dim() == 1, f"The accept_nums and prev_cachelens are expected to be a 1D tensor but got {accept_nums.dim()}D and {prev_cachelens.dim()}D."

        bsz = accept_nums.shape[0]
        n_local_heads, head_dim = self.model.base_model.config.n_local_heads, self.model.base_model.config.head_dim
        draft_length = self.draft_lengths[self.cnt]
        stride = draft_length + 1
        max_accept_len = int(accept_nums.max().item())

        base = torch.arange(max_accept_len, device=self.device, dtype=torch.long) * stride # [max_accept_len]
        src = (prev_cachelens[:, None] + base[None, :]).reshape(-1) # [bsz * max_accept_len]
        dst = (prev_cachelens[:, None] + torch.arange(max_accept_len, device=self.device, dtype=torch.long)[None, :]).reshape(-1) # [bsz * max_accept_len]

        cols = torch.arange(max_accept_len, device=self.device, dtype=torch.long).expand(bsz, -1) # [bsz, max_accept_len]
        valid = (cols < accept_nums[:, None]).reshape(-1) # [bsz * max_accept_len]

        bidx = torch.arange(bsz, device=self.device, dtype=torch.long).repeat_interleave(max_accept_len) # [bsz * max_accept_len]
        bidx, src, dst = bidx[valid], src[valid], dst[valid] # [nnz]

        for layer in self.model.base_model.layers:
            kv = layer.attention.kv_cache.kv_cache
            kv = kv.permute(0, 2, 1, 3, 4) # [num_pages, page_size, 2, n_local_heads, head_dim]
            orig = kv.shape
            kv = kv.reshape(bsz, -1, 2, n_local_heads, head_dim) # [bsz, num_pages * page_size, 2, n_local_heads, head_dim]
            kv[bidx, dst] = kv[bidx, src] # RHS copy → LHS write (overlap-safe)
            kv = kv.reshape(orig).permute(0, 2, 1, 3, 4) # [num_pages, 2, page_size, n_local_heads, head_dim]
            layer.attention.kv_cache.kv_cache = kv

        if len(self.draft_lengths) == 1:
            inserted_len = (draft_length + 1) ** 2
        else:
            inserted_len = (self.draft_lengths[0] + 1) * (self.draft_lengths[1] + 1)
        self.delete_kv(inserted_len - accept_nums)


    def _as_len_tensor(self, lens):
        if isinstance(lens, torch.Tensor):
            t = lens.to(device=self.device, dtype=self.paged_kv_last_page_len.dtype)
        else:
            t = torch.tensor(lens, device=self.device, dtype=self.paged_kv_last_page_len.dtype)
        if t.dim() == 0:
            t = t.expand_as(self.paged_kv_last_page_len)
        return t


    def insert_kv(self, dec_lens):
        dec = self._as_len_tensor(dec_lens)
        if torch.all(dec <= 0):
            return

        old_full = self.num_pages_per_request.clone() - 1
        old_tail = self.paged_kv_last_page_len.clone()
        ps = self.page_size

        total_after = old_full * ps + old_tail + dec
        new_full = torch.where(
            total_after > 0,
            torch.div(total_after - 1, ps, rounding_mode='floor').to(old_full.dtype),
            torch.zeros_like(old_full),
        )
        new_tail = torch.where(
            total_after > 0,
            (((total_after - 1) % ps) + 1).to(old_tail.dtype),
            torch.zeros_like(old_tail),
        )

        add_pages = (new_full - old_full).clamp_min(0).to(torch.int32)
        if add_pages.max().item() > 0:
            self.paged_kv_indptr, self.paged_kv_indices = self.page_manager.allocate_counts(
                add_pages, self.paged_kv_indices, self.paged_kv_indptr
            )
            self.num_pages_per_request += add_pages  # sync with new_full

        self.paged_kv_last_page_len = new_tail
        self.cachelens = (self.cachelens + dec).clamp_min(0)


    def delete_kv(self, del_lens):
        dec = self._as_len_tensor(del_lens)
        if torch.all(dec <= 0):
            return

        old_full = self.num_pages_per_request.clone() - 1
        old_tail = self.paged_kv_last_page_len.clone()
        ps = self.page_size

        total_before = old_full * ps + old_tail
        total_after = (total_before - dec).clamp_min(0)

        new_full = torch.where(
            total_after > 0,
            torch.div(total_after - 1, ps, rounding_mode='floor').to(old_full.dtype),
            torch.zeros_like(old_full),
        )
        new_tail = torch.where(
            total_after > 0,
            (((total_after - 1) % ps) + 1).to(old_tail.dtype),
            torch.zeros_like(old_tail),
        )

        free_pages = (old_full - new_full).clamp_min(0).to(torch.int32)
        if free_pages.max().item() > 0:
            self.paged_kv_indptr, self.paged_kv_indices = self.page_manager.free_counts(
                free_pages, self.paged_kv_indices, self.paged_kv_indptr
            )
            self.num_pages_per_request -= free_pages  # sync with new_full

        self.paged_kv_last_page_len = new_tail
        self.cachelens = (self.cachelens - dec).clamp_min(0)
    

    def clear_kv(self):
        for b in self.model.base_model.layers:
            b.attention.kv_cache.kv_cache.zero_()
        
        self.cachelens = torch.zeros(self.batch_size, dtype=torch.int32, device=self.device)
        self.qo_indptr = torch.arange(self.batch_size+1, dtype=torch.int32, device=self.device)
        
        self.page_manager.reset()
        self.num_pages_per_request = torch.ones((self.batch_size), device=self.device, dtype=torch.int32)
        self.paged_kv_indptr = torch.arange(self.batch_size+1, dtype=torch.int32, device=self.device)
        self.paged_kv_indices = self.page_manager.allocate(torch.arange(self.batch_size, dtype=torch.int32, device=self.device))
        self.paged_kv_last_page_len = torch.zeros((self.batch_size), dtype=torch.int32, device=self.device)