"""
The ensemble of multiple standard transformers LLM models, with automatic kv-cache projection. It shares the same interface as the standard transformers LLM models.
"""

from typing import List, Optional, Union
import torch
import copy
from torch import nn
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
import json

from rosetta.model.projector import Projector
from rosetta.model.aggregator import Aggregator
from rosetta.model.sampling import sample_token
from transformers.utils import ModelOutput
try:
    from transformers.generation.utils import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
except Exception:
    GreedySearchDecoderOnlyOutput = None
    SampleDecoderOnlyOutput = None

def clone_kv_cache(kv_cache: DynamicCache) -> DynamicCache:
        new_cache = DynamicCache()
        for k, v in zip(kv_cache.key_cache, kv_cache.value_cache):
            new_cache.key_cache.append(k.clone().detach())
            new_cache.value_cache.append(v.clone().detach())
        return new_cache

class RosettaModel(nn.Module):
    """
    Drop in replacement for the standard transformers LLM models, like Qwen3ForCausalLM
    """
    def __init__(self, model_list: List[PreTrainedModel], base_model_idx = 0, projector_list: List[Projector] = [], aggregator_list: List[Aggregator] = [], include_response: bool = False):
        super().__init__()
        # model list: a list of model, model 0 by default is the base model
        # projector list: a list of projector
        # standard init with additional model list parameter
        # kv-cache dict: key (source_model_idx, target_model_idx), value (Cache), assume only convert at prefill with one type of model
        # projector dict: key (source_model_idx, target_model_idx) value dict(key (source_model_layer_idx, M_target value )

        self.base_model_idx = base_model_idx
        self.model_list = nn.ModuleList(model_list)

        device = model_list[base_model_idx].device
        dtype = model_list[base_model_idx].dtype
        self.projector_list = nn.ModuleList(projector_list).to(device=device, dtype=dtype)
        self.aggregator_list = nn.ModuleList(aggregator_list).to(device=device, dtype=dtype)

        self.projector_dict = {}
        self.aggregator_dict = {}
        self.kv_cache_dict = {}
        self._generation_hook_handlers = []

        self.include_response = include_response

    @property
    def device(self):
        return self.model_list[self.base_model_idx].device
    
    def to(self, device):
        """
        Move the RosettaModel and all underlying models and projectors to the specified device.
        """
        super().to(device)
        for model in self.model_list:
            model.to(device)
        for projector in self.projector_list:
            projector.to(device)
        for aggregator in self.aggregator_list:
            aggregator.to(device)
        return self
        
    # set projector 
    def set_projector_config(self, 
                        source_model_idx: int, 
                        source_model_layer_idx: int, 
                        target_model_idx: int,
                        target_model_layer_idx: int, 
                        projector_idx: int):
        """
        Set the projector configuration
        Args:
            source_model_idx: int, the index of the source model
            source_model_layer_idx: int, the index of the source model layer
            target_model_idx: int, the index of the target model
            target_model_layer_idx: int, the index of the target model layer
            projector_idx: int, the index of the projector

        The projector dict structure supports multiple projectors per target layer.
        Structure:
        {
            target_model_idx: {
                source_model_idx: {
                    target_model_layer_idx: [(source_model_layer_idx, projector_idx), ...]
                }
            }
        }
        Repeated calls for the same (target, source, target_layer) append additional pairs.
        """

        if target_model_idx not in self.projector_dict.keys():
            self.projector_dict[target_model_idx] = {}
        if source_model_idx not in self.projector_dict[target_model_idx].keys():
            self.projector_dict[target_model_idx][source_model_idx] = {}
        # Accumulate list of (source_layer, projector_idx) for this target layer
        layer_entry = self.projector_dict[target_model_idx][source_model_idx].get(target_model_layer_idx)
        if layer_entry is None:
            self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx] = [(source_model_layer_idx, projector_idx)]
        else:
            layer_entry.append((source_model_layer_idx, projector_idx))


    def load_projector(self, projector_list):
        self.projector_list: List[Projector] = projector_list

    def load_aggregator(self, aggregator_list):
        self.aggregator_list: List[Aggregator] = aggregator_list


    def get_projector(self, 
                        source_model_idx, 
                        source_model_layer_idx, 
                        target_model_idx,
                        target_model_layer_idx):
        pair_list = self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx]
        if len(pair_list) == 0:
            raise ValueError("No projector configured for the given target layer")
        # Prefer exact source layer match
        for src_layer, projector_id in pair_list:
            if src_layer == source_model_layer_idx:
                return self.projector_list[projector_id]
        # Fallback: return the first projector
        return self.projector_list[pair_list[0][1]]

    def set_aggregator_idx(self,
                           source_model_idx: int,
                           target_model_idx: int,
                           target_model_layer_idx: int,
                           aggregator_idx: int):
        if target_model_idx not in self.aggregator_dict:
            self.aggregator_dict[target_model_idx] = {}
        if source_model_idx not in self.aggregator_dict[target_model_idx]:
            self.aggregator_dict[target_model_idx][source_model_idx] = {}
        self.aggregator_dict[target_model_idx][source_model_idx][target_model_layer_idx] = aggregator_idx


    @staticmethod
    def load_json(file_name):
        with open(file_name, "r") as f:
            result = json.load(f)
        return result
    
    @staticmethod
    def _convert_dict_keys_to_ints(obj):
        """
        Recursively convert dictionary keys that look like integers back to int.
        This reverses json.dump's coercion of dict keys to strings.
        """
        if isinstance(obj, dict):
            new_obj = {}
            for key, value in obj.items():
                if isinstance(key, str) and key.lstrip('-').isdigit():
                    new_key = int(key)
                else:
                    new_key = key
                new_obj[new_key] = RosettaModel._convert_dict_keys_to_ints(value)
            return new_obj
        if isinstance(obj, list):
            return [RosettaModel._convert_dict_keys_to_ints(v) for v in obj]
        return obj
    
    
    def save_projector_config(self, file_name):
        with open(file_name, "w") as f:
            json.dump(self.projector_dict, f)

    
    def load_projector_config(self, config_path):
        if config_path.endswith(".json"):
            loaded = RosettaModel.load_json(config_path)
            self.projector_dict = RosettaModel._convert_dict_keys_to_ints(loaded)

    def save_aggregator_config(self, file_name):
        with open(file_name, "w") as f:
            json.dump(self.aggregator_dict, f)

    def load_aggregator_config(self, config_path):
        if config_path.endswith(".json"):
            loaded = RosettaModel.load_json(config_path)
            self.aggregator_dict = RosettaModel._convert_dict_keys_to_ints(loaded)


    def set_kv_cache_dict(self, source_model_idx, target_model_idx, cache):
        if target_model_idx not in self.kv_cache_dict.keys():
            self.kv_cache_dict[target_model_idx] = {}
        if cache is None:
            # Initialize with a DynamicCache instead of RosettaCache for now
            self.kv_cache_dict[target_model_idx][source_model_idx] = DynamicCache() # noqa, maybe we should use RosettaCache here
        else:
            self.kv_cache_dict[target_model_idx][source_model_idx] = cache

    def make_k_proj_hook(new_k_cache):
        def k_proj_hook(module, input, output):
            updated_k_cache = output.clone()
            batch_size, seq_len, dim = updated_k_cache.shape

            reshaped_cache = new_k_cache.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim)
            updated_k_cache[:, 0:seq_len, :] = reshaped_cache
        
            return updated_k_cache
        return k_proj_hook

    def make_v_proj_hook(new_v_cache):
        def v_proj_hook(module, input, output):
            updated_v_cache = output.clone()
            batch_size, seq_len, dim = updated_v_cache.shape

            reshaped_cache = new_v_cache.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim)
            updated_v_cache[:, 0:seq_len, :] = reshaped_cache

            return updated_v_cache
        return v_proj_hook
    
    def register_hooks(self, input_ids, attention_mask, position_ids, base_kv_cache, source_model_idx, source_kv_cache):

        base_kv_copy = clone_kv_cache(base_kv_cache)
        source_kv_copy = clone_kv_cache(source_kv_cache)

        new_length = input_ids.shape[1]

        base_output_kv_cache = self.model_list[self.base_model_idx].forward(
                    input_ids=input_ids,
                    attention_mask=attention_mask, 
                    position_ids=position_ids,
                    past_key_values=base_kv_copy,
                    labels=None,
                    use_cache=True, 
                ).past_key_values
        source_output_kv_cache = self.model_list[source_model_idx].forward(
                    input_ids=input_ids,
                    attention_mask=attention_mask, 
                    position_ids=position_ids,
                    past_key_values=source_kv_copy,
                    labels=None,
                    use_cache=True, 
                ).past_key_values
        
        for target_layer_idx, entry in self.projector_dict[self.base_model_idx][source_model_idx].items():
            base_key_cache, base_value_cache = base_output_kv_cache[target_layer_idx]
            new_base_key_cache = base_key_cache[:, :, -new_length:, :]
            new_base_value_cache = base_value_cache[:, :, -new_length:, :]
            new_base_kv_cache = (new_base_key_cache, new_base_value_cache)

            pair_list = entry

            projected_kv_list = []
            source_kv_list = []
            for source_model_layer_idx, projector_idx in pair_list:
                source_key_cache, source_value_cache = source_output_kv_cache[source_model_layer_idx]
                new_source_key_cache = source_key_cache[:, :, -new_length:, :]
                new_source_value_cache = source_value_cache[:, :, -new_length:, :]
                new_source_kv_cache = (new_source_key_cache, new_source_value_cache)
                projected_key, projected_value = self.projector_list[projector_idx].forward(
                    new_source_kv_cache, # tuple of (key, value), each of shape (B, N, H, D)
                    new_base_kv_cache
                )
                projected_kv_list.append((projected_key, projected_value))
                source_kv_list.append(new_source_kv_cache)

            # Aggregate (fallback to first projector if no aggregator is available)
            use_aggregator = (
                len(projected_kv_list) > 1 and
                len(self.aggregator_list) > 0 and
                self.base_model_idx in self.aggregator_dict and
                source_model_idx in self.aggregator_dict[self.base_model_idx] and
                target_layer_idx in self.aggregator_dict[self.base_model_idx][source_model_idx]
            )

            if use_aggregator:
                aggregator_idx = self.aggregator_dict[self.base_model_idx][source_model_idx][target_layer_idx]
                agg_key, agg_value = self.aggregator_list[aggregator_idx].forward(
                    source_kv_list,
                    new_base_kv_cache,
                    projected_kv_list
                )
            else:
                # Fallback to first projector result when no aggregator is available
                agg_key, agg_value = projected_kv_list[0]

            # Update cache
            base_output_kv_cache.key_cache[target_layer_idx][:, :, -new_length:, :] = agg_key
            base_output_kv_cache.value_cache[target_layer_idx][:, :, -new_length:, :] = agg_value

        hook_handlers = []

        for i in range(self.model_list[self.base_model_idx].config.num_hidden_layers):
            handler_k = self.model_list[self.base_model_idx].model.layers[i].self_attn.k_proj.register_forward_hook(
                RosettaModel.make_k_proj_hook(base_output_kv_cache.key_cache[i][:, :, -new_length:, :])
            )
            handler_v = self.model_list[self.base_model_idx].model.layers[i].self_attn.v_proj.register_forward_hook(
                RosettaModel.make_v_proj_hook(base_output_kv_cache.value_cache[i][:, :, -new_length:, :])
            )
            hook_handlers.append((handler_k, handler_v))

        return hook_handlers
    
    def remove_hooks(self, hook_handlers):
        for handler_k, handler_v in hook_handlers:
            handler_k.remove()
            handler_v.remove()

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, 
                                    inputs_embeds=None, cache_position=None, position_ids=None, 
                                    use_cache=True, **kwargs):
        """
        Custom prepare_inputs_for_generation to handle KV cache projection during generation
        """
        # Get prepare inputs from base model's original method
        original_prepare_inputs = getattr(self, '_original_prepare_inputs', None)
        if original_prepare_inputs is not None:
            # Use the correct parameter format for transformers' prepare_inputs_for_generation
            base_model_inputs = original_prepare_inputs(
                input_ids, 
                past_key_values=past_key_values, 
                attention_mask=attention_mask, 
                inputs_embeds=inputs_embeds, 
                use_cache=use_cache, 
                cache_position=cache_position,
                **kwargs
            )
        else:
            # Fallback: create basic inputs manually
            base_model_inputs = {
                'input_ids': input_ids,
                'past_key_values': past_key_values,
                'attention_mask': attention_mask,
                'position_ids': position_ids,
                'use_cache': use_cache
            }
        
        # If we're in generation mode and have projectors configured, update caches and set up hooks
        if (past_key_values is not None and 
            hasattr(self, '_in_generation') and self._in_generation and
            self.base_model_idx in self.projector_dict and
            len(self.projector_dict[self.base_model_idx]) > 0):
            
            # Clean up any existing hooks first
            self._cleanup_generation_hooks()
            
            # Update source model caches for the new input_ids
            self._update_generation_source_cache(input_ids, attention_mask, position_ids)
            
            # Update base model cache and compute projections
            updated_base_cache = self._update_base_cache_and_compute_projections(
                input_ids, attention_mask, position_ids, past_key_values
            )
            
            # Set up hooks with pre-computed projections
            self._setup_generation_hooks_with_projections()
        
        return base_model_inputs
    
    def _setup_generation_hooks_with_projections(self):
        """
        Set up hooks for generation process using pre-computed projections
        """
        if not hasattr(self, '_generation_projections'):
            return
            
        # Set up hooks for each layer that has projections
        for target_layer_idx, (projected_key, projected_value) in self._generation_projections.items():
            # Create hooks to override the k_proj and v_proj outputs for this layer
            def make_k_override_hook(proj_k):
                def k_hook(module, input, output):
                    # Replace the output with projected key
                    batch_size, seq_len, dim = output.shape
                    projected_reshaped = proj_k.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim)
                    return projected_reshaped
                return k_hook
            
            def make_v_override_hook(proj_v):
                def v_hook(module, input, output):
                    # Replace the output with projected value
                    batch_size, seq_len, dim = output.shape
                    projected_reshaped = proj_v.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim)
                    return projected_reshaped
                return v_hook
            
            # Register hooks
            k_handler = self.model_list[self.base_model_idx].model.layers[target_layer_idx].self_attn.k_proj.register_forward_hook(
                make_k_override_hook(projected_key)
            )
            v_handler = self.model_list[self.base_model_idx].model.layers[target_layer_idx].self_attn.v_proj.register_forward_hook(
                make_v_override_hook(projected_value)
            )
            self._generation_hook_handlers.extend([k_handler, v_handler])

    def _cleanup_generation_hooks(self):
        """
        Clean up generation hooks and projections
        """
        for handler in self._generation_hook_handlers:
            handler.remove()
        self._generation_hook_handlers = []
        
        # Clean up projections cache
        if hasattr(self, '_generation_projections'):
            self._generation_projections.clear()

    def _update_generation_source_cache(self, input_ids, attention_mask, position_ids):
        """
        Update source model caches during generation
        """
        if not hasattr(self, '_generation_source_cache'):
            self._generation_source_cache = {}
        

        if input_ids.shape[1] > 1:

            last_token_ids = input_ids[:, -1:]
            last_position_ids = position_ids[:, -1:] if position_ids is not None else None
        else:
            last_token_ids = input_ids
            last_position_ids = position_ids
            
        for source_model_idx in self.projector_dict[self.base_model_idx].keys():
            if source_model_idx not in self._generation_source_cache:
                self._generation_source_cache[source_model_idx] = None
                
            # Forward through source model to get updated cache
            source_output = self.model_list[source_model_idx].forward(
                input_ids=last_token_ids,
                attention_mask=attention_mask,  
                position_ids=last_position_ids,
                past_key_values=self._generation_source_cache[source_model_idx],
                use_cache=True
            )
            self._generation_source_cache[source_model_idx] = source_output.past_key_values

    def _update_base_cache_and_compute_projections(self, input_ids, attention_mask, position_ids, past_key_values):
        """
        Update base model cache and compute projections for the new token
        """
        
        if hasattr(self, '_generation_base_cache_clean'):
            clean_base_cache = self._generation_base_cache_clean
        else:
            clean_base_cache = past_key_values

       
        if input_ids.shape[1] > 1:    
            last_token_ids = input_ids[:, -1:]
            last_position_ids = position_ids[:, -1:] if position_ids is not None else None
        else:
            last_token_ids = input_ids
            last_position_ids = position_ids
        
        # Forward through base model using clean cache
        base_output = self.model_list[self.base_model_idx].forward(
            input_ids=last_token_ids,
            attention_mask=attention_mask,  
            position_ids=last_position_ids,
            past_key_values=clean_base_cache,
            use_cache=True
        )
        updated_base_cache = base_output.past_key_values
        

        self._generation_base_cache_clean = updated_base_cache
        
        # Store projections for hooks
        if not hasattr(self, '_generation_projections'):
            self._generation_projections = {}
        
        # Compute projections for each configured layer
        for source_model_idx in self.projector_dict[self.base_model_idx].keys():
            if source_model_idx in self._generation_source_cache:
                source_kv_cache = self._generation_source_cache[source_model_idx]
                
                for target_layer_idx, (source_model_layer_idx, projector_idx) in self.projector_dict[self.base_model_idx][source_model_idx].items():
                    source_key_cache, source_value_cache = source_kv_cache[source_model_layer_idx]
                    base_key_cache, base_value_cache = updated_base_cache[target_layer_idx]

                    # Get the new token's KV (last position in the cache)
                    new_source_key_cache = source_key_cache[:, :, -1:, :]
                    new_source_value_cache = source_value_cache[:, :, -1:, :]
                    new_base_key_cache = base_key_cache[:, :, -1:, :]
                    new_base_value_cache = base_value_cache[:, :, -1:, :]

                    new_source_kv_cache = (new_source_key_cache, new_source_value_cache)
                    new_base_kv_cache = (new_base_key_cache, new_base_value_cache)

                    # Apply projection
                    projected_key, projected_value = self.projector_list[projector_idx].forward(
                        new_source_kv_cache,
                        new_base_kv_cache
                    )
                    
                    # Store projections for this layer
                    self._generation_projections[target_layer_idx] = (projected_key, projected_value)
        
        return updated_base_cache

    def forward(
        self,
        kv_cache_index: Optional[List] = None,
        input_ids: Optional[Union[torch.LongTensor, List[torch.LongTensor]]] = None,
        attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        # **kwargs: Unpack[KwargsForCausalLM],
        *args,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        """
        Forward pass
        KVCache index is a list of tensors with shape (B, sec_seq_len, 2), indicating the source and target kv cache model index

        If input_ids is LongTensor, default to same input ids for different models
        If input_ids is Tuple, default to different input ids for different models.

        No Rosetta: (-1, 0)
        """
        
        # noqa
        self.kv_cache_dict = dict()

        # Handle different input formats: if input_ids is a list, use per-model inputs
        if isinstance(input_ids, list):
            # Use list format: different input_ids and attention_mask for each model
            base_input_ids = input_ids[self.base_model_idx] if input_ids is not None else None
            base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
            _, seqlen = base_input_ids.size() if base_input_ids is not None else (0, 0)
        else:
            # Use tensor format: same input_ids and attention_mask for all models (backward compatibility)
            base_input_ids = input_ids
            base_attention_mask = attention_mask
            _, seqlen = input_ids.size() if input_ids is not None else (0, 0)

        num_sections = len(kv_cache_index) if kv_cache_index is not None else 1

        section_lengths = [kv_cache_index[i].shape[1] for i in range(num_sections)] if kv_cache_index is not None else [seqlen]
        section_starts = [0]
        for l in section_lengths:
            section_starts.append(section_starts[-1] + l)
        
        curr_base_kv_cache = past_key_values

        if seqlen >= 1:
            for i in range(num_sections):
                start = section_starts[i]
                end = section_starts[i + 1]
                prefill_input_ids = base_input_ids[:, start:end] if base_input_ids is not None else None
                prefill_attention_mask = base_attention_mask[:, :end] if base_attention_mask is not None else None
                prefill_position_ids = position_ids[:, start:end] if position_ids is not None else None
                prefill_labels = labels[:, start:end] if labels is not None else None

                if i == num_sections - 1 and self.include_response:
                    hook_handlers = self.register_hooks(input_ids=input_ids[:, start:end], attention_mask=attention_mask[:, :end], position_ids=position_ids[:, start:end],
                                                        base_kv_cache=self.kv_cache_dict[self.base_model_idx][self.base_model_idx],
                                                        source_model_idx=1, 
                                                        source_kv_cache=self.kv_cache_dict[self.base_model_idx][1])

                # calculate target model kvcache
                output = self.model_list[self.base_model_idx].forward(
                    input_ids=prefill_input_ids,
                    attention_mask=prefill_attention_mask, 
                    position_ids=prefill_position_ids,
                    past_key_values=curr_base_kv_cache,
                    labels=prefill_labels,
                    use_cache=use_cache, 
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    *args,
                    **kwargs
                )

                if self.base_model_idx not in self.kv_cache_dict:
                    self.kv_cache_dict[self.base_model_idx] = {}
                if self.base_model_idx not in self.kv_cache_dict[self.base_model_idx]:
                    self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = None
                self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = output.past_key_values

                curr_base_kv_cache: DynamicCache = output.past_key_values

                if i == num_sections - 1 and self.include_response:
                    self.remove_hooks(hook_handlers)
                
                # if i != num_sections - 1:
                for source_model_idx in range(1, len(self.model_list)):
                    if self.base_model_idx not in self.kv_cache_dict:
                        self.kv_cache_dict[self.base_model_idx] = {}
                    if source_model_idx not in self.kv_cache_dict[self.base_model_idx]:
                        self.kv_cache_dict[self.base_model_idx][source_model_idx] = None

                    # Get model-specific input_ids and attention_mask
                    if isinstance(input_ids, list):
                        source_input_ids = input_ids[source_model_idx]
                        source_attention_mask = attention_mask[source_model_idx] if attention_mask is not None else None
                        source_prefill_input_ids = source_input_ids[:, start:end] if source_input_ids is not None else None
                        source_prefill_attention_mask = source_attention_mask[:, :end] if source_attention_mask is not None else None
                    else:
                        # Backward compatibility: use same input for all models
                        source_prefill_input_ids = prefill_input_ids
                        source_prefill_attention_mask = prefill_attention_mask

                    curr_source_kv_cache = self.model_list[source_model_idx].forward(
                        input_ids=source_prefill_input_ids,
                        attention_mask=source_prefill_attention_mask,
                        position_ids=prefill_position_ids,
                        past_key_values=self.kv_cache_dict[self.base_model_idx][source_model_idx],
                        use_cache=use_cache, 
                        output_attentions=output_attentions,
                        output_hidden_states=output_hidden_states,
                        *args,
                        **kwargs
                    ).past_key_values
                    self.kv_cache_dict[self.base_model_idx][source_model_idx] = curr_source_kv_cache

                # calculate source model kvcache and apply projections
                if self.base_model_idx in self.projector_dict:
                    source_model_idx = kv_cache_index[i][0][0][0].item()  # Get the source model index from the kv_cache_index
                    if source_model_idx != -1:
                        for target_layer_idx, entry in self.projector_dict[self.base_model_idx][source_model_idx].items():
                            base_key_cache, base_value_cache = curr_base_kv_cache[target_layer_idx]
                            new_base_key_cache = base_key_cache[:, :, start:end, :]
                            new_base_value_cache = base_value_cache[:, :, start:end, :]
                            new_base_kv_cache = (new_base_key_cache, new_base_value_cache)

                            pair_list = entry

                            projected_kv_list = []
                            source_kv_list = []
                            for source_model_layer_idx, projector_idx in pair_list:
                                source_key_cache, source_value_cache = self.kv_cache_dict[self.base_model_idx][source_model_idx][source_model_layer_idx]
                                new_source_key_cache = source_key_cache[:, :, start:end, :]
                                new_source_value_cache = source_value_cache[:, :, start:end, :]
                                new_source_kv_cache = (new_source_key_cache, new_source_value_cache)
                                projected_key, projected_value = self.projector_list[projector_idx].forward(
                                    new_source_kv_cache, # tuple of (key, value), each of shape (B, N, H, D)
                                    new_base_kv_cache
                                )
                                projected_kv_list.append((projected_key, projected_value))
                                source_kv_list.append(new_source_kv_cache)

                            # Aggregate (fallback to first projector if no aggregator is available)
                            use_aggregator = (
                                len(projected_kv_list) > 1 and
                                len(self.aggregator_list) > 0 and
                                self.base_model_idx in self.aggregator_dict and
                                source_model_idx in self.aggregator_dict[self.base_model_idx] and
                                target_layer_idx in self.aggregator_dict[self.base_model_idx][source_model_idx]
                            )

                            if use_aggregator:
                                aggregator_idx = self.aggregator_dict[self.base_model_idx][source_model_idx][target_layer_idx]
                                agg_key, agg_value = self.aggregator_list[aggregator_idx].forward(
                                    source_kv_list,
                                    new_base_kv_cache,
                                    projected_kv_list
                                )
                            else:
                                # Fallback to first projector result when no aggregator is available
                                agg_key, agg_value = projected_kv_list[0]

                            # Update cache with aggregated result
                            curr_base_kv_cache.key_cache[target_layer_idx][:, :, start:end, :] = agg_key
                            curr_base_kv_cache.value_cache[target_layer_idx][:, :, start:end, :] = agg_value
                        
                        output.past_key_values = curr_base_kv_cache
                                                                             
        # use base model for decode phase
        else:
            # Handle list input format for decode phase as well
            decode_input_ids = input_ids[self.base_model_idx] if isinstance(input_ids, list) else input_ids
            decode_attention_mask = attention_mask[self.base_model_idx] if isinstance(attention_mask, list) else attention_mask
            
            output = self.model_list[self.base_model_idx].forward(
                input_ids=decode_input_ids,
                attention_mask=decode_attention_mask,
                position_ids=position_ids,
                past_key_values=curr_base_kv_cache,
                inputs_embeds=inputs_embeds,
                labels=labels,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                cache_position=cache_position,
                *args,
                **kwargs
            )

        return output
    
    def oracle_forward(
        self,
        kv_cache_index: Optional[List] = None,
        input_ids: Optional[Union[torch.LongTensor, List[torch.LongTensor]]] = None,
        attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        # **kwargs: Unpack[KwargsForCausalLM],
        *args,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        """
        Forward pass
        KVCache index is a list of tensors with shape (B, sec_seq_len, 2), indicating the source and target kv cache model index

        If input_ids is LongTensor, default to same input ids for different models
        If input_ids is Tuple, default to different input ids for different models.

        No Rosetta: (-1, 0)
        """
        
        # noqa
        self.kv_cache_dict = dict()

        # Handle different input formats: if input_ids is a list, use per-model inputs
        if isinstance(input_ids, list):
            # Use list format: different input_ids and attention_mask for each model
            base_input_ids = input_ids[self.base_model_idx] if input_ids is not None else None
            base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
            _, seqlen = base_input_ids.size() if base_input_ids is not None else (0, 0)
        else:
            # Use tensor format: same input_ids and attention_mask for all models (backward compatibility)
            base_input_ids = input_ids
            base_attention_mask = attention_mask
            _, seqlen = input_ids.size() if input_ids is not None else (0, 0)

        num_sections = len(kv_cache_index) if kv_cache_index is not None else 1

        section_lengths = [kv_cache_index[i].shape[1] for i in range(num_sections)] if kv_cache_index is not None else [seqlen]
        section_starts = [0]
        for l in section_lengths:
            section_starts.append(section_starts[-1] + l)
        
        curr_base_kv_cache = past_key_values

        loss = nn.MSELoss()
        loss_output = 0
        if seqlen > 1:
            for i in range(num_sections):
                start = section_starts[i]
                end = section_starts[i + 1]
                prefill_input_ids = base_input_ids[:, start:end] if base_input_ids is not None else None
                prefill_attention_mask = base_attention_mask[:, :end] if base_attention_mask is not None else None
                prefill_position_ids = position_ids[:, start:end] if position_ids is not None else None
                prefill_labels = labels[:, start:end] if labels is not None else None

                if i == num_sections - 1 and self.include_response:
                    hook_handlers = self.register_hooks(input_ids=input_ids[:, start:end], attention_mask=attention_mask[:, :end], position_ids=position_ids[:, start:end],
                                                        base_kv_cache=self.kv_cache_dict[self.base_model_idx][self.base_model_idx],
                                                        source_model_idx=1, 
                                                        source_kv_cache=self.kv_cache_dict[self.base_model_idx][1])

                # calculate target model kvcache
                output = self.model_list[self.base_model_idx].forward(
                    input_ids=prefill_input_ids,
                    attention_mask=prefill_attention_mask, 
                    position_ids=prefill_position_ids,
                    past_key_values=curr_base_kv_cache,
                    labels=prefill_labels,
                    use_cache=use_cache, 
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    *args,
                    **kwargs
                )

                if self.base_model_idx not in self.kv_cache_dict:
                    self.kv_cache_dict[self.base_model_idx] = {}
                if self.base_model_idx not in self.kv_cache_dict[self.base_model_idx]:
                    self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = None
                self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = output.past_key_values

                curr_base_kv_cache: DynamicCache = output.past_key_values

                if i == num_sections - 1 and self.include_response:
                    self.remove_hooks(hook_handlers)
                
                # if i != num_sections - 1:
                for source_model_idx in range(1, len(self.model_list)):
                    if self.base_model_idx not in self.kv_cache_dict:
                        self.kv_cache_dict[self.base_model_idx] = {}
                    if source_model_idx not in self.kv_cache_dict[self.base_model_idx]:
                        self.kv_cache_dict[self.base_model_idx][source_model_idx] = None

                    # Get model-specific input_ids and attention_mask
                    if isinstance(input_ids, list):
                        source_input_ids = input_ids[source_model_idx]
                        source_attention_mask = attention_mask[source_model_idx] if attention_mask is not None else None
                        source_prefill_input_ids = source_input_ids[:, start:end] if source_input_ids is not None else None
                        source_prefill_attention_mask = source_attention_mask[:, :end] if source_attention_mask is not None else None
                    else:
                        # Backward compatibility: use same input for all models
                        source_prefill_input_ids = prefill_input_ids
                        source_prefill_attention_mask = prefill_attention_mask

                    curr_source_kv_cache = self.model_list[source_model_idx].forward(
                        input_ids=source_prefill_input_ids,
                        attention_mask=source_prefill_attention_mask,
                        position_ids=prefill_position_ids,
                        past_key_values=self.kv_cache_dict[self.base_model_idx][source_model_idx],
                        use_cache=use_cache, 
                        output_attentions=output_attentions,
                        output_hidden_states=output_hidden_states,
                        *args,
                        **kwargs
                    ).past_key_values
                    self.kv_cache_dict[self.base_model_idx][source_model_idx] = curr_source_kv_cache

                # calculate source model kvcache and apply projections
                if self.base_model_idx in self.projector_dict:
                    source_model_idx = kv_cache_index[i][0][0][0].item()  # Get the source model index from the kv_cache_index
                    if source_model_idx != -1:
                        for target_layer_idx, entry in self.projector_dict[self.base_model_idx][source_model_idx].items():
                            base_key_cache, base_value_cache = curr_base_kv_cache[target_layer_idx]
                            new_base_key_cache = base_key_cache[:, :, start:end, :]
                            new_base_value_cache = base_value_cache[:, :, start:end, :]
                            new_base_kv_cache = (new_base_key_cache, new_base_value_cache)

                            pair_list = entry

                            projected_kv_list = []
                            source_kv_list = []
                            for source_model_layer_idx, projector_idx in pair_list:
                                source_key_cache, source_value_cache = self.kv_cache_dict[self.base_model_idx][source_model_idx][source_model_layer_idx]
                                new_source_key_cache = source_key_cache[:, :, start:end, :]
                                new_source_value_cache = source_value_cache[:, :, start:end, :]
                                new_source_kv_cache = (new_source_key_cache, new_source_value_cache)
                                projected_key, projected_value = self.projector_list[projector_idx].forward(
                                    new_source_kv_cache, # tuple of (key, value), each of shape (B, N, H, D)
                                    new_base_kv_cache
                                )
                                loss_output = loss_output + loss(torch.dstack([projected_key, projected_value]), 
                                                   torch.dstack([new_source_key_cache, new_source_value_cache]))
                                projected_kv_list.append((projected_key, projected_value))
                                source_kv_list.append(new_source_kv_cache)

                            # Aggregate (fallback to first projector if no aggregator is available)
                            use_aggregator = (
                                len(projected_kv_list) > 1 and
                                len(self.aggregator_list) > 0 and
                                self.base_model_idx in self.aggregator_dict and
                                source_model_idx in self.aggregator_dict[self.base_model_idx] and
                                target_layer_idx in self.aggregator_dict[self.base_model_idx][source_model_idx]
                            )

                            if use_aggregator:
                                aggregator_idx = self.aggregator_dict[self.base_model_idx][source_model_idx][target_layer_idx]
                                agg_key, agg_value = self.aggregator_list[aggregator_idx].forward(
                                    source_kv_list,
                                    new_base_kv_cache,
                                    projected_kv_list
                                )
                            else:
                                # Fallback to first projector result when no aggregator is available
                                agg_key, agg_value = projected_kv_list[0]

                            # Update cache with aggregated result
                            curr_base_kv_cache.key_cache[target_layer_idx][:, :, start:end, :] = agg_key
                            curr_base_kv_cache.value_cache[target_layer_idx][:, :, start:end, :] = agg_value
                        
                        output.past_key_values = curr_base_kv_cache
                                                                             
        # use base model for decode phase
        else:
            # Handle list input format for decode phase as well
            decode_input_ids = input_ids[self.base_model_idx] if isinstance(input_ids, list) else input_ids
            decode_attention_mask = attention_mask[self.base_model_idx] if isinstance(attention_mask, list) else attention_mask
            
            output = self.model_list[self.base_model_idx].forward(
                input_ids=decode_input_ids,
                attention_mask=decode_attention_mask,
                position_ids=position_ids,
                past_key_values=curr_base_kv_cache,
                inputs_embeds=inputs_embeds,
                labels=labels,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                cache_position=cache_position,
                *args,
                **kwargs
            )

        return output, loss_output
    

    @torch.no_grad()
    def old_generate(self, 
                 kv_cache_index,
                 input_ids, 
                 past_key_values=None,
                 attention_mask=None,
                 position_ids=None,
                 use_cache=True,
                 *args,
                 **kwargs):
        """
        Enhanced generate function with KV cache projection support
        """
        # Set generation mode flag
        self._in_generation = True
        
        try:
            # Handle different input formats: if input_ids is a list, use per-model inputs
            if isinstance(input_ids, list):
                # Use list format: different input_ids and attention_mask for each model
                base_input_ids = input_ids[self.base_model_idx] if input_ids is not None else None
                base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
                input_shape = base_input_ids.shape if base_input_ids is not None else (0, 0)
            else:
                # Use tensor format: same input_ids and attention_mask for all models (backward compatibility)
                base_input_ids = input_ids
                base_attention_mask = attention_mask
                input_shape = input_ids.shape if input_ids is not None else (0, 0)
            
            # Initialize generation source cache if we have projectors
            if self.base_model_idx in self.projector_dict:
                self._generation_source_cache = {}
                self._generation_base_cache_clean = None  # 保存纯净的base cache
                
                # Prefill phase: process all tokens except the last one for initial cache
                if input_shape[1] > 1:
                    # Handle list input format for prefill
                    if isinstance(input_ids, list):
                        prefill_input_ids = [ids[:, :-1] for ids in input_ids]
                        prefill_attention_mask = [mask[:, :-1] for mask in attention_mask] if attention_mask is not None else None
                        prefill_position_ids = position_ids[:, :-1] if position_ids is not None else None

                        prefill_kv_cache_index = []
                        for section_kv_cache_index in kv_cache_index:
                            if section_kv_cache_index is not kv_cache_index[-1]:
                                prefill_kv_cache_index.append(section_kv_cache_index)
                            else:
                                if section_kv_cache_index.shape[1] > 1:
                                    section_kv_cache_index = section_kv_cache_index[:, :-1, :]
                                    prefill_kv_cache_index.append(section_kv_cache_index)

                        output = self.forward(
                            kv_cache_index=prefill_kv_cache_index,
                            input_ids=prefill_input_ids,
                            attention_mask=prefill_attention_mask,
                            position_ids=prefill_position_ids,
                            use_cache=True
                        )
                    else:
                        # Backward compatibility: use tensor format
                        prefill_input_ids = base_input_ids[:, :-1]
                        prefill_attention_mask = base_attention_mask[:, :-1] if base_attention_mask is not None else None
                        prefill_position_ids = position_ids[:, :-1] if position_ids is not None else None
                        prefill_kv_cache_index = []
                        for section_kv_cache_index in kv_cache_index:
                            if section_kv_cache_index is not kv_cache_index[-1]:
                                prefill_kv_cache_index.append(section_kv_cache_index)
                            else:
                                if section_kv_cache_index.shape[1] > 1:
                                    section_kv_cache_index = section_kv_cache_index[:, :-1, :]
                                    prefill_kv_cache_index.append(section_kv_cache_index)
                        output = self.forward(
                            kv_cache_index=prefill_kv_cache_index,
                            input_ids=prefill_input_ids,
                            attention_mask=prefill_attention_mask,
                            position_ids=prefill_position_ids,
                            use_cache=True
                        )

                    for source_model_idx in self.projector_dict[self.base_model_idx].keys():
                        self._generation_source_cache[source_model_idx] = self.kv_cache_dict[self.base_model_idx][source_model_idx]
                    
                    self._generation_base_cache_clean = clone_kv_cache(self.kv_cache_dict[self.base_model_idx][self.base_model_idx])

                    past_key_values = output.past_key_values
                    
                    # Use only the last token for actual generation
                    past_length = input_shape[1] - 1
                    if isinstance(input_ids, list):
                        # Extract last token from each model's input_ids
                        input_ids = [ids[:, -1:] for ids in input_ids]
                        # For base model generation, we only need the base model's input_ids
                        base_input_ids = input_ids[self.base_model_idx]
                    else:
                        # Backward compatibility
                        input_ids = base_input_ids[:, -1:]
                        base_input_ids = input_ids
                    
                    if position_ids is not None:
                        position_ids = position_ids[:, -1:]
                else:

                    if past_key_values is not None:
                        self._generation_base_cache_clean = clone_kv_cache(past_key_values)
                    
                    # For single token case, use base model's input_ids
                    if isinstance(input_ids, list):
                        base_input_ids = input_ids[self.base_model_idx]
                    past_length = 0
            else:
                # No projectors case - handle input format for base model
                if isinstance(input_ids, list):
                    base_input_ids = input_ids[self.base_model_idx]
                    base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
                    past_length = base_input_ids.shape[1] - 1 if base_input_ids.shape[1] > 1 else 0
                else:
                    base_input_ids = input_ids
                    base_attention_mask = attention_mask
                    past_length = input_ids.shape[1] - 1 if input_ids.shape[1] > 1 else 0
            
            # Override the base model's prepare_inputs_for_generation method temporarily
            if self.include_response:
                original_prepare_inputs = self.model_list[self.base_model_idx].prepare_inputs_for_generation
                # Store reference for our custom method to use
                self._original_prepare_inputs = original_prepare_inputs
                self.model_list[self.base_model_idx].prepare_inputs_for_generation = self.prepare_inputs_for_generation
            
            # Call the original generate method using base model's inputs
            cache_position = torch.arange(past_length, past_length + 1, device=base_input_ids.device)

            output = self.model_list[self.base_model_idx].generate(
                input_ids=base_input_ids,
                past_key_values=past_key_values,
                attention_mask=base_attention_mask,
                position_ids=None,
                use_cache=use_cache,
                cache_position=cache_position,
                *args,
                **kwargs
            )
            
            # Restore original method
            if self.include_response:
                self.model_list[self.base_model_idx].prepare_inputs_for_generation = original_prepare_inputs
            
            return output
            
        finally:
            # Cleanup
            self._in_generation = False
            self._cleanup_generation_hooks()
            if hasattr(self, '_generation_source_cache'):
                delattr(self, '_generation_source_cache')
            if hasattr(self, '_generation_projections'):
                delattr(self, '_generation_projections')
            if hasattr(self, '_generation_base_cache_clean'):
                delattr(self, '_generation_base_cache_clean')
            # Ensure prepare_inputs_for_generation is restored even if an exception occurred
            if hasattr(self, '_original_prepare_inputs'):
                self.model_list[self.base_model_idx].prepare_inputs_for_generation = self._original_prepare_inputs
                delattr(self, '_original_prepare_inputs')

    @torch.no_grad()
    def generate(
        self,
        kv_cache_index,
        input_ids,
        max_new_tokens: Optional[int] = None,
        past_key_values: Optional[Cache] = None,
        attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
        position_ids: Optional[torch.LongTensor] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        pad_token_id: Optional[int] = None,
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        repetition_penalty: float = 1.0,
        presence_penalty: float = 0.0,
        frequency_penalty: float = 0.0,
        do_sample: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        max_length: Optional[int] = None,
        use_cache: bool = True,
        *args,
        **kwargs,
    ):
        """
        New generation loop without using the base model's generate.
        - Uses this module's forward for prefill and per-token decode.
        - Samples tokens via rosetta.model.sampling.sample_token.
        Returns a tensor of shape [batch, prompt_len + generated_len] for the base model stream.
        """
        # Derive number of tokens to generate
        # If max_new_tokens not provided, infer from max_length
        if isinstance(input_ids, list):
            base_input_ids_for_len = input_ids[self.base_model_idx]
        else:
            base_input_ids_for_len = input_ids
        prompt_len = base_input_ids_for_len.size(1)

        # Default eos/pad from base model tokenizer/config if not provided
        base_model = self.model_list[self.base_model_idx]
        gen_cfg = getattr(base_model, "generation_config", None)
        cfg_obj = gen_cfg if gen_cfg is not None else getattr(base_model, "config", None)
        if eos_token_id is None and cfg_obj is not None:
            eos_token_id = getattr(cfg_obj, "eos_token_id", None)
        if pad_token_id is None and cfg_obj is not None:
            pad_token_id = getattr(cfg_obj, "pad_token_id", None)
        if pad_token_id is None and eos_token_id is not None:
            pad_token_id = eos_token_id if isinstance(eos_token_id, int) else eos_token_id[0]

        if max_new_tokens is None:
            if max_length is not None:
                if max_length <= prompt_len:
                    max_new_tokens = 0
                else:
                    max_new_tokens = max_length - prompt_len
            else:
                raise ValueError("Provide max_new_tokens or max_length")
        if max_new_tokens < 0:
            raise ValueError("max_new_tokens must be non-negative")

        # Resolve base inputs
        if isinstance(input_ids, list):
            base_input_ids = input_ids[self.base_model_idx]
            base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
        else:
            base_input_ids = input_ids
            base_attention_mask = attention_mask

        if base_attention_mask is None:
            base_attention_mask = torch.ones_like(base_input_ids, dtype=torch.long, device=base_input_ids.device)

        batch_size = base_input_ids.size(0)

        # Prefill to build caches and obtain initial logits
        prefill_output = self.forward(
            kv_cache_index=kv_cache_index,
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            *args,
            **kwargs,
        )

        current_past = prefill_output.past_key_values
        all_input_ids = base_input_ids
        current_attention_mask = base_attention_mask

        # EOS handling setup
        eos_set = None
        if eos_token_id is not None:
            eos_set = set(eos_token_id if isinstance(eos_token_id, list) else [eos_token_id])
        finished = torch.zeros(batch_size, dtype=torch.bool, device=all_input_ids.device)

        # Start from last prefill logits
        last_logits = prefill_output.logits[:, -1, :]

        # Determine sampling mode
        if do_sample is None:
            do_sample = False
        effective_temperature = temperature if do_sample else 0.0

        # Optional scores collection
        collect_scores = bool(return_dict_in_generate) and bool(output_scores)
        scores = []

        for _ in range(max_new_tokens):
            if collect_scores:
                scores.append(last_logits)
            # Apply repetition/presence/frequency penalties to logits before sampling
            adjusted_logits = last_logits
            if (
                (repetition_penalty is not None and repetition_penalty != 1.0) or
                (presence_penalty is not None and presence_penalty != 0.0) or
                (frequency_penalty is not None and frequency_penalty != 0.0)
            ):
                adjusted_logits = last_logits.clone()
                vocab_size = adjusted_logits.size(-1)
                # Per-batch penalty application for clarity and correctness
                for b in range(batch_size):
                    seq_tokens = all_input_ids[b]
                    if seq_tokens.numel() == 0:
                        continue
                    counts = torch.bincount(seq_tokens, minlength=vocab_size)
                    if counts.dtype != torch.float32 and counts.dtype != torch.float64:
                        counts = counts.to(adjusted_logits.dtype)
                    # Presence penalty: penalize any token that has appeared
                    if presence_penalty and presence_penalty != 0.0:
                        presence_mask = counts > 0
                        if presence_mask.any():
                            adjusted_logits[b, presence_mask] = adjusted_logits[b, presence_mask] - presence_penalty
                    # Frequency penalty: penalize proportionally to frequency
                    if frequency_penalty and frequency_penalty != 0.0:
                        adjusted_logits[b] = adjusted_logits[b] - frequency_penalty * counts
                    # Repetition penalty (HF-style): divide positive logits, multiply negative logits
                    if repetition_penalty and repetition_penalty != 1.0:
                        rep_mask = counts > 0
                        if rep_mask.any():
                            pos_mask = rep_mask & (adjusted_logits[b] > 0)
                            neg_mask = rep_mask & ~pos_mask
                            if pos_mask.any():
                                adjusted_logits[b, pos_mask] = adjusted_logits[b, pos_mask] / repetition_penalty
                            if neg_mask.any():
                                adjusted_logits[b, neg_mask] = adjusted_logits[b, neg_mask] * repetition_penalty

            # Sample next token
            next_token = sample_token(adjusted_logits, temperature=effective_temperature, top_p=top_p, top_k=top_k)
            if not isinstance(next_token, torch.Tensor):
                next_token = torch.tensor([next_token], device=all_input_ids.device, dtype=torch.long).repeat(batch_size)

            # Apply EOS logic
            if eos_set is not None:
                just_finished = torch.zeros_like(finished)
                for eid in eos_set:
                    just_finished |= (next_token == eid)
                finished = finished | just_finished
                if pad_token_id is not None:
                    next_token = torch.where(
                        finished,
                        torch.tensor(pad_token_id, device=next_token.device, dtype=next_token.dtype),
                        next_token,
                    )

            # Append sampled token
            next_token_unsqueezed = next_token.unsqueeze(1)
            all_input_ids = torch.cat([all_input_ids, next_token_unsqueezed], dim=1)
            current_attention_mask = torch.cat(
                [
                    current_attention_mask,
                    torch.ones((batch_size, 1), device=current_attention_mask.device, dtype=current_attention_mask.dtype),
                ],
                dim=1,
            )

            # Early stop if all sequences finished
            if eos_set is not None and torch.all(finished):
                break

            # Decode one step using cached states; pass base-stream tensors
            if self.include_response:
                kv_cache_index = [torch.tensor([1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(all_input_ids.device)]
            else:
                kv_cache_index = [torch.tensor([-1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(all_input_ids.device)]

            decode_output = self.forward(
                kv_cache_index=kv_cache_index,
                input_ids=next_token_unsqueezed,
                attention_mask=current_attention_mask,
                position_ids=None,
                past_key_values=current_past,
                use_cache=True,
                *args,
                **kwargs,
            )
            current_past = decode_output.past_key_values
            last_logits = decode_output.logits[:, -1, :]

        # Return style compatible with HF generate
        if return_dict_in_generate:
            if GreedySearchDecoderOnlyOutput is not None and SampleDecoderOnlyOutput is not None:
                if do_sample:
                    return SampleDecoderOnlyOutput(
                        sequences=all_input_ids,
                        scores=scores if collect_scores else None,
                    )
                else:
                    return GreedySearchDecoderOnlyOutput(
                        sequences=all_input_ids,
                        scores=scores if collect_scores else None,
                    )
            # Fallback to generic ModelOutput
            result = {"sequences": all_input_ids}
            if collect_scores:
                result["scores"] = scores
            return ModelOutput(**result)
        return all_input_ids

class OracleRosettaModel(nn.Module):
    """
    Drop in replacement for the standard transformers LLM models, like Qwen3ForCausalLM
    """
    def __init__(self, model_list: List[PreTrainedModel], base_model_idx = 0, projector_list: List[Projector] = [], aggregator_list: List[Aggregator] = [], include_response: bool = False):
        super().__init__()
        # model list: a list of model, model 0 by default is the base model
        # projector list: a list of projector
        # standard init with additional model list parameter
        # kv-cache dict: key (source_model_idx, target_model_idx), value (Cache), assume only convert at prefill with one type of model
        # projector dict: key (source_model_idx, target_model_idx) value dict(key (source_model_layer_idx, M_target value )

        self.base_model_idx = base_model_idx
        self.model_list = nn.ModuleList(model_list)

        device = model_list[base_model_idx].device
        dtype = model_list[base_model_idx].dtype
        self.projector_list = nn.ModuleList(projector_list).to(device=device, dtype=dtype)
        self.aggregator_list = nn.ModuleList(aggregator_list).to(device=device, dtype=dtype)

        self.projector_dict = {}
        self.aggregator_dict = {}
        self.kv_cache_dict = {}
        self._generation_hook_handlers = []

        self.include_response = include_response

    @property
    def device(self):
        return self.model_list[self.base_model_idx].device
    
    def to(self, device):
        """
        Move the RosettaModel and all underlying models and projectors to the specified device.
        """
        super().to(device)
        for model in self.model_list:
            model.to(device)
        for projector in self.projector_list:
            projector.to(device)
        for aggregator in self.aggregator_list:
            aggregator.to(device)
        return self
        
    # set projector 
    def set_projector_config(self, 
                        source_model_idx: int, 
                        source_model_layer_idx: int, 
                        target_model_idx: int,
                        target_model_layer_idx: int, 
                        projector_idx: int):
        """
        Set the projector configuration
        Args:
            source_model_idx: int, the index of the source model
            source_model_layer_idx: int, the index of the source model layer
            target_model_idx: int, the index of the target model
            target_model_layer_idx: int, the index of the target model layer
            projector_idx: int, the index of the projector

        The projector dict structure supports multiple projectors per target layer.
        Structure:
        {
            target_model_idx: {
                source_model_idx: {
                    target_model_layer_idx: [(source_model_layer_idx, projector_idx), ...]
                }
            }
        }
        Repeated calls for the same (target, source, target_layer) append additional pairs.
        """

        if target_model_idx not in self.projector_dict.keys():
            self.projector_dict[target_model_idx] = {}
        if source_model_idx not in self.projector_dict[target_model_idx].keys():
            self.projector_dict[target_model_idx][source_model_idx] = {}
        # Accumulate list of (source_layer, projector_idx) for this target layer
        layer_entry = self.projector_dict[target_model_idx][source_model_idx].get(target_model_layer_idx)
        if layer_entry is None:
            self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx] = [(source_model_layer_idx, projector_idx)]
        else:
            layer_entry.append((source_model_layer_idx, projector_idx))


    def load_projector(self, projector_list):
        self.projector_list: List[Projector] = projector_list

    def load_aggregator(self, aggregator_list):
        self.aggregator_list: List[Aggregator] = aggregator_list


    def get_projector(self, 
                        source_model_idx, 
                        source_model_layer_idx, 
                        target_model_idx,
                        target_model_layer_idx):
        pair_list = self.projector_dict[target_model_idx][source_model_idx][target_model_layer_idx]
        if len(pair_list) == 0:
            raise ValueError("No projector configured for the given target layer")
        # Prefer exact source layer match
        for src_layer, projector_id in pair_list:
            if src_layer == source_model_layer_idx:
                return self.projector_list[projector_id]
        # Fallback: return the first projector
        return self.projector_list[pair_list[0][1]]

    def set_aggregator_idx(self,
                           source_model_idx: int,
                           target_model_idx: int,
                           target_model_layer_idx: int,
                           aggregator_idx: int):
        if target_model_idx not in self.aggregator_dict:
            self.aggregator_dict[target_model_idx] = {}
        if source_model_idx not in self.aggregator_dict[target_model_idx]:
            self.aggregator_dict[target_model_idx][source_model_idx] = {}
        self.aggregator_dict[target_model_idx][source_model_idx][target_model_layer_idx] = aggregator_idx


    @staticmethod
    def load_json(file_name):
        with open(file_name, "r") as f:
            result = json.load(f)
        return result
    
    @staticmethod
    def _convert_dict_keys_to_ints(obj):
        """
        Recursively convert dictionary keys that look like integers back to int.
        This reverses json.dump's coercion of dict keys to strings.
        """
        if isinstance(obj, dict):
            new_obj = {}
            for key, value in obj.items():
                if isinstance(key, str) and key.lstrip('-').isdigit():
                    new_key = int(key)
                else:
                    new_key = key
                new_obj[new_key] = RosettaModel._convert_dict_keys_to_ints(value)
            return new_obj
        if isinstance(obj, list):
            return [RosettaModel._convert_dict_keys_to_ints(v) for v in obj]
        return obj
    
    
    def save_projector_config(self, file_name):
        with open(file_name, "w") as f:
            json.dump(self.projector_dict, f)

    
    def load_projector_config(self, config_path):
        if config_path.endswith(".json"):
            loaded = RosettaModel.load_json(config_path)
            self.projector_dict = RosettaModel._convert_dict_keys_to_ints(loaded)

    def save_aggregator_config(self, file_name):
        with open(file_name, "w") as f:
            json.dump(self.aggregator_dict, f)

    def load_aggregator_config(self, config_path):
        if config_path.endswith(".json"):
            loaded = RosettaModel.load_json(config_path)
            self.aggregator_dict = RosettaModel._convert_dict_keys_to_ints(loaded)


    def set_kv_cache_dict(self, source_model_idx, target_model_idx, cache):
        if target_model_idx not in self.kv_cache_dict.keys():
            self.kv_cache_dict[target_model_idx] = {}
        if cache is None:
            # Initialize with a DynamicCache instead of RosettaCache for now
            self.kv_cache_dict[target_model_idx][source_model_idx] = DynamicCache() # noqa, maybe we should use RosettaCache here
        else:
            self.kv_cache_dict[target_model_idx][source_model_idx] = cache

    def make_k_proj_hook(new_k_cache):
        def k_proj_hook(module, input, output):
            updated_k_cache = output.clone()
            batch_size, seq_len, dim = updated_k_cache.shape

            reshaped_cache = new_k_cache.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim)
            updated_k_cache[:, 0:seq_len, :] = reshaped_cache
        
            return updated_k_cache
        return k_proj_hook

    def make_v_proj_hook(new_v_cache):
        def v_proj_hook(module, input, output):
            updated_v_cache = output.clone()
            batch_size, seq_len, dim = updated_v_cache.shape

            reshaped_cache = new_v_cache.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim)
            updated_v_cache[:, 0:seq_len, :] = reshaped_cache

            return updated_v_cache
        return v_proj_hook
    
    def register_hooks(self, input_ids, attention_mask, position_ids, base_kv_cache, source_model_idx, source_kv_cache):

        base_kv_copy = clone_kv_cache(base_kv_cache)
        source_kv_copy = clone_kv_cache(source_kv_cache)

        new_length = input_ids.shape[1]

        base_output_kv_cache = self.model_list[self.base_model_idx].forward(
                    input_ids=input_ids,
                    attention_mask=attention_mask, 
                    position_ids=position_ids,
                    past_key_values=base_kv_copy,
                    labels=None,
                    use_cache=True, 
                ).past_key_values
        source_output_kv_cache = self.model_list[source_model_idx].forward(
                    input_ids=input_ids,
                    attention_mask=attention_mask, 
                    position_ids=position_ids,
                    past_key_values=source_kv_copy,
                    labels=None,
                    use_cache=True, 
                ).past_key_values
        
        for target_layer_idx, entry in self.projector_dict[self.base_model_idx][source_model_idx].items():
            base_key_cache, base_value_cache = base_output_kv_cache[target_layer_idx]
            new_base_key_cache = base_key_cache[:, :, -new_length:, :]
            new_base_value_cache = base_value_cache[:, :, -new_length:, :]
            new_base_kv_cache = (new_base_key_cache, new_base_value_cache)

            pair_list = entry

            projected_kv_list = []
            source_kv_list = []
            for source_model_layer_idx, projector_idx in pair_list:
                source_key_cache, source_value_cache = source_output_kv_cache[source_model_layer_idx]
                new_source_key_cache = source_key_cache[:, :, -new_length:, :]
                new_source_value_cache = source_value_cache[:, :, -new_length:, :]
                new_source_kv_cache = (new_source_key_cache, new_source_value_cache)
                projected_key, projected_value = self.projector_list[projector_idx].forward(
                    new_source_kv_cache, # tuple of (key, value), each of shape (B, N, H, D)
                    new_base_kv_cache
                )
                projected_kv_list.append((projected_key, projected_value))
                source_kv_list.append(new_source_kv_cache)

            # Aggregate (fallback to first projector if no aggregator is available)
            use_aggregator = (
                len(projected_kv_list) > 1 and
                len(self.aggregator_list) > 0 and
                self.base_model_idx in self.aggregator_dict and
                source_model_idx in self.aggregator_dict[self.base_model_idx] and
                target_layer_idx in self.aggregator_dict[self.base_model_idx][source_model_idx]
            )

            if use_aggregator:
                aggregator_idx = self.aggregator_dict[self.base_model_idx][source_model_idx][target_layer_idx]
                agg_key, agg_value = self.aggregator_list[aggregator_idx].forward(
                    source_kv_list,
                    new_base_kv_cache,
                    projected_kv_list
                )
            else:
                # Fallback to first projector result when no aggregator is available
                agg_key, agg_value = projected_kv_list[0]

            # Update cache
            base_output_kv_cache.key_cache[target_layer_idx][:, :, -new_length:, :] = agg_key
            base_output_kv_cache.value_cache[target_layer_idx][:, :, -new_length:, :] = agg_value

        hook_handlers = []

        for i in range(self.model_list[self.base_model_idx].config.num_hidden_layers):
            handler_k = self.model_list[self.base_model_idx].model.layers[i].self_attn.k_proj.register_forward_hook(
                RosettaModel.make_k_proj_hook(base_output_kv_cache.key_cache[i][:, :, -new_length:, :])
            )
            handler_v = self.model_list[self.base_model_idx].model.layers[i].self_attn.v_proj.register_forward_hook(
                RosettaModel.make_v_proj_hook(base_output_kv_cache.value_cache[i][:, :, -new_length:, :])
            )
            hook_handlers.append((handler_k, handler_v))

        return hook_handlers
    
    def remove_hooks(self, hook_handlers):
        for handler_k, handler_v in hook_handlers:
            handler_k.remove()
            handler_v.remove()

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, 
                                    inputs_embeds=None, cache_position=None, position_ids=None, 
                                    use_cache=True, **kwargs):
        """
        Custom prepare_inputs_for_generation to handle KV cache projection during generation
        """
        # Get prepare inputs from base model's original method
        original_prepare_inputs = getattr(self, '_original_prepare_inputs', None)
        if original_prepare_inputs is not None:
            # Use the correct parameter format for transformers' prepare_inputs_for_generation
            base_model_inputs = original_prepare_inputs(
                input_ids, 
                past_key_values=past_key_values, 
                attention_mask=attention_mask, 
                inputs_embeds=inputs_embeds, 
                use_cache=use_cache, 
                cache_position=cache_position,
                **kwargs
            )
        else:
            # Fallback: create basic inputs manually
            base_model_inputs = {
                'input_ids': input_ids,
                'past_key_values': past_key_values,
                'attention_mask': attention_mask,
                'position_ids': position_ids,
                'use_cache': use_cache
            }
        
        # If we're in generation mode and have projectors configured, update caches and set up hooks
        if (past_key_values is not None and 
            hasattr(self, '_in_generation') and self._in_generation and
            self.base_model_idx in self.projector_dict and
            len(self.projector_dict[self.base_model_idx]) > 0):
            
            # Clean up any existing hooks first
            self._cleanup_generation_hooks()
            
            # Update source model caches for the new input_ids
            self._update_generation_source_cache(input_ids, attention_mask, position_ids)
            
            # Update base model cache and compute projections
            updated_base_cache = self._update_base_cache_and_compute_projections(
                input_ids, attention_mask, position_ids, past_key_values
            )
            
            # Set up hooks with pre-computed projections
            self._setup_generation_hooks_with_projections()
        
        return base_model_inputs
    
    def _setup_generation_hooks_with_projections(self):
        """
        Set up hooks for generation process using pre-computed projections
        """
        if not hasattr(self, '_generation_projections'):
            return
            
        # Set up hooks for each layer that has projections
        for target_layer_idx, (projected_key, projected_value) in self._generation_projections.items():
            # Create hooks to override the k_proj and v_proj outputs for this layer
            def make_k_override_hook(proj_k):
                def k_hook(module, input, output):
                    # Replace the output with projected key
                    batch_size, seq_len, dim = output.shape
                    projected_reshaped = proj_k.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim)
                    return projected_reshaped
                return k_hook
            
            def make_v_override_hook(proj_v):
                def v_hook(module, input, output):
                    # Replace the output with projected value
                    batch_size, seq_len, dim = output.shape
                    projected_reshaped = proj_v.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim)
                    return projected_reshaped
                return v_hook
            
            # Register hooks
            k_handler = self.model_list[self.base_model_idx].model.layers[target_layer_idx].self_attn.k_proj.register_forward_hook(
                make_k_override_hook(projected_key)
            )
            v_handler = self.model_list[self.base_model_idx].model.layers[target_layer_idx].self_attn.v_proj.register_forward_hook(
                make_v_override_hook(projected_value)
            )
            self._generation_hook_handlers.extend([k_handler, v_handler])

    def _cleanup_generation_hooks(self):
        """
        Clean up generation hooks and projections
        """
        for handler in self._generation_hook_handlers:
            handler.remove()
        self._generation_hook_handlers = []
        
        # Clean up projections cache
        if hasattr(self, '_generation_projections'):
            self._generation_projections.clear()

    def _update_generation_source_cache(self, input_ids, attention_mask, position_ids):
        """
        Update source model caches during generation
        """
        if not hasattr(self, '_generation_source_cache'):
            self._generation_source_cache = {}
        
       
        if input_ids.shape[1] > 1:
            
            last_token_ids = input_ids[:, -1:]
            last_position_ids = position_ids[:, -1:] if position_ids is not None else None
        else:
            last_token_ids = input_ids
            last_position_ids = position_ids
            
        for source_model_idx in self.projector_dict[self.base_model_idx].keys():
            if source_model_idx not in self._generation_source_cache:
                self._generation_source_cache[source_model_idx] = None
                
            # Forward through source model to get updated cache
            source_output = self.model_list[source_model_idx].forward(
                input_ids=last_token_ids,
                attention_mask=attention_mask,  
                position_ids=last_position_ids,
                past_key_values=self._generation_source_cache[source_model_idx],
                use_cache=True
            )
            self._generation_source_cache[source_model_idx] = source_output.past_key_values

    def _update_base_cache_and_compute_projections(self, input_ids, attention_mask, position_ids, past_key_values):
        """
        Update base model cache and compute projections for the new token
        """
        
        if hasattr(self, '_generation_base_cache_clean'):
            clean_base_cache = self._generation_base_cache_clean
        else:
            
            clean_base_cache = past_key_values

        
        if input_ids.shape[1] > 1:
            
            last_token_ids = input_ids[:, -1:]
            last_position_ids = position_ids[:, -1:] if position_ids is not None else None
        else:
            last_token_ids = input_ids
            last_position_ids = position_ids
        
        # Forward through base model using clean cache
        base_output = self.model_list[self.base_model_idx].forward(
            input_ids=last_token_ids,
            attention_mask=attention_mask,  
            position_ids=last_position_ids,
            past_key_values=clean_base_cache,
            use_cache=True
        )
        updated_base_cache = base_output.past_key_values
        
        
        self._generation_base_cache_clean = updated_base_cache
        
        # Store projections for hooks
        if not hasattr(self, '_generation_projections'):
            self._generation_projections = {}
        
        # Compute projections for each configured layer
        for source_model_idx in self.projector_dict[self.base_model_idx].keys():
            if source_model_idx in self._generation_source_cache:
                source_kv_cache = self._generation_source_cache[source_model_idx]
                
                for target_layer_idx, (source_model_layer_idx, projector_idx) in self.projector_dict[self.base_model_idx][source_model_idx].items():
                    source_key_cache, source_value_cache = source_kv_cache[source_model_layer_idx]
                    base_key_cache, base_value_cache = updated_base_cache[target_layer_idx]

                    # Get the new token's KV (last position in the cache)
                    new_source_key_cache = source_key_cache[:, :, -1:, :]
                    new_source_value_cache = source_value_cache[:, :, -1:, :]
                    new_base_key_cache = base_key_cache[:, :, -1:, :]
                    new_base_value_cache = base_value_cache[:, :, -1:, :]

                    new_source_kv_cache = (new_source_key_cache, new_source_value_cache)
                    new_base_kv_cache = (new_base_key_cache, new_base_value_cache)

                    # Apply projection
                    projected_key, projected_value = self.projector_list[projector_idx].forward(
                        new_source_kv_cache,
                        new_base_kv_cache
                    )
                    
                    # Store projections for this layer
                    self._generation_projections[target_layer_idx] = (projected_key, projected_value)
        
        return updated_base_cache

    def forward(
        self,
        kv_cache_index: Optional[List] = None,
        input_ids: Optional[Union[torch.LongTensor, List[torch.LongTensor]]] = None,
        attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        # **kwargs: Unpack[KwargsForCausalLM],
        identifier = -1,
        subject = None,
        *args,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        """
        Forward pass
        KVCache index is a list of tensors with shape (B, sec_seq_len, 2), indicating the source and target kv cache model index

        If input_ids is LongTensor, default to same input ids for different models
        If input_ids is Tuple, default to different input ids for different models.

        No Rosetta: (-1, 0)
        """
        
        # noqa
        self.kv_cache_dict = dict()

        # Handle different input formats: if input_ids is a list, use per-model inputs
        if isinstance(input_ids, list):
            # Use list format: different input_ids and attention_mask for each model
            base_input_ids = input_ids[self.base_model_idx] if input_ids is not None else None
            base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
            _, seqlen = base_input_ids.size() if base_input_ids is not None else (0, 0)
        else:
            # Use tensor format: same input_ids and attention_mask for all models (backward compatibility)
            base_input_ids = input_ids
            base_attention_mask = attention_mask
            _, seqlen = input_ids.size() if input_ids is not None else (0, 0)

        num_sections = len(kv_cache_index) if kv_cache_index is not None else 1

        section_lengths = [kv_cache_index[i].shape[1] for i in range(num_sections)] if kv_cache_index is not None else [seqlen]
        section_starts = [0]
        for l in section_lengths:
            section_starts.append(section_starts[-1] + l)
        
        curr_base_kv_cache = past_key_values

        if seqlen > 1:
            for i in range(num_sections):
                start = section_starts[i]
                end = section_starts[i + 1]
                prefill_input_ids = base_input_ids[:, start:end] if base_input_ids is not None else None
                prefill_attention_mask = base_attention_mask[:, :end] if base_attention_mask is not None else None
                prefill_position_ids = position_ids[:, start:end] if position_ids is not None else None
                prefill_labels = labels[:, start:end] if labels is not None else None

                if i == num_sections - 1 and self.include_response:
                    hook_handlers = self.register_hooks(input_ids=input_ids[:, start:end], attention_mask=attention_mask[:, :end], position_ids=position_ids[:, start:end],
                                                        base_kv_cache=self.kv_cache_dict[self.base_model_idx][self.base_model_idx],
                                                        source_model_idx=1, 
                                                        source_kv_cache=self.kv_cache_dict[self.base_model_idx][1])

                # calculate target model kvcache
                output = self.model_list[self.base_model_idx].forward(
                    input_ids=prefill_input_ids,
                    attention_mask=prefill_attention_mask, 
                    position_ids=prefill_position_ids,
                    past_key_values=curr_base_kv_cache,
                    labels=prefill_labels,
                    use_cache=use_cache, 
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    *args,
                    **kwargs
                )

                if self.base_model_idx not in self.kv_cache_dict:
                    self.kv_cache_dict[self.base_model_idx] = {}
                if self.base_model_idx not in self.kv_cache_dict[self.base_model_idx]:
                    self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = None
                self.kv_cache_dict[self.base_model_idx][self.base_model_idx] = output.past_key_values

                curr_base_kv_cache: DynamicCache = output.past_key_values

                if i == num_sections - 1 and self.include_response:
                    self.remove_hooks(hook_handlers)
                
                # if i != num_sections - 1:
                for source_model_idx in range(1, len(self.model_list)):
                    if self.base_model_idx not in self.kv_cache_dict:
                        self.kv_cache_dict[self.base_model_idx] = {}
                    if source_model_idx not in self.kv_cache_dict[self.base_model_idx]:
                        self.kv_cache_dict[self.base_model_idx][source_model_idx] = None

                    # Get model-specific input_ids and attention_mask
                    if isinstance(input_ids, list):
                        source_input_ids = input_ids[source_model_idx]
                        source_attention_mask = attention_mask[source_model_idx] if attention_mask is not None else None
                        source_prefill_input_ids = source_input_ids[:, start:end] if source_input_ids is not None else None
                        source_prefill_attention_mask = source_attention_mask[:, :end] if source_attention_mask is not None else None
                    else:
                        # Backward compatibility: use same input for all models
                        source_prefill_input_ids = prefill_input_ids
                        source_prefill_attention_mask = prefill_attention_mask

                    curr_source_kv_cache = self.model_list[source_model_idx].forward(
                        input_ids=source_prefill_input_ids,
                        attention_mask=source_prefill_attention_mask,
                        position_ids=prefill_position_ids,
                        past_key_values=self.kv_cache_dict[self.base_model_idx][source_model_idx],
                        use_cache=use_cache, 
                        output_attentions=output_attentions,
                        output_hidden_states=output_hidden_states,
                        *args,
                        **kwargs
                    ).past_key_values
                    self.kv_cache_dict[self.base_model_idx][source_model_idx] = curr_source_kv_cache

                # calculate source model kvcache and apply projections
                if self.base_model_idx in self.projector_dict:
                    source_model_idx = kv_cache_index[i][0][0][0].item()  # Get the source model index from the kv_cache_index
                    if source_model_idx != -1:
                        for target_layer_idx, entry in self.projector_dict[self.base_model_idx][source_model_idx].items():
                            base_key_cache, base_value_cache = curr_base_kv_cache[target_layer_idx]
                            new_base_key_cache = base_key_cache[:, :, start:end, :]
                            new_base_value_cache = base_value_cache[:, :, start:end, :]
                            new_base_kv_cache = (new_base_key_cache, new_base_value_cache)

                            pair_list = entry

                            projected_kv_list = []
                            source_kv_list = []
                            for source_model_layer_idx, projector_idx in pair_list:
                                source_key_cache, source_value_cache = self.kv_cache_dict[self.base_model_idx][source_model_idx][source_model_layer_idx]
                                new_source_key_cache = source_key_cache[:, :, start:end, :]
                                new_source_value_cache = source_value_cache[:, :, start:end, :]
                                new_source_kv_cache = (new_source_key_cache, new_source_value_cache)
                                projected_key, projected_value = self.projector_list[projector_idx].forward(
                                    new_source_kv_cache, # tuple of (key, value), each of shape (B, N, H, D)
                                    new_base_kv_cache
                                )
                                projected_kv_list.append((projected_key, projected_value))

                                # --------------
                                # save base and projected kv cache
                                torch.save((projected_key, projected_value), f"oracle/projected_kv/{subject}_{identifier}_{i}.pt")
                                torch.save(new_base_kv_cache, f"oracle/target_kv/{subject}_{identifier}_{i}.pt")
                                # --------------
                                source_kv_list.append(new_source_kv_cache)

                            # Aggregate (fallback to first projector if no aggregator is available)
                            use_aggregator = (
                                len(projected_kv_list) > 1 and
                                len(self.aggregator_list) > 0 and
                                self.base_model_idx in self.aggregator_dict and
                                source_model_idx in self.aggregator_dict[self.base_model_idx] and
                                target_layer_idx in self.aggregator_dict[self.base_model_idx][source_model_idx]
                            )

                            if use_aggregator:
                                aggregator_idx = self.aggregator_dict[self.base_model_idx][source_model_idx][target_layer_idx]
                                agg_key, agg_value = self.aggregator_list[aggregator_idx].forward(
                                    source_kv_list,
                                    new_base_kv_cache,
                                    projected_kv_list
                                )
                            else:
                                # Fallback to first projector result when no aggregator is available
                                agg_key, agg_value = projected_kv_list[0]

                            # Update cache with aggregated result
                            curr_base_kv_cache.key_cache[target_layer_idx][:, :, start:end, :] = agg_key
                            curr_base_kv_cache.value_cache[target_layer_idx][:, :, start:end, :] = agg_value
                        
                        output.past_key_values = curr_base_kv_cache
                                                                             
        # use base model for decode phase
        else:
            # Handle list input format for decode phase as well
            decode_input_ids = input_ids[self.base_model_idx] if isinstance(input_ids, list) else input_ids
            decode_attention_mask = attention_mask[self.base_model_idx] if isinstance(attention_mask, list) else attention_mask
            
            output = self.model_list[self.base_model_idx].forward(
                input_ids=decode_input_ids,
                attention_mask=decode_attention_mask,
                position_ids=position_ids,
                past_key_values=curr_base_kv_cache,
                inputs_embeds=inputs_embeds,
                labels=labels,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                cache_position=cache_position,
                *args,
                **kwargs
            )

        return output
    

    @torch.no_grad()
    def old_generate(self, 
                 kv_cache_index,
                 input_ids, 
                 past_key_values=None,
                 attention_mask=None,
                 position_ids=None,
                 use_cache=True,
                 *args,
                 **kwargs):
        """
        Enhanced generate function with KV cache projection support
        """
        # Set generation mode flag
        self._in_generation = True
        
        try:
            # Handle different input formats: if input_ids is a list, use per-model inputs
            if isinstance(input_ids, list):
                # Use list format: different input_ids and attention_mask for each model
                base_input_ids = input_ids[self.base_model_idx] if input_ids is not None else None
                base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
                input_shape = base_input_ids.shape if base_input_ids is not None else (0, 0)
            else:
                # Use tensor format: same input_ids and attention_mask for all models (backward compatibility)
                base_input_ids = input_ids
                base_attention_mask = attention_mask
                input_shape = input_ids.shape if input_ids is not None else (0, 0)
            
            # Initialize generation source cache if we have projectors
            if self.base_model_idx in self.projector_dict:
                self._generation_source_cache = {}
                self._generation_base_cache_clean = None  
                
                # Prefill phase: process all tokens except the last one for initial cache
                if input_shape[1] > 1:
                    # Handle list input format for prefill
                    if isinstance(input_ids, list):
                        prefill_input_ids = [ids[:, :-1] for ids in input_ids]
                        prefill_attention_mask = [mask[:, :-1] for mask in attention_mask] if attention_mask is not None else None
                        prefill_position_ids = position_ids[:, :-1] if position_ids is not None else None
                        
                        # Create kv_cache_index for prefill
                        prefill_length = prefill_input_ids[self.base_model_idx].shape[1]

                        prefill_kv_cache_index = []
                        for section_kv_cache_index in kv_cache_index:
                            if section_kv_cache_index is not kv_cache_index[-1]:
                                prefill_kv_cache_index.append(section_kv_cache_index)
                            else:
                                prefill_kv_cache_index.append(section_kv_cache_index[:, :-1, :])

                        output = self.forward(
                            kv_cache_index=prefill_kv_cache_index,
                            input_ids=prefill_input_ids,
                            attention_mask=prefill_attention_mask,
                            position_ids=prefill_position_ids,
                            use_cache=True
                        )
                    else:
                        # Backward compatibility: use tensor format
                        prefill_input_ids = base_input_ids[:, :-1]
                        prefill_attention_mask = base_attention_mask[:, :-1] if base_attention_mask is not None else None
                        prefill_position_ids = position_ids[:, :-1] if position_ids is not None else None
                        prefill_kv_cache_index = []
                        for section_kv_cache_index in kv_cache_index:
                            if section_kv_cache_index is not kv_cache_index[-1]:
                                prefill_kv_cache_index.append(section_kv_cache_index)
                            else:
                                if section_kv_cache_index.shape[1] > 1:
                                    section_kv_cache_index = section_kv_cache_index[:, :-1, :]
                                    prefill_kv_cache_index.append(section_kv_cache_index)

                        output = self.forward(
                            kv_cache_index=prefill_kv_cache_index,
                            input_ids=prefill_input_ids,
                            attention_mask=prefill_attention_mask,
                            position_ids=prefill_position_ids,
                            use_cache=True
                        )

                    for source_model_idx in self.projector_dict[self.base_model_idx].keys():
                        self._generation_source_cache[source_model_idx] = self.kv_cache_dict[self.base_model_idx][source_model_idx]
                    
                    self._generation_base_cache_clean = clone_kv_cache(self.kv_cache_dict[self.base_model_idx][self.base_model_idx])

                    past_key_values = output.past_key_values

                    # Use only the last token for actual generation
                    past_length = input_shape[1] - 1
                    if isinstance(input_ids, list):
                        # Extract last token from each model's input_ids
                        input_ids = [ids[:, -1:] for ids in input_ids]
                        # For base model generation, we only need the base model's input_ids
                        base_input_ids = input_ids[self.base_model_idx]
                    else:
                        # Backward compatibility
                        input_ids = base_input_ids[:, -1:]
                        base_input_ids = input_ids
                    
                    if position_ids is not None:
                        position_ids = position_ids[:, -1:]
                else:

                    if past_key_values is not None:
                        self._generation_base_cache_clean = clone_kv_cache(past_key_values)
                    
                    # For single token case, use base model's input_ids
                    if isinstance(input_ids, list):
                        base_input_ids = input_ids[self.base_model_idx]
                    past_length = 0
            else:
                # No projectors case - handle input format for base model
                if isinstance(input_ids, list):
                    base_input_ids = input_ids[self.base_model_idx]
                    base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
                    past_length = base_input_ids.shape[1] - 1 if base_input_ids.shape[1] > 1 else 0
                else:
                    base_input_ids = input_ids
                    base_attention_mask = attention_mask
                    past_length = input_ids.shape[1] - 1 if input_ids.shape[1] > 1 else 0
            
            # Override the base model's prepare_inputs_for_generation method temporarily
            if self.include_response:
                original_prepare_inputs = self.model_list[self.base_model_idx].prepare_inputs_for_generation
                # Store reference for our custom method to use
                self._original_prepare_inputs = original_prepare_inputs
                self.model_list[self.base_model_idx].prepare_inputs_for_generation = self.prepare_inputs_for_generation
            
            # Call the original generate method using base model's inputs
            cache_position = torch.arange(past_length, past_length + 1, device=base_input_ids.device)

            output = self.model_list[self.base_model_idx].generate(
                input_ids=base_input_ids,
                past_key_values=past_key_values,
                attention_mask=base_attention_mask,
                position_ids=position_ids,
                use_cache=use_cache,
                cache_position=cache_position,
                *args,
                **kwargs
            )
            
            # Restore original method
            if self.include_response:
                self.model_list[self.base_model_idx].prepare_inputs_for_generation = original_prepare_inputs
            
            return output
            
        finally:
            # Cleanup
            self._in_generation = False
            self._cleanup_generation_hooks()
            if hasattr(self, '_generation_source_cache'):
                delattr(self, '_generation_source_cache')
            if hasattr(self, '_generation_projections'):
                delattr(self, '_generation_projections')
            if hasattr(self, '_generation_base_cache_clean'):
                delattr(self, '_generation_base_cache_clean')
            # Ensure prepare_inputs_for_generation is restored even if an exception occurred
            if hasattr(self, '_original_prepare_inputs'):
                self.model_list[self.base_model_idx].prepare_inputs_for_generation = self._original_prepare_inputs
                delattr(self, '_original_prepare_inputs')

    @torch.no_grad()
    def generate(
        self,
        kv_cache_index,
        input_ids,
        max_new_tokens: Optional[int] = None,
        past_key_values: Optional[Cache] = None,
        attention_mask: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
        position_ids: Optional[torch.LongTensor] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        pad_token_id: Optional[int] = None,
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        do_sample: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        max_length: Optional[int] = None,
        use_cache: bool = True,
        *args,
        **kwargs,
    ):
        """
        New generation loop without using the base model's generate.
        - Uses this module's forward for prefill and per-token decode.
        - Samples tokens via rosetta.model.sampling.sample_token.
        Returns a tensor of shape [batch, prompt_len + generated_len] for the base model stream.
        """
        # Derive number of tokens to generate
        # If max_new_tokens not provided, infer from max_length
        if isinstance(input_ids, list):
            base_input_ids_for_len = input_ids[self.base_model_idx]
        else:
            base_input_ids_for_len = input_ids
        prompt_len = base_input_ids_for_len.size(1)

        # Default eos/pad from base model tokenizer/config if not provided
        base_model = self.model_list[self.base_model_idx]
        gen_cfg = getattr(base_model, "generation_config", None)
        cfg_obj = gen_cfg if gen_cfg is not None else getattr(base_model, "config", None)
        if eos_token_id is None and cfg_obj is not None:
            eos_token_id = getattr(cfg_obj, "eos_token_id", None)
        if pad_token_id is None and cfg_obj is not None:
            pad_token_id = getattr(cfg_obj, "pad_token_id", None)
        if pad_token_id is None and eos_token_id is not None:
            pad_token_id = eos_token_id if isinstance(eos_token_id, int) else eos_token_id[0]

        if max_new_tokens is None:
            if max_length is not None:
                if max_length <= prompt_len:
                    max_new_tokens = 0
                else:
                    max_new_tokens = max_length - prompt_len
            else:
                raise ValueError("Provide max_new_tokens or max_length")
        if max_new_tokens < 0:
            raise ValueError("max_new_tokens must be non-negative")

        # Resolve base inputs
        if isinstance(input_ids, list):
            base_input_ids = input_ids[self.base_model_idx]
            base_attention_mask = attention_mask[self.base_model_idx] if attention_mask is not None else None
        else:
            base_input_ids = input_ids
            base_attention_mask = attention_mask

        if base_attention_mask is None:
            base_attention_mask = torch.ones_like(base_input_ids, dtype=torch.long, device=base_input_ids.device)

        batch_size = base_input_ids.size(0)

        # Prefill to build caches and obtain initial logits
        prefill_output = self.forward(
            kv_cache_index=kv_cache_index,
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            *args,
            **kwargs,
        )

        current_past = prefill_output.past_key_values
        all_input_ids = base_input_ids
        current_attention_mask = base_attention_mask

        # EOS handling setup
        eos_set = None
        if eos_token_id is not None:
            eos_set = set(eos_token_id if isinstance(eos_token_id, list) else [eos_token_id])
        finished = torch.zeros(batch_size, dtype=torch.bool, device=all_input_ids.device)

        # Start from last prefill logits
        last_logits = prefill_output.logits[:, -1, :]

        # Determine sampling mode
        if do_sample is None:
            do_sample = False
        effective_temperature = temperature if do_sample else 0.0

        # Optional scores collection
        collect_scores = bool(return_dict_in_generate) and bool(output_scores)
        scores = []

        for _ in range(max_new_tokens):
            if collect_scores:
                scores.append(last_logits)
            # Sample next token
            next_token = sample_token(last_logits, temperature=effective_temperature, top_p=top_p, top_k=top_k)
            if not isinstance(next_token, torch.Tensor):
                next_token = torch.tensor([next_token], device=all_input_ids.device, dtype=torch.long).repeat(batch_size)

            # Apply EOS logic
            if eos_set is not None:
                just_finished = torch.zeros_like(finished)
                for eid in eos_set:
                    just_finished |= (next_token == eid)
                finished = finished | just_finished
                if pad_token_id is not None:
                    next_token = torch.where(
                        finished,
                        torch.tensor(pad_token_id, device=next_token.device, dtype=next_token.dtype),
                        next_token,
                    )

            # Append sampled token
            next_token_unsqueezed = next_token.unsqueeze(1)
            all_input_ids = torch.cat([all_input_ids, next_token_unsqueezed], dim=1)
            current_attention_mask = torch.cat(
                [
                    current_attention_mask,
                    torch.ones((batch_size, 1), device=current_attention_mask.device, dtype=current_attention_mask.dtype),
                ],
                dim=1,
            )

            # Early stop if all sequences finished
            if eos_set is not None and torch.all(finished):
                break

            # Decode one step using cached states; pass base-stream tensors
            if self.include_response:
                kv_cache_index = [torch.tensor([1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(all_input_ids.device)]
            else:
                kv_cache_index = [torch.tensor([-1, 0], dtype=torch.long).repeat(1, 1).unsqueeze(0).to(all_input_ids.device)]

            decode_output = self.forward(
                kv_cache_index=kv_cache_index,
                input_ids=next_token_unsqueezed,
                attention_mask=current_attention_mask,
                position_ids=None,
                past_key_values=current_past,
                use_cache=True,
                *args,
                **kwargs,
            )
            current_past = decode_output.past_key_values
            last_logits = decode_output.logits[:, -1, :]

        # Return style compatible with HF generate
        if return_dict_in_generate:
            if GreedySearchDecoderOnlyOutput is not None and SampleDecoderOnlyOutput is not None:
                if do_sample:
                    return SampleDecoderOnlyOutput(
                        sequences=all_input_ids,
                        scores=scores if collect_scores else None,
                    )
                else:
                    return GreedySearchDecoderOnlyOutput(
                        sequences=all_input_ids,
                        scores=scores if collect_scores else None,
                    )
            # Fallback to generic ModelOutput
            result = {"sequences": all_input_ids}
            if collect_scores:
                result["scores"] = scores
            return ModelOutput(**result)
        return all_input_ids