from typing import  Optional, Union, Callable

import torch
from torch import nn
import wandb

from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast
)
from transformers import AutoModelForCausalLM
from transformers.processing_utils import Unpack
from transformers.utils import logging


from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2PreTrainedModel, Qwen2ForCausalLM, Qwen2RMSNorm, Qwen2RotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward

from .configuration_mem_model import MemQwen2Config
from .memory_bank import (ExternalMemory, AttentionExternalMemory, retrieve_layer_memory, update_memory_fifo,
                          update_attn_weights, update_memory_attn)


logger = logging.get_logger(__name__)


class MemQwen2CrossAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: MemQwen2Config, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = True
        self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
        self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
        self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        mem_hidden_states_v: torch.Tensor,
        mem_hidden_states_k: torch.Tensor = None,
        past_key_value: Optional[Cache] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        if self.config.use_last_prompt_token_as_key:
            key_states = self.k_proj(mem_hidden_states_k)
        else:
            key_states = self.k_proj(mem_hidden_states_v)
        key_states = key_states.view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(mem_hidden_states_v).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        sliding_window = None
        if (
            self.config.use_sliding_window
            and getattr(self.config, "sliding_window", None) is not None
            and self.layer_idx >= self.config.max_window_layers
        ):
            sliding_window = self.config.sliding_window

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=sliding_window,  # main diff with Llama
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class MemQwen2DecoderLayer(Qwen2DecoderLayer):
    r"""
    MeM: Fusion with memory.
    """

    def __init__(self, config: MemQwen2Config, layer_idx: int):
        super().__init__(config, layer_idx)
        if config.update_strategy == 'attn':
            self.output_cross_attentions = True
        else:
            self.output_cross_attentions = False
        self.use_flash_attention = getattr(config, "use_flash_attention", False)

        self.cross_attn_mem = self.build_cross_attention_mem(config, layer_idx)
        self.layer_idx = layer_idx

        self.fusion_func = config.fusion_func
        self.use_last_prompt_token_as_key = config.use_last_prompt_token_as_key

        if self.fusion_func in ['softmax', 'softmax*']:
            self.alpha = nn.Linear(in_features=2 * config.hidden_size, out_features=2, bias=False)
        elif self.fusion_func in ['sigmoid', 'sigmoid*']:
            self.alpha = nn.Linear(in_features=2 * config.hidden_size, out_features=config.hidden_size, bias=True)
        elif self.fusion_func == 'sigmoid_alpha':
            self.alpha = nn.Linear(in_features=config.hidden_size, out_features=1)
        else:
            raise NotImplementedError(f'Fusion function "{self.fusion_func}" not implemented.')

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        external_memory_v: Optional[ExternalMemory] = None,
        external_memory_k: Optional[ExternalMemory] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
        # Self Attention
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        hidden_states_ln = hidden_states

        # Self Attention
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)

        mem_hidden_states, mem_attn_weights = self.memory_interaction(
            hidden_states, external_memory_v, external_memory_k
        )  #
        if mem_attn_weights is not None:
            mem_attn_weights = mem_attn_weights.mean(dim=(1, 2))  # (bs, memory_size)

        if self.fusion_func == 'sigmoid_alpha':
            hidden_states = residual + hidden_states
            gate_mem = torch.sigmoid(self.alpha(hidden_states))  # [bs, seq_len, 1]

            hidden_states = (1 - gate_mem) * hidden_states + gate_mem * mem_hidden_states
            if wandb.run is not None:
                wandb.log(
                    {
                        f'gate_orig_{self.layer_idx}': (1 - gate_mem).mean().item(),
                        f'gate_mem_{self.layer_idx}': gate_mem.mean().item(),
                    }
                )
        else:
            raise NotImplementedError(f'Fusion function "{self.fusion_func}" not implemented.')

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (self_attn_weights,)

        outputs += (hidden_states_ln,)

        if mem_attn_weights is not None:
            outputs += (mem_attn_weights,)

        return outputs

    def build_cross_attention_mem(self, config: MemQwen2Config, layer_idx: int):
        return MemQwen2CrossAttention(config, layer_idx)


    def memory_interaction(self, hidden_states: torch.Tensor, external_memory_v: Optional[ExternalMemory], external_memory_k: Optional[ExternalMemory]):
        bs = hidden_states.shape[0]
        memory_size, hidden_dim = external_memory_v.shape

        # if no need to change data, expand more efficient than repeat
        bs_external_memory_v = external_memory_v.unsqueeze(0).expand(bs, memory_size, hidden_dim)
        if self.use_last_prompt_token_as_key:
            bs_external_memory_k = external_memory_k.unsqueeze(0).expand(bs, memory_size, hidden_dim)
        else:
            bs_external_memory_k = None

        mem_outputs, attn_weights = self.cross_attn_mem(
            hidden_states=hidden_states,
            mem_hidden_states_v=bs_external_memory_v,
            mem_hidden_states_k=bs_external_memory_k,
            attention_mask=None,
            output_attentions=self.output_cross_attentions,
        )
        del bs_external_memory_v, bs_external_memory_k

        return mem_outputs, attn_weights


class MemQwen2PreTrainedModel(Qwen2PreTrainedModel):
    config_class = MemQwen2Config
    _no_split_modules = ['Qwen2DecoderLayer', 'MemQwen2DecoderLayer']

class MemQwen2Model(MemQwen2PreTrainedModel):
    def __init__(self, config: MemQwen2Config):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList([])
        # MeM: Insert memory to target layers.
        self.memory_insert_layers = config.memory_insert_layers
        for idx in range(config.num_hidden_layers):
            if idx in self.memory_insert_layers:
                self.layers.extend([MemQwen2DecoderLayer(config, layer_idx=idx)])
            else:
                self.layers.extend([Qwen2DecoderLayer(config, idx)])
        # MeM: Add memory
        self.update_memory = config.update_memory

        self.external_memory_v = ExternalMemory(config)
        self.use_last_prompt_token_as_key = config.use_last_prompt_token_as_key
        if self.use_last_prompt_token_as_key:
            self.external_memory_k = ExternalMemory(config)
        else:
            self.external_memory_k = None
        self.update_strategy = config.update_strategy
        if self.update_strategy == 'attn':
            self.external_memory_attn = AttentionExternalMemory(config)  # MeM: Save attention scores of memory tokens
        else:
            self.external_memory_attn = None

        self.memory_update_steps = config.memory_update_steps

        if config.mem_hidden_path is not None:
            raise NotImplementedError

        logger.info('MeM: Successfully built external memory.')

        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = Qwen2RotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def find_last_non_pad_newline_token(self, input_ids):
        r"""
        Find last token which is not a padding token or a newline token.
        In Llama vocab:
            {'\n': 198, '\n\n': 271, '\n\n\n': 1406}
        """
        last_non_pad_newline_indices = []
        for sample in input_ids:
            last_non_pad_newline_index = len(sample) - 1  # init last token as last non_pad
            for i in range(len(sample) - 1, -1, -1):
                if sample[i] != self.padding_idx and (sample[i] not in [198, 271, 1406]):
                    last_non_pad_newline_index = i
                    break
            last_non_pad_newline_indices.append(last_non_pad_newline_index)
        return last_non_pad_newline_indices

    def find_last_prompt_token(self, labels):
        last_prompt_token_indices = []
        for i, label_sequence in enumerate(labels):
            # Find indices where labels have the padding value (-100), marking prompt tokens
            prompt_token_indices = (label_sequence == -100).nonzero(as_tuple=True)[0]

            # Get the last index of the prompt tokens
            last_prompt_index = prompt_token_indices[-1]
            last_prompt_token_indices.append(last_prompt_index)

        return last_prompt_token_indices

    def process_hidden_states(self, hidden_states, input_ids, labels):
        # Transform hidden_states(bz, seq_len, hidden_dim) to memory_hidden_states(bz, 1, hidden_dim)
        last_non_pad = self.find_last_non_pad_newline_token(input_ids)
        mem_hidden_lst_v = [hidden_states[i, last_non_pad[i], :].detach() for i in range(len(last_non_pad))]
        mem_hidden_states_v = torch.stack(mem_hidden_lst_v).unsqueeze(1)

        if self.use_last_prompt_token_as_key:
            last_prompt = self.find_last_prompt_token(labels)
            mem_hidden_lst_k = [hidden_states[i, last_prompt[i], :].detach() for i in range(len(last_prompt))]
            mem_hidden_states_k = torch.stack(mem_hidden_lst_k).unsqueeze(1)
        else:
            mem_hidden_states_k = None

        return mem_hidden_states_v, mem_hidden_states_k

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        labels: torch.LongTensor = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> BaseModelOutputWithPast:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
        if not isinstance(past_key_values, (type(None), Cache)):
            raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        # MeM: Modify the following codes to set up external memory
        if self.training or self.update_while_predicting:
            if self.external_memory_v is not None:
                self.external_memory_v.update_to_memory_modules()

            if self.external_memory_k is not None:
                self.external_memory_k.update_to_memory_modules()

            if self.external_memory_attn is not None:
                self.external_memory_attn.update_to_memory_modules()

        for layer_idx, decoder_layer in enumerate(self.layers[:self.config.num_hidden_layers]):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if layer_idx in self.memory_insert_layers:
                layer_external_memory_v, layer_external_memory_k = retrieve_layer_memory(
                    layer_idx, self.external_memory_v, self.external_memory_k
                )
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    external_memory_v=layer_external_memory_v,
                    external_memory_k=layer_external_memory_k,
                    **flash_attn_kwargs,
                )

                if output_attentions:
                    hidden_states_ln_idx = 2
                else:
                    hidden_states_ln_idx = 1
                hidden_states_ln = layer_outputs[hidden_states_ln_idx]

                if self.update_strategy == 'attn':
                    mem_attn_weights_idx = hidden_states_ln_idx + 1
                    mem_attn_weights = layer_outputs[mem_attn_weights_idx]

                if self.training or self.update_while_predicting:
                    memory_hidden_states_v, memory_hidden_states_k = self.process_hidden_states(
                        hidden_states_ln, input_ids, labels
                    )

                if self.update_strategy == 'attn':
                    idx = self.external_memory_attn.layer_index_map[layer_idx]
                    if not self.external_memory_v.memory_full[layer_idx]:
                        # MeM: If memery_bank not full, fill the bank
                        update_memory_fifo(layer_idx, self.external_memory_v, memory_hidden_states_v.detach())
                        if self.use_last_prompt_token_as_key:
                            update_memory_fifo(layer_idx, self.external_memory_k, memory_hidden_states_k.detach())
                    else:
                        if torch.isnan(mem_attn_weights.mean(0)).any():
                            # MeM: If attn_weights contain NaN, do not update.
                            pass
                        else:
                            update_attn_weights(self.external_memory_attn, layer_idx, mem_attn_weights.detach())

                            if self.external_memory_attn.memory_accumulate_steps[layer_idx] >= self.memory_update_steps:

                                update_memory_attn(layer_idx, self.external_memory_attn, self.external_memory_v,
                                                    memory_hidden_states_v.detach())
                                if self.use_last_prompt_token_as_key:
                                    update_memory_attn(layer_idx, self.external_memory_attn, self.external_memory_k,
                                                        memory_hidden_states_k.detach())
                else:
                    update_memory_fifo(layer_idx, self.external_memory_v, memory_hidden_states_v.detach())
                    if self.use_last_prompt_token_as_key:
                        update_memory_fifo(layer_idx, self.external_memory_k, memory_hidden_states_k.detach())

            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **flash_attn_kwargs,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

class MemQwen2ForCausalLM(Qwen2ForCausalLM):
    config_class = MemQwen2Config

    def __init__(self, config: MemQwen2Config):
        super(Qwen2ForCausalLM, self).__init__(config)
        self.model = MemQwen2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def init_cross_attention_from_self(self):
        # Use Qwen2's self-attention weights as the initialization of cross-attention module
        for name, param in self.named_parameters():
            if 'cross_attn_mem' in name:
                self_attention_module_name = name.replace('cross_attn_mem', 'self_attn')
                param.data.copy_(self.state_dict()[self_attention_module_name])

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # MeM: Add "labels" as an argument
        outputs: BaseModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            cache_position=cache_position,
            labels=labels,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

AutoModelForCausalLM.register(MemQwen2Config, MemQwen2ForCausalLM)
