import copy
import json
import time

import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig,AutoConfig
from .modeling_llama_kv_target_hybrid import LlamaForCausalLM as KVLlamaForCausalLM
from .modeling_llama_kv_draft_hybrid import LlamaForCausalLM as KVLlamaForCausalLM_retrieval
from .utils_hybrid import *
from .kv_cache import initialize_past_key_values
from transformers import AutoTokenizer
import os
from huggingface_hub import hf_hub_download
from .configs import EConfig

from .tree import Tree
from termcolor import colored
from datetime import datetime
from typing import Optional, List, Tuple

from functorch import vmap
import time
import numpy as np
import gc

class SPModel(nn.Module):

    def __init__(
            self,
            base_model,
            base_model_name_or_path,
            draft_model,
    ):

        super().__init__()
        self.base_model = base_model
        self.config = base_model.config
        self.hidden_size = base_model.lm_head.weight.shape[-1]
        self.vocab_size = base_model.lm_head.weight.shape[0]
        self.base_model_name_or_path = base_model_name_or_path
        self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)
        self.draft_model = draft_model
        self.draft_stable_kv=None

        self.full_draft_kv=None
        self.use_streamingLLMcache = False
        self.evicted = 0

    def get_tokenizer(self):
        return self.tokenizer

    @classmethod
    def from_pretrained(
            cls,
            Type="LLaMA",
            base_model_path=None,
            draft_model_path=None,
            **kwargs,
    ):
        base_model = KVLlamaForCausalLM.from_pretrained(
            base_model_path, **kwargs
        )

        draft_model = KVLlamaForCausalLM_retrieval.from_pretrained(
            draft_model_path, **kwargs
        )

        model = cls(
            base_model,
            base_model_path,
            draft_model
        )

        return model

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            tree_attention_mask=None,
            labels=None,
            past_key_values=None,
            output_orig=False,
            position_ids=None,
            init=True,
            nodes=None,
            threshold=None,
            max_depth=None,
            logits_processor=None,
            retrieve_attn_scores=False
    ):
        with torch.inference_mode():
            if self.measure_time:
                torch.cuda.synchronize()
                start = time.perf_counter()
            
            outputs = self.base_model.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                tree_attention_mask=tree_attention_mask,
                past_key_values=past_key_values,
                position_ids=position_ids,
                output_attentions=True,
                init=init,
                target_use_flash_prefill = self.target_use_flash_prefill,
                target_use_hybrid_tree_attn = self.target_use_hybrid_tree_attn,
                retrieve_attn_scores=retrieve_attn_scores
            )
            
            if self.measure_time:
                torch.cuda.synchronize()
                time_target_forward = time.perf_counter() - start
                if init:
                    self.time_target_prefill_list.append(time_target_forward)
                else:
                    if retrieve_attn_scores:
                        self.time_target_forward_yes_retrieval_list.append(time_target_forward)
                    else:
                        self.time_target_forward_no_retrieval_list.append(time_target_forward)

            if output_orig:
                orig = self.base_model.lm_head(outputs[0])
            hidden_states = outputs[0].clone()
            
            if self.use_retrieval_cache:
                if retrieve_attn_scores:
                    self.attn_scores = outputs.attentions[-1] # already returns last layer attention only
            
        if init:
            if logits_processor is not None:
                logits = orig[:, -1]
                logits = logits_processor(None, logits)
                probabilities = torch.nn.functional.softmax(logits, dim=1)
                token = torch.multinomial(probabilities, 1)
            else:
                token = torch.argmax(orig[:, -1])
                token = token[None, None]
            input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)

            input_ids, position_ids, tree_attention_mask,parent=self.draft(input_ids,nodes,threshold,max_depth)

            return input_ids,position_ids,tree_attention_mask,token,parent, outputs
        else:
            return outputs, orig, hidden_states

    def process_tree_mask(self, tree_attention_mask, init_len):
        attention_mask=torch.full((tree_attention_mask.size(0), init_len), 0, device=tree_attention_mask.device)
        tree_mask = torch.where(tree_attention_mask == 0, torch.finfo(torch.float32).min, 0)
        attention_mask=torch.cat([attention_mask,tree_mask],dim=-1)
        attention_mask = attention_mask[None, None, :, :]
        return attention_mask

    @torch.no_grad()
    def draft(self,input_ids,nodes,threshold,max_depth):
        # print('### Start of DRAFT ###')
        # print(colored(f'input_ids passed to draft()): {input_ids.shape}', 'cyan'))
        # print(colored(f'self.total_seq_len BEFORE: {self.total_seq_len}', 'cyan'))
        # if self.draft_stable_kv is not None:
        #     print(colored(f'self.draft_stable_kv: {self.draft_stable_kv[0][0].shape}', 'cyan'))
        if self.measure_time:
            torch.cuda.synchronize()
            t1 = time.perf_counter()
        len_posi = input_ids.shape[1]-1
        ###### Initial Forward to generate top_k branches ######
        if hasattr(self, "draft_stable_kv") and self.draft_stable_kv is not None:
            if self.use_streamingLLMcache or self.use_retrieval_cache:
                full_kv_len = self.total_seq_len
                draft_outputs = self.draft_model.model(
                    input_ids=input_ids[:, full_kv_len:].to(self.draft_model.model.embed_tokens.weight.device),
                    past_key_values=self.draft_stable_kv,
                    return_kv=True,
                    draft_use_flash_prefill = self.draft_use_flash_prefill
                )
            else:
                kv_len = self.draft_stable_kv[0][0].shape[2]
                draft_outputs = self.draft_model.model(
                    input_ids=input_ids[:, kv_len:].to(self.draft_model.model.embed_tokens.weight.device),
                    past_key_values=self.draft_stable_kv,
                    return_kv=True,
                    draft_use_flash_prefill = self.draft_use_flash_prefill
                )
            if self.measure_time:
                torch.cuda.synchronize()
                time_init_forward = time.perf_counter() - t1
                self.time_init_forward_list.append(time_init_forward)

        ###### Prefill ######
        else:
            draft_outputs = self.draft_model.model(
                input_ids=input_ids.to(self.draft_model.model.embed_tokens.weight.device),
                return_kv=True,
                init=True,
                draft_use_flash_prefill = self.draft_use_flash_prefill
            )
            if self.measure_time:
                torch.cuda.synchronize()
                time_draft_prefill = time.perf_counter() - t1
                self.time_draft_prefill_list.append(time_draft_prefill)

        # print(colored(f'draft_outputs[1][0][0].shape: {draft_outputs[1][0][0].shape}', 'magenta'))
        if self.measure_time:
            torch.cuda.synchronize()
            t2 = time.perf_counter()
        if self.use_streamingLLMcache:
            # updating full cache is alrady done in streaming_prefill_draft()
            # if prefilling_in_chunks:
                # self.draft_stable_kv = past_key_values
                # self.print_cache_status()
            # else:
            newly_appended_len = input_ids.shape[-1] - self.total_seq_len
            self.update_full_draft_cache(draft_outputs[1], tokens_appended=newly_appended_len)
            self.draft_stable_kv = self.update_working_cache_from_full()

        elif self.use_retrieval_cache:
            # if prefilling_in_chunks:
            #     self.draft_stable_kv = past_key_values
            #     self.print_cache_status()
            # else:
            if self.measure_time:
                torch.cuda.synchronize()
                start_1 = time.perf_counter()
            newly_appended_len = input_ids.shape[-1] - self.total_seq_len
            self.update_full_draft_cache(draft_outputs[1], tokens_appended=newly_appended_len)
            if self.measure_time:
                torch.cuda.synchronize()
                time_update_cache_1 = time.perf_counter() - start_1
            
            if self.measure_time:
                torch.cuda.synchronize()
                start_2 = time.perf_counter()
            self.draft_stable_kv = self.update_working_cache_retrieval_main(top_k_chunks=self.retrieve_top_k)
            if self.measure_time:
                torch.cuda.synchronize()
                time_update_cache_2= time.perf_counter() - start_1
            
            # if self.show_time:
            #     print(colored(f'update_full_draft_cache: {1000*time_update_cache_1:.3f}','red'))
            #     print(colored(f'update_working_cache_retrieval_main: {1000*time_update_cache_2:.3f}','red'))

            if self.retrieval_verbose:
                self.print_retrieved_chunks()
        else:
            self.draft_stable_kv=draft_outputs[1]

        # print(colored(f'draft cache len: {self.draft_stable_kv[0][-1].shape[2]}','magenta'))

        if self.measure_time:
            torch.cuda.synchronize()
            time_update_cache = time.perf_counter() - t2 # update cache
        # print(colored(f'\nAFTER initial forward', 'cyan'))
        # print(colored(f'self.total_seq_len: {self.total_seq_len}', 'cyan'))
        # print(colored(f'self.draft_stable_kv: {self.draft_stable_kv[0][0].shape}', 'cyan'))
        
        if self.measure_time:
            torch.cuda.synchronize()
            t3 = time.perf_counter()
        past_key_values=self.draft_stable_kv

        # new total length of kv cache after initial forward
        init_len = past_key_values[0][0].size(2)

        # target_model_pos_diff = len_posi - (init_len) 
        target_model_pos_diff = len_posi - (init_len - 1) 

        last_hidden=draft_outputs[0][:,-1]
        last_headout = self.draft_model.lm_head(last_hidden)

        tree = Tree(nodes, last_hidden.device, threshold, max_depth)

        logits = last_headout.unsqueeze(0)

        # test draft
        if self.test_generate_cache:
            probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
            next_token = torch.multinomial(probabilities, num_samples=1).view(1, -1)
            self.test_autoregressive_generation(next_token, past_key_values, num_tokens=30)
            return
        
        if self.measure_time:
            torch.cuda.synchronize()
            time_init_tree = time.perf_counter() - t3 # update cache
        
            torch.cuda.synchronize()
            t4 = time.perf_counter()
        step = 0
        # print(colored(f'\ENTERING tree draft loop', 'cyan'))
        while True:
            tree_output = tree.update(
                torch.softmax(logits.to(last_hidden.device), dim=-1, dtype=torch.float32))

            input_ids = tree_output["input_ids"].unsqueeze(0)

            if self.use_streamingLLMcache or self.use_retrieval_cache:
                position_ids = tree_output["position_ids"] + init_len-1
            else:
                position_ids = tree_output["position_ids"] + len_posi

            if tree_output["is_final"]:
                break
            tree_attention_mask_with_kv=self.process_tree_mask(tree_output["attention_mask"], init_len)

            # print(colored(f'Step {step}', 'yellow'))
            # print(colored(f'input_ids.shape: {input_ids.shape}', 'yellow'))
            # print(colored(f'past_key_values[0][0].shape[2]: {past_key_values[0][0].shape[2]}', 'yellow'))
            # print(colored(f'position_ids.shape: {position_ids.shape}', 'yellow'))
            # print(colored(f'position_ids: {position_ids}', 'yellow'))
            # print(colored(f'tree_attention_mask_with_kv.shape: {tree_attention_mask_with_kv.shape}', 'yellow'))

            if self.measure_time:
                torch.cuda.synchronize()
                t5 = time.perf_counter()
            draft_outputs = self.draft_model.model(
                input_ids=input_ids,
                position_ids=position_ids,
                past_key_values=past_key_values,
                tree_attention_mask=tree_attention_mask_with_kv,
                return_kv=True,
                draft_use_flash_prefill = self.draft_use_flash_prefill
            )
            if self.measure_time:
                torch.cuda.synchronize()
                time_draft_step = time.perf_counter() - t5 # update cache
                self.time_draft_step_list.append(time_draft_step)

            past_key_values=draft_outputs[1]
            last_hidden = draft_outputs[0]
            last_headout = self.draft_model.lm_head(last_hidden)
            logits = last_headout

            step += 1

        if self.use_streamingLLMcache or self.use_retrieval_cache:
            position_ids += target_model_pos_diff
        
        # print(colored(f'Final returned position_ids.shape: {position_ids.shape}', 'red'))
        # print(colored(f'Final returned position_ids: {position_ids}', 'red'))
        # print('### End of DRAFT ###')
        if self.measure_time:
            torch.cuda.synchronize()
            time_draft_tree = time.perf_counter() - t4 # update cache
        
            self.time_update_cache_list.append(time_update_cache)
            self.time_init_tree_list.append(time_init_tree)
            self.time_draft_tree_list.append(time_draft_tree)
        return input_ids, position_ids, tree_output["attention_mask"], tree_output["parent_last"]

    @torch.no_grad()
    def spgenerate(
            self,
            input_ids,
            temperature=0.0,
            top_p=0.0,
            top_k=0.0,
            max_new_tokens=512,
            max_length=1000,
            nodes=50,
            threshold=0.5,
            max_depth=10,
            output_result_line=False,
            print_input=False,
            verbose = False,
            use_streamingLLM_cache=False,
            sink_size=16,
            recent_size=512,
            use_retrieval_cache=False,
            retrieval_verbose=False,
            retrieval_chunk_size=32,
            retrieve_top_k=15,
            print_draft_tree=False,
            test_generate_cache=False,
            show_time=False,
            measure_time=False,
            retrieve_every_n_steps=16,
            target_use_flash_prefill = False,
            target_use_hybrid_tree_attn = False,
            draft_use_flash_prefill = False
    ):   
        assert not (use_streamingLLM_cache and use_retrieval_cache), f"Both streamingLLM and retrieval_cache are True. Select one cache type."

        if print_input:
            decoded_text = self.tokenizer.decode(input_ids[0,-50:].tolist(), skip_special_tokens=True)
            print(colored(f'Input: {decoded_text}','blue'))

        tree_config = {
            'max_depth': max_depth,
            'total_nodes': nodes,
            'sp_threshold': threshold 
        }
        self.print_draft_tree = print_draft_tree

        self.use_streamingLLMcache = use_streamingLLM_cache
        self.sink_size = sink_size
        self.recent_size = recent_size

        self.use_retrieval_cache = use_retrieval_cache
        self.retrieval_chunk_size = retrieval_chunk_size
        self.retrieve_top_k = retrieve_top_k
        self.retrieval_verbose = retrieval_verbose
        self.retrieve_every_n_steps = retrieve_every_n_steps
        self.num_chunks_old = 0
        self.retrieval_condition = False
        
        self.attn_scores = None # attention scores returned from the forward pass
        self.attn_scores_final = None # final attention scores used for retrieval (includes newly accepted tokens)

        self.test_generate_cache = test_generate_cache
        self.show_time = show_time
        self.measure_time = measure_time

        self.time_init_forward_list = []
        self.time_target_prefill_list = []
        self.time_draft_prefill_list = []
        self.time_update_cache_list = []
        self.time_init_tree_list = []
        self.time_draft_tree_list = []
        self.time_draft_step_list = []
        self.time_target_forward_no_retrieval_list = []
        self.time_target_forward_yes_retrieval_list = []
        
        self.timestep = 0

        self.target_use_flash_prefill = target_use_flash_prefill
        self.target_use_hybrid_tree_attn = target_use_hybrid_tree_attn
        self.draft_use_flash_prefill = draft_use_flash_prefill
        
        if temperature > 1e-5:
            logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
        else:
            logits_processor = None

        assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"

        input_len = input_ids.shape[1]

        # initialize caches
        if self.use_streamingLLMcache or self.use_retrieval_cache:
            self.full_cache_budget = input_len + 500
            self.init_caches()
        else:
            self.draft_stable_kv = None
            self.full_draft_kv = None
            self.evicted = 0

        (
            past_key_values,
            past_key_values_data,
            current_length_data,
        ) = initialize_past_key_values(self.base_model, input_len)
        self.past_key_values = past_key_values
        self.past_key_values_data = past_key_values_data
        self.current_length_data = current_length_data
    
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.synchronize()
    
        start_time = datetime.now()
        
        # Prefill target model and draft model + initial draft
        draft_input_ids,draft_position_ids,tree_attention_mask,last_token,parent, outputs = self(
            input_ids=input_ids, past_key_values=past_key_values,  output_orig=True, 
            nodes=nodes, threshold=threshold, max_depth=max_depth, logits_processor=logits_processor
        )

        draft_input_ids=torch.cat([last_token.to(draft_input_ids.device),draft_input_ids],dim=-1)
        draft_position_ids=torch.cat([torch.tensor([draft_position_ids[0]-1],device=draft_position_ids.device), draft_position_ids],dim=-1)
        tree_attention_mask=torch.cat([torch.zeros(1,tree_attention_mask.size(1),dtype=tree_attention_mask.dtype,device=tree_attention_mask.device),tree_attention_mask],dim=0)
        tree_attention_mask = torch.cat([torch.ones(tree_attention_mask.size(0), 1,dtype=tree_attention_mask.dtype,device=tree_attention_mask.device), tree_attention_mask],
                                        dim=1)

        new_token = 0
        total_tokens_list = []
        accept_length_list = []

        time_TD_list = []
        time_verify_list = []
        for idx in range(max_length):
            assert past_key_values[0][0].shape[2]==draft_position_ids[0]
            if self.measure_time:
                torch.cuda.synchronize()
                start_1 = time.perf_counter()
            logits, hidden_state_new, outputs = tree_decoding(
                self,
                draft_input_ids,
                past_key_values,
                draft_position_ids,
                tree_attention_mask
            )
            if self.measure_time:
                torch.cuda.synchronize()
                time_tree_decoding = time.perf_counter() - start_1
                time_TD_list.append(time_tree_decoding)

            old_len = input_ids.shape[1]

            if self.measure_time:
                torch.cuda.synchronize()
                start_2 = time.perf_counter()
            input_ids, best_candidate, accept_length, draft_input_ids, draft_position_ids, tree_attention_mask, parent=verify(input_ids,
                                                                      logits,
                                                                      draft_input_ids,
                                                                      draft_position_ids,
                                                                      tree_attention_mask,
                                                                      past_key_values_data,
                                                                      current_length_data,
                                                                      parent,
                                                                      self,
                                                                      nodes,
                                                                      threshold,
                                                                      max_depth,
                                                                      logits_processor)
            if self.measure_time:
                torch.cuda.synchronize()
                time_verify = time.perf_counter() - start_2
                time_verify_list.append(time_verify)
            
            accept_length_list.append(accept_length.item() if isinstance(accept_length, torch.Tensor) else int(accept_length))

            generated_tokens_list = print_newly_accepted_tokens(old_len, input_ids,
                                                        self.tokenizer, verbose=verbose)
            new_token+=accept_length+1
            total_tokens_list.extend(generated_tokens_list)

            self.timestep += 1

            if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
                break
            if new_token > max_new_tokens:
                break

        # Calculate eval metrics
        avg_accept_length = round(sum(accept_length_list)/len(accept_length_list), 3)
        inference_time = (datetime.now() - start_time).total_seconds()

        total_generated = new_token
        tokens_per_sec = round(total_generated/inference_time,2)

        if measure_time:
            avg_tree_decode_time = round(np.mean(time_TD_list) * 1000, 3)
            avg_verify_time = round(np.mean(time_verify_list) * 1000, 3)
            
            time_target_prefill = round(np.mean(self.time_target_prefill_list) * 1000, 3)
            time_draft_prefill = round(np.mean(self.time_draft_prefill_list) * 1000, 3)
            time_init_forward = round(np.mean(self.time_init_forward_list[1:]) * 1000, 3)
            time_update_cache = round(np.mean(self.time_update_cache_list[1:]) * 1000, 3)
            time_draft_tree = round(np.mean(self.time_draft_tree_list[1:]) * 1000, 3)
            time_draft_step = round(np.mean(self.time_draft_step_list[1:]) * 1000, 3) # exclude very first draft
            time_target_forward_yes_retrieval = round(np.mean(self.time_target_forward_yes_retrieval_list[1:]) * 1000, 3) # exclude very first draft
            time_target_forward_no_retrieval = round(np.mean(self.time_target_forward_no_retrieval_list[1:]) * 1000, 3) # exclude very first draft
        
        if self.show_time:
            print(colored(f'\n\nAvg time Target prefill: {time_target_prefill} ms', 'yellow'))
            print(colored(f'Avg time Draft prefill: {time_draft_prefill} ms', 'yellow'))
            print(colored(f'Avg time Tree Decoding: {avg_tree_decode_time} ms', 'yellow'))
            print(colored(f'Avg time Draft init forward: {time_init_forward} ms', 'yellow'))
            print(colored(f'Avg time Draft tree attn step: {time_draft_step} ms', 'yellow'))
            print(colored(f'Avg time Verify: {avg_verify_time} ms', 'yellow'))
            print(colored(f'\nOther Draft times: \n  time_update_cache:{time_update_cache} ms\n  time_draft_tree:{time_draft_tree} ms\n', 'yellow'))
            print(colored(f'\n\nAvg time Target Forward Pass (yes retrieval): {time_target_forward_yes_retrieval} ms', 'yellow'))
            print(colored(f'Avg time Target Forward Pass (no retrieval): {time_target_forward_no_retrieval} ms', 'yellow'))
            
        if output_result_line:
            print(colored(
                f"\nGenerated {total_generated} tokens in {inference_time:.2f}s. "
                f"\nToken/sec: {tokens_per_sec}"
                f"\nAverage acceptance length: {avg_accept_length:.3f}",
                'cyan'
            ))

        results = {
            'avg_accept_length': avg_accept_length,
            'total_generated': total_generated,
            'inference_time': inference_time,
            'tokens_per_sec': tokens_per_sec,
            'tree_config': tree_config,
            'latency': {
                'verify': time_verify_list,
                'tree_decode': time_TD_list,
                'target_prefill': self.time_target_prefill_list,
                'draft_prefill': self.time_draft_prefill_list,
                'draft_init_forward': self.time_init_forward_list[1:],
                'draft_update_cache': self.time_update_cache_list[1:],
                'draft_tree': self.time_draft_tree_list[1:],
                'draft_step': self.time_draft_step_list[1:],
                'target_forward_yes_retrieval': self.time_target_forward_yes_retrieval_list[1:],
                'target_forward_no_retrieval': self.time_target_forward_no_retrieval_list[1:]
            }
        }

        return input_ids, results
        
    def warm_up_gpus(self, input_shape,allocate_full_target_cache):
        print(colored('Warming up GPUs..','cyan'))
        dummy_input = torch.randint(
            low=0, 
            high=self.tokenizer.vocab_size, 
            size=input_shape, 
            device=self.draft_model.model.embed_tokens.weight.device
        )
        for _ in range(2):
            _ = self.spgenerate(
                dummy_input,
                temperature=0,            
                max_new_tokens=3,      
                nodes=50,                 
                threshold=0.7,            
                max_depth=3,
                output_result_line=False,
                print_input=False,
                verbose=False,
                use_streamingLLM_cache=True,
                sink_size=16,
                recent_size=512,
                use_retrieval_cache=False,
                retrieval_verbose=False,
                print_draft_tree=False,
                test_generate_cache=False,
                show_time=False,
                allocate_full_target_cache=allocate_full_target_cache
            )

        torch.cuda.synchronize()
        print(colored('Warmup complete!','cyan'))

    def init_caches(self):
        """
        Preallocate both the full cache and the working cache.
        Both are allocated with a budget dimension of full_cache_budget.
        (We assume that full_cache_budget and working_cache_budget are provided;
        here we assume working_cache_budget == sink_size + recent_size.)
        Initially both caches are empty (filled with zeros), and total_seq_len is 0.
        """
        # Get values from the model config.
        num_hidden_layers = self.draft_model.config.num_hidden_layers
        num_heads = self.draft_model.config.num_attention_heads
        head_dim = self.draft_model.config.hidden_size // num_heads

        # full cache: shape: [layers, batch_size, num_heads, full_cache_budget, head_dim]
        self.full_draft_kv = []
        for _ in range(num_hidden_layers):
            full_K = torch.zeros(
                [1, num_heads, self.full_cache_budget, head_dim],
                dtype=torch.float16,
                device=self.draft_model.device
            )
            full_V = torch.zeros(
                [1, num_heads, self.full_cache_budget, head_dim],
                dtype=torch.float16,
                device=self.draft_model.device
            )
            self.full_draft_kv.append((full_K, full_V))

        self.total_seq_len = 0
        self.seq_len_total_old = 0
        self.evicted = 0
        self.draft_stable_kv = None
        self.chunks = None
        self.cached_chunks = None
        self.draft_model.model.past_key_position_ids = None
        
        self.recent_start = 0
        self.recent_end = 0

    def update_full_draft_cache(self, new_kv: List[Tuple[torch.Tensor, torch.Tensor]], tokens_appended: int):
        """
        Update the full draft KV cache with the new tokens.
        new_kv is the returned KV from the forward pass (a working-cache view).
        tokens_appended is the number of new tokens processed in this forward pass.
        
        The full cache is preallocated with size self.full_cache_budget, and
        self.total_seq_len tracks the current number of tokens stored.
        This function copies the last tokens_appended tokens from new_kv (from the working view)
        into the full cache.
        """
        # Check that we don't exceed the allocated budget.
        if self.total_seq_len + tokens_appended > self.full_cache_budget:
            raise RuntimeError(
                f"Full cache budget exceeded: total_seq_len {self.total_seq_len} + new {tokens_appended} > {self.full_cache_budget}"
            )

        # Precompute destination slice indices.
        dest_start = self.total_seq_len
        dest_end = dest_start + tokens_appended
        device = self.draft_model.device

        # For each layer in the new KV, copy the last tokens_appended tokens into the full cache.
        for i, (new_K, new_V) in enumerate(new_kv):
            full_K, full_V = self.full_draft_kv[i]
            # Ensure new_K and new_V are on the correct device.
            new_K = new_K.to(device, non_blocking=True)
            new_V = new_V.to(device, non_blocking=True)
            # Copy the last tokens_appended tokens from new_K/new_V into the full cache.
            full_K[:, :, dest_start:dest_end, :].copy_(new_K[:, :, -tokens_appended:, :])
            full_V[:, :, dest_start:dest_end, :].copy_(new_V[:, :, -tokens_appended:, :])
        
        # Update the total sequence length.
        self.total_seq_len = dest_end

    def update_working_cache_from_full(self) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """
        Build the working cache (self.draft_stable_kv) by indexing into the full cache.
        The working cache is defined to be the concatenation of:
        - the first sink_size tokens (the "sink" region), and 
        - the last recent_size tokens (the "recent" region)
        If total_seq_len is less than sink_size+recent_size, simply use all tokens.
        Also updates self.evicted to be the number of tokens that are outside the working window.
        """
        working_kv = []
        for (full_K, full_V) in self.full_draft_kv:
            if self.total_seq_len <= self.sink_size + self.recent_size:
                working_K_layer = full_K[:, :, :self.total_seq_len, :].clone()
                working_V_layer = full_V[:, :, :self.total_seq_len, :].clone()
                # print(colored(f'Using cache regions: {0}~{self.total_seq_len-1}','magenta'))
            else:
                sink_part_K = full_K[:, :, :self.sink_size, :].clone()
                sink_part_V = full_V[:, :, :self.sink_size, :].clone()
                recent_part_K = full_K[:, :, self.total_seq_len - self.recent_size:self.total_seq_len, :].clone()
                recent_part_V = full_V[:, :, self.total_seq_len - self.recent_size:self.total_seq_len, :].clone()
                working_K_layer = torch.cat([sink_part_K, recent_part_K], dim=2)
                working_V_layer = torch.cat([sink_part_V, recent_part_V], dim=2)
                # print(colored(f'Using cache regions: {0}~{self.sink_size-1}, {self.total_seq_len - self.recent_size}~{self.total_seq_len-1}','magenta'))
            working_kv.append((working_K_layer, working_V_layer))
        self.evicted = max(self.total_seq_len - (self.sink_size + self.recent_size), 0)
        working_kv_len = working_kv[0][0].shape[2]
        # print(colored(f'Working KV length: {working_kv_len}','red'))
        
        # truncate Draft model's past_key_position_ids:
         # this is when the prefill chunk size is smaller than the working cache size (only right after prefill)
        self.draft_model.model.past_key_position_ids = (
            torch.cat([self.draft_model.model.past_key_position_ids,
                    torch.arange(self.draft_model.model.past_key_position_ids.shape[1], working_kv_len, device=self.draft_model.model.past_key_position_ids.device).unsqueeze(0)],
                    dim=1)
            if self.draft_model.model.past_key_position_ids.shape[1] < working_kv_len
            else self.draft_model.model.past_key_position_ids
        )[:, :working_kv_len]

        self.recent_start = max(0, self.total_seq_len - self.recent_size)
        self.recent_end = self.total_seq_len-1
        return working_kv
    
    def print_cache_status(self):
        print(colored(f'Total KV Length: {self.total_seq_len}, Sink: 0~{self.sink_size-1}, Recent: {self.recent_start}~{self.recent_end}','yellow'))

    @torch.no_grad()
    def parallel_prefill_draft(self, input_ids: torch.Tensor, chunk_size: int = 256):
        """
        Splits the full input_ids (of shape [1, L]) into chunks, pads and stacks them
        into a batched input, then runs a single forward pass on the draft model.
        The resulting KV caches are trimmed per-chunk and concatenated in sequence,
        then used to update the full draft cache.
        
        Args:
            input_ids: Tensor of shape [1, L] representing the full input.
            chunk_size: Number of tokens per chunk.
            
        Returns:
            A tuple (last_hidden, full_past_key_values) where:
            - last_hidden: the last hidden state from the final (last) chunk.
            - full_past_key_values: the full stacked KV cache from all chunks.
        """
        B, L = input_ids.shape  # assume B = 1 for generation
        
        # torch.cuda.synchronize()
        # 1. Split the input into chunks and record their lengths.
        chunks = []
        chunk_lengths = []
        for i in range(0, L, chunk_size):
            end = min(i + chunk_size, L)
            # Each chunk has shape [1, chunk_len]; squeeze to get [chunk_len]
            chunk = input_ids[:, i:end].squeeze(0)
            chunks.append(chunk)
            chunk_lengths.append(chunk.size(0))
        
        # 2. Pad the chunks so they all have the same length.
        # Use the model's pad_token_id if available, else default to 0.
        pad_token = getattr(self.draft_model.model.config, "pad_token_id", 0)
        # pad_sequence returns a tensor of shape [num_chunks, max_chunk_length]
        padded_chunks = torch.nn.utils.rnn.pad_sequence(chunks, batch_first=True, padding_value=pad_token)
        
        # 3. Run a batched forward pass.
        # draft_outputs[0] will be hidden states with shape [num_chunks, max_chunk_length, hidden_dim]
        # draft_outputs[1] is assumed to be the KV cache as a tuple (one tuple per layer),
        # each element having shape [num_chunks, num_heads, seq_len, head_dim].
        draft_outputs = self.draft_model.model(
            input_ids=padded_chunks,
            # use_cache=True,
            return_kv=True
        )
        # print(colored(f'draft_outputs[0].shape: {draft_outputs[0].shape}','magenta'))
        # print(colored(f'draft_outputs[1][0].shape: {draft_outputs[1][0][0].shape}','magenta'))
        
        # 4. Rebuild the full KV cache by unpadding and concatenating along the sequence dimension.
        full_past_key_values = []
        # Process each layer's KV cache independently.
        for layer_kv in draft_outputs[1]:
            k, v = layer_kv  # shapes: [num_chunks, num_heads, seq_len, head_dim]
            k_list, v_list = [], []
            for i, seq_len in enumerate(chunk_lengths):
                # Take only the non-padded tokens from each chunk.
                k_list.append(k[i, :, :seq_len, :])
                v_list.append(v[i, :, :seq_len, :])
            # Concatenate along the sequence dimension.
            k_cat = torch.cat(k_list, dim=1)  # shape: [num_heads, total_seq_length, head_dim]
            v_cat = torch.cat(v_list, dim=1)
            # Add batch dimension (since B=1).
            full_past_key_values.append((k_cat.unsqueeze(0), v_cat.unsqueeze(0)))
        
        # 6. Use the final chunk's output as the overall result.
        # Get the last hidden state of the last non-padded token in the last chunk.
        last_chunk_length = chunk_lengths[-1]
        # print(colored(f'last_chunk_length:{last_chunk_length}','magenta'))
        # print(colored(f'draft_outputs[0].shape: {draft_outputs[0].shape}','magenta'))
        last_chunk_hidden = draft_outputs[0][-1, :last_chunk_length, :].unsqueeze(0) # [1, seq_len, hidden_dim]
        outputs = (last_chunk_hidden, full_past_key_values)
        return outputs
        
    def prepare_chunks(self):
        """
        Called once (right after prefill) to split the full cache (of length self.total_seq_len)
        into consecutive chunks of fixed size (self.retrieval_chunk_size). Each chunk is represented as a tuple:
        (chunk_idx, start, end) where end - start <= self.retrieval_chunk_size.
        """
        self.chunks = []
        current_start = 0
        chunk_idx = 0
        while current_start < self.total_seq_len:
            end_pos = min(current_start + self.retrieval_chunk_size, self.total_seq_len)
            self.chunks.append((chunk_idx, current_start, end_pos))
            chunk_idx += 1
            current_start = end_pos
        # Save the current full length so that later we know how many new tokens were appended.
        self.seq_len_total_old = self.total_seq_len
        self.num_chunks = len(self.chunks)
        self.num_chunks_old = self.num_chunks

    def update_chunks(self):
        """
        Called after new tokens have been appended to the full KV cache.
        self.total_seq_len has been updated externally (by update_full_draft_cache).
        This function updates self.chunks to reflect the new total length.
        
        It does so by:
        1) Filling the last chunk (if not already full) with some of the new tokens.
        2) Creating new chunk(s) (each of size self.retrieval_chunk_size, except possibly the last one)
            for any remaining new tokens.
        """
        # Calculate how many new tokens were appended.
        new_tokens = self.total_seq_len - self.seq_len_total_old
        if new_tokens <= 0:
            return  # No new tokens; nothing to do.

        # If there are no chunks yet, create them from scratch.
        if not hasattr(self, "chunks") or self.chunks is None or len(self.chunks) == 0:
            self.prepare_chunks()
            return

        # Get the last chunk's info.
        last_chunk_idx, last_start, last_end = self.chunks[-1]
        last_chunk_size = last_end - last_start
        remaining_new_tokens = new_tokens

        # 1) If the last chunk is not full, fill it up as much as possible.
        capacity = self.retrieval_chunk_size - last_chunk_size
        if capacity > 0:
            tokens_to_add = min(capacity, remaining_new_tokens)
            # Update the last chunk's end index.
            self.chunks[-1] = (last_chunk_idx, last_start, last_end + tokens_to_add)
            remaining_new_tokens -= tokens_to_add
    
        # 2) For any remaining tokens, create new chunks.
        current_start = self.total_seq_len - remaining_new_tokens
        while remaining_new_tokens > 0:
            tokens_in_chunk = min(self.retrieval_chunk_size, remaining_new_tokens)
            new_chunk = (self.chunks[-1][0] + 1, current_start, current_start + tokens_in_chunk)
            self.chunks.append(new_chunk)
            current_start += tokens_in_chunk
            remaining_new_tokens -= tokens_in_chunk

        # Update the stored old full length.
        self.seq_len_total_old = self.total_seq_len
        self.num_chunks = len(self.chunks)

        if self.num_chunks > self.num_chunks_old:
            self.num_chunks_old = self.num_chunks
            return True # new chunk was added => update working cache
        return False

    def update_working_cache_retrieval(self, top_k_chunks: int = 15,
                                       do_retrieval=False,
                                       is_updated_chunks=False) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        # initial cache: use recent chunks (we don't have attn scores yet)
        if not hasattr(self, "selected_chunks"):
            num_init = min(self.retrieve_top_k, len(self.chunks))
            self.selected_chunks = self.chunks[-num_init:]
        
        # Only retrieve top-k upon retrieval condition
        if do_retrieval:
            attn = self.attn_scores_final

            n = len(self.chunks)
            if n == 0:
                raise ValueError("No chunks available for retrieval.")

            chunks_tensor = torch.tensor(
                [[start, end] for (_, start, end) in self.chunks],
                dtype=torch.long,
                device=attn.device
            )

            starts = chunks_tensor[:, 0]  # shape: [num_chunks]
            ends = chunks_tensor[:, 1]    # shape: [num_chunks]

            # Compute cumulative sum of attn scores for fast range-sum computation
            cum_attn = torch.cumsum(attn, dim=0) # shape: [L]
    
            # For each chunk, the sum is cum_attn[ends-1] - (cum_attn[starts-1] if start>0 else 0)
            lower = torch.where(starts > 0, cum_attn[starts - 1], torch.zeros_like(starts, dtype=attn.dtype))
            ends_minus_one = torch.clamp(ends - 1, max=cum_attn.size(0) - 1)
            chunk_sums = cum_attn[ends_minus_one] - lower

            # Compute lengths and then means (cast lengths to float)
            lengths = (ends - starts).float() 
            chunk_means = chunk_sums / lengths  # shape: [num_chunks]
            
            k = min(top_k_chunks, chunk_means.size(0))
            topk = torch.topk(chunk_means, k=k)
            selected_indices = topk.indices  # indices into the list of chunks
            selected_chunks = [self.chunks[i] for i in selected_indices.tolist()]

            selected_chunks.sort(key=lambda x: x[0])
            self.selected_chunks = selected_chunks

            # reset retrieval condition and attn scores
            self.retrieval_condition = False
            self.attn_scores = None
            self.attn_scores_final = None

        # if new chunk is added, automatically update
        if is_updated_chunks:
            # grab the newly created chunk
            new_chunk = self.chunks[-1]  # (chunk_id, start, end)
            new_chunk_id = new_chunk[0]
            existing_ids = {cid for cid, _, _ in self.selected_chunks}
            # only append it if it’s not already in the selected set
            if new_chunk_id not in existing_ids:
                self.selected_chunks.append(new_chunk)
                
        # update last selected chunk
        if self.selected_chunks[-1][0] == self.chunks[-1][0]:
            chunk_id, start, _ = self.selected_chunks[-1]
            new_end = self.chunks[-1][2]
            self.selected_chunks[-1] = (chunk_id, start, new_end)
                
        all_indices = []
        for (_, start, end) in self.selected_chunks:
            all_indices.extend(range(start, end))
        if len(all_indices) == 0:
            raise ValueError("No tokens retrieved from the full cache. Check your chunk settings and attn_scores_final.")
        
        retrieved_indices = torch.tensor(all_indices, dtype=torch.long)
        retrieved_indices = torch.unique(retrieved_indices, sorted=True).to(self.draft_model.device)

        if retrieved_indices.numel() == 0:
            raise ValueError("No tokens retrieved from the full cache. Check your chunk settings and attn_scores_final.")

        # Build working cache by advanced indexing into full cache for each layer
        working_kv = []
        for (full_K, full_V) in self.full_draft_kv:
            working_K_layer = full_K.index_select(dim=2, index=retrieved_indices)
            working_V_layer = full_V.index_select(dim=2, index=retrieved_indices)
            working_kv.append((working_K_layer, working_V_layer))
        self.draft_stable_kv = working_kv

        # Update evicted count
        self.evicted = self.total_seq_len - retrieved_indices.numel()

        # Update past_key_position_ids
        past_ids = self.draft_model.model.past_key_position_ids  # shape [1, current_length]
        current_length = past_ids.shape[1]
        target_length = retrieved_indices.numel()
        if current_length < target_length:
            extra_ids = torch.arange(current_length, target_length, device=past_ids.device).unsqueeze(0)
            new_past_ids = torch.cat([past_ids, extra_ids], dim=1)
        else:
            new_past_ids = past_ids[:, :target_length]
        self.draft_model.model.past_key_position_ids = new_past_ids

        if self.retrieval_verbose:
            if self.retrieval_condition:
                self.print_retrieved_chunks(order="id")
        return working_kv


    def update_working_cache_retrieval_main(self, top_k_chunks: int = 15):
        """
        Convenience function that first updates the chunk metadata (if new tokens were appended)
        and then updates the working cache (self.draft_stable_kv) based on retrieval.
        
        It assumes that self.attn_scores_final is already set (e.g., computed from the last accepted query)
        and that self.total_seq_len has been updated by update_full_draft_cache.
        """
        # torch.cuda.synchronize()
        if self.measure_time:
            torch.cuda.synchronize()
            start_1 = time.perf_counter()
        is_updated_chunks = self.update_chunks()
        if self.measure_time:
            torch.cuda.synchronize()    
            time_update_chunks = time.perf_counter() - start_1
        # if self.show_time:
        #     print(colored(f'time_update_chunks: {time_update_chunks * 1000:.2f} ms', 'magenta'))

        if self.measure_time:
            torch.cuda.synchronize()
            start_2 = time.perf_counter()
            
        working_kv = self.update_working_cache_retrieval(top_k_chunks=top_k_chunks,
                                                            do_retrieval=self.retrieval_condition,
                                                            is_updated_chunks=is_updated_chunks) 

        if self.measure_time:
            torch.cuda.synchronize()
            time_retrieval = time.perf_counter() - start_2

        # if self.show_time:
        #     print(colored(f'time_retrieval: {time_retrieval * 1000:.2f} ms', 'magenta'))

        # end_kv = time.perf_counter()
        # kv_time = end_kv - start_kv

        # print(f"update_working_cache_retrieval time: {kv_time * 1000:.2f} ms")

        return working_kv
        
    def print_retrieval_cache_status(self):
        """
        Print the status of the retrieval-based cache, including:
        - Total sequence length stored in the full KV cache.
        - Full cache budget.
        - The list of all fixed-size chunks (self.chunks).
        - The retrieved chunks (self.cached_chunks) as determined by retrieval.
        - The number of evicted tokens.
        All output is printed in yellow.
        """
        print(colored("==== Retrieval Cache Status ====", "yellow"))
        print(colored(f"Total Sequence Length: {self.total_seq_len}", "yellow"))
        print(colored(f"Number of Chunks: {len(self.chunks)}", "yellow"))
        print(colored(f"All Chunks: {self.chunks}", "yellow"))
        
        if hasattr(self, "cached_chunks") and self.cached_chunks is not None:
            print(colored(f"Retrieved Chunks: {self.cached_chunks}", "yellow"))
        else:
            print(colored("Retrieved Chunks: None", "yellow"))

        if self.draft_stable_kv is not None:
            print(colored(f'Working cache length: {self.draft_stable_kv[0][0].shape[2]}','yellow'))
    
    def print_retrieved_chunks(self, order="id"):
        if order == "score":
            chunks_list = self.selected_chunks
            msg = "\nRetrieved chunk IDs (descending attn score): "
        elif order == "id":
            chunks_list = sorted(self.selected_chunks, key=lambda x: x[0])
            msg = "\nRetrieved chunk IDs: "
        else:
            print(colored(f"Unknown 'order' option: {order}. Choose 'score' or 'id'.", 'red'))
            return

        chunk_ids_str = ", ".join(str(chunk[0]) for chunk in chunks_list)

        print(colored(msg + chunk_ids_str + '\n', 'yellow'))
        
    @torch.no_grad()
    def test_autoregressive_generation(self, next_token, past_key_values, num_tokens=10):
        # We'll assume that self.draft_stable_kv has been updated already.
        device = next_token.device
        
        # Set the current past KV to the updated working cache.
        current_past = past_key_values
        
        # We'll collect the generated tokens here.
        generated_tokens = [next_token.item()]

        # Loop for num_tokens steps.
        for _ in range(num_tokens):
            # Run a forward pass with the current input.
            outputs = self.draft_model.model(
                input_ids=next_token,
                past_key_values=current_past,
                return_kv=True
            )
            # outputs[0] is the logits, outputs[1] is the updated past key/values.
            last_hidden = outputs[0][:,-1]
            last_headout = self.draft_model.lm_head(last_hidden)
            logits = last_headout.unsqueeze(0)

            current_past = outputs[1]  # update the past KV for the next step

            probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
            next_token = torch.multinomial(probabilities, num_samples=1).view(1, -1)
            # Greedy decoding: select the token with highest probability.
            # logits[:, -1, :] corresponds to the logits for the last time step.
            generated_tokens.append(next_token.item())

        # Decode the generated tokens (using your tokenizer).
        generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        print(colored(f"{generated_text}", 'red'))


    # def update_working_cache_retrieval(self, top_k_chunks: int = 15) -> List[Tuple[torch.Tensor, torch.Tensor]]:
    #     """
    #     Build the working cache for retrieval-based caching by indexing into the full cache.
    #     For retrieval-based caching, we do not use a sink region.
    #     We call retrieve_chunks() (which uses self.attn_scores_final to score each chunk) to get
    #     the selected chunks. Then we combine the token indices from these chunks (in sorted order)
    #     and build the working cache KV pairs.
        
    #     Returns:
    #         working_kv (List[Tuple[torch.Tensor, torch.Tensor]]): The working cache KV pairs.
    #     """
    #     # Retrieve the selected chunks.
    #     selected_chunks = self.retrieve_chunks(top_k_chunks=top_k_chunks)
        
    #     # # Combine token indices from selected chunks.
    #     # retrieved_indices = []
    #     # for (_, start, end) in selected_chunks:
    #     #     retrieved_indices.extend(range(start, end))
    #     # retrieved_indices = sorted(set(retrieved_indices))

    #     # Combine token indices from selected chunks.
    #     chunk_ranges = torch.cat([torch.arange(start, end, device='cuda') 
    #                             for _, start, end in selected_chunks])
    #     retrieved_indices = torch.unique(chunk_ranges).sort().values

    #     if len(retrieved_indices) == 0:
    #         raise ValueError("No tokens retrieved from the full cache. Check your chunk settings and attn_scores_final.")
        
    #     # Build working cache by indexing into each layer's full cache.
    #     working_kv = []
    #     for (full_K, full_V) in self.full_draft_kv:
    #         # working_K_layer = full_K[:, :, retrieved_indices, :].clone()
    #         # working_V_layer = full_V[:, :, retrieved_indices, :].clone()

    #         full_K = full_K.contiguous()
    #         full_V = full_V.contiguous()
    #         working_K_layer = full_K.index_select(2, retrieved_indices)
    #         working_V_layer = full_V.index_select(2, retrieved_indices)

    #         working_kv.append((working_K_layer, working_V_layer))
        
    #     # Update the working cache.
    #     self.draft_stable_kv = working_kv
        
    #     # Update the evicted count (number of tokens not retrieved).
    #     self.evicted = self.total_seq_len - len(retrieved_indices)
        
    #     # Optionally, store the selected chunks for debugging.
    #     self.cached_chunks = selected_chunks

    #     past_ids = self.draft_model.model.past_key_position_ids
    #     current_length = past_ids.shape[1]
    #     target_length = len(retrieved_indices)
        
    #     # truncate Draft model's past_key_position_ids accordingly
    #     # occurs if prefill chunk size is less than working retrieval cache size
    #     if current_length < target_length:
    #         # Create additional indices from current_length up to target_length - 1
    #         extra_ids = torch.arange(current_length, target_length, device=past_ids.device).unsqueeze(0)
    #         # Concatenate along the second dimension to fill up to target_length
    #         new_past_ids = torch.cat([past_ids, extra_ids], dim=1)
    #     else:
    #         # Truncate if current_length is greater than or equal to target_length
    #         new_past_ids = past_ids[:, :target_length]
    #     self.draft_model.model.past_key_position_ids = new_past_ids
    #     return working_kv

    @torch.no_grad()
    def streaming_prefill_draft(self, input_ids: torch.Tensor, chunk_size: int = 256):
        """
        Processes the full input_ids in small chunks to prefill the draft model’s KV caches.
        After processing each chunk, updates the full cache by appending only the new tokens,
        then rebuilds the working KV cache by indexing into the full cache.
        
        Args:
            input_ids: Tensor of shape [1, L] (the entire input sequence).
            chunk_size: The size of each chunk.
            
        Returns:
            The final output from the draft model (the output from processing the last chunk).
        """
        device = self.draft_model.model.embed_tokens.weight.device
        input_ids = input_ids.to(device, non_blocking=True)
        total_processed = 0
        B, L = input_ids.shape  # typically B=1 for generation
        out = None

        past_key_values = None
        while total_processed < L:
            end_idx = min(total_processed + chunk_size, L)
            # chunk_ids = input_ids[:, total_processed:end_idx].to(device)
            chunk_ids = input_ids[:, total_processed:end_idx]
            
            if past_key_values is None:
                # First chunk: no previous KV exists.
                out = self.draft_model.model(
                    input_ids=chunk_ids,
                    use_cache=True,
                    return_kv=True
                )
            else:
                out = self.draft_model.model(
                    input_ids=chunk_ids,
                    past_key_values=past_key_values,
                    use_cache=True,
                    return_kv=True
                )
            # Number of new tokens processed in this chunk.
            tokens_appended = chunk_ids.shape[-1]
            # Update the full KV cache with these new tokens.
            # Note: update_full_draft_cache copies the last tokens_appended tokens
            # from the working cache view (out[1]) into the full cache.

            self.update_full_draft_cache(out[1], tokens_appended=tokens_appended)
            # Rebuild the working cache by indexing into the full cache.
            past_key_values = self.update_working_cache_from_full()
            total_processed = end_idx
 

        return out, past_key_values