import time
from typing import List, Optional, Tuple, Union
import math

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, LlamaPreTrainedModel, LlamaModel
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.cache_utils import Cache, DynamicCache
from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaDecoderLayer, LlamaForCausalLM, LlamaRotaryEmbedding, repeat_kv
from transformers.utils import logging

import wandb

from .configuration_mem_model import MemLlamaConfig
from .memory_bank import (
    ExternalMemory,
    AttentionExternalMemory,
    retrieve_layer_memory,
    update_memory_fifo,
    update_attn_weights,
    update_memory_attn,
)

logging.set_verbosity_info()
logger = logging.get_logger(__name__)


# MeM: Modify based on transformers.models.llama.modeling_llama.LlamaAttention
class MemLlamaCrossAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: MemLlamaConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)

    def forward(
        self,
        hidden_states: torch.Tensor,
        mem_hidden_states_v: torch.Tensor,
        mem_hidden_states_k: torch.Tensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        bsz, q_len, _ = hidden_states.size()
        _, mem_size, _ = mem_hidden_states_v.size()

        if self.config.pretraining_tp > 1:
            raise NotImplementedError('pretraining_tp > 1 is not supported for MemLlamaCrossAttention.')

        else:
            query_states = self.q_proj(hidden_states)
            key_states = (
                self.k_proj(mem_hidden_states_k)
                if self.config.use_last_prompt_token_as_key
                else self.k_proj(mem_hidden_states_v)
            )
            value_states = self.v_proj(mem_hidden_states_v)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, mem_size, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, mem_size, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights


# MeM: Modify based on transformers.models.llama.modeling_llama.LlamaDecoderLayer
class MemLlamaDecoderLayer(LlamaDecoderLayer):
    r"""
    MeM: Fusion with memory.
    """

    def __init__(self, config: MemLlamaConfig, layer_idx: int):
        super().__init__(config, layer_idx)
        # MeM: Add the following parameters for interaction with memory bank
        # self.output_cross_attentions = getattr(config, "output_cross_attentions", False)
        if config.update_strategy == 'attn':
            self.output_cross_attentions = True
        else:
            self.output_cross_attentions = False
        self.use_flash_attention = getattr(config, "use_flash_attention", False)

        self.cross_attn_mem = self.build_cross_attention_mem(config, layer_idx)
        self.layer_idx = layer_idx

        self.fusion_func = config.fusion_func
        self.use_last_prompt_token_as_key = config.use_last_prompt_token_as_key

        if self.fusion_func in ['softmax', 'softmax*']:
            self.alpha = nn.Linear(in_features=2 * config.hidden_size, out_features=2, bias=False)
        elif self.fusion_func in ['sigmoid', 'sigmoid*']:
            self.alpha = nn.Linear(in_features=2 * config.hidden_size, out_features=config.hidden_size, bias=True)
        elif self.fusion_func == 'sigmoid_alpha':
            self.alpha = nn.Linear(in_features=config.hidden_size, out_features=1)
        else:
            raise NotImplementedError(f'Fusion function "{self.fusion_func}" not implemented.')

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        external_memory_v: Optional[ExternalMemory] = None,
        external_memory_k: Optional[ExternalMemory] = None,
    ) -> Tuple[torch.FloatTensor]:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)
        hidden_states_ln = hidden_states

        # Self Attention
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)

        # MeM: Interaction with memory
        # mem_hidden_states: (bs , seq_len, h_dim) mem_attn_weights: (bs, num_heads, seq_len, memory_size)
        mem_hidden_states, mem_attn_weights = self.memory_interaction(
            hidden_states, external_memory_v, external_memory_k
        )  #
        if mem_attn_weights is not None:
            mem_attn_weights = mem_attn_weights.mean(dim=(1, 2))  # (bs, memory_size)

        if self.fusion_func == 'softmax':
            hidden_states_cat = torch.cat((hidden_states, mem_hidden_states), dim=-1)  # (bs, seq_len, 2 * h_dim)
            alphas = self.alpha(hidden_states_cat)  # (bs, seq_len, 2)
            balanced_alphas = torch.softmax(alphas, dim=-1)  # (bs, seq_len, 2)
            gate_orig = balanced_alphas[:, :, 0].unsqueeze(-1)  # (bs, seq_len, 1)
            gate_mem = balanced_alphas[:, :, 1].unsqueeze(-1)  # (bs, seq_len, 1)
            # TODO: check residual placement
            # hidden_states = residual + hidden_states

            hidden_states = gate_orig * hidden_states + gate_mem * mem_hidden_states
            # TODO: check residual placement
            hidden_states = residual + hidden_states

            if wandb.run is not None:
                wandb.log(
                    {
                        f'gate_orig_{self.layer_idx}': gate_orig.mean().item(),
                        f'gate_mem_{self.layer_idx}': gate_mem.mean().item(),
                    }
                )

        elif self.fusion_func == 'sigmoid':
            hidden_states_cat = torch.cat((hidden_states, mem_hidden_states), dim=-1)  # (bs, seq_len, 2 * h_dim)

            gate = torch.sigmoid(self.alpha(hidden_states_cat))  # (bs, seq_len, h_dim)

            # TODO: check residual placement
            # hidden_states = residual + hidden_states

            hidden_states = gate * hidden_states + (1 - gate) * mem_hidden_states

            # TODO: check residual placement
            hidden_states = residual + hidden_states

            if wandb.run is not None:
                wandb.log(
                    {
                        f'gate_orig_{self.layer_idx}': gate.mean().item(),
                        f'gate_mem_{self.layer_idx}': (1 - gate).mean().item(),
                    }
                )

        elif self.fusion_func == 'softmax*':
            hidden_states_cat = torch.cat((hidden_states, mem_hidden_states), dim=-1)  # (bs, seq_len, 2 * h_dim)
            alphas = self.alpha(hidden_states_cat)  # (bs, seq_len, 2)
            balanced_alphas = torch.softmax(alphas, dim=-1)  # (bs, seq_len, 2)
            gate_orig = balanced_alphas[:, :, 0].unsqueeze(-1)  # (bs, seq_len, 1)
            gate_mem = balanced_alphas[:, :, 1].unsqueeze(-1)  # (bs, seq_len, 1)
            # TODO: check residual placement
            hidden_states = residual + hidden_states

            hidden_states = gate_orig * hidden_states + gate_mem * mem_hidden_states
            # TODO: check residual placement
            # hidden_states = residual + hidden_states
            if wandb.run is not None:
                wandb.log(
                    {
                        f'gate_orig_{self.layer_idx}': gate_orig.mean().item(),
                        f'gate_mem_{self.layer_idx}': gate_mem.mean().item(),
                    }
                )

        elif self.fusion_func == 'sigmoid*':
            hidden_states_cat = torch.cat((hidden_states, mem_hidden_states), dim=-1)  # (bs, seq_len, 2 * h_dim)

            gate = torch.sigmoid(self.alpha(hidden_states_cat))  # (bs, seq_len, h_dim)

            # TODO: check residual placement
            hidden_states = residual + hidden_states

            hidden_states = gate * hidden_states + (1 - gate) * mem_hidden_states

            # TODO: check residual placement
            # hidden_states = residual + hidden_states
            if wandb.run is not None:
                wandb.log(
                    {
                        f'gate_orig_{self.layer_idx}': gate.mean().item(),
                        f'gate_mem_{self.layer_idx}': (1 - gate).mean().item(),
                    }
                )

        elif self.fusion_func == 'sigmoid_alpha':
            hidden_states = residual + hidden_states
            gate_mem = torch.sigmoid(self.alpha(hidden_states))  # [bs, seq_len, 1]

            hidden_states = (1 - gate_mem) * hidden_states + gate_mem * mem_hidden_states
            if wandb.run is not None:
                wandb.log(
                    {
                        f'gate_orig_{self.layer_idx}': (1 - gate_mem).mean().item(),
                        f'gate_mem_{self.layer_idx}': gate_mem.mean().item(),
                    }
                )

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        outputs += (hidden_states_ln,)

        if mem_attn_weights is not None:
            outputs += (mem_attn_weights,)

        return outputs

    def build_cross_attention_mem(self, config: MemLlamaConfig, layer_idx: int):
        if self.use_flash_attention:
            # TODO: Implement MemLlamaFlashCrossAttention
            # return MemLlamaFlashCrossAttention(
            #     config=config,
            #     layer_idx=layer_idx,
            # )
            raise NotImplementedError('MemLlamaFlashCrossAttention has not been implemented yet.')
        else:
            return MemLlamaCrossAttention(config=config, layer_idx=layer_idx)

    def memory_interaction(self, hidden_states, external_memory_v, external_memory_k):
        bs = hidden_states.shape[0]
        memory_size, hidden_dim = external_memory_v.shape

        # if no need to change data, expand more efficient than repeat
        bs_external_memory_v = external_memory_v.unsqueeze(0).expand(bs, memory_size, hidden_dim)
        if self.use_last_prompt_token_as_key:
            bs_external_memory_k = external_memory_k.unsqueeze(0).expand(bs, memory_size, hidden_dim)
        else:
            bs_external_memory_k = None

        mem_outputs, attn_weights = self.cross_attn_mem(
            hidden_states=hidden_states,
            mem_hidden_states_v=bs_external_memory_v,
            mem_hidden_states_k=bs_external_memory_k,
            attention_mask=None,
            output_attentions=self.output_cross_attentions,
        )
        del bs_external_memory_v, bs_external_memory_k

        return mem_outputs, attn_weights


class MemLlamaPreTrainedModel(LlamaPreTrainedModel):
    config_class = MemLlamaConfig
    _no_split_modules = ['LlamaDecoderLayer', 'MemLlamaDecoderLayer']


class MemLlamaModel(MemLlamaPreTrainedModel):
    def __init__(self, config: MemLlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        
        self.layers = nn.ModuleList([])
        # MeM: Insert memory to target layers.
        self.memory_insert_layers = config.memory_insert_layers
        for idx in range(config.num_hidden_layers):
            if idx in self.memory_insert_layers:
                self.layers.extend([MemLlamaDecoderLayer(config, layer_idx=idx)])
            else:
                self.layers.extend([LlamaDecoderLayer(config, idx)])
        
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        # Modify self.layers for MeM

        self.use_last_prompt_token_as_key = config.use_last_prompt_token_as_key

        self.update_while_predicting = config.update_while_predicting

        self.update_strategy = config.update_strategy

        # MeM: Add memory
        self.update_memory = config.update_memory

        self.external_memory_v = ExternalMemory(config)
        if self.use_last_prompt_token_as_key:
            self.external_memory_k = ExternalMemory(config)
        else:
            self.external_memory_k = None

        if self.update_strategy == 'attn':
            self.external_memory_attn = AttentionExternalMemory(config)  # MeM: Save attention scores of memory tokens
        else:
            self.external_memory_attn = None

        self.memory_update_steps = config.memory_update_steps

        if config.mem_hidden_path is not None:
            raise NotImplementedError

        logger.info('MeM: Successfully built external memory.')

        self.post_init()

        self._update_causal_mask = LlamaModel._update_causal_mask.__get__(self, MemLlamaModel)

        self._prepare_4d_causal_attention_mask_with_cache_position = (
            LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
        )

    def find_last_non_pad_newline_token(self, input_ids):
        r"""
        Find last token which is not a padding token or a newline token.
        In Llama vocab:
            {'\n': 198, '\n\n': 271, '\n\n\n': 1432}
        """
        last_non_pad_newline_indices = []
        for sample in input_ids:
            last_non_pad_newline_index = len(sample) - 1  # init last token as last non_pad
            for i in range(len(sample) - 1, -1, -1):
                if sample[i] != self.padding_idx and (sample[i] not in [198, 271, 1432]):
                    last_non_pad_newline_index = i
                    break
            last_non_pad_newline_indices.append(last_non_pad_newline_index)
        return last_non_pad_newline_indices

    def find_last_prompt_token(self, labels):
        last_prompt_token_indices = []
        for i, label_sequence in enumerate(labels):
            # Find indices where labels have the padding value (-100), marking prompt tokens
            prompt_token_indices = (label_sequence == -100).nonzero(as_tuple=True)[0]

            # Get the last index of the prompt tokens
            last_prompt_index = prompt_token_indices[-1]
            last_prompt_token_indices.append(last_prompt_index)

        return last_prompt_token_indices

    def process_hidden_states(self, hidden_states, input_ids, labels):
        # Transform hidden_states(bz, seq_len, hidden_dim) to memory_hidden_states(bz, 1, hidden_dim)
        last_non_pad = self.find_last_non_pad_newline_token(input_ids)
        mem_hidden_lst_v = [hidden_states[i, last_non_pad[i], :].detach() for i in range(len(last_non_pad))]
        mem_hidden_states_v = torch.stack(mem_hidden_lst_v).unsqueeze(1)

        if self.use_last_prompt_token_as_key:
            last_prompt = self.find_last_prompt_token(labels)
            mem_hidden_lst_k = [hidden_states[i, last_prompt[i], :].detach() for i in range(len(last_prompt))]
            mem_hidden_states_k = torch.stack(mem_hidden_lst_k).unsqueeze(1)
        else:
            mem_hidden_states_k = None

        return mem_hidden_states_v, mem_hidden_states_k

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        labels: torch.LongTensor = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
        if not isinstance(past_key_values, (type(None), Cache)):
            raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        # MeM: Modify the following codes to set up external memory
        if self.training or self.update_while_predicting:
            if self.external_memory_v is not None:
                self.external_memory_v.update_to_memory_modules()

            if self.external_memory_k is not None:
                self.external_memory_k.update_to_memory_modules()

            if self.external_memory_attn is not None:
                self.external_memory_attn.update_to_memory_modules()

        for layer_idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if layer_idx in self.memory_insert_layers:
                layer_external_memory_v, layer_external_memory_k = retrieve_layer_memory(
                    layer_idx, self.external_memory_v, self.external_memory_k
                )

                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    position_embeddings=position_embeddings,
                    external_memory_v=layer_external_memory_v,
                    external_memory_k=layer_external_memory_k,
                )

                # TODO: decide which hidden state to update memory

                if output_attentions:
                    hidden_states_ln_idx = 2
                else:
                    hidden_states_ln_idx = 1
                hidden_states_ln = layer_outputs[hidden_states_ln_idx]

                if self.update_strategy == 'attn':
                    mem_attn_weights_idx = hidden_states_ln_idx + 1
                    mem_attn_weights = layer_outputs[mem_attn_weights_idx]

                if self.training or self.update_while_predicting:
                    memory_hidden_states_v, memory_hidden_states_k = self.process_hidden_states(
                        hidden_states_ln, input_ids, labels
                    )

                    if self.update_strategy == 'attn':

                        # if torch.distributed.get_rank() == 0:
                        #     print('\n\n')
                        #     # print(self.external_memory_attn.memory_accumulate_steps[layer_idx])
                        #     print(mem_attn_weights)
                        #     time.sleep(1)

                        idx = self.external_memory_attn.layer_index_map[layer_idx]
                        if not self.external_memory_v.memory_full[layer_idx]:
                            # MeM: If memery_bank not full, fill the bank
                            update_memory_fifo(layer_idx, self.external_memory_v, memory_hidden_states_v.detach())
                            if self.use_last_prompt_token_as_key:
                                update_memory_fifo(layer_idx, self.external_memory_k, memory_hidden_states_k.detach())
                        else:
                            if torch.isnan(mem_attn_weights.mean(0)).any():
                                # MeM: If attn_weights contain NaN, do not update.
                                pass
                            else:
                                # if torch.distributed.get_rank() == 0:
                                #     print('Before:')
                                #     print(self.external_memory_attn.update_memory_modules[idx])

                                update_attn_weights(self.external_memory_attn, layer_idx, mem_attn_weights.detach())

                                # if torch.distributed.get_rank() == 0:
                                #     print('Attn:')
                                #     print(mem_attn_weights.mean(0))
                                #     print('After:')
                                #     print(self.external_memory_attn.update_memory_modules[idx])

                                if (
                                    self.external_memory_attn.memory_accumulate_steps[layer_idx]
                                    >= self.memory_update_steps
                                ):
                                    # if torch.distributed.get_rank() == 0:
                                    #     print('Before:')
                                    #     print(self.external_memory_attn.update_indices_buffer[layer_idx])
                                    #     print(self.external_memory_v.update_memory_modules[idx][:, :1])

                                    update_memory_attn(
                                        layer_idx,
                                        self.external_memory_attn,
                                        self.external_memory_v,
                                        memory_hidden_states_v.detach(),
                                    )

                                    # if torch.distributed.get_rank() == 0:
                                    #     # print('Attn:')
                                    #     # print(mem_attn_weights.mean(0))
                                    #     print('After:')
                                    #     print(self.external_memory_attn.update_indices_buffer[layer_idx])
                                    #     print(self.external_memory_v.update_memory_modules[idx][:, :1])
                                    #     time.sleep(2)

                                    if self.use_last_prompt_token_as_key:
                                        update_memory_attn(
                                            layer_idx,
                                            self.external_memory_attn,
                                            self.external_memory_k,
                                            memory_hidden_states_k.detach(),
                                        )

                    else:
                        # self.external_memory_v.update_memory(layer_idx, memory_hidden_states_v.detach())
                        update_memory_fifo(layer_idx, self.external_memory_v, memory_hidden_states_v.detach())
                        if self.use_last_prompt_token_as_key:
                            # self.external_memory_k.update_memory(layer_idx, memory_hidden_states_k.detach())
                            update_memory_fifo(layer_idx, self.external_memory_k, memory_hidden_states_k.detach())

            else:
                # MeM: Using Original LlamaDecoderLayer
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )


class MemLlamaForCausalLM(LlamaForCausalLM):
    config_class = MemLlamaConfig

    def __init__(self, config: MemLlamaConfig):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = MemLlamaModel(config)  # MeM: Set model to MemLlamaModel
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

    def init_cross_attention_from_self(self):
        # Use Llama's self-attention weights as the initialization of cross-attention module
        for name, param in self.named_parameters():
            if 'cross_attn_mem' in name:
                self_attention_module_name = name.replace('cross_attn_mem', 'self_attn')
                param.data.copy_(self.state_dict()[self_attention_module_name])

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        # MeM: Add "labels" as an argument
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            labels=labels,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


AutoModelForCausalLM.register(MemLlamaConfig, MemLlamaForCausalLM)
