from transformers import LlamaModel
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# 
# transformers-4.52.4
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Tuple, Union, Dict
import numpy as np
import torch
import random
import torch.utils.checkpoint
import math
from torch import nn
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
from transformers import LlamaConfig
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)

if is_torch_flex_attn_available():
    from torch.nn.attention.flex_attention import BlockMask

    from transformers.integrations.flex_attention import make_flex_block_causal_mask

from transformers.integrations import use_kernel_forward_from_hub

logger = logging.get_logger(__name__)
from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding, ALL_ATTENTION_FUNCTIONS

from models.modeling_outputs import (
    BaseModelOutputWithPastAndPruning,
    CausalLMOutputWithPastAndPruning
)

# from kan import KANLayer

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class LlamaMLP(nn.Module):
    def __init__(self, config, num_extra_neurons, pruning_ratio=0.5, input_dependent=False, lora=False):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        self.act_fn = ACT2FN[config.hidden_act]
        self.num_extra_neurons = num_extra_neurons
        self.pruning_ratio = pruning_ratio
        
        self.gate_bias = nn.ParameterList([nn.Parameter(torch.empty(self.intermediate_size)) for _ in range(num_extra_neurons)])
        self.up_bias = nn.ParameterList([nn.Parameter(torch.empty(self.intermediate_size)) for _ in range(num_extra_neurons)])
        self.down_bias = nn.ParameterList([nn.Parameter(torch.empty(self.hidden_size)) for _ in range(num_extra_neurons)])
        
        self.input_dependent = input_dependent
        if self.input_dependent:
            self.gate_act = nn.Linear(self.hidden_size // 2, self.num_extra_neurons, bias=False)
            self.up_act = nn.Linear(self.hidden_size // 2, self.num_extra_neurons, bias=False)
            self.down_act = nn.Linear(self.intermediate_size // 2, self.num_extra_neurons, bias=False)
        else:
            self.gate_act = None
            self.up_act = None
            self.down_act = None
            
        self.ensemble = False
        self.lora = lora
        if self.lora:
            self.gate_A = nn.Linear(self.hidden_size, self.num_extra_neurons, bias=False)
            self.gate_B = nn.Linear(self.num_extra_neurons, self.intermediate_size, bias=False)
            
            self.up_A = nn.Linear(self.hidden_size, self.num_extra_neurons, bias=False)
            self.up_B = nn.Linear(self.num_extra_neurons, self.intermediate_size, bias=False)
            
            self.down_A = nn.Linear(self.intermediate_size, self.num_extra_neurons, bias=False)
            self.down_B = nn.Linear(self.num_extra_neurons, self.hidden_size, bias=False)
            
            self.gate_firing = None
            self.up_firing = None
            self.down_firing = None
        else:
            self.gate_firing = nn.Parameter(torch.empty(self.num_extra_neurons, self.hidden_size))
            self.up_firing = nn.Parameter(torch.empty(self.num_extra_neurons, self.hidden_size))
            self.down_firing = nn.Parameter(torch.empty(self.num_extra_neurons, self.intermediate_size))
            
            self.gate_A = None
            self.gate_B = None
            
            self.up_A = None
            self.up_B = None
            
            self.down_A = None
            self.down_B = None
    
    def set_pruning_ratio(self, pruning_ratio):
        self.pruning_ratio = pruning_ratio
        
    def set_bias(self):
        self.gate_proj.bias.data = nn.Parameter(self.gate_proj(self.gate_firing))
        self.up_proj.bias.data = nn.Parameter(self.up_proj(self.up_firing))
        self.down_proj.bias.data = nn.Parameter(self.down_proj(self.down_firing))
    
    def lora_initialization(self):
        
        self.gate_A = nn.Linear(self.hidden_size, self.num_extra_neurons, bias=False)
        self.gate_B = nn.Linear(self.num_extra_neurons, self.intermediate_size, bias=False)
        nn.init.zeros_(self.gate_B.weight)
        
        self.up_A = nn.Linear(self.hidden_size, self.num_extra_neurons, bias=False)
        self.up_B = nn.Linear(self.num_extra_neurons, self.intermediate_size, bias=False)
        nn.init.zeros_(self.up_B.weight)
        
        self.down_A = nn.Linear(self.intermediate_size, self.num_extra_neurons, bias=False)
        self.down_B = nn.Linear(self.num_extra_neurons, self.hidden_size, bias=False)
        nn.init.zeros_(self.down_B.weight)
    
    def bias_initialization(self):
        
        if self.input_dependent:
            self.gate_act = nn.Linear(self.hidden_size // 2, self.num_extra_neurons, bias=False)
            self.up_act = nn.Linear(self.hidden_size // 2, self.num_extra_neurons, bias=False)
            self.down_act = nn.Linear(self.intermediate_size // 2, self.num_extra_neurons, bias=False)
            
        try:
            for i in range(self.num_extra_neurons):
                nn.init.zeros_(self.gate_bias[i].data)
                nn.init.zeros_(self.up_bias[i].data)
                nn.init.zeros_(self.down_bias[i].data)
        except:
            pass
        
        try:
            nn.init.zeros_(self.gate_firing.data)
            nn.init.zeros_(self.up_firing.data)
            nn.init.zeros_(self.down_firing.data)
        except:
            pass
    
    def enable_ensemble(self):
        self.ensemble = True
        
    def column_masking(self, input, layer, bias=None, act=None, firing=None, A=None, B=None):
        
        _, indices = torch.sort(torch.abs(input), dim=-1, descending=True)
        masked_indices = indices[:,:,int(indices.size(-1) * self.pruning_ratio):]
        
        mask = torch.ones_like(input)
        mask.scatter_(dim=2, index=masked_indices, value=False)
        if self.lora:
            sparse_input = input * mask
            output = layer(sparse_input) + B(A(sparse_input))
            orthogonal_loss = 0
        else:
            if self.training and self.num_extra_neurons > 1:
                shape = (input.shape[0], input.shape[1], self.num_extra_neurons)
                alpha = F.softmax(torch.randn(shape), dim=-1).to(input.device)
                sparse_input = input * mask + alpha @ firing
                
                similarity = F.cosine_similarity(firing.unsqueeze(1), firing.unsqueeze(0), dim=-1) / 0.01
                label = torch.arange(self.num_extra_neurons).to(similarity.device)
                orthogonal_loss = F.cross_entropy(similarity, label)
            else:
                sparse_input = input * mask + firing.mean(dim=0, keepdim=True)
                
                orthogonal_loss = 0

            output = layer(sparse_input)
            
        return output, orthogonal_loss

    def forward(self, x, y=None):
        
        if y is not None:
            gate_proj_y, up_proj_y = self.gate_proj(y), self.up_proj(y)
            y = self.act_fn(gate_proj_y) * up_proj_y
            down_proj_y = self.down_proj(y)
        else:
            down_proj_y = None
            
        gate_proj_x, gate_orthogonal = self.column_masking(
            x, 
            self.gate_proj, 
            self.gate_bias, 
            self.gate_act,
            self.gate_firing,
            self.gate_A,
            self.gate_B
        )
        up_proj_x, up_orthogonal = self.column_masking(
            x, 
            self.up_proj, 
            self.up_bias, 
            self.up_act,
            self.up_firing,
            self.up_A,
            self.up_B
        )
        x = self.act_fn(gate_proj_x) * up_proj_x
        down_proj_x, down_orthogonal = self.column_masking(
            x, 
            self.down_proj, 
            self.down_bias, 
            self.down_act,
            self.down_firing,
            self.down_A,
            self.down_B
        )
        
        orthogonal_loss = gate_orthogonal + up_orthogonal + down_orthogonal
        
        semantic_loss = 0
        if self.training and y is not None:
            gate_loss = F.mse_loss(
                gate_proj_x.reshape(-1, gate_proj_x.size(-1)),
                gate_proj_y.reshape(-1, gate_proj_y.size(-1)),
            )
            up_loss = F.mse_loss(
                up_proj_x.reshape(-1, up_proj_x.size(-1)),
                up_proj_y.reshape(-1, up_proj_y.size(-1)),
            )
            down_loss = F.mse_loss(
                down_proj_x.reshape(-1, down_proj_x.size(-1)),
                down_proj_y.reshape(-1, down_proj_y.size(-1)),
            )
            
            # gate_loss = wasserstein_loss(gate_proj_x, gate_proj_y, dim=-1)
            # up_loss = wasserstein_loss(up_proj_x, up_proj_y, dim=-1)
            # down_loss = wasserstein_loss(down_proj_x, down_proj_y, dim=-1)
            
            # gate_loss = randomized_ot_upperbound(gate_proj_x, gate_proj_y)
            # up_loss = randomized_ot_upperbound(up_proj_x, up_proj_y)
            # down_loss = randomized_ot_upperbound(down_proj_x, down_proj_y)
        else:
            gate_loss, up_loss, down_loss = 0, 0, 0
        
        semantic_loss = gate_loss + up_loss + down_loss
        # if orthogonal_loss > semantic_loss:
        #     orthogonal_loss *= semantic_loss.detach().item() / orthogonal_loss.detach().item()
        
        return down_proj_x, semantic_loss + orthogonal_loss, down_proj_y          
          

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self, 
        config: LlamaConfig, 
        layer_idx: int, 
        num_extra_neurons: int, 
        pruning_ratio: float = 0.5,
        input_dependent: bool = False,
        lora: bool = False,
    ):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = True

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
        )
        self.q_bias = nn.ParameterList([nn.Parameter(torch.empty(config.num_attention_heads * self.head_dim)) for _ in range(num_extra_neurons)])
        self.k_bias = nn.ParameterList([nn.Parameter(torch.empty(config.num_key_value_heads * self.head_dim)) for _ in range(num_extra_neurons)])
        self.v_bias = nn.ParameterList([nn.Parameter(torch.empty(config.num_key_value_heads * self.head_dim)) for _ in range(num_extra_neurons)])
        self.o_bias = nn.ParameterList([nn.Parameter(torch.empty(config.hidden_size)) for _ in range(num_extra_neurons)])
        
        self.num_extra_neurons = num_extra_neurons
        self.lora = lora
        if self.lora:
            self.q_A = nn.Linear(config.hidden_size, self.num_extra_neurons, bias=False)
            self.q_B = nn.Linear(self.num_extra_neurons, config.num_attention_heads * self.head_dim, bias=False)
            
            self.k_A = nn.Linear(config.hidden_size, self.num_extra_neurons, bias=False)
            self.k_B = nn.Linear(self.num_extra_neurons, config.num_key_value_heads * self.head_dim, bias=False)
            
            self.v_A = nn.Linear(config.hidden_size, self.num_extra_neurons, bias=False)
            self.v_B = nn.Linear(self.num_extra_neurons, config.num_key_value_heads * self.head_dim, bias=False)
            
            self.o_A = nn.Linear(config.num_attention_heads * self.head_dim, self.num_extra_neurons, bias=False)
            self.o_B = nn.Linear(self.num_extra_neurons, config.hidden_size, bias=False)
            
            self.q_firing = None 
            self.k_firing = None
            self.v_firing = None
            self.o_firing = None
        else:
                
            self.q_firing = nn.Parameter(torch.empty(self.num_extra_neurons, config.hidden_size)) 
            self.k_firing = nn.Parameter(torch.empty(self.num_extra_neurons, config.hidden_size)) 
            self.v_firing = nn.Parameter(torch.empty(self.num_extra_neurons, config.hidden_size)) 
            self.o_firing = nn.Parameter(torch.empty(self.num_extra_neurons, config.num_attention_heads * self.head_dim)) 
            
            self.q_A = None
            self.q_B = None
            
            self.k_A = None
            self.k_B = None
            
            self.v_A = None
            self.v_B = None
            
            self.o_A = None
            self.o_B = None
        
        self.num_extra_neurons = num_extra_neurons
        self.pruning_ratio = pruning_ratio
        self.input_dependent = input_dependent
        if self.input_dependent:
            self.q_act = nn.Linear(config.hidden_size // 2, self.num_extra_neurons, bias=False)
            self.k_act = nn.Linear(config.hidden_size // 2, self.num_extra_neurons, bias=False)
            self.v_act = nn.Linear(config.hidden_size // 2, self.num_extra_neurons, bias=False)
            self.o_act = nn.Linear(config.num_attention_heads * self.head_dim // 2, self.num_extra_neurons, bias=False)
        else:
            self.q_act = None
            self.k_act = None
            self.v_act = None
            self.o_act = None
            
        self.ensemble = False
        
    def bias_initialization(self):
        if self.input_dependent:
            self.q_act = nn.Linear(self.config.hidden_size // 2, self.num_extra_neurons, bias=False)
            self.k_act = nn.Linear(self.config.hidden_size // 2, self.num_extra_neurons, bias=False)
            self.v_act = nn.Linear(self.config.hidden_size // 2, self.num_extra_neurons, bias=False)
            self.o_act = nn.Linear(self.config.num_attention_heads * self.head_dim // 2, self.num_extra_neurons, bias=False)
          
        try:
            for i in range(self.num_extra_neurons):
                nn.init.zeros_(self.q_bias[i].data)
                nn.init.zeros_(self.k_bias[i].data)
                nn.init.zeros_(self.v_bias[i].data)
                nn.init.zeros_(self.o_bias[i].data)
        except:
            pass
        
        try:
            nn.init.zeros_(self.q_firing.data)
            nn.init.zeros_(self.k_firing.data)
            nn.init.zeros_(self.v_firing.data)
            nn.init.zeros_(self.o_firing.data)
        except:
            pass
            
    def set_pruning_ratio(self, pruning_ratio):
        self.pruning_ratio = pruning_ratio
        
    def set_bias(self):
        self.q_proj.bias.data = nn.Parameter(self.q_proj(self.q_firing))
        self.k_proj.bias.data = nn.Parameter(self.k_proj(self.k_firing))
        self.v_proj.bias.data = nn.Parameter(self.v_proj(self.v_firing))
        self.o_proj.bias.data = nn.Parameter(self.o_proj(self.o_firing))
        
    def lora_initialization(self):
        self.q_A = nn.Linear(self.config.hidden_size, self.num_extra_neurons, bias=False)
        self.q_B = nn.Linear(self.num_extra_neurons, self.config.num_attention_heads * self.head_dim, bias=False)
        nn.init.zeros_(self.q_B.weight)
        
        self.k_A = nn.Linear(self.config.hidden_size, self.num_extra_neurons, bias=False)
        self.k_B = nn.Linear(self.num_extra_neurons, self.config.num_key_value_heads * self.head_dim, bias=False)
        nn.init.zeros_(self.k_B.weight)
        
        self.v_A = nn.Linear(self.config.hidden_size, self.num_extra_neurons, bias=False)
        self.v_B = nn.Linear(self.num_extra_neurons, self.config.num_key_value_heads * self.head_dim, bias=False)
        nn.init.zeros_(self.v_B.weight)
        
        self.o_A = nn.Linear(self.config.num_attention_heads * self.head_dim, self.num_extra_neurons, bias=False)
        self.o_B = nn.Linear(self.num_extra_neurons, self.config.hidden_size, bias=False)
        nn.init.zeros_(self.o_B.weight)
            
    def enable_ensemble(self):
        self.ensemble = True
        
    def column_masking(self, input, layer, bias=None, act=None, firing=None, A=None, B=None):
        
        _, indices = torch.sort(torch.abs(input), dim=-1, descending=True)
        masked_indices = indices[:,:,int(indices.size(-1) * self.pruning_ratio):]
        
        mask = torch.ones_like(input)
        mask.scatter_(dim=2, index=masked_indices, value=False)
        if self.lora:
            sparse_input = input * mask
            output = layer(sparse_input) + B(A(sparse_input))
            orthogonal_loss = 0
        else:
            if self.training and self.num_extra_neurons > 1:
                shape = (input.shape[0], input.shape[1], self.num_extra_neurons)
                alpha = F.softmax(torch.randn(shape), dim=-1).to(input.device)
                sparse_input = input * mask + alpha @ firing
                
                similarity = F.cosine_similarity(firing.unsqueeze(1), firing.unsqueeze(0), dim=-1) / 0.01
                label = torch.arange(self.num_extra_neurons).to(similarity.device)
                orthogonal_loss = F.cross_entropy(similarity, label)
            
            else:
                sparse_input = input * mask + firing.mean(dim=0, keepdim=True)
                
                orthogonal_loss = 0

            output = layer(sparse_input)
            
        return output, orthogonal_loss
    
    def forward(
        self,
        pruned_states: torch.Tensor,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        pruned_past_key_value: Optional[Cache] = None,
        hidden_past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        input_shape = pruned_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)
        
        #########################Column Masking###############################
        if hidden_states is not None:
            hidden_query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            hidden_key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            hidden_value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        else:
            hidden_query_states, hidden_key_states, hidden_value_states = None, None, None
        
        pruned_query_states, query_orthogonal = self.column_masking(
            pruned_states, self.q_proj, self.q_bias, self.q_act, self.q_firing, self.q_A, self.q_B
        )
        pruned_query_states = pruned_query_states.view(hidden_shape).transpose(1, 2)
        pruned_key_states, key_orthogonal = self.column_masking(
            pruned_states, self.k_proj, self.k_bias, self.k_act, self.k_firing, self.k_A, self.k_B
        )
        pruned_key_states = pruned_key_states.view(hidden_shape).transpose(1, 2)
        pruned_value_states, value_orthogonal = self.column_masking(
            pruned_states, self.v_proj, self.v_bias, self.v_act, self.v_firing, self.v_A, self.v_B
        )
        pruned_value_states = pruned_value_states.view(hidden_shape).transpose(1, 2)
        if self.training and hidden_states is not None:
            query_loss = F.mse_loss(
                pruned_query_states.reshape(-1, pruned_query_states.size(-1)),
                hidden_query_states.reshape(-1, hidden_query_states.size(-1)),
            )
            key_loss = F.mse_loss(
                pruned_key_states.reshape(-1, pruned_key_states.size(-1)),
                hidden_key_states.reshape(-1, hidden_key_states.size(-1)),
            )
            value_loss = F.mse_loss(
                pruned_value_states.reshape(-1, pruned_value_states.size(-1)),
                hidden_value_states.reshape(-1, hidden_value_states.size(-1)),
            )
        else:
            query_loss, key_loss, value_loss = 0, 0, 0
        
        cos, sin = position_embeddings
        pruned_query_states, pruned_key_states = apply_rotary_pos_emb(pruned_query_states, pruned_key_states, cos, sin)
        if hidden_states is not None:
            hidden_query_states, hidden_key_states = apply_rotary_pos_emb(hidden_query_states, hidden_key_states, cos, sin)

        if pruned_past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            pruned_key_states, pruned_value_states = pruned_past_key_value.update(
                pruned_key_states,
                pruned_value_states, 
                self.layer_idx, 
                cache_kwargs
            )
        if hidden_past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            pruned_key_states, pruned_value_states = hidden_past_key_value.update(
                pruned_key_states,
                pruned_value_states, 
                self.layer_idx, 
                cache_kwargs
            )

        attention_interface: Callable = eager_attention_forward

        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        pruned_attn_output, pruned_attn_weights = attention_interface(
            self,
            pruned_query_states,
            pruned_key_states,
            pruned_value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        pruned_attn_output = pruned_attn_output.reshape(*input_shape, -1).contiguous()
        pruned_attn_output, attn_orthogonal = self.column_masking(
            pruned_attn_output, self.o_proj, self.o_bias, self.o_act, self.o_firing, self.o_A, self.o_B
        )
        
        orthogonal_loss = query_orthogonal + key_orthogonal + value_orthogonal + attn_orthogonal
        
        if hidden_states is not None:
            hidden_attn_output, hidden_attn_weights = attention_interface(
                self,
                hidden_query_states,
                hidden_key_states,
                hidden_value_states,
                attention_mask,
                dropout=0.0 if not self.training else self.attention_dropout,
                scaling=self.scaling,
                **kwargs,
            )

            hidden_attn_output = hidden_attn_output.reshape(*input_shape, -1).contiguous()
            hidden_attn_output = self.o_proj(hidden_attn_output)
        else:
            hidden_attn_output = None
        
        
        if self.training and hidden_states is not None:
            attn_loss = F.mse_loss(
                pruned_attn_output.reshape(-1, pruned_attn_output.size(-1)),
                hidden_attn_output.reshape(-1, hidden_attn_output.size(-1)),
            )
        else:
            attn_loss = 0
        
        semantic_loss = query_loss + key_loss + value_loss + attn_loss    
        # if orthogonal_loss > semantic_loss:
        #     orthogonal_loss *= semantic_loss.detach().item() / orthogonal_loss.detach().item()
        # if self.training:
        #     semantic_loss = F.kl_div(F.log_softmax(pruned_attn_weights, dim=-1), F.softmax(hidden_attn_weights, dim=-1))
        # else:
        #     semantic_loss = 0
        # semantic_loss = 0
        
        return pruned_attn_output, pruned_attn_weights, semantic_loss + orthogonal_loss, hidden_attn_output


class LlamaDecoderLayer(GradientCheckpointingLayer):
    def __init__(
        self, 
        config: LlamaConfig, 
        layer_idx: int, 
        num_extra_neurons: int, 
        pruning_ratio: Union[float, Dict] = None,
        input_dependent: bool = False,
        lora: bool = False,
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        
        if isinstance(pruning_ratio, Dict):
            mlp_pruning_ratio = pruning_ratio['mlp']
            attn_pruning_ratio = pruning_ratio['attn']
        elif isinstance(pruning_ratio, float):
            mlp_pruning_ratio = attn_pruning_ratio = pruning_ratio
        else:
            mlp_pruning_ratio = attn_pruning_ratio = 0.5

        self.self_attn = LlamaAttention(config=config, 
                                        layer_idx=layer_idx, 
                                        num_extra_neurons=num_extra_neurons, 
                                        pruning_ratio=attn_pruning_ratio,
                                        input_dependent=input_dependent,
                                        lora=lora
                                        )
        self.mlp = LlamaMLP(config, num_extra_neurons, pruning_ratio=mlp_pruning_ratio, input_dependent=input_dependent, lora=lora)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lora = lora
    
    def bias_initialization(self):
        self.self_attn.bias_initialization()
        self.mlp.bias_initialization()
        
    def lora_initialization(self):
        self.self_attn.lora_initialization()
        self.mlp.lora_initialization()

    def set_pruning_ratio(self, pruning_ratio):
        
        if isinstance(pruning_ratio, Dict):
            mlp_pruning_ratio = pruning_ratio['mlp']
            attn_pruning_ratio = pruning_ratio['attn']
        elif isinstance(pruning_ratio, float):
            mlp_pruning_ratio = attn_pruning_ratio = pruning_ratio
            
        self.self_attn.set_pruning_ratio(attn_pruning_ratio)
        self.mlp.set_pruning_ratio(mlp_pruning_ratio)
        
    def set_bias(self):
        self.self_attn.set_bias()
        self.mlp.set_bias()
        
    def enable_ensemble(self):
        self.self_attn.enable_ensemble()
        self.mlp.enable_ensemble()
    
    def get_post_attention_layernorm(self):
        return self.post_attention_layernorm
        
    def forward(
        self,
        pruned_states: torch.Tensor,
        hidden_states: torch.Tensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        pruned_past_key_value: Optional[Cache] = None,
        hidden_past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        pruned_residual = pruned_states
        hidden_residual = hidden_states if hidden_states is not None else None
        
        pruned_states = self.input_layernorm(pruned_states)
        hidden_states = self.input_layernorm(hidden_states) if hidden_states is not None else None

        # Self Attention
        pruned_states, pruned_attn_weights, attn_loss, hidden_states = self.self_attn(
            pruned_states=pruned_states,
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            pruned_past_key_value=pruned_past_key_value,
            hidden_past_key_value=hidden_past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        pruned_states = pruned_residual + pruned_states
        hidden_states = hidden_residual + hidden_states if hidden_states is not None else None

        # Fully Connected
        pruned_residual = pruned_states
        hidden_residual = hidden_states if hidden_states is not None else None
        
        pruned_states = self.post_attention_layernorm(pruned_states)
        hidden_states = self.post_attention_layernorm(hidden_states) if hidden_states is not None else None
        pruned_states, mlp_loss, hidden_states = self.mlp(pruned_states, hidden_states)
        
        pruned_states = pruned_residual + pruned_states
        hidden_states = hidden_residual + hidden_states if hidden_states is not None else None

        outputs = (pruned_states,)
        if hidden_states is not None:
            outputs += (hidden_states,)
        if output_attentions:
            outputs += (pruned_attn_weights,)
        outputs += (attn_loss + mlp_loss,)
        
        return outputs


@auto_docstring
class LlamaPreTrainedModel(PreTrainedModel):
    config_class = LlamaConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["LlamaDecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_flex_attn = True
    _supports_cache_class = True
    _supports_quantized_cache = True
    _supports_static_cache = True
    _supports_attention_backend = True

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, LlamaRMSNorm):
            module.weight.data.fill_(1.0)


@auto_docstring
class LlamaModel(LlamaPreTrainedModel):
    def __init__(
        self, 
        config: LlamaConfig, 
        num_extra_neurons=1, 
        pruning_ratio: Union[float, list]=0.5,
        input_dependent: bool = False,
        lora: bool = False
    ):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        
        if isinstance(pruning_ratio, float):
            pruning_ratio = [pruning_ratio] * config.num_hidden_layers
        elif isinstance(pruning_ratio, list):
            assert len(pruning_ratio) == config.num_hidden_layers
            
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(
                config, 
                layer_idx, 
                num_extra_neurons, 
                pruning_ratio[layer_idx],
                input_dependent,
                lora,
            ) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        self.gradient_checkpointing = False
        self.lora = lora

        # Initialize weights and apply final processing
        self.post_init()
        
    def bias_initialization(self):
        for i in range(self.config.num_hidden_layers):
            self.layers[i].bias_initialization()
        
    def lora_initialization(self):
        for i in range(self.config.num_hidden_layers):
            self.layers[i].lora_initialization()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value
    
    def get_post_attention_layernorm(self, layer_idx):
        return self.layers[layer_idx].get_post_attention_layernorm()
    
    def set_pruning_ratio(self, pruning_ratio):
        
        if isinstance(pruning_ratio, float):
            pruning_ratio = [pruning_ratio] * self.config.num_hidden_layers
        elif isinstance(pruning_ratio, list):
            assert len(pruning_ratio) == self.config.num_hidden_layers

        for layer_idx in range(self.config.num_hidden_layers):
            self.layers[layer_idx].set_pruning_ratio(pruning_ratio[layer_idx])
        
    def set_bias(self):
        for layer_idx in range(self.config.num_hidden_layers):
            self.layers[layer_idx].set_bias()
            
    def enable_ensemble(self):
        for layer_idx in range(self.config.num_hidden_layers):
            self.layers[layer_idx].enable_ensemble()
        
    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> BaseModelOutputWithPastAndPruning:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
        if not isinstance(past_key_values, (type(None), Cache)):
            raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            pruned_past_key_values = DynamicCache()
            hidden_past_key_values = DynamicCache()
        else:
            pruned_past_key_values = past_key_values
            hidden_past_key_values = past_key_values

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        pruned_states = inputs_embeds
        if self.training:
            hidden_states= inputs_embeds.clone()
        else:
            hidden_states = None

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(pruned_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        
        semantic_loss = 0

        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            if output_hidden_states:
                all_hidden_states += (pruned_states,)

            layer_outputs = decoder_layer(
                pruned_states=pruned_states,
                hidden_states=hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                pruned_past_key_values=pruned_past_key_values,
                hidden_past_key_values=hidden_past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **flash_attn_kwargs,
            )

            pruned_states = layer_outputs[0]
            hidden_states = layer_outputs[1] if hidden_states is not None else None

            if output_attentions and hidden_states is None:
                all_self_attns += (layer_outputs[1],)
            elif output_attentions:
                all_self_attns += (layer_outputs[2],)
                
            semantic_loss += layer_outputs[-1]

        pruned_states = self.norm(pruned_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (pruned_states,)

        return BaseModelOutputWithPastAndPruning(
            last_pruned_state=pruned_states,
            last_hidden_state=self.norm(hidden_states) if hidden_states is not None else None,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            pruning_loss=semantic_loss,
        )

    def _update_causal_mask(
        self,
        attention_mask: Union[torch.Tensor, "BlockMask"],
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool = False,
    ):
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and (attention_mask == 0.0).any():
                return attention_mask
            return None
        if self.config._attn_implementation == "flex_attention":
            if isinstance(attention_mask, torch.Tensor):
                attention_mask = make_flex_block_causal_mask(attention_mask)
            return attention_mask

        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
        # to infer the attention mask.
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False

        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
        if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=self.training,
            ):
                return None

        dtype = input_tensor.dtype
        sequence_length = input_tensor.shape[1]
        if using_compilable_cache:
            target_length = past_key_values.get_max_cache_shape()
        else:
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
        )

        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type in ["cuda", "xpu", "npu"]
            and not output_attentions
        ):
            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
            # Details: https://github.com/pytorch/pytorch/issues/110213
            min_dtype = torch.finfo(dtype).min
            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

        return causal_mask

    @staticmethod
    def _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        cache_position: torch.Tensor,
        batch_size: int,
        **kwargs,
    ):
        """
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        """
        if attention_mask is not None and attention_mask.dim() == 4:
            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
            causal_mask = attention_mask
        else:
            min_dtype = torch.finfo(dtype).min
            causal_mask = torch.full(
                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
            )
            if sequence_length != 1:
                causal_mask = torch.triu(causal_mask, diagonal=1)
            causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
            if attention_mask is not None:
                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
                    causal_mask.device
                )
                padding_mask = padding_mask == 0
                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                    padding_mask, min_dtype
                )

        return causal_mask


class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


@auto_docstring
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]
    _tp_plan = {"lm_head": "colwise_rep"}
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config, num_extra_neurons=1, pruning_ratio: Union[float, list]=0.5, input_dependent: bool = False, lora: bool = False):
        super().__init__(config)
        self.model = LlamaModel(config, num_extra_neurons, pruning_ratio, input_dependent, lora)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()
        self.pruning_ratio = pruning_ratio

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings
    
    def get_post_attention_layernorm(self, layer_idx):
        return self.model.get_post_attention_layernorm(layer_idx)
    

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def bias_initialization(self):
        self.model.bias_initialization()

    def lora_initialization(self):
        self.model.lora_initialization()
        
    def set_pruning_ratio(self, pruning_ratio):
        self.model.set_pruning_ratio(pruning_ratio)
        self.pruning_ratio = pruning_ratio
        
    def set_bias(self):
        self.model.set_bias()

    def enable_ensemble(self):
        self.model.enable_ensemble()
        
    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[KwargsForCausalLM],
    ) -> CausalLMOutputWithPast:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs: BaseModelOutputWithPastAndPruning = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            cache_position=cache_position,
            **kwargs,
        )

        pruned_state = outputs.last_pruned_state
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        pruned_logits = self.lm_head(pruned_state[:, slice_indices, :])

        loss = None
        if labels is not None:
            
            loss = self.loss_function(logits=pruned_logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
            if self.training and self.pruning_ratio != 1:
                hidden_logits = self.lm_head(outputs.last_hidden_state[:, slice_indices, :])

                kl_loss = F.kl_div(F.log_softmax(pruned_logits, dim=-1), F.softmax(hidden_logits, dim=-1))
                pruning_loss = 0.5 * outputs.pruning_loss * (kl_loss.detach().item() / (outputs.pruning_loss.detach().item() + 1e-8))
                lambda_weight = loss.detach().item() /  (pruning_loss + kl_loss).detach().item()
                loss = lambda_weight * (pruning_loss + kl_loss)
            
        return CausalLMOutputWithPast(
            loss=loss,
            logits=pruned_logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@auto_docstring(
    custom_intro="""
    The LLaMa Model transformer with a sequence classification head on top (linear layer).

    [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-2) do.

    Since it does classification on the last token, it requires to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    """
)
class LlamaForSequenceClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = LlamaModel(config)
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ) -> SequenceClassifierOutputWithPast:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """

        transformer_outputs: BaseModelOutputWithPast = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        hidden_states = transformer_outputs.last_hidden_state
        logits = self.score(hidden_states)

        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = inputs_embeds.shape[0]

        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
        if self.config.pad_token_id is None:
            last_non_pad_token = -1
        elif input_ids is not None:
            # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
        else:
            last_non_pad_token = -1
            logger.warning_once(
                f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
            )

        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )


@auto_docstring
class LlamaForQuestionAnswering(LlamaPreTrainedModel):
    base_model_prefix = "transformer"

    # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
    def __init__(self, config):
        super().__init__(config)
        self.transformer = LlamaModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, 2)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.transformer.embed_tokens

    def set_input_embeddings(self, value):
        self.transformer.embed_tokens = value

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        **kwargs,
    ) -> QuestionAnsweringModelOutput:
        outputs: BaseModelOutputWithPast = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        sequence_output = outputs.last_hidden_state

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        loss = None
        if start_positions is not None and end_positions is not None:
            loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)

        return QuestionAnsweringModelOutput(
            loss=loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@auto_docstring
class LlamaForTokenClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = LlamaModel(config)
        if getattr(config, "classifier_dropout", None) is not None:
            classifier_dropout = config.classifier_dropout
        elif getattr(config, "hidden_dropout", None) is not None:
            classifier_dropout = config.hidden_dropout
        else:
            classifier_dropout = 0.1
        self.dropout = nn.Dropout(classifier_dropout)
        self.score = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ) -> TokenClassifierOutput:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """

        outputs: BaseModelOutputWithPast = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        logits = self.score(sequence_output)

        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, self.config)

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


__all__ = [
    "LlamaForCausalLM",
    "LlamaModel",
    "LlamaPreTrainedModel",
    "LlamaForSequenceClassification",
    "LlamaForQuestionAnswering",
    "LlamaForTokenClassification",
]
