# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code for Knowledge Base augmented Language Model model, code mostly adapted from LLaMA's source code
Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
"""

import math
import os
import copy
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
    _CONFIG_FOR_DOC,
    LLAMA_INPUTS_DOCSTRING,
    LLAMA_START_DOCSTRING,
    LlamaDynamicNTKScalingRotaryEmbedding,
    LlamaLinearScalingRotaryEmbedding,
    LlamaMLP,
    LlamaPreTrainedModel,
    LlamaRMSNorm,
    LlamaRotaryEmbedding,
    apply_rotary_pos_emb,
    repeat_kv,
)
from transformers.utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)

from atlaskv.models.kblam_config import KBLaMConfig, KGLaMConfig

logger = logging.get_logger(__name__)

PADDING_VALUE = torch.finfo(torch.bfloat16).min

def get_llama_attention_classes(use_kg: bool = False):
    return {
        "eager": KglamLlamaAttention if use_kg else KblamLlamaAttention,
        "flash_attention_2": KglamLlamaAttention if use_kg else KblamLlamaAttention,
        "sdpa": KglamLlamaAttention if use_kg else KblamLlamaAttention,
    }

_LLAMA_ATTENTION_CLASSES = None
def set_llama_attention_classes(use_kg: bool = False):
    global _LLAMA_ATTENTION_CLASSES
    _LLAMA_ATTENTION_CLASSES = get_llama_attention_classes(use_kg)
    logger.info(f"Using KG attention classes: {_LLAMA_ATTENTION_CLASSES}")

def get_current_llama_attention_classes():
    global _LLAMA_ATTENTION_CLASSES
    if _LLAMA_ATTENTION_CLASSES is None:
        _LLAMA_ATTENTION_CLASSES = get_llama_attention_classes(False)
    return _LLAMA_ATTENTION_CLASSES


class KglamLlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' implemented as Rectangular attention"""

    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.score_shift = nn.Parameter(torch.zeros(self.num_heads, 1) - 3)
        self.q_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
        )
        #TODO: define different q_proj for different KB layers
        self.q_proj_kg = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=config.attention_bias,
        )
        self.v_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=config.attention_bias,
        )
        self.o_proj = nn.Linear(
            self.hidden_size, self.hidden_size, bias=config.attention_bias
        )
        self._init_rope()

    def _init_rope(self):
        if self.config.rope_scaling is None:
            self.rotary_emb = LlamaRotaryEmbedding(
                self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.rope_theta,
            )
        else:
            scaling_type = self.config.rope_scaling["rope_type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "llama3":
                self.rotary_emb = LlamaRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    base=self.rope_theta,
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

    def prune_key_value(self, query, kb_keys, kb_values, topk_size=20):
        assert (
            query.requires_grad is False
        ), "This function should only be used at test time"
        batch_size, num_heads, kb_len, head_dim = kb_keys.shape
        attn_weights = torch.matmul(query, kb_keys.transpose(2, 3)) / math.sqrt(
            self.head_dim
        )  # Batchsize, num_heads, query_size, key_size
        if topk_size >= kb_len:
            return kb_keys, kb_values, attn_weights
        with torch.autograd.no_grad():
            top_idx = attn_weights.sum((1, 2)).topk(min(kb_len, topk_size), -1)[1]
            # top_idx = attn_weights.sum(1).topk(topk_size, -1)[1]
            top_idx = top_idx.view(batch_size, -1, topk_size, 1).expand(
                batch_size, num_heads, topk_size, head_dim
            )
            kb_keys = kb_keys.gather(-2, top_idx)
            kb_values = kb_values.gather(-2, top_idx)
        return kb_keys, kb_values, attn_weights[..., :topk_size]

    def prune_key(self, attn_weights, layer_kb_indices=None, topk_size=20):
        """
        Target: Prune the keys and values of the attention matrix, and return the top-k indices
        """
        assert (
            attn_weights.requires_grad is False
        ), "This function should only be used at test time"
        batch_size, num_heads, max_len, layer_kb_len = attn_weights.shape
        # softmax here is necessary
        attn_weights_prob = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
        if layer_kb_indices is not None:
            if topk_size >= layer_kb_len:
                return layer_kb_indices
            with torch.autograd.no_grad():
                # top_idx = attn_weights_prob.sum((1, 2), dtype=torch.float32).topk(min(layer_kb_len, topk_size), -1)[1]
                top_idx = torch.topk(attn_weights_prob.sum((1, 2)), min(layer_kb_len, topk_size), dim=-1)[1]
                # filter out the indices that are not in layer_kb_indices
                # top_idx = top_idx % 5
            return layer_kb_indices[top_idx.squeeze().to('cpu')]
        else:
            if topk_size >= layer_kb_len:
                return range(layer_kb_len)
            with torch.autograd.no_grad():
                # top_idx = attn_weights_prob.sum((1, 2), dtype=torch.float32).topk(min(layer_kb_len, topk_size), -1)[1]
                top_idx = torch.topk(attn_weights_prob.sum((1, 2)), min(layer_kb_len, topk_size), dim=-1)[1]
            return top_idx.squeeze().to('cpu')

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        kb_kvs: Optional[tuple] = None,
        kb_config: Optional[KBLaMConfig | KGLaMConfig] = None,
        save_attention_weights: bool = True,
        attention_save_loc: Optional[str] = None,
        attention_file_base_name: Optional[str] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if save_attention_weights:
            assert (
                attention_save_loc is not None
            ), "Please provide a location to save the attention weights"
            assert (
                attention_file_base_name is not None
            ), "Please provide a base name for the attention weights"
        bsz, q_len, _ = hidden_states.size()
        if self.config.pretraining_tp > 1:
            key_value_slicing = (
                self.num_key_value_heads * self.head_dim
            ) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [
                F.linear(hidden_states, query_slices[i])
                for i in range(self.config.pretraining_tp)
            ]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [
                F.linear(hidden_states, key_slices[i])
                for i in range(self.config.pretraining_tp)
            ]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [
                F.linear(hidden_states, value_slices[i])
                for i in range(self.config.pretraining_tp)
            ]
            value_states = torch.cat(value_states, dim=-1)

        else:   # obtain qkv representations after heads
            query_states = self.q_proj(hidden_states)           # [bs, max_len, 4096]
            query_states_2 = self.q_proj_kg(hidden_states)     # [bs, max_len, 4096]
            key_states = self.k_proj(hidden_states)            # [bs, max_len, 1024]
            value_states = self.v_proj(hidden_states)          # [bs, max_len, 1024]

        query_states = query_states.view(                       # [bs, num_heads, max_len, head_dim]
            bsz, q_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        query_states_2 = query_states_2.view(                   # [bs, num_heads, max_len, head_dim]
            bsz, q_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        key_states = key_states.view(                           # [bs, num_key_value_heads, max_len, head_dim]
            bsz, q_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)
        value_states = value_states.view(                       # [bs, num_key_value_heads, max_len, head_dim]
            bsz, q_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)  # position embeddings
        query_states, key_states = apply_rotary_pos_emb(        # [bs, num_key_value_heads, max_len, head_dim]
            query_states, key_states, cos, sin
        )

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.layer_idx, cache_kwargs
            )

        key_states = repeat_kv(key_states, self.num_key_value_groups)  # [bs, num_heads, max_len, head_dim]
        value_states = repeat_kv(value_states, self.num_key_value_groups)  # [bs, num_heads, max_len, head_dim]
        kb_layer_frequency = kb_config.kb_layer_frequency

        if kb_kvs is not None:
            if self.layer_idx % kb_layer_frequency == 0:
                if kb_config.use_hierarchial_kv:
                    root_kb_keys, inter_kb_keys, leaf_kb_keys, kb_values, root_kb_indices, inter_kb_indices, leaf_kb_indices, root_c2id, inter_c2id = (
                        kb_kvs  # (kb_len, head_dim * num_heads * num_adapters)
                    )
                    kb_idx = (self.layer_idx // kb_layer_frequency)  # Should be something inside the kb config
                    if len(root_kb_keys.shape) == 2:  # Not batch dim
                        root_kb_len = root_kb_keys.shape[0]
                        inter_kb_len = inter_kb_keys.shape[0]
                        leaf_kb_len = leaf_kb_keys.shape[0]
                        # root keys transformation
                        root_kb_keys = root_kb_keys.reshape(
                            root_kb_len,
                            1 + self.config.num_hidden_layers // kb_layer_frequency,
                            -1,
                        )[:, kb_idx]
                        root_kb_keys = root_kb_keys.view(
                            root_kb_len, self.num_heads, self.head_dim
                        ).transpose(0, 1)
                        root_kb_keys = root_kb_keys.unsqueeze(0).expand(
                            bsz, self.num_heads, root_kb_len, self.head_dim
                        )
                        # inter keys transformation
                        inter_kb_keys = inter_kb_keys.reshape(
                            inter_kb_len,
                            1 + self.config.num_hidden_layers // kb_layer_frequency,
                            -1,
                        )[:, kb_idx]
                        inter_kb_keys = inter_kb_keys.view(
                            inter_kb_len, self.num_heads, self.head_dim
                        ).transpose(0, 1)
                        inter_kb_keys = inter_kb_keys.unsqueeze(0).expand(
                            bsz, self.num_heads, inter_kb_len, self.head_dim
                        )
                        # leaf keys transformation
                        leaf_kb_keys = leaf_kb_keys.reshape(
                            leaf_kb_len,
                            1 + self.config.num_hidden_layers // kb_layer_frequency,
                            -1,
                        )[:, kb_idx]
                        leaf_kb_keys = leaf_kb_keys.view(
                            leaf_kb_len, self.num_heads, self.head_dim
                        ).transpose(0, 1)
                        leaf_kb_keys = leaf_kb_keys.unsqueeze(0).expand(
                            bsz, self.num_heads, leaf_kb_len, self.head_dim
                        )
                        # values transformation
                        kb_values = kb_values.reshape(
                            leaf_kb_len,
                            1 + self.config.num_hidden_layers // kb_layer_frequency,
                            -1,
                        )[:, kb_idx]
                        kb_values = kb_values.view(
                            leaf_kb_len, self.num_heads, self.head_dim
                        ).transpose(0, 1)
                        kb_values = kb_values.unsqueeze(0).expand(
                            bsz, self.num_heads, leaf_kb_len, self.head_dim
                        )
                    elif len(root_kb_keys.shape) == 3:  # Has a batch dim
                        root_kb_len = root_kb_keys.shape[1]
                        inter_kb_len = inter_kb_keys.shape[1]
                        leaf_kb_len = leaf_kb_keys.shape[1]
                        # root keys transformation
                        root_kb_keys = root_kb_keys.view(
                            bsz,
                            root_kb_len,
                            1 + self.config.num_hidden_layers // kb_layer_frequency,
                            -1,
                        )[:, :, kb_idx]                 # [bs, root_kb_size, 4096]
                        root_kb_keys = root_kb_keys.view(
                            bsz, root_kb_len, self.num_heads, self.head_dim
                        ).transpose(1, 2)               # [bs, num_heads, root_kb_size, head_dim]
                        # inter keys transformation
                        inter_kb_keys = inter_kb_keys.view(
                            bsz,
                            inter_kb_len,
                            1 + self.config.num_hidden_layers // kb_layer_frequency,
                            -1,
                        )[:, :, kb_idx]                 # [bs, inter_kb_size, 4096]
                        inter_kb_keys = inter_kb_keys.view(
                            bsz, inter_kb_len, self.num_heads, self.head_dim
                        ).transpose(1, 2)               # [bs, num_heads, inter_kb_size, head_dim]
                        # leaf keys transformation
                        leaf_kb_keys = leaf_kb_keys.view(
                            bsz,
                            leaf_kb_len,
                            1 + self.config.num_hidden_layers // kb_layer_frequency,
                            -1,
                        )[:, :, kb_idx]                 # [bs, leaf_kb_size, 4096]
                        leaf_kb_keys = leaf_kb_keys.view(
                            bsz, leaf_kb_len, self.num_heads, self.head_dim
                        ).transpose(1, 2)               # [bs, num_heads, leaf_kb_size, head_dim]
                        # values transformation
                        kb_values = kb_values.view(
                            bsz,
                            leaf_kb_len,
                            1 + self.config.num_hidden_layers // kb_layer_frequency,
                            -1,
                        )[:, :, kb_idx]                 # [bs, kb_size, 4096]
                        kb_values = kb_values.view(
                            bsz, leaf_kb_len, self.num_heads, self.head_dim
                        ).transpose(1, 2)               # [bs, num_heads, kb_size, head_dim]
                    root_kb_len = root_kb_keys.shape[2]
                else:
                    kb_keys, kb_values = (
                        kb_kvs  # (kb_len, head_dim * num_heads * num_adapters)
                    )

                    kb_idx = (
                        self.layer_idx // kb_layer_frequency
                    )  # Should be something inside the kb config
                    if len(kb_keys.shape) == 2:  # Not batch dim
                        kb_len = kb_keys.shape[0]
                        kb_keys = kb_keys.reshape(
                            kb_len,
                            1 + self.config.num_hidden_layers // kb_layer_frequency,
                            -1,
                        )[:, kb_idx]
                        kb_values = kb_values.reshape(
                            kb_len,
                            1 + self.config.num_hidden_layers // kb_layer_frequency,
                            -1,
                        )[:, kb_idx]
                        kb_keys = kb_keys.view(
                            kb_len, self.num_heads, self.head_dim
                        ).transpose(0, 1)
                        kb_values = kb_values.view(
                            kb_len, self.num_heads, self.head_dim
                        ).transpose(0, 1)
                        kb_keys = kb_keys.unsqueeze(0).expand(
                            bsz, self.num_heads, kb_len, self.head_dim
                        )
                        kb_values = kb_values.unsqueeze(0).expand(
                            bsz, self.num_heads, kb_len, self.head_dim
                        )
                        # Append the KB keys and values in the front, in front of padding
                        # key_states = torch.concat([kb_keys, key_states], dim=2)
                        # value_states = torch.concat([kb_values, value_states], dim=2)
                    elif len(kb_keys.shape) == 3:  # Has a batch dim
                        kb_len = kb_keys.shape[1]
                        kb_keys = kb_keys.view(
                            bsz,
                            kb_len,
                            1 + self.config.num_hidden_layers // kb_layer_frequency,
                            -1,
                        )[:, :, kb_idx]                 # [bs, kb_size, 4096]
                        kb_values = kb_values.view(
                            bsz,
                            kb_len,
                            1 + self.config.num_hidden_layers // kb_layer_frequency,
                            -1,
                        )[:, :, kb_idx]                 # [bs, kb_size, 4096]
                        kb_keys = kb_keys.view(
                            bsz, kb_len, self.num_heads, self.head_dim
                        ).transpose(1, 2)               # [bs, num_heads, kb_size, head_dim]
                        kb_values = kb_values.view(
                            bsz, kb_len, self.num_heads, self.head_dim
                        ).transpose(1, 2)               # [bs, num_heads, kb_size, head_dim]
                    # Modify the attention matrix: Appendx a (seq_len, kb_len) block to the left
                    kb_len = kb_keys.shape[2]

        attn_weights = torch.matmul(
            query_states, key_states.transpose(2, 3)
        ) / math.sqrt(self.head_dim)                    # [bs, num_heads, max_len, max_len]
        sep_query_head = kb_config.sep_query_head
        kb_scale_factor = kb_config.kb_scale_factor
        if sep_query_head:
            if kb_kvs is not None:
                if self.layer_idx % kb_layer_frequency == 0:
                    if kb_config.use_hierarchial_kv:
                        if torch.is_grad_enabled():     # train mode
                            # hierarchial conditional probabilities
                            # make sure all keys are on GPU
                            root_kb_keys_gpu = root_kb_keys.to(query_states_2.device)   # [bs, num_heads, kb_len, head_dim]
                            inter_kb_keys_gpu = inter_kb_keys.to(query_states_2.device) # [bs, num_heads, kb_len, head_dim]
                            leaf_kb_keys_gpu = leaf_kb_keys.to(query_states_2.device)   # [bs, num_heads, kb_len, head_dim]
                            # calculate three-level attention logits, no pruning
                            attn_weights_2_root = torch.matmul(
                                query_states_2, root_kb_keys_gpu.transpose(2, 3)
                            ) / math.sqrt(self.head_dim)    # [bs, num_heads, max_len, kb_len]
                            attn_weights_2_inter = torch.matmul(
                                query_states_2, inter_kb_keys_gpu.transpose(2, 3)
                            ) / math.sqrt(self.head_dim)    # [bs, num_heads, max_len, kb_len]
                            attn_weights_2_leaf = torch.matmul(
                                query_states_2, leaf_kb_keys_gpu.transpose(2, 3)
                            ) / math.sqrt(self.head_dim)    # [bs, num_heads, max_len, kb_len]
                            # scale factor
                            if kb_scale_factor is not None:
                                kb_len = attn_weights_2.shape[-1]   # kb sizes of all layers are the same
                                attn_weights_2_root = (
                                    attn_weights_2_root - np.log(kb_len) + np.log(kb_scale_factor)
                                )
                                attn_weights_2_inter = (
                                    attn_weights_2_inter - np.log(kb_len) + np.log(kb_scale_factor)
                                )
                                attn_weights_2_leaf = (
                                    attn_weights_2_leaf - np.log(kb_len) + np.log(kb_scale_factor)
                                )
                            select_leaf_kb_values = kb_values
                        else:           # eval mode
                            #FIXME: device communication cost should be reduced
                            # calculate root layer logits, slow here
                            attn_weights_2_root = torch.matmul(
                                query_states_2, root_kb_keys.transpose(2, 3)
                            ) / math.sqrt(self.head_dim)    # [bs, num_heads, max_len, root_kb_len]
                            # prune root layer keys and get top-k indices
                            top_idx = self.prune_key(attn_weights_2_root, root_kb_indices, kb_config.root_top_k_kb)
                            select_inter_key_idx = []
                            for i in np.array(top_idx.cpu()):
                                select_inter_key_idx.extend(root_c2id[str(i)])
                            select_inter_key_idx = torch.tensor(select_inter_key_idx).to(query_states_2.device)
                            # ============offloading root layer keys and logits to CPU============
                            attn_weights_2_root_cpu = attn_weights_2_root.to('cpu', non_blocking=True)
                            root_kb_keys_cpu = root_kb_keys.to('cpu', non_blocking=True)
                            del attn_weights_2_root
                            del root_kb_keys
                            torch.cuda.empty_cache()
                            # ====================================================================
                            # select covered inter layer keys from CPU 
                            select_inter_kb_indices = np.intersect1d(select_inter_key_idx.cpu(), inter_kb_indices.cpu())
                            select_inter_local_key_idx = np.where(np.isin(inter_kb_indices, select_inter_kb_indices))[0]
                            select_inter_kb_keys_cpu = inter_kb_keys[:,:,select_inter_local_key_idx,:]
                            # ============uploading inter layer keys to GPU=======================
                            select_inter_kb_keys = select_inter_kb_keys_cpu.to(query_states_2.device, non_blocking=True)
                            # ====================================================================
                            # calculate inter layer logits
                            attn_weights_2_inter = torch.matmul(
                                query_states_2, select_inter_kb_keys.transpose(2, 3)
                            ) / math.sqrt(self.head_dim)    # [bs, num_heads, max_len, inter_kb_len]
                            # prune inter layer keys and get top-k indices
                            top_idx = self.prune_key(attn_weights_2_inter, inter_kb_indices, kb_config.inter_top_k_kb)
                            select_leaf_key_idx = []
                            for i in np.array(top_idx.cpu()):
                                select_leaf_key_idx.extend(inter_c2id[str(i)])
                            select_leaf_key_idx = torch.tensor(select_leaf_key_idx).to(query_states_2.device)
                            # ============offloading inter layer keys and logits to CPU============
                            attn_weights_2_inter_cpu = attn_weights_2_inter.to('cpu', non_blocking=True)
                            select_inter_kb_keys_cpu = select_inter_kb_keys.to('cpu', non_blocking=True)
                            del attn_weights_2_inter
                            del select_inter_kb_keys
                            torch.cuda.empty_cache()
                            # =====================================================================
                            # select covered leaf layer keys from CPU
                            select_leaf_kb_indices = np.intersect1d(select_leaf_key_idx.cpu(), leaf_kb_indices.cpu())
                            select_leaf_local_key_idx = np.where(np.isin(leaf_kb_indices, select_leaf_kb_indices))[0]
                            select_leaf_kb_keys_cpu = leaf_kb_keys[:,:,select_leaf_local_key_idx,:]
                            # ============uploading leaf layer keys to GPU========================
                            select_leaf_kb_keys = select_leaf_kb_keys_cpu.to(query_states_2.device, non_blocking=True)
                            # ====================================================================
                            # calculate leaf layer logits
                            attn_weights_2_leaf = torch.matmul(
                                query_states_2, select_leaf_kb_keys.transpose(2, 3)
                            ) / math.sqrt(self.head_dim)    # [bs, num_heads, max_len, leaf_kb_len]
                            # prune leaf layer keys and get top-k indices
                            top_idx = self.prune_key(attn_weights_2_leaf, None, kb_config.leaf_top_k_kb)
                            # get final leaf layer logits
                            select_attn_weights_2_leaf = attn_weights_2_leaf[:,:,:,top_idx]
                            select_leaf_kb_values_cpu = kb_values[:,:,top_idx,:]
                            # ============offloading before-pruned leaf layer keys and logits to CPU============
                            attn_weights_2_leaf_cpu = attn_weights_2_leaf.to('cpu', non_blocking=True)
                            select_leaf_kb_keys_cpu = select_leaf_kb_keys.to('cpu', non_blocking=True)
                            del attn_weights_2_leaf
                            del select_leaf_kb_keys
                            torch.cuda.empty_cache()
                            # ==================================================================================
                            # ============================uploading selected values to GPU======================
                            select_leaf_kb_values = select_leaf_kb_values_cpu.to(query_states_2.device, non_blocking=True)
                            # ==================================================================================
                            attn_weights_2 = select_attn_weights_2_leaf
                            if kb_scale_factor is not None:
                                kb_len = attn_weights_2.shape[-1]
                                attn_weights_2 = (
                                    attn_weights_2 - np.log(kb_len) + np.log(kb_scale_factor)
                                )
                    else:
                        attn_weights_2 = torch.matmul(
                            query_states_2, kb_keys.transpose(2, 3)
                        ) / math.sqrt(self.head_dim)
                        if kb_scale_factor is not None:
                            attn_weights_2 = (
                                attn_weights_2 - np.log(kb_len) + np.log(kb_scale_factor)
                            )
                        select_leaf_kb_values = kb_values

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask
        # upcast attention to fp32 for numerical stability
        attn_weights_original = attn_weights.clone()
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)

        if kb_kvs is not None and self.layer_idx % kb_layer_frequency == 0:
            if kb_config.use_hierarchial_kv and torch.is_grad_enabled():    # train mode and hierarchial mode enabled
                leaf_norm_exp_logtis = torch.exp(attn_weights_2_leaf - \
                    torch.max(torch.concat([attn_weights_original, attn_weights_2_leaf, attn_weights_2_inter, attn_weights_2_root], dim=-1), dim=-1, keepdim=True)[0])
                inter_norm_exp_logits = torch.exp(attn_weights_2_inter - \
                    torch.max(torch.concat([attn_weights_original, attn_weights_2_leaf, attn_weights_2_inter, attn_weights_2_root], dim=-1), dim=-1, keepdim=True)[0])
                root_norm_exp_logits = torch.exp(attn_weights_2_root - \
                    torch.max(torch.concat([attn_weights_original, attn_weights_2_leaf, attn_weights_2_inter, attn_weights_2_root], dim=-1), dim=-1, keepdim=True)[0])
                attn_weights_2_softmax_denominator = torch.sum(leaf_norm_exp_logtis, dim=-1, keepdim=True, dtype=torch.float32) + \
                    torch.sum(inter_norm_exp_logits, dim=-1, keepdim=True, dtype=torch.float32) + \
                    torch.sum(root_norm_exp_logits, dim=-1, keepdim=True, dtype=torch.float32)
                attn_weights_softmax_denominator = torch.sum(torch.exp(attn_weights_original - torch.max(torch.concat([attn_weights_original, attn_weights_2], dim=-1), dim=-1, keepdim=True)[0]), dim=-1, keepdim=True, dtype=torch.float32)                              
                # scale the attention weights
                alpha1 = attn_weights_softmax_denominator.to(torch.float32) / torch.sum(torch.concat([attn_weights_softmax_denominator.to(torch.float32), attn_weights_2_softmax_denominator.to(torch.float32)], dim=-1), dim=-1, keepdim=True, dtype=torch.float32)
                # scale the attention weights2
                alpha2 = attn_weights_2_softmax_denominator.to(torch.float32) / torch.sum(torch.concat([attn_weights_softmax_denominator.to(torch.float32), attn_weights_2_softmax_denominator.to(torch.float32)], dim=-1), dim=-1, keepdim=True, dtype=torch.float32)
                print("NO!!!!!!!!!!!!!!!")
            else:   # all eval mode and training without hierarchial mode goes here
                attn_weights_softmax_denominator = torch.sum(torch.exp(attn_weights_original - torch.max(torch.concat([attn_weights_original, attn_weights_2], dim=-1), dim=-1, keepdim=True)[0]), dim=-1, keepdim=True, dtype=torch.float32)
                attn_weights_2_softmax_denominator = torch.sum(torch.exp(attn_weights_2 - torch.max(torch.concat([attn_weights_original, attn_weights_2], dim=-1), dim=-1, keepdim=True)[0]), dim=-1, keepdim=True, dtype=torch.float32)
                attn_weights_2 = nn.functional.softmax(attn_weights_2, dim=-1, dtype=torch.float32)
                # scale the attention weights
                alpha1 = attn_weights_softmax_denominator.to(torch.float32) / torch.sum(torch.concat([attn_weights_softmax_denominator.to(torch.float32), attn_weights_2_softmax_denominator.to(torch.float32)], dim=-1), dim=-1, keepdim=True, dtype=torch.float32)
                # scale the attention weights2
                alpha2 = attn_weights_2_softmax_denominator.to(torch.float32) / torch.sum(torch.concat([attn_weights_softmax_denominator.to(torch.float32), attn_weights_2_softmax_denominator.to(torch.float32)], dim=-1), dim=-1, keepdim=True, dtype=torch.float32)

        if not attn_weights.requires_grad:
            # TODO: Make this function injectable
            if save_attention_weights:
                if q_len > 1:   # when during the decoding process, q_len is 1, so only save during the prefilling process
                    if kb_config.use_hierarchial_kv:
                        save_path = os.path.join(
                            attention_save_loc,
                            f"{attention_file_base_name}_{self.layer_idx}.npy",
                        )
                        if kb_kvs is not None and self.layer_idx % kb_layer_frequency == 0:
                            full_shape_attention_weights2 = torch.full((bsz, self.num_heads, attn_weights.shape[2], kb_values.shape[-2]), -1e9).to(torch.float32).to(query_states.device)
                            full_shape_attention_weights2[:,:,:,select_leaf_local_key_idx[top_idx]] = attn_weights_2
                            attn_weights_save = torch.concat([alpha2 * full_shape_attention_weights2, alpha1 * attn_weights], dim=-1)
                        else:
                            attn_weights_save = attn_weights
                        np.save(
                            save_path,
                            attn_weights_save.to(torch.float32).cpu().detach().numpy(),
                        )
                    else:
                        save_path = os.path.join(
                            attention_save_loc,
                            f"{attention_file_base_name}_{self.layer_idx}.npy",
                        )
                        if kb_kvs is not None and self.layer_idx % kb_layer_frequency == 0:
                            attn_weights_save = torch.concat([alpha2 * attn_weights_2, alpha1 * attn_weights], dim=-1)
                        else:
                            attn_weights_save = attn_weights
                        np.save(
                            save_path,
                            attn_weights_save.to(torch.float32).cpu().detach().numpy(),
                        )
        attn_weights = attn_weights.to(query_states.dtype)
        attn_weights = nn.functional.dropout(
            attn_weights, p=self.attention_dropout, training=self.training
        )
        attn_output = torch.matmul(attn_weights, value_states)              # [bs, num_heads, max_len, head_dim]
        if kb_kvs is not None and self.layer_idx % kb_layer_frequency == 0:
            attn_weights_2 = attn_weights_2.to(query_states.dtype)
            attn_weights_2 = nn.functional.dropout(
                attn_weights_2, p=self.attention_dropout, training=self.training
            )
            attn_output_2 = torch.matmul(attn_weights_2, select_leaf_kb_values)             # [bs, num_heads, max_len, head_dim]
            alpha1 = alpha1.to(query_states.dtype)
            alpha2 = alpha2.to(query_states.dtype)
            attn_output = alpha1 * attn_output + alpha2 * attn_output_2
            # attn_output = attn_output + attn_output_2

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(
                self.hidden_size // self.config.pretraining_tp, dim=2
            )
            o_proj_slices = self.o_proj.weight.split(
                self.hidden_size // self.config.pretraining_tp, dim=1
            )
            attn_output = sum(
                [
                    F.linear(attn_output[i], o_proj_slices[i])
                    for i in range(self.config.pretraining_tp)
                ]
            )
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class KblamLlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' implemented as Rectangular attention"""

    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.score_shift = nn.Parameter(torch.zeros(self.num_heads, 1) - 3)
        self.q_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
        )
        self.q_proj_new = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=config.attention_bias,
        )
        self.v_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=config.attention_bias,
        )
        self.o_proj = nn.Linear(
            self.hidden_size, self.hidden_size, bias=config.attention_bias
        )
        self._init_rope()

    def _init_rope(self):
        if self.config.rope_scaling is None:
            self.rotary_emb = LlamaRotaryEmbedding(
                self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.rope_theta,
            )
        else:
            scaling_type = self.config.rope_scaling["rope_type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "llama3":
                self.rotary_emb = LlamaRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    base=self.rope_theta,
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

    def prune_key_value(self, query, kb_keys, kb_values, topk_size=20):
        assert (
            query.requires_grad is False
        ), "This function should only be used at test time"
        batch_size, num_heads, kb_len, head_dim = kb_keys.shape
        attn_weights = torch.matmul(query, kb_keys.transpose(2, 3)) / math.sqrt(
            self.head_dim
        )  # Batchsize, num_heads, query_size, key_size
        if topk_size >= kb_len:
            return kb_keys, kb_values, attn_weights
        with torch.autograd.no_grad():
            top_idx = attn_weights.sum((1, 2)).topk(min(kb_len, topk_size), -1)[1]
            # top_idx = attn_weights.sum(1).topk(topk_size, -1)[1]
            top_idx = top_idx.view(batch_size, -1, topk_size, 1).expand(
                batch_size, num_heads, topk_size, head_dim
            )
            kb_keys = kb_keys.gather(-2, top_idx)
            kb_values = kb_values.gather(-2, top_idx)
        return kb_keys, kb_values, attn_weights[..., :topk_size]

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        kb_kvs: Optional[tuple] = None,
        kb_config: Optional[KBLaMConfig | KGLaMConfig] = None,
        save_attention_weights: bool = True,
        attention_save_loc: Optional[str] = None,
        attention_file_base_name: Optional[str] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if save_attention_weights:
            assert (
                attention_save_loc is not None
            ), "Please provide a location to save the attention weights"
            assert (
                attention_file_base_name is not None
            ), "Please provide a base name for the attention weights"
        bsz, q_len, _ = hidden_states.size()
        if self.config.pretraining_tp > 1:
            key_value_slicing = (
                self.num_key_value_heads * self.head_dim
            ) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [
                F.linear(hidden_states, query_slices[i])
                for i in range(self.config.pretraining_tp)
            ]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [
                F.linear(hidden_states, key_slices[i])
                for i in range(self.config.pretraining_tp)
            ]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [
                F.linear(hidden_states, value_slices[i])
                for i in range(self.config.pretraining_tp)
            ]
            value_states = torch.cat(value_states, dim=-1)

        else:   # obtain qkv representations after heads
            query_states = self.q_proj(hidden_states)           # [bs, max_len, 4096]
            query_states_2 = self.q_proj_new(hidden_states)     # [bs, max_len, 4096]
            key_states = self.k_proj(hidden_states)            # [bs, max_len, 1024]
            value_states = self.v_proj(hidden_states)          # [bs, max_len, 1024]

        query_states = query_states.view(                       # [bs, num_heads, max_len, head_dim]
            bsz, q_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        query_states_2 = query_states_2.view(                   # [bs, num_heads, max_len, head_dim]
            bsz, q_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        key_states = key_states.view(                           # [bs, num_key_value_heads, max_len, head_dim]
            bsz, q_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)
        value_states = value_states.view(                       # [bs, num_key_value_heads, max_len, head_dim]
            bsz, q_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)  # position embeddings
        query_states, key_states = apply_rotary_pos_emb(        # [bs, num_key_value_heads, max_len, head_dim]
            query_states, key_states, cos, sin
        )

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.layer_idx, cache_kwargs
            )

        key_states = repeat_kv(key_states, self.num_key_value_groups)  # [bs, num_heads, max_len, head_dim]
        value_states = repeat_kv(value_states, self.num_key_value_groups)  # [bs, num_heads, max_len, head_dim]
        kb_layer_frequency = kb_config.kb_layer_frequency
        dynamic_sparsify = kb_config.dynamic_sparsify
        topk_size = kb_config.top_k_kb
        attn_weights_2 = None
        if kb_kvs is not None:
            if self.layer_idx % kb_layer_frequency == 0:
                kb_keys, kb_values = (
                    kb_kvs  # (kb_len, head_dim * num_heads * num_adapters)
                )
                kb_idx = (
                    self.layer_idx // kb_layer_frequency
                )  # Should be something inside the kb config
                if len(kb_keys.shape) == 2:  # Not batch dim
                    kb_len = kb_keys.shape[0]
                    kb_keys = kb_keys.reshape(
                        kb_len,
                        1 + self.config.num_hidden_layers // kb_layer_frequency,
                        -1,
                    )[:, kb_idx]
                    kb_values = kb_values.reshape(
                        kb_len,
                        1 + self.config.num_hidden_layers // kb_layer_frequency,
                        -1,
                    )[:, kb_idx]
                    kb_keys = kb_keys.view(
                        kb_len, self.num_heads, self.head_dim
                    ).transpose(0, 1)
                    kb_values = kb_values.view(
                        kb_len, self.num_heads, self.head_dim
                    ).transpose(0, 1)
                    kb_keys = kb_keys.unsqueeze(0).expand(
                        bsz, self.num_heads, kb_len, self.head_dim
                    )
                    kb_values = kb_values.unsqueeze(0).expand(
                        bsz, self.num_heads, kb_len, self.head_dim
                    )
                    if dynamic_sparsify:
                        kb_keys, kb_values, attn_weights_2 = self.prune_key_value(
                            query_states_2, kb_keys, kb_values, topk_size
                        )
                    # Append the KB keys and values in the front, in front of padding
                    key_states = torch.concat([kb_keys, key_states], dim=2)
                    value_states = torch.concat([kb_values, value_states], dim=2)
                elif len(kb_keys.shape) == 3:  # Has a batch dim
                    kb_len = kb_keys.shape[1]
                    kb_keys = kb_keys.view(
                        bsz,
                        kb_len,
                        1 + self.config.num_hidden_layers // kb_layer_frequency,
                        -1,
                    )[:, :, kb_idx]                 # [bs, kb_size, 4096]
                    kb_values = kb_values.view(
                        bsz,
                        kb_len,
                        1 + self.config.num_hidden_layers // kb_layer_frequency,
                        -1,
                    )[:, :, kb_idx]                 # [bs, kb_size, 4096]
                    kb_keys = kb_keys.view(
                        bsz, kb_len, self.num_heads, self.head_dim
                    ).transpose(1, 2)               # [bs, num_heads, kb_size, head_dim]
                    kb_values = kb_values.view(
                        bsz, kb_len, self.num_heads, self.head_dim
                    ).transpose(1, 2)               # [bs, num_heads, kb_size, head_dim]
                    if dynamic_sparsify:
                        kb_keys, kb_values, attn_weights_2 = self.prune_key_value(
                            query_states_2, kb_keys, kb_values, topk_size
                        )
                    # Append the KB keys and values in the front, in front of padding
                    key_states = torch.concat([kb_keys, key_states], dim=2)         # [bs, num_heads, max_len + kb_size, head_dim]
                    value_states = torch.concat([kb_values, value_states], dim=2)   # [bs, num_heads, max_len + kb_size, head_dim]
                # Modify the attention matrix: Appendx a (seq_len, kb_len) block to the left
                kb_len = kb_keys.shape[2]
                kb_atten_mask = attention_mask.new_zeros(bsz, 1, q_len, kb_len)     # [bs, 1, max_len, kb_size]
                padding_mask = torch.all(
                    attention_mask < 0, -1, keepdim=True
                )  # (bsz, num_heads, q_len, 1)
                kb_atten_mask = (
                    padding_mask * PADDING_VALUE + (~padding_mask) * kb_atten_mask
                )
                attention_mask = torch.concat([kb_atten_mask, attention_mask], dim=-1)

        attn_weights = torch.matmul(
            query_states, key_states.transpose(2, 3)
        ) / math.sqrt(self.head_dim)                    # [bs, num_heads, max_len, max_len + kb_size]
        sep_query_head = kb_config.sep_query_head
        kb_scale_factor = kb_config.kb_scale_factor
        if sep_query_head:
            if kb_kvs is not None:
                if self.layer_idx % kb_layer_frequency == 0:
                    # If we have pruned the KB tokens, then this quantity should have been computed,
                    # if not, then we compute it here
                    if attn_weights_2 is None:
                        attn_weights_2 = torch.matmul(
                            query_states_2, kb_keys.transpose(2, 3)
                        ) / math.sqrt(self.head_dim)
                    attn_weights = attn_weights[:, :, :, kb_len:]
                    if kb_scale_factor is not None:
                        attn_weights_2 = (
                            attn_weights_2 - np.log(kb_len) + np.log(kb_scale_factor)
                        )
                    attn_weights = torch.concat([attn_weights_2, attn_weights], -1)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask
        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
        if not attn_weights.requires_grad:
            # TODO: Make this function injectable
            if save_attention_weights:
                if q_len > 1:
                    save_path = os.path.join(
                        attention_save_loc,
                        f"{attention_file_base_name}_{self.layer_idx}.npy",
                    )
                    np.save(
                        save_path,
                        attn_weights.to(torch.float32).cpu().detach().numpy(),
                    )
        attn_weights = attn_weights.to(query_states.dtype)
        attn_weights = nn.functional.dropout(
            attn_weights, p=self.attention_dropout, training=self.training
        )
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(
                self.hidden_size // self.config.pretraining_tp, dim=2
            )
            o_proj_slices = self.o_proj.weight.split(
                self.hidden_size // self.config.pretraining_tp, dim=1
            )
            attn_output = sum(
                [
                    F.linear(attn_output[i], o_proj_slices[i])
                    for i in range(self.config.pretraining_tp)
                ]
            )
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = get_current_llama_attention_classes()[config._attn_implementation](
            config=config, layer_idx=layer_idx
        )

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        kb_kvs: Optional[tuple] = None,
        kb_config: Optional[KBLaMConfig | KGLaMConfig] = None,
        save_attention_weights: bool = False,
        attention_save_loc: Optional[str] = None,
        attention_file_base_name: Optional[str] = None,
    ) -> Tuple[
        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
    ]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*):
                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
                query_sequence_length, key_sequence_length)` if default attention is used.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Rectangular Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            kb_kvs=kb_kvs,
            kb_config=kb_config,
            save_attention_weights=save_attention_weights,
            attention_save_loc=attention_save_loc,
            attention_file_base_name=attention_file_base_name,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs


@add_start_docstrings(
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
    LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.hidden_size, self.padding_idx
        )
        self.layers = nn.ModuleList(
            [
                LlamaDecoderLayer(config, layer_idx)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        kb_kvs: Optional[tuple] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        kb_config: Optional[KBLaMConfig | KGLaMConfig] = None,
        save_attention_weights: bool = False,
        attention_save_loc: Optional[str] = None,
        attention_file_base_name: Optional[str] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)    # [bs, max_len, 4096]

        return_legacy_cache = False
        if use_cache and not isinstance(
            past_key_values, Cache
        ):  # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = True
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)

        if cache_position is None:
            past_seen_tokens = (
                past_key_values.get_seq_length() if past_key_values is not None else 0
            )
            cache_position = torch.arange(
                past_seen_tokens,
                past_seen_tokens + inputs_embeds.shape[1],
                device=inputs_embeds.device,
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask,
            inputs_embeds,
            cache_position,
            past_key_values,
            output_attentions,
        )

        # embed positions
        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    kb_kvs,
                    kb_config,
                    save_attention_weights,
                    attention_save_loc,
                    attention_file_base_name,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    kb_kvs=kb_kvs,
                    kb_config=kb_config,
                    save_attention_weights=save_attention_weights,
                    attention_save_loc=attention_save_loc,
                    attention_file_base_name=attention_file_base_name,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
                if v is not None
            )
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
    ):
        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and 0.0 in attention_mask:
                return attention_mask
            return None

        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
        # to infer the attention mask.
        past_seen_tokens = (
            past_key_values.get_seq_length() if past_key_values is not None else 0
        )
        using_static_cache = isinstance(past_key_values, StaticCache)

        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
        if (
            self.config._attn_implementation == "sdpa"
            and not using_static_cache
            and not output_attentions
        ):
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=self.training,
            ):
                return None

        dtype, device = input_tensor.dtype, input_tensor.device
        min_dtype = torch.finfo(dtype).min
        sequence_length = input_tensor.shape[1]
        if using_static_cache:
            target_length = past_key_values.get_max_length()
        else:
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        if attention_mask is not None and attention_mask.dim() == 4:
            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
            if attention_mask.max() != 0:
                raise ValueError(
                    "Custom 4D attention mask should be passed in inverted form with max==0`"
                )
            causal_mask = attention_mask
        else:
            causal_mask = torch.full(                   # [max_len, max_len]
                (sequence_length, target_length),
                fill_value=min_dtype,
                dtype=dtype,
                device=device,
            )
            if sequence_length != 1:
                causal_mask = torch.triu(causal_mask, diagonal=1)   # upper triangle
            causal_mask *= torch.arange(
                target_length, device=device
            ) > cache_position.reshape(-1, 1)
            causal_mask = causal_mask[None, None, :, :].expand(     # [bs, 1, max_len, max_len]
                input_tensor.shape[0], 1, -1, -1
            )
            if attention_mask is not None:
                causal_mask = (
                    causal_mask.clone()
                )  # copy to contiguous memory for in-place edit
                mask_length = attention_mask.shape[-1]
                padding_mask = (
                    causal_mask[:, :, :, :mask_length]
                    + attention_mask[:, None, None, :]
                )
                padding_mask = padding_mask == 0
                causal_mask[:, :, :, :mask_length] = causal_mask[
                    :, :, :, :mask_length
                ].masked_fill(padding_mask, min_dtype)
        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type == "cuda"
            and not output_attentions
        ):
            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
            # Details: https://github.com/pytorch/pytorch/issues/110213
            causal_mask = AttentionMaskConverter._unmask_unattended(
                causal_mask, min_dtype
            )

        return causal_mask


class KblamLlamaForCausalLM(LlamaPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        base_model_name_or_path = (
            config.base_model_name_or_path
            if hasattr(config, "base_model_name_or_path")
            else config._name_or_path
        )
        print(base_model_name_or_path)
        self.model = LlamaModel.from_pretrained(
            base_model_name_or_path, torch_dtype=config.torch_dtype
        )
        self.vocab_size = self.model.config.vocab_size
        self.lm_head = nn.Linear(
            self.model.config.hidden_size, self.model.config.vocab_size, bias=False
        )

        # Initialize weights and apply final processing
        self.post_init()

        if config._attn_implementation == "flash_attention_2":
            raise NotImplementedError(
                "Flash Attention 2 is not yet supported for KBLaM."
            )

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def get_kblam_config(self):
        return self.config

    def set_kblam_config(self, config):
        self.config = config

    def update_generation_config(self, tokenizer):
        self.generation_config.pad_token_id = tokenizer.pad_token_id
        self.generation_config.eos_token_id = tokenizer.eos_token_id

    def load_query_head(self, ckpt_dir):
        learned_query_heads = torch.load(ckpt_dir)
        assert len(learned_query_heads) == self.model.config.num_hidden_layers
        for i, attn_layer in enumerate(self.model.layers):
            attn_layer.self_attn.q_proj_new.load_state_dict(
                learned_query_heads[f"layer_{i}"]
            )
        self.config.sep_query_head = True

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(
        output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
    )
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        kb_kvs: Optional[tuple] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        kb_config: Optional[KBLaMConfig | KGLaMConfig] = None,
        save_attention_weights: bool = False,
        attention_save_loc: Optional[str] = None,
        attention_file_base_name: Optional[str] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""

        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            kb_kvs=kb_kvs,              # ([bs, kb_size+entities, embd_dim], [bs, kb_size+entities, embd_dim])
            kb_config=kb_config,
            save_attention_weights=save_attention_weights,
            attention_save_loc=attention_save_loc,
            attention_file_base_name=attention_file_base_name,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(
                self.vocab_size // self.config.pretraining_tp, dim=0
            )
            logits = [
                F.linear(hidden_states, lm_head_slices[i])
                for i in range(self.config.pretraining_tp)
            ]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        cache_position=None,
        use_cache=True,
        kb_kvs: Optional[tuple] = None,
        kb_config: Optional[KBLaMConfig | KGLaMConfig] = None,
        **kwargs,
    ):
        past_length = 0
        if past_key_values is not None:
            if isinstance(past_key_values, Cache):
                past_length = (
                    cache_position[0]
                    if cache_position is not None
                    else past_key_values.get_seq_length()
                )
                max_cache_length = (
                    torch.tensor(
                        past_key_values.get_max_length(), device=input_ids.device
                    )
                    if past_key_values.get_max_length() is not None
                    else None
                )
                cache_length = (
                    past_length
                    if max_cache_length is None
                    else torch.min(max_cache_length, past_length)
                )
            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
            else:
                cache_length = past_length = past_key_values[0][0].shape[2]
                max_cache_length = None

            # Keep only the unprocessed tokens:
            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
            if (
                attention_mask is not None
                and attention_mask.shape[1] > input_ids.shape[1]
            ):
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
            # input_ids based on the past_length.
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]
            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length
            ):
                attention_mask = attention_mask[:, -max_cache_length:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        model_inputs = copy.copy(kwargs)
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs["inputs_embeds"] = inputs_embeds
        else:
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
            # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
            # TODO: use `next_tokens` directly instead.
            model_inputs["input_ids"] = input_ids.contiguous()

        input_length = (
            position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
        )
        if cache_position is None:
            cache_position = torch.arange(
                past_length, past_length + input_length, device=input_ids.device
            )
        elif use_cache:
            cache_position = cache_position[-input_length:]

        model_inputs.update(
            {
                "position_ids": position_ids,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "use_cache": use_cache,
                "attention_mask": attention_mask,
                "kb_kvs": kb_kvs,
                "kb_config": kb_config,
            }
        )
        return model_inputs

    # def save_pretrained(
    #     self,
    #     save_directory: Union[str, os.PathLike],
    #     is_main_process: bool = True,
    #     state_dict: Optional[dict] = None,
    #     **kwargs,
    # ):
    #     """ Save the learned query heads in the model. The rest of the weight can be loaded from the base model. """
    #     if state_dict is not None:
    #         super().save_pretrained(save_directory, is_main_process, state_dict, **kwargs)

    #     else:
    #         state_dict = self.state_dict()

    #         new_state_dict = {}
    #         for param in state_dict.keys():
    #             if "q_proj_new" not in param:
    #                 pass
    #             else:
    #                 new_state_dict[param] = state_dict[param]
    #         super().save_pretrained(save_directory, is_main_process, new_state_dict, **kwargs)

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(
                    past_state.index_select(0, beam_idx.to(past_state.device))
                    for past_state in layer_past
                ),
            )
        return reordered_past


class KglamLlamaForCausalLM(LlamaPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        base_model_name_or_path = (
            config.base_model_name_or_path
            if hasattr(config, "base_model_name_or_path")
            else config._name_or_path
        )
        print(base_model_name_or_path)
        self.model = LlamaModel.from_pretrained(
            base_model_name_or_path, torch_dtype=config.torch_dtype
        )
        self.vocab_size = self.model.config.vocab_size
        self.lm_head = nn.Linear(
            self.model.config.hidden_size, self.model.config.vocab_size, bias=False
        )

        # Initialize weights and apply final processing
        self.post_init()

        if config._attn_implementation == "flash_attention_2":
            raise NotImplementedError(
                "Flash Attention 2 is not yet supported for KBLaM."
            )

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def get_kblam_config(self):
        return self.config

    def set_kblam_config(self, config):
        self.config = config

    def update_generation_config(self, tokenizer):
        self.generation_config.pad_token_id = tokenizer.pad_token_id
        self.generation_config.eos_token_id = tokenizer.eos_token_id

    def load_query_head(self, ckpt_dir):
        learned_query_heads = torch.load(ckpt_dir)
        assert len(learned_query_heads) == self.model.config.num_hidden_layers
        for i, attn_layer in enumerate(self.model.layers):
            attn_layer.self_attn.q_proj_kg.load_state_dict(
                learned_query_heads[f"layer_{i}"]
            )
        self.config.sep_query_head = True

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(
        output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
    )
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        kb_kvs: Optional[tuple] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        kb_config: Optional[KBLaMConfig | KGLaMConfig] = None,
        save_attention_weights: bool = False,
        attention_save_loc: Optional[str] = None,
        attention_file_base_name: Optional[str] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""

        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            kb_kvs=kb_kvs,              # ([bs, kb_size+entities, embd_dim], [bs, kb_size+entities, embd_dim])
            kb_config=kb_config,
            save_attention_weights=save_attention_weights,
            attention_save_loc=attention_save_loc,
            attention_file_base_name=attention_file_base_name,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(
                self.vocab_size // self.config.pretraining_tp, dim=0
            )
            logits = [
                F.linear(hidden_states, lm_head_slices[i])
                for i in range(self.config.pretraining_tp)
            ]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        cache_position=None,
        use_cache=True,
        kb_kvs: Optional[tuple] = None,
        kb_config: Optional[KBLaMConfig | KGLaMConfig] = None,
        **kwargs,
    ):
        past_length = 0
        if past_key_values is not None:
            if isinstance(past_key_values, Cache):
                past_length = (
                    cache_position[0]
                    if cache_position is not None
                    else past_key_values.get_seq_length()
                )
                max_cache_length = (
                    torch.tensor(
                        past_key_values.get_max_length(), device=input_ids.device
                    )
                    if past_key_values.get_max_length() is not None
                    else None
                )
                cache_length = (
                    past_length
                    if max_cache_length is None
                    else torch.min(max_cache_length, past_length)
                )
            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
            else:
                cache_length = past_length = past_key_values[0][0].shape[2]
                max_cache_length = None

            # Keep only the unprocessed tokens:
            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
            if (
                attention_mask is not None
                and attention_mask.shape[1] > input_ids.shape[1]
            ):
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
            # input_ids based on the past_length.
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]
            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length
            ):
                attention_mask = attention_mask[:, -max_cache_length:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        model_inputs = copy.copy(kwargs)
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs["inputs_embeds"] = inputs_embeds
        else:
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
            # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
            # TODO: use `next_tokens` directly instead.
            model_inputs["input_ids"] = input_ids.contiguous()

        input_length = (
            position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
        )
        if cache_position is None:
            cache_position = torch.arange(
                past_length, past_length + input_length, device=input_ids.device
            )
        elif use_cache:
            cache_position = cache_position[-input_length:]

        model_inputs.update(
            {
                "position_ids": position_ids,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "use_cache": use_cache,
                "attention_mask": attention_mask,
                "kb_kvs": kb_kvs,
                "kb_config": kb_config,
            }
        )
        return model_inputs

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(
                    past_state.index_select(0, beam_idx.to(past_state.device))
                    for past_state in layer_past
                ),
            )
        return reordered_past
