# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import inspect
import warnings

import time
import math

from functools import partial
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, KLDivLoss

import concurrent

from transformers.models.llama.modeling_llama import (
    _CONFIG_FOR_DOC,
    LLAMA_INPUTS_DOCSTRING,
    LlamaModel, 
    LlamaForCausalLM,
    LlamaDecoderLayer,
    LlamaRMSNorm,
    LlamaRotaryEmbedding,
    apply_rotary_pos_emb,
    rotate_half,
    repeat_kv,
    logger
)
from transformers.cache_utils import (
    Cache,
    DynamicCache,
)

# import LlamaConfig
from transformers.models.llama.configuration_llama import LlamaConfig

from transformers.generation.configuration_utils import GenerationConfig, GenerationMode
from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer

from transformers.generation import GenerationConfig, GenerationMixin
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.generation.utils import (
    ModelOutput,
    EncoderDecoderCache,
    GenerateOutput, 
    GenerateNonBeamOutput,
    GenerateEncoderDecoderOutput,
    is_deepspeed_zero3_enabled,
    is_torchdynamo_compiling,
)

from modeling_outputs import MultiHeadCausalLMOutputWithPast, MultiHeadGenerateDecoderOnlyOutput, MultiHeadBaseModelOutputWithPast

from transformers.utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)

class MyDynamicCache(DynamicCache):

    def __init__(self) -> None:
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
        kv_pad: Optional[torch.Tensor] = None,
        early_exit_layer_idx: int = 16,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """

        cur_len = key_states.shape[-2]
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += cur_len

        bsz = key_states.shape[0]
        # Update the cache
        if len(self.key_cache) < early_exit_layer_idx: # 0-16, 1-17
            if bsz == 1:
                key_states = torch.cat([key_states, key_states], dim=0)
                value_states = torch.cat([value_states, value_states], dim=0)
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        elif layer_idx < early_exit_layer_idx:
            if bsz == 1:
                key_states = torch.cat([key_states, key_states], dim=0)
                value_states = torch.cat([value_states, value_states], dim=0)
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
            else:
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
                self.key_cache[layer_idx][0, :, -2, :] = self.key_cache[layer_idx][0, :, -1, :]
                self.value_cache[layer_idx][0, :, -2, :] = self.value_cache[layer_idx][0, :, -1, :]
        else:
            layer_idx_mod = layer_idx % early_exit_layer_idx
            self.key_cache[layer_idx_mod][0, :, -cur_len:, :] = key_states
            self.value_cache[layer_idx_mod][0, :, -cur_len:, :] = value_states

        if layer_idx < early_exit_layer_idx:
            return (self.key_cache[layer_idx][1:], self.value_cache[layer_idx][1:]) if bsz == 1 else (self.key_cache[layer_idx], self.value_cache[layer_idx])
        else:
            return (self.key_cache[layer_idx_mod][:1], self.value_cache[layer_idx_mod][:1]) if bsz == 1 else (self.key_cache[layer_idx_mod], self.value_cache[layer_idx_mod])
        

class MyLlamaModel(LlamaModel):

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        self.is_parallelism_initialized = False
        self.is_placeholders_initialized = False
        self.config = config
        self.threshold = 0.2
        self.EE_layer = self.config.num_hidden_layers//2
        
        self.init_placeholders() # only for inference

        # Initialize weights and apply final processing
        self.post_init()

    def init_placeholders(self):
        self.is_placeholders_initialized = True

        self.next_token_placeholder = torch.tensor([[1]], device='cuda')

        # self.position_ids_placeholder = torch.arange(self.config.rope_scaling['original_max_position_embeddings'], device='cuda').unsqueeze(0).repeat(2, 1)
        self.position_ids_placeholder = torch.arange(4096, device='cuda').unsqueeze(0).repeat(2, 1)

        self.position_ids_placeholder[1] += 1
        
        # self.attention_mask_placeholder = torch.zeros((2, self.config.num_attention_heads, 1, self.config.rope_scaling['original_max_position_embeddings']), device='cuda', dtype=self.embed_tokens.weight.dtype)
        self.attention_mask_placeholder = torch.zeros((2, self.config.num_attention_heads, 1, 4096), device='cuda', dtype=self.embed_tokens.weight.dtype)

        # self.attention_mask_placeholder_matrix = torch.triu(torch.full((self.config.num_attention_heads, 1, self.config.rope_scaling['original_max_position_embeddings'], self.config.rope_scaling['original_max_position_embeddings']), float('-inf'), device='cuda'), diagonal=1)
        self.attention_mask_placeholder_matrix = torch.triu(torch.full((self.config.num_attention_heads, 1, 4096, 4096), float('-inf'), device='cuda'), diagonal=1)


        self.attention_mask_pad = torch.zeros((2, self.config.num_attention_heads, 1, 1), device='cuda')
        self.attention_mask_pad[0, :, :, -1] = float("-inf")

        self.kv_states_pad = torch.zeros((1, self.config.num_key_value_heads, 1, self.config.hidden_size//self.config.num_attention_heads), device='cuda')
        self.layer_norm_weights_placeholder = torch.zeros((2, 1, self.config.hidden_size), device='cuda')

    def init_parallelism(self):
        self.is_parallelism_initialized = True

        early_exit_layer_idx = self.config.num_hidden_layers//2

        self.input_layernorm_weights_combined = []
        self.post_layernorm_weights_combined = []
        self.q_proj_weights_combined = [] # q_proj_bias is None
        self.k_proj_weights_combined = [] # k_proj_bias is None
        self.v_proj_weights_combined = [] # v_proj_bias is None
        self.o_proj_weights_combined = [] # o_proj_bias is None
        self.mlp_gate_proj_weights_combined = []
        self.mlp_up_proj_weights_combined = []
        self.mlp_down_proj_weights_combined = []
        for layer_id in range(early_exit_layer_idx):
    
            # Combine input layer norm weights (layer i and layer i + self.EE_layer)
            input_weight_i = torch.cat([self.layers[early_exit_layer_idx + layer_id].input_layernorm.weight.unsqueeze(0), self.layers[layer_id].input_layernorm.weight.unsqueeze(0)], dim=0)
            self.input_layernorm_weights_combined.append(input_weight_i.unsqueeze(1))

            # Combine post layer norm weights (layer i and layer i + self.EE_layer)
            post_weight_i = torch.cat([self.layers[early_exit_layer_idx + layer_id].post_attention_layernorm.weight.unsqueeze(0), self.layers[layer_id].post_attention_layernorm.weight.unsqueeze(0)], dim=0)
            self.post_layernorm_weights_combined.append(post_weight_i.unsqueeze(1))

            # Combine q_proj weights and bias (layer i and layer i + 16)
            q_weight_i = torch.cat([self.layers[early_exit_layer_idx + layer_id].self_attn.q_proj.weight.unsqueeze(0), self.layers[layer_id].self_attn.q_proj.weight.unsqueeze(0)], dim=0)
            self.q_proj_weights_combined.append(q_weight_i)

            # Combine k_proj weights and bias (layer i and layer i + 16)
            k_weight_i = torch.cat([self.layers[early_exit_layer_idx + layer_id].self_attn.k_proj.weight.unsqueeze(0), self.layers[layer_id].self_attn.k_proj.weight.unsqueeze(0)], dim=0)
            self.k_proj_weights_combined.append(k_weight_i)

            # Combine v_proj weights and bias (layer i and layer i + 16)
            v_weight_i = torch.cat([self.layers[early_exit_layer_idx + layer_id].self_attn.v_proj.weight.unsqueeze(0), self.layers[layer_id].self_attn.v_proj.weight.unsqueeze(0)], dim=0)
            self.v_proj_weights_combined.append(v_weight_i)

            # Combine o_proj weights and bias (layer i and layer i + 16)
            o_weight_i = torch.cat([self.layers[early_exit_layer_idx + layer_id].self_attn.o_proj.weight.unsqueeze(0), self.layers[layer_id].self_attn.o_proj.weight.unsqueeze(0)], dim=0)
            self.o_proj_weights_combined.append(o_weight_i)

            # Combine mlp gate proj weights and bias (layer i and layer i + 16)
            mlp_gate_weight_i = torch.cat([self.layers[early_exit_layer_idx + layer_id].mlp.gate_proj.weight.unsqueeze(0), self.layers[layer_id].mlp.gate_proj.weight.unsqueeze(0)], dim=0)
            self.mlp_gate_proj_weights_combined.append(mlp_gate_weight_i)
            # mlp_gate_bias_i = torch.cat([self.layers[early_exit_layer_idx + layer_id].mlp.gate_proj.bias, self.layers[layer_id].mlp.gate_proj.bias], dim=0)

            # Combine mlp up proj weights and bias (layer i and layer i + 16)
            mlp_up_weight_i = torch.cat([self.layers[early_exit_layer_idx + layer_id].mlp.up_proj.weight.unsqueeze(0), self.layers[layer_id].mlp.up_proj.weight.unsqueeze(0)], dim=0)
            self.mlp_up_proj_weights_combined.append(mlp_up_weight_i)

            # Combine mlp down proj weights and bias (layer i and layer i + 16)
            mlp_down_weight_i = torch.cat([self.layers[early_exit_layer_idx + layer_id].mlp.down_proj.weight.unsqueeze(0), self.layers[layer_id].mlp.down_proj.weight.unsqueeze(0)], dim=0)
            self.mlp_down_proj_weights_combined.append(mlp_down_weight_i)
        
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        hidden_state_shortcuts: Optional[nn.Linear] = None,
        combined_lm_head_weights: Dict[int, torch.Tensor] = None,
        combined_eye_and_shortcut: Dict[int, torch.Tensor] = None,
        enable_early_exit: Optional[bool] = False,
        lm_head: Optional[nn.Linear] = None,
        pre_calculated_early_exit_layer_hidden_states: Optional[torch.FloatTensor] = None,
        pre_decoded_early_exit_token: Optional[int] = None,
    ) -> Union[Tuple, MultiHeadBaseModelOutputWithPast]:
        
        if not self.training and not self.is_parallelism_initialized: # only for inference
            self.init_parallelism()

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        return_legacy_cache = False

        # print(f'past_key_values: {past_key_values}')

        if use_cache and not isinstance(past_key_values, Cache):  # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = True
            past_key_values = MyDynamicCache.from_legacy_cache(past_key_values) if enable_early_exit else DynamicCache.from_legacy_cache(past_key_values)
            logger.warning_once(
                "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
                "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
            )

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        # embed positions
        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        next_early_exit_token = None
        next_early_exit_layer_hidden_states = None
        shortcut_triggered = False
        passed_verify = None

        for decoder_layer in self.layers:
            # print(f'Layer Index: {decoder_layer.self_attn.layer_idx}')  

            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                )
            elif not self.training:
                if enable_early_exit:
                    layer_id = decoder_layer.self_attn.layer_idx
                    if layer_id == self.EE_layer: # [16, 40]
                            
                        early_exit_layer_hidden_states = all_hidden_states[layer_id] if pre_calculated_early_exit_layer_hidden_states is None else pre_calculated_early_exit_layer_hidden_states # [1, 1, 4096]

                        if pre_decoded_early_exit_token is not None:
                            shortcut_triggered = True

                            token_id = pre_decoded_early_exit_token

                            hidden_states_tok_1 = early_exit_layer_hidden_states

                            # next_input_ids = torch.tensor([[token_id]], device=hidden_states_tok_1.device)
                            # hidden_states_tok_2 = self.embed_tokens(next_input_ids)
                            hidden_states_tok_2 = self.embed_tokens(self.next_token_placeholder * token_id) # non-tensor
                            
                            combined_hidden_states, past_key_value_parallel = self._update_KV_cache_in_parallel(hidden_states_tok_1, hidden_states_tok_2, early_exit_layer_idx=self.EE_layer, position_ids=position_ids, past_key_value=past_key_values, cache_position=cache_position, attention_mask=causal_mask)

                            last_layer_hidden_states, next_early_exit_layer_hidden_states = combined_hidden_states[:1, ...], combined_hidden_states[-1:, ...]
                        else:
                            # shortcut_logits = (self.norm(early_exit_layer_hidden_states) @ combined_lm_head_weights[layer_id][1]).float()

                            shortcut_logits = lm_head(self.norm(hidden_state_shortcuts[f'layer_{layer_id}'](early_exit_layer_hidden_states))).float()

                            probs = nn.functional.softmax(shortcut_logits[:, -1, :], dim=-1)
                            max_prob, max_token = torch.max(probs, dim=-1)

                            if max_prob[-1] > self.threshold:
                            # if pre_decoded_early_exit_token is not None:
                                shortcut_triggered = True
                                token_id = max_token[-1]
                                # print(f"\nToken ID: {token_id}: with probability: {torch.max(probs)}")
                                # print(f'Pre-decoded Early Exit Token: {pre_decoded_early_exit_token}')
                                # assert token_id == pre_decoded_early_exit_token
                                # parallel compute two forward passes
                                hidden_states_tok_1 = early_exit_layer_hidden_states

                                # next_input_ids = torch.tensor([[token_id]], device=hidden_states_tok_1.device)
                                # hidden_states_tok_2 = self.embed_tokens(next_input_ids)
                                hidden_states_tok_2 = self.embed_tokens(self.next_token_placeholder * token_id) # non-tensor
                                
                                combined_hidden_states, past_key_value_parallel = self._update_KV_cache_in_parallel(hidden_states_tok_1, hidden_states_tok_2, early_exit_layer_idx=self.EE_layer, position_ids=position_ids, past_key_value=past_key_values, cache_position=cache_position, attention_mask=causal_mask)

                                last_layer_hidden_states, next_early_exit_layer_hidden_states = combined_hidden_states[:1, ...], combined_hidden_states[-1:, ...]

                            else:
                                print(f'max_prob : {max_prob}, token: {max_token}')
                                layer_outputs = decoder_layer(
                                    early_exit_layer_hidden_states,
                                    attention_mask=causal_mask,
                                    position_ids=position_ids,
                                    past_key_value=past_key_values,
                                    output_attentions=output_attentions,
                                    use_cache=use_cache,
                                    cache_position=cache_position,
                                )  
                    elif layer_id < self.EE_layer and pre_calculated_early_exit_layer_hidden_states is not None: # [15, 39]
                        continue
                    else:
                        layer_outputs = decoder_layer(
                            hidden_states,
                            attention_mask=causal_mask,
                            position_ids=position_ids,
                            past_key_value=past_key_values,
                            output_attentions=output_attentions,
                            use_cache=use_cache,
                            cache_position=cache_position,
                        )
                else:
                    layer_outputs = decoder_layer(
                        hidden_states,
                        attention_mask=causal_mask,
                        position_ids=position_ids,
                        past_key_value=past_key_values,
                        output_attentions=output_attentions,
                        use_cache=use_cache,
                        cache_position=cache_position,
                    )  
            else: # for training
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )  

            if use_cache and shortcut_triggered is False:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

            if shortcut_triggered is True:
                # print(f"Shortcut triggered! Max Token Prob: {torch.max(probs)}: {token_id}")
                hidden_states = last_layer_hidden_states
                next_decoder_cache = past_key_value_parallel
                break
            else:
                hidden_states = layer_outputs[0]

        if shortcut_triggered is True:
            # combined_hidden_states to float()
            # start_time = time.time()
            # combined_hidden_states = combined_hidden_states.to(torch.float32)
            # print(f'combined_hidden_states to float() time: {time.time() - start_time}')

            # combined_hidden_states = self.norm(combined_hidden_states) # perfrom token_1 verification and token_2 early exiting
            # combined_logits = (combined_hidden_states @ combined_lm_head_weights[self.EE_layer]).float()

            # combined_eye_and_shortcut_manual = torch.cat([torch.eye(hidden_state_shortcuts[f'layer_{self.EE_layer}'].weight.shape[1], device=hidden_state_shortcuts[f'layer_{self.EE_layer}'].weight.device, dtype=hidden_state_shortcuts[f'layer_{self.EE_layer}'].weight.dtype).unsqueeze(0), hidden_state_shortcuts[f'layer_{self.EE_layer}'].weight.T.unsqueeze(0)], dim=0) # shape: [2, 4096, 4096]
            # combined_hidden_states = self.norm(combined_hidden_states @ combined_eye_and_shortcut_manual)

            combined_hidden_states = self.norm(combined_hidden_states @ combined_eye_and_shortcut[f'layer_{self.EE_layer}'])
            combined_logits = lm_head(combined_hidden_states).float()

            # combined_hidden_states to bfloat16
            # start_time = time.time()
            combined_probs = nn.functional.softmax(combined_logits[:, -1, :], dim=-1)
            max_prob, max_token = torch.max(combined_probs, dim=-1)

            print(f'max_prob (verification): {max_prob}')
            print(f'max_token (verification): {max_token}')

            # print(f'Time to compute combined_probs: {time.time() - start_time}')

            # print(f"Combined Max Prob: {max_prob}: {max_token}")

            # next_logits = (self.norm(next_early_exit_layer_hidden_states) @ combined_lm_head_weights[16][1]).float()
            # next_probs = nn.functional.softmax(next_logits[:, -1, :], dim=-1)
            # print(f"Next Max Prob: {torch.max(next_probs)}, token: {torch.argmax(next_probs)}")

            # top_5_probs, top_5_indices = combined_probs.topk(5)
            # cur_position = position_ids[0][0].cpu().numpy() - 24
            # print(f'Early Exit Token for iteration {cur_position}: {token_id}')
            # print(f'Top 5 Verification for iteration {cur_position}: {top_5_indices[0].cpu().numpy()}, prob: {top_5_probs[0].cpu().numpy()}') 
            # print(f'Top 5 Next Token for iteration {cur_position + 1}: {top_5_indices[1].cpu().numpy()}, prob: {top_5_probs[1].cpu().numpy()}')

            if token_id != max_token[0]:
                # print(f"!!!Token ID Mismatch: {token_id} (EE): {max_token[0]} (Truth)!!!")
                # top 5 token ids and their probabilities

                # print(f'Past Key Value Parallel Shape: {len(past_key_value_parallel)}') # (32, 2, (bsz, seq_len, num_heads, head_dim))
                # for i in range(len(past_key_value_parallel)):
                #     print(f'Layer {i} KV Cache Shape: {past_key_value_parallel[i][0].shape}')
                # print(f'past_key_value_parallel length: {past_key_value_parallel.get_seq_length()}')

                past_key_value_parallel.crop(past_key_value_parallel.get_seq_length() - 1)
                # for i in range(len(past_key_value_parallel)):
                #     print(f'Layer {i} KV Cache Shape: {past_key_value_parallel[i][0].shape}')

                next_early_exit_layer_hidden_states = None # to be replaced by next_early_exit_token
                token_id = max_token[0] # convert to fully decoded token
                passed_verify = False
            else:
                passed_verify = True
                next_early_exit_token = max_token[1] if max_prob[1] > self.threshold else None
        
        hidden_states = self.norm(hidden_states)

        # print(f'Input_ids shape (after forward pass): {input_ids.shape}')
        # print(f"Hidden States Shape: {hidden_states.shape}")

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None

        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return MultiHeadBaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            pre_calculated_early_exit_layer_hidden_states=next_early_exit_layer_hidden_states,
            pre_decoded_early_exit_token=next_early_exit_token,
            passed_verify=passed_verify,
            early_exit_id=token_id if shortcut_triggered else None,
            early_exit_logits=shortcut_logits if (shortcut_triggered and pre_decoded_early_exit_token is None) else None,
        )
    
    def _update_KV_cache_in_parallel(self, hidden_states_tok_1, hidden_states_tok_2, early_exit_layer_idx, position_ids, past_key_value, cache_position, attention_mask):

        bsz, seq_len, _ = hidden_states_tok_1.shape
        head_dim, num_key_value_heads, num_heads, num_key_value_groups = self.layers[0].self_attn.head_dim, self.layers[0].self_attn.num_key_value_heads, self.layers[0].self_attn.num_heads, self.layers[0].self_attn.num_key_value_groups

        position_ids_tok_1, position_ids_tok_2 = position_ids, position_ids + 1 # [bsz, 1]

        combined_hidden_states = torch.cat([hidden_states_tok_1, hidden_states_tok_2], dim=0)
        

        for layer_id in range(early_exit_layer_idx):

            # Self Attention
            # Parallel Layer Norm: 6e-5 
            # [2, 1, 4096]
            combined_residual = combined_hidden_states

            combined_hidden_states = self.parallel_LLamaRMSNorm(combined_hidden_states, self.layers[early_exit_layer_idx + layer_id].input_layernorm, self.layers[layer_id].input_layernorm, layer_id, 'input_norm')

            query_states_combined = self.parallel_proj('q_proj', layer_id, combined_hidden_states) # [2, 1, 4096]
            key_states_combined = self.parallel_proj('k_proj', layer_id, combined_hidden_states) # [2, 1, 1024]
            value_states_combined = self.parallel_proj('v_proj', layer_id,combined_hidden_states)  # [2, 1, 1024]

            # Batched RoPE
            # Choice 1: 1e-4 seconds
            cos_combined, sin_combined = self.rotary_emb(value_states_combined, position_ids=torch.cat([position_ids_tok_1, position_ids_tok_2], dim=0)) 

            # Choice 2: 1e-3 seconds
            # print(f'torch cat position ids: {torch.cat([position_ids_tok_1, position_ids_tok_2], dim=0)}')
            # print(f'placeholder positon ids: {self.position_ids_placeholder[:, position_ids[0][0]].unsqueeze(1)}')

            query_states_combined = query_states_combined.view(2, seq_len, num_heads, head_dim).transpose(1, 2)
            key_states_combined = key_states_combined.view(2, seq_len, num_key_value_heads, head_dim).transpose(1, 2)
            value_states_combined = value_states_combined.view(2, seq_len, num_key_value_heads, head_dim).transpose(1, 2)
            query_states_combined, key_states_combined = apply_rotary_pos_emb(query_states_combined, key_states_combined, cos_combined, sin_combined)

            # (2, 1, 1024)

            if past_key_value is not None:

                # key_states_tok_1, value_states_tok_1 = past_key_value.update(key_states_combined[0].unsqueeze(0), value_states_combined[0].unsqueeze(0), layer_id + early_exit_layer_idx, kv_pad=self.kv_states_pad)

                # key_states_tok_2, value_states_tok_2 = past_key_value.update(key_states_combined[1].unsqueeze(0), value_states_combined[1].unsqueeze(0), layer_id)

                key_states_combined, value_states_combined = past_key_value.update(key_states_combined, value_states_combined, layer_id, early_exit_layer_idx=early_exit_layer_idx)

            # key_states_combined = torch.cat([key_states_tok_1, key_states_tok_2], dim=0)
            # value_states_combined = torch.cat([value_states_tok_1, value_states_tok_2], dim=0) # 1e-5 for 2 calls

            kv_combined = torch.cat([key_states_combined, value_states_combined], dim=0)
            kv_combined = repeat_kv(kv_combined, num_key_value_groups)

            causal_mask = attention_mask
            if attention_mask is not None:
                causal_mask = causal_mask[:, :, :, : key_states_tok_1.shape[-2]]
            else:
                # torch cat 1e-5 seconds
                # causal_mask_cat = torch.cat([self.attention_mask_placeholder[:, :, :, :key_states_tok_1.shape[-2]-1], self.attention_mask_pad], dim=-1) 

                # matrix slicing 9e-6 seconds
                idx = key_states_combined.shape[-2]
                a = self.attention_mask_placeholder_matrix[:, :, idx-2:idx, :idx]
                reshaped = a.reshape(num_heads, 2, idx, 1)
                causal_mask = reshaped.permute(1, 0, 3, 2)
                # assert torch.equal(causal_mask, causal_mask_cat)

            is_causal = True if causal_mask is None and seq_len > 1 else False
            attn_output = torch.nn.functional.scaled_dot_product_attention( # 2e-5 seconds
                query_states_combined, 
                kv_combined[:2, ...], # key_state_combined
                kv_combined[2:, ...], # value_state_combined
                attn_mask=causal_mask,
                dropout_p=self.attention_dropout if self.training else 0.0,
                is_causal=is_causal,
            )

            # causal_mask[0, :, :, -1] = 0
            attn_output = attn_output.transpose(1, 2).contiguous()
            attn_output = attn_output.view(attn_output.shape[0], seq_len, -1)
            combined_attn_output = attn_output
            # print(f'Batched SDPA Time: {time.time() - start_time}')

            combined_hidden_states = self.parallel_proj('o_proj', layer_id, combined_attn_output)
            # print(f'combined_hidden_states shape: {combined_hidden_states.shape}')

            combined_hidden_states = self.parallel_residual_connection(self.layers, early_exit_layer_idx, layer_id, combined_hidden_states, combined_residual)
            # hidden_states_tok_1 = combined_hidden_states[:1, ...]
            # hidden_states_tok_2 = combined_hidden_states[-1:, ...]

        # return combined_hidden_states[:1, ...], combined_hidden_states[-1:, ...], past_key_value
        return combined_hidden_states, past_key_value
    
    def compute_sdpa(self, bsz, seq_len, query_states, key_states, value_states, attention_mask):
        causal_mask = attention_mask

        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        is_causal = True if causal_mask is None and seq_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )        

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, seq_len, -1)

        return attn_output   

    def parallel_LLamaRMSNorm(self, combined_hidden_states, layer_norm_1, layer_norm_2, layer_id, norm_type):
        input_dtype = combined_hidden_states.dtype
        combined_hidden_states = combined_hidden_states.to(torch.float32)
        variance = combined_hidden_states.pow(2).mean(-1, keepdim=True)

        # assert layer_norm_1.variance_epsilon == layer_norm_2.variance_epsilon
        combined_hidden_states = combined_hidden_states * torch.rsqrt(variance + layer_norm_1.variance_epsilon)

        # weights = torch.cat([layer_norm_1.weight.unsqueeze(0), layer_norm_2.weight.unsqueeze(0)], dim=0)
        # weights = weights.unsqueeze(1) # (2, 1, hidden_size)
        # return weights * combined_hidden_states.to(input_dtype)
    
        if norm_type == 'input_norm':
            return self.input_layernorm_weights_combined[layer_id] * combined_hidden_states.to(input_dtype)
        elif norm_type == 'post_norm':
            return self.post_layernorm_weights_combined[layer_id] * combined_hidden_states.to(input_dtype)
        else:
            raise ValueError(f'Invalid norm type: {norm_type}')

    def parallel_residual_connection(self, layers, early_exit_layer_idx, layer_id, combined_hidden_states, combined_residual):

        combined_hidden_states += combined_residual
        combined_residual = combined_hidden_states

        combined_hidden_states = self.parallel_LLamaRMSNorm(combined_hidden_states, layers[early_exit_layer_idx + layer_id].post_attention_layernorm, layers[layer_id].post_attention_layernorm, layer_id, 'post_norm')
        
        mlp_tok_1 = layers[early_exit_layer_idx + layer_id].mlp
        mlp_tok_2 = layers[layer_id].mlp

        # combined_gate_proj_weight = torch.cat([mlp_tok_1.gate_proj.weight.unsqueeze(0), mlp_tok_2.gate_proj.weight.unsqueeze(0)], dim=0)
        combined_gate_proj_weight = self.mlp_gate_proj_weights_combined[layer_id]
        combined_gate_states = torch.matmul(combined_hidden_states, combined_gate_proj_weight.transpose(-1, -2))

        if mlp_tok_1.gate_proj.bias is not None and mlp_tok_2.gate_proj.bias is not None:
            combined_bias = torch.cat([mlp_tok_1.gate_proj.bias, mlp_tok_2.gate_proj.bias]) 
            combined_gate_states += combined_bias
        combined_gate_states = mlp_tok_1.act_fn(combined_gate_states) # mlp_tok_1.act_fn = mlp_tok_2.act_fn = SiLU

        # combined_up_proj_weight = torch.cat([mlp_tok_1.up_proj.weight.unsqueeze(0), mlp_tok_2.up_proj.weight.unsqueeze(0)], dim=0)
        combined_up_proj_weight = self.mlp_up_proj_weights_combined[layer_id]
        combined_up_proj_states = torch.matmul(combined_hidden_states, combined_up_proj_weight.transpose(-1, -2))

        if mlp_tok_1.up_proj.bias is not None and mlp_tok_2.up_proj.bias is not None:
            combined_bias = torch.cat([mlp_tok_1.up_proj.bias, mlp_tok_2.up_proj.bias]) 
            combined_up_proj_states += combined_bias

        intermediate_states = combined_gate_states * combined_up_proj_states

        # combined_down_proj_weight = torch.cat([mlp_tok_1.down_proj.weight.unsqueeze(0), mlp_tok_2.down_proj.weight.unsqueeze(0)], dim=0)
        combined_down_proj_weight = self.mlp_down_proj_weights_combined[layer_id]
        combined_down_proj_states = torch.matmul(intermediate_states, combined_down_proj_weight.transpose(-1, -2))

        if mlp_tok_1.down_proj.bias is not None and mlp_tok_2.down_proj.bias is not None:
            combined_bias = torch.cat([mlp_tok_1.down_proj.bias, mlp_tok_2.down_proj.bias]) 
            combined_down_proj_states += combined_bias
        
        combined_hidden_states = combined_residual + combined_down_proj_states 

        # mlp_tok_1_down_proj = mlp_tok_1.down_proj(mlp_tok_1.act_fn(mlp_tok_1.gate_proj(x)) * mlp_tok_1.up_proj(x))
        # mlp_tok_2_down_proj = mlp_tok_2.down_proj(mlp_tok_2.act_fn(mlp_tok_1.gate_proj(x)) * mlp_tok_2.up_proj(x))

        return combined_hidden_states

    def parallel_proj(self, proj_name, layer_id, combined_hidden_states):
        '''project names: q_proj, k_proj, v_proj, o_proj'''

        if proj_name == 'q_proj':
            combined_proj_states = torch.matmul(combined_hidden_states, self.q_proj_weights_combined[layer_id].transpose(-1, -2)) # 1e-5 seconds
        elif proj_name == 'k_proj':
            combined_proj_states = torch.matmul(combined_hidden_states, self.k_proj_weights_combined[layer_id].transpose(-1, -2))
        elif proj_name == 'v_proj':
            combined_proj_states = torch.matmul(combined_hidden_states, self.v_proj_weights_combined[layer_id].transpose(-1, -2))
        elif proj_name == 'o_proj':
            combined_proj_states = torch.matmul(combined_hidden_states, self.o_proj_weights_combined[layer_id].transpose(-1, -2))
        
        # if proj_1.bias is not None and proj_2.bias is not None:
        #     combined_bias = torch.cat([proj_1.bias, proj_2.bias]) 
        #     combined_proj_states += combined_bias

        # Verification; the precision error could be large
        # if hidden_states_tok_1 is not None and hidden_states_tok_2 is not None:
        #     state_1 = proj_1(hidden_states_tok_1)
        #     state_2 = proj_2(hidden_states_tok_2)
        #     print(f'{proj_name} combined_proj_states shape: {combined_proj_states.shape}')
        #     print(f'{proj_name} state_1 shape (layer {early_exit_layer_idx + layer_id}): {state_1.shape}')
        #     print(f'{proj_name} state_2 shape (layer {layer_id}): {state_2.shape}')
        #     combined_proj_gt = torch.cat([state_1, state_2], dim=0)
        #     try:
        #         assert torch.equal(combined_proj_states, combined_proj_gt)
        #         print(f'==={proj_name} combined_proj_states and concatenated states ARE equal')
        #     except AssertionError:
        #         print(f'==={proj_name} combined_proj_states and concatenated states are NOT equal')
        #         # print sum of differences
        #         print(torch.max(torch.abs(combined_proj_states - combined_proj_gt)))
        #         if torch.allclose(combined_proj_states, combined_proj_gt):
        #             print('But they are close')

        return combined_proj_states

    
class MultiHeadLlamaForCausalLM(LlamaForCausalLM):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):

        super().__init__(config)
        # self.model = LlamaModel(config)
        self.model = MyLlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.n_layers = config.num_hidden_layers

        if self.n_layers == 32:
            self.early_layers_with_lm_heads = [4, 8, 12, 16, 20, 24, 28] # range(4, 32, 4), 7B/8B, e.g, CodeLlama-7b-Instruct
        elif self.n_layers == 40:
            self.early_layers_with_lm_heads = [5, 10, 15, 20, 25, 30, 35] # range(5, 40, 5), 13B, e.g, CodeLlama-13b-Instruct
        elif self.n_layers == 48:
            # self.early_layers_with_lm_heads = [6, 12, 18, 24, 30, 36, 42] # range(6, 48, 6), 34B, e.g, CodeLlama-34b-Instruct
            self.early_layers_with_lm_heads = [24] # range(6, 48, 6), 34B, e.g, CodeLlama-34b-Instruct
        elif self.n_layers == 80:
            self.early_layers_with_lm_heads = [10, 20, 30, 40, 50, 60, 70] # range(10, 80, 10), 70B, e.g, CodeLlama-70b-Instruct
        else:
            raise ValueError(f'Invalid number of layers: {self.n_layers}')

        self.hidden_state_shortcuts = nn.ModuleDict({
            f"layer_{i}": nn.Linear(config.hidden_size, config.hidden_size, bias=False)
            for i in self.early_layers_with_lm_heads
        })

        self.is_combined_weights_initialized = False

        # Initialize weights and apply final processing
        self.post_init()

    def init_combined_weights(self):
        self.is_combined_weights_initialized = True
        self.combined_eye_and_shortcut = {}
        self.combined_lm_head_weights = {}
        self.estimated_lm_head_weights = {}
        eye_matrix = torch.eye(self.config.hidden_size, device='cuda', dtype=self.model.embed_tokens.weight.dtype)
        for layer_id in self.early_layers_with_lm_heads:
            self.combined_eye_and_shortcut[f'layer_{layer_id}'] = torch.cat([eye_matrix.unsqueeze(0), self.hidden_state_shortcuts[f'layer_{layer_id}'].weight.T.unsqueeze(0)], dim=0)

            # self.estimated_lm_head_weights[layer_id] = self.hidden_state_shortcuts[f'layer_{layer_id}'].weight.T @ self.lm_head.weight.T # [4096, 128256]
            # self.combined_lm_head_weights[layer_id] = torch.cat([self.lm_head.weight.T.unsqueeze(0), self.estimated_lm_head_weights[layer_id].unsqueeze(0)], dim=0) # [2, 4096, 128256]

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=MultiHeadCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        return_early_layer_logits: Optional[bool] = False,
        enable_early_exit: Optional[bool] = False,
        pre_calculated_early_exit_layer_hidden_states: Optional[torch.FloatTensor] = None,
        pre_decoded_early_exit_token: Optional[int] = None,
    ) -> Union[Tuple, MultiHeadCausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        if not self.is_combined_weights_initialized and not self.training: # only for inference
            self.init_combined_weights()

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            hidden_state_shortcuts=self.hidden_state_shortcuts,
            lm_head=self.lm_head,
            combined_lm_head_weights=self.combined_lm_head_weights,
            combined_eye_and_shortcut=self.combined_eye_and_shortcut,
            enable_early_exit=enable_early_exit,
            pre_calculated_early_exit_layer_hidden_states=pre_calculated_early_exit_layer_hidden_states,
            pre_decoded_early_exit_token=pre_decoded_early_exit_token,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
            logits = logits.float()
        else:
            if outputs.early_exit_id is not None: # this means no verification of shortcut
                logits = outputs.early_exit_logits # no need to convert to float, already float
            else:
                logits = self.lm_head(hidden_states)
                logits = logits.float()
            # logits = self.lm_head(hidden_states)
            # logits = logits.float()

        # assert torch.all(outputs.hidden_states[32] == outputs[0])

        loss = None
        loss_dict = {}
        logits_dict = {}
        if return_early_layer_logits is True:
            for layer_id in range(0, self.n_layers + 2, 2):
                if layer_id not in self.early_layers_with_lm_heads:
                    # use default LM head
                    logits_dict[layer_id] = self.lm_head(self.model.norm(outputs.hidden_states[layer_id])).float()
                else:
                    estimated_hidden_state = self.model.norm(self.hidden_state_shortcuts[f'layer_{layer_id}'](outputs.hidden_states[layer_id]))
                    logits_dict[layer_id] = self.lm_head(estimated_hidden_state).float()

        if labels is not None:

            '''https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html'''
            loss_fct = KLDivLoss(reduction="batchmean", log_target=True)

            log_target = F.log_softmax(logits.view(-1, self.config.vocab_size), dim=-1)

            for layer_id in self.early_layers_with_lm_heads:
                estimated_hidden_state = self.model.norm(self.hidden_state_shortcuts[f'layer_{layer_id}'](outputs.hidden_states[layer_id]))
                logits_dict[layer_id] = self.lm_head(estimated_hidden_state).float()

                loss_dict[f"layer_{layer_id}"] = loss_fct(F.log_softmax(logits_dict[layer_id].view(-1, self.config.vocab_size), dim=-1), log_target)
            
            loss = sum(loss_dict.values())

            # print(f'last layer logits shape: {logits.shape}')
            # print(f'layer {early_layers_with_lm_heads[0]} logits shape: {logits_dict[early_layers_with_lm_heads[0]].shape}')

            # log_target = F.log_softmax(logits[:, :-1, :].reshape(-1, self.config.vocab_size), dim=-1)

            # early_exit_layer_loss = loss_fct(F.log_softmax(logits_dict[early_layers_with_lm_heads[0]][:, :-1, :].reshape(-1, self.config.vocab_size), dim=-1), log_target)
            # early_exit_layer_loss = loss_fct(F.log_softmax(logits_dict[self.early_layers_with_lm_heads[0]].view(-1, self.config.vocab_size), dim=-1), log_target)
            # loss_dict = {f"layer_{self.early_layers_with_lm_heads[0]}": early_exit_layer_loss}
            # loss = early_exit_layer_loss

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return MultiHeadCausalLMOutputWithPast(
            loss=loss,
            loss_dict=loss_dict,
            logits=logits,
            logits_dict=logits_dict,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            pre_calculated_early_exit_layer_hidden_states=outputs.pre_calculated_early_exit_layer_hidden_states,
            early_exit_id=outputs.early_exit_id,
            pre_decoded_early_exit_token=outputs.pre_decoded_early_exit_token,
            passed_verify=outputs.passed_verify
        )

    def _sample(
        self,
        input_ids: torch.LongTensor,
        logits_processor: LogitsProcessorList,
        stopping_criteria: StoppingCriteriaList,
        generation_config: GenerationConfig,
        synced_gpus: bool,
        streamer: Optional["BaseStreamer"],
        logits_warper: Optional[LogitsProcessorList] = None,
        enable_early_exit: bool = False,
        return_early_layer_logits: bool = False,
        **model_kwargs,
    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
        r"""
        Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.

        Parameters:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation.
            logits_processor (`LogitsProcessorList`):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
            stopping_criteria (`StoppingCriteriaList`):
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
            generation_config ([`~generation.GenerationConfig`]):
                The generation configuration to be used as parametrization of the decoding method.
            synced_gpus (`bool`):
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
            streamer (`BaseStreamer`, *optional*):
                Streamer object that will be used to stream the generated sequences. Generated tokens are passed
                through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
            logits_warper (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
                to warp the prediction score distribution of the language modeling head applied before multinomial
                sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
                `generation_config`)
            model_kwargs:
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.

        Return:
            [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
            A `torch.LongTensor` containing the generated tokens (default behaviour) or a
            [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
            `model.config.is_encoder_decoder=True`.
        """
    
        # init values
        pad_token_id = generation_config.pad_token_id
        output_attentions = generation_config.output_attentions
        output_hidden_states = generation_config.output_hidden_states
        output_scores = generation_config.output_scores
        output_logits = generation_config.output_logits
        return_dict_in_generate = generation_config.return_dict_in_generate
        has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
        do_sample = generation_config.do_sample
        if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
            raise ValueError(
                "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
                f"{logits_warper})."
            )

        # init attention / hidden states / scores tuples
        early_exit_layer_logits = () if (return_dict_in_generate and output_logits) else None
        scores = () if (return_dict_in_generate and output_scores) else None
        raw_logits = () if (return_dict_in_generate and output_logits) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

        # keep track of which sequences are already finished
        batch_size = input_ids.shape[0]
        this_peer_finished = False
        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
        model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

        cnt = 0
        early_exit_cnt = 0
        success_verify_cnt = 0
        pre_calculated_early_exit_layer_hidden_states = None
        pre_decoded_early_exit_token=None
        
        # time_elapsed = []
        # condition_time = []
        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):

            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # if cnt > 1:
            #     exit()

            # print(f'\n==Iteration {cnt}==\ninput_ids shape (raw): {input_ids.shape}')
            # print(f'Model Inputs (actual): {model_inputs}')

            # print(f'pre_calculated_early_exit_layer_hidden_states shape: {pre_calculated_early_exit_layer_hidden_states.shape if pre_calculated_early_exit_layer_hidden_states is not None else None}')

            # start_time_forward = time.time()
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_early_layer_logits=return_early_layer_logits,
                enable_early_exit=enable_early_exit if cnt > 0 else False,
                pre_calculated_early_exit_layer_hidden_states=pre_calculated_early_exit_layer_hidden_states,
                pre_decoded_early_exit_token=pre_decoded_early_exit_token
            )
            # time_elapsed.append(time.time() - start_time_forward)

            # print(f'Token {cnt} time: {time_elapsed[-1]}')

            # time elapsed for the second part
            pre_calculated_early_exit_layer_hidden_states = outputs.pre_calculated_early_exit_layer_hidden_states
            pre_decoded_early_exit_token = outputs.pre_decoded_early_exit_token
            passed_verify=outputs.passed_verify

            if pre_calculated_early_exit_layer_hidden_states is not None or passed_verify is not None:
                early_exit_cnt += 1
                if passed_verify is True:
                    success_verify_cnt += 1
            
            if synced_gpus and this_peer_finished:
                continue  # don't waste resources running the code we don't need
            
            if outputs.early_exit_id is None:
                # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
                # (the clone itself is always small)
                next_token_logits = outputs.logits[:, -1, :].clone()

                # pre-process distribution
                next_token_scores = logits_processor(input_ids, next_token_logits)

                if do_sample:
                    next_token_scores = logits_warper(input_ids, next_token_scores)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:

                if output_scores:
                    scores += (next_token_scores,)

                if output_logits:
                    raw_logits += (next_token_logits,)
                    
                    # early_exit_layer_logits_dict = {}
                    # for k,v in outputs.logits_dict.items():
                    #     early_exit_layer_logits_dict[k] = logits_processor(input_ids, v[:, -1, :].clone())

                    # early_exit_layer_logits += (early_exit_layer_logits_dict,)

                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

            if outputs.early_exit_id is None:
                # token selection
                if do_sample:
                    probs = nn.functional.softmax(next_token_scores, dim=-1)
                    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                else:
                    next_tokens = torch.argmax(next_token_scores, dim=-1)
            else:
                next_tokens = outputs.early_exit_id.unsqueeze(0)

            # finished sentences should have their next token be a padding token
            if has_eos_stopping_criteria:
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            if streamer is not None:
                streamer.put(next_tokens.cpu())
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs,
                model_kwargs,
                is_encoder_decoder=self.config.is_encoder_decoder,
            )

            # verification
            # if pre_calculated_early_exit_layer_hidden_states is not None:
            #     early_exit_id = outputs.early_exit_id
            #     assert early_exit_id is not None
            #     if next_tokens[0] != early_exit_id:
            #         # print(f'\n\n==Verification Failed at iteration {cnt}: Early Exiting token {early_exit_id} != Fully Decoding Tokens: {next_tokens}')
            #         pre_calculated_early_exit_layer_hidden_states = None
            #         # if not aligned, update pre_calculated_early_exit_layer_hidden_states=None and clear past_key_values

            #         past_key_values = model_kwargs.get("past_key_values")
            #         past_key_values.crop(past_key_values.get_seq_length() -1)
  
            #         model_kwargs["past_key_values"] = past_key_values
                    

            unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
            this_peer_finished = unfinished_sequences.max() == 0

            # This is needed to properly delete outputs.logits which may be very large for first iteration
            # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
            del outputs

            cnt += 1

        # print(f'Generation Time (model.forward): {sum(time_elapsed)}')
        # print(f'Generation Time (condition_time): {sum(condition_time)} for {len(condition_time)} iterations')

        if streamer is not None:
            streamer.end()

        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return GenerateEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    logits=raw_logits,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
                    cross_attentions=cross_attentions,
                    decoder_hidden_states=decoder_hidden_states,
                    past_key_values=model_kwargs.get("past_key_values"),
                )
            else:
                return MultiHeadGenerateDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    logits=raw_logits,
                    early_exit_layer_logits = early_exit_layer_logits,
                    early_exit_cnt = early_exit_cnt,
                    success_verify_cnt = success_verify_cnt,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                    past_key_values=model_kwargs.get("past_key_values"),
                )
        else:
            return input_ids

    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
        synced_gpus: Optional[bool] = None,
        assistant_model: Optional["PreTrainedModel"] = None,
        streamer: Optional["BaseStreamer"] = None,
        negative_prompt_ids: Optional[torch.Tensor] = None,
        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
        enable_early_exit: bool = False,
        return_early_layer_logits: bool = False,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        r"""

        Generates sequences of token ids for models with a language modeling head.

        <Tip warning={true}>

        Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
        model's default generation configuration. You can override any `generation_config` by passing the corresponding
        parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.

        For an overview of generation strategies and code examples, check out the [following
        guide](../generation_strategies).

        </Tip>

        Parameters:
            inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
                The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
                method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
                should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
                `input_ids`, `input_values`, `input_features`, or `pixel_values`.
            generation_config ([`~generation.GenerationConfig`], *optional*):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which has the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            logits_processor (`LogitsProcessorList`, *optional*):
                Custom logits processors that complement the default logits processors built from arguments and
                generation config. If a logit processor is passed that is already created with the arguments or a
                generation config an error is thrown. This feature is intended for advanced users.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                Custom stopping criteria that complements the default stopping criteria built from arguments and a
                generation config. If a stopping criteria is passed that is already created with the arguments or a
                generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
                sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
                intended for advanced users.
            prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
                If provided, this function constraints the beam search to allowed tokens only at each step. If not
                provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
                `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
                on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
                for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
                Retrieval](https://arxiv.org/abs/2010.00904).
            synced_gpus (`bool`, *optional*):
                Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
                `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
                generating before other GPUs. Otherwise it'll be set to `False`.
            assistant_model (`PreTrainedModel`, *optional*):
                An assistant model that can be used to accelerate generation. The assistant model must have the exact
                same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
                is much faster than running generation with the model you're calling generate from. As such, the
                assistant model should be much smaller.
            streamer (`BaseStreamer`, *optional*):
                Streamer object that will be used to stream the generated sequences. Generated tokens are passed
                through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
            negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                The negative prompt needed for some processors such as CFG. The batch size must match the input batch
                size. This is an experimental feature, subject to breaking API changes in future versions.
            negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Attention_mask for `negative_prompt_ids`.
            kwargs (`Dict[str, Any]`, *optional*):
                Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
                forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
                specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.

        Return:
            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
            or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`.

                If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
                [`~utils.ModelOutput`] types are:

                    - [`~generation.GenerateDecoderOnlyOutput`],
                    - [`~generation.GenerateBeamDecoderOnlyOutput`]

                If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
                [`~utils.ModelOutput`] types are:

                    - [`~generation.GenerateEncoderDecoderOutput`],
                    - [`~generation.GenerateBeamEncoderDecoderOutput`]
        """
        # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
        self._validate_model_class()
        tokenizer = kwargs.pop("tokenizer", None)  # Pull this out first, we only use it for stopping criteria
        generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
        self._validate_model_kwargs(model_kwargs.copy())
        self._validate_assistant(assistant_model)

        # 2. Set generation parameters if not already defined
        if synced_gpus is None:
            if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
                synced_gpus = True
            else:
                synced_gpus = False
        

        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs
        kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None

        # 3. Define model inputs
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
            inputs, generation_config.bos_token_id, model_kwargs
        )
        batch_size = inputs_tensor.shape[0]

        device = inputs_tensor.device
        self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

        # decoder-only models must use left-padding for batched generation.
        if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
            # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
            # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
            if (
                generation_config._pad_token_tensor is not None
                and batch_size > 1
                and len(inputs_tensor.shape) == 2
                and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
            ):
                logger.warning(
                    "A decoder-only architecture is being used, but right-padding was detected! For correct "
                    "generation results, please set `padding_side='left'` when initializing the tokenizer."
                )

        # 4. Define other model kwargs
        # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
        # generating the first new token or not, and we only want to use the embeddings for the first new token)
        if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
            model_kwargs["use_cache"] = True
        else:
            model_kwargs["use_cache"] = generation_config.use_cache

        if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
                inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
            )

        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
            # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, model_kwargs, model_input_name, generation_config
            )

        # 5. Prepare `input_ids` which will be used for auto-regressive generation
        if self.config.is_encoder_decoder:
            input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
                batch_size=batch_size,
                model_input_name=model_input_name,
                model_kwargs=model_kwargs,
                decoder_start_token_id=generation_config._decoder_start_token_tensor,
                device=inputs_tensor.device,
            )
        else:
            input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

        if generation_config.token_healing:
            input_ids = self.heal_tokens(input_ids, tokenizer)

        if streamer is not None:
            streamer.put(input_ids.cpu())

        # 6. Prepare `max_length` depending on other stopping criteria.
        input_ids_length = input_ids.shape[-1]
        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
        generation_config = self._prepare_generated_length(
            generation_config=generation_config,
            has_default_max_length=has_default_max_length,
            has_default_min_length=has_default_min_length,
            model_input_name=model_input_name,
            inputs_tensor=inputs_tensor,
            input_ids_length=input_ids_length,
        )

        use_dynamic_cache_by_default = False
        if "mamba" in self.__class__.__name__.lower():
            cache_name = "cache_params"
        else:
            cache_name = "past_key_values"
        if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None):
            raise ValueError(
                f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
                "Cache object) is unsupported. Please use only one of the two."
            )
        elif generation_config.cache_implementation is not None:
            if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
                if generation_config.cache_implementation == "static" and not self._supports_static_cache:
                    raise ValueError(
                        "This model does not support `cache_implementation='static'`. Please check the following "
                        "issue: https://github.com/huggingface/transformers/issues/28981"
                    )
                model_kwargs[cache_name] = self._get_cache(
                    generation_config.cache_implementation,
                    getattr(generation_config, "num_beams", 1) * batch_size,
                    generation_config.max_length,
                    model_kwargs,
                )
            elif generation_config.cache_implementation == "quantized":
                if not self._supports_quantized_cache:
                    raise ValueError(
                        "This model does not support the quantized cache. If you want your model to support quantized "
                        "cache, please open an issue."
                    )

                cache_config = (
                    generation_config.cache_config
                    if generation_config.cache_config is not None
                    else QuantizedCacheConfig()
                )
                cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]

                if cache_config.backend == "quanto" and not is_quanto_available():
                    raise ImportError(
                        "You need to install `quanto` in order to use KV cache quantization with quanto backend. "
                        "Please install it via  with `pip install quanto`"
                    )
                elif cache_config.backend == "HQQ" and not is_hqq_available():
                    raise ImportError(
                        "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
                        "Please install it via  with `pip install hqq`"
                    )

                model_kwargs[cache_name] = cache_class(cache_config)
        # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
        # keeps copying the cache thus using much more memory
        elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
            past = model_kwargs.get(cache_name, None)
            requires_cross_attention_cache = (
                self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
            )
            if past is None:
                model_kwargs[cache_name] = (
                    MyDynamicCache() if enable_early_exit else DynamicCache()
                    if not requires_cross_attention_cache
                    else EncoderDecoderCache(DynamicCache(), DynamicCache()) 
                )
                use_dynamic_cache_by_default = True
            elif isinstance(past, tuple):
                model_kwargs[cache_name] = (
                    MyDynamicCache.from_legacy_cache(past) if enable_early_exit else DynamicCache.from_legacy_cache(past)
                    if not requires_cross_attention_cache
                    else EncoderDecoderCache.from_legacy_cache(past)
                )
                use_dynamic_cache_by_default = True

        self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

        # 7. determine generation mode
        generation_mode = generation_config.get_generation_mode(assistant_model)

        if streamer is not None and (generation_config.num_beams > 1):
            raise ValueError(
                "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
            )

        if self.device.type != input_ids.device.type:
            warnings.warn(
                "You are calling .generate() with the `input_ids` being on a device type different"
                f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
                f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
                " Please make sure that you have put `input_ids` to the"
                f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
                " running `.generate()`.",
                UserWarning,
            )

        # 8. prepare distribution pre_processing samplers
        prepared_logits_processor = self._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_length,
            encoder_input_ids=inputs_tensor,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            logits_processor=logits_processor,
            device=inputs_tensor.device,
            model_kwargs=model_kwargs,
            negative_prompt_ids=negative_prompt_ids,
            negative_prompt_attention_mask=negative_prompt_attention_mask,
        )

        # 9. prepare stopping criteria
        prepared_stopping_criteria = self._get_stopping_criteria(
            generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
        )

        # 10. go into different generation modes
        if generation_mode == GenerationMode.ASSISTED_GENERATION:
            if generation_config.num_return_sequences > 1:
                raise ValueError(
                    "num_return_sequences has to be 1 when doing assisted generate, "
                    f"but is {generation_config.num_return_sequences}."
                )
            if batch_size > 1:
                raise ValueError("assisted generate is only supported for batch_size = 1")
            if not model_kwargs["use_cache"]:
                raise ValueError("assisted generate requires `use_cache=True`")
            if generation_config.cache_implementation == "static":
                raise ValueError("assisted generate is not supported with `static_cache`")
            if self._is_stateful:
                # In assisted generation we need the ability to confirm whether the model would pick certain tokens,
                # which is not possible with stateful models (they can't reset to a previous subset of generated text)
                raise ValueError(
                    f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
                )

            # 11. Get the candidate generator, given the parameterization
            candidate_generator = self._get_candidate_generator(
                generation_config=generation_config,
                input_ids=input_ids,
                inputs_tensor=inputs_tensor,
                assistant_model=assistant_model,
                logits_processor=logits_processor,
                model_kwargs=model_kwargs,
            )

            # 12. prepare logits warper (if `do_sample` is `True`)
            prepared_logits_warper = (
                self._get_logits_warper(
                    generation_config,
                    device=input_ids.device,
                )
                if generation_config.do_sample
                else None
            )

            # 13. run assisted generate
            result = self._assisted_decoding(
                input_ids,
                candidate_generator=candidate_generator,
                logits_processor=prepared_logits_processor,
                logits_warper=prepared_logits_warper,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                streamer=streamer,
                **model_kwargs,
            )
        elif generation_mode == GenerationMode.DOLA_GENERATION:
            if self._is_stateful:
                # DoLa decoding was not designed for stateful models, and would require some changes
                raise ValueError(
                    f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}"
                )
            prepared_logits_warper = (
                self._get_logits_warper(generation_config, device=input_ids.device)
                if generation_config.do_sample
                else None
            )
            result = self._dola_decoding(
                input_ids,
                dola_layers=generation_config.dola_layers,
                logits_processor=prepared_logits_processor,
                logits_warper=prepared_logits_warper,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                streamer=streamer,
                **model_kwargs,
            )

        elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
            if not model_kwargs["use_cache"]:
                raise ValueError("Contrastive search requires `use_cache=True`")
            if self._is_stateful:
                # Just like assisted generation, we need to be able to rollback to a previous state (see comment above)
                raise ValueError(
                    f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}"
                )

            result = self._contrastive_search(
                input_ids,
                logits_processor=prepared_logits_processor,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                streamer=streamer,
                **model_kwargs,
            )

        elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
            # 11. prepare logits warper
            prepared_logits_warper = (
                self._get_logits_warper(generation_config, device=input_ids.device)
                if generation_config.do_sample
                else None
            )
            
            # 12. expand input_ids with `num_return_sequences` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

            # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
            result = self._sample(
                input_ids,
                logits_processor=prepared_logits_processor,
                logits_warper=prepared_logits_warper,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                streamer=streamer,
                enable_early_exit=enable_early_exit,
                return_early_layer_logits=return_early_layer_logits,
                **model_kwargs,
            )

        elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
            # 11. prepare logits warper
            prepared_logits_warper = (
                self._get_logits_warper(generation_config, device=input_ids.device)
                if generation_config.do_sample
                else None
            )

            # 12. prepare beam search scorer
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=generation_config.num_beams,
                device=inputs_tensor.device,
                length_penalty=generation_config.length_penalty,
                do_early_stopping=generation_config.early_stopping,
                num_beam_hyps_to_keep=generation_config.num_return_sequences,
                max_length=generation_config.max_length,
            )

            # 13. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_beams,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

            # 14. run beam sample
            result = self._beam_search(
                input_ids,
                beam_scorer,
                logits_processor=prepared_logits_processor,
                logits_warper=prepared_logits_warper,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

        elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
            # 11. prepare beam search scorer
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=generation_config.num_beams,
                device=inputs_tensor.device,
                length_penalty=generation_config.length_penalty,
                do_early_stopping=generation_config.early_stopping,
                num_beam_hyps_to_keep=generation_config.num_return_sequences,
                num_beam_groups=generation_config.num_beam_groups,
                max_length=generation_config.max_length,
            )
            # 12. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_beams,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
            # 13. run beam search
            result = self._group_beam_search(
                input_ids,
                beam_scorer,
                logits_processor=prepared_logits_processor,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

        elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
            final_constraints = []
            if generation_config.constraints is not None:
                final_constraints = generation_config.constraints

            if generation_config.force_words_ids is not None:

                def typeerror():
                    raise ValueError(
                        "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]` "
                        f"of positive integers, but is {generation_config.force_words_ids}."
                    )

                if (
                    not isinstance(generation_config.force_words_ids, list)
                    or len(generation_config.force_words_ids) == 0
                ):
                    typeerror()

                for word_ids in generation_config.force_words_ids:
                    if isinstance(word_ids[0], list):
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any(not isinstance(token_ids, list) for token_ids in word_ids):
                            typeerror()
                        if any(
                            any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
                            for token_ids in word_ids
                        ):
                            typeerror()

                        constraint = DisjunctiveConstraint(word_ids)
                    else:
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
                            typeerror()

                        constraint = PhrasalConstraint(word_ids)
                    final_constraints.append(constraint)

            # 11. prepare beam search scorer
            constrained_beam_scorer = ConstrainedBeamSearchScorer(
                constraints=final_constraints,
                batch_size=batch_size,
                num_beams=generation_config.num_beams,
                device=inputs_tensor.device,
                length_penalty=generation_config.length_penalty,
                do_early_stopping=generation_config.early_stopping,
                num_beam_hyps_to_keep=generation_config.num_return_sequences,
                max_length=generation_config.max_length,
            )
            # 12. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_beams,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
            # 13. run beam search
            result = self._constrained_beam_search(
                input_ids,
                constrained_beam_scorer=constrained_beam_scorer,
                logits_processor=prepared_logits_processor,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

        # Convert to legacy cache if needed
        if use_dynamic_cache_by_default and generation_config.return_legacy_cache:
            if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
                if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)):
                    result.past_key_values = result.past_key_values.to_legacy_cache()
        return result