from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel, LlamaDecoderLayer, LlamaAttention, apply_rotary_pos_emb, repeat_kv, LlamaRMSNorm, LlamaRotaryEmbedding, LlamaMLP, LlamaForCausalLM, LlamaRotaryEmbedding, LlamaSdpaAttention, LlamaFlashAttention2
from transformers import AutoModelForCausalLM
from typing import List, Optional, Tuple, Union
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)
from transformers.utils import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from peft import (    # LoRA Setting
    PeftModel,
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
)
import torch
import math
import os
from torch import nn
import torch.nn.functional as F
from modeling.common import freeze, print_trainable_parameters, print_0


logger = logging.get_logger(__name__)


class GLlamaRotaryEmbedding(LlamaRotaryEmbedding):
    def __init__(self, dim=None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1, rope_type="default", config = None):
        super().__init__(dim, max_position_embeddings, base, device, scaling_factor, rope_type, config)
    
    def forward(self, x, position_ids):
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # [b, 64, 1]
        position_ids_expanded = position_ids[:, None, :].float() # [b, 1, seq]
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # [b, seq, 64]
            emb = torch.cat((freqs, freqs), dim=-1) # [b, seq, 128]
            cos = emb.cos()
            sin = emb.sin()

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class GLlamaAttention(LlamaAttention):
    def __init__(self, config: LlamaConfig, layer_idx: int | None = None):
        super().__init__(config, layer_idx)
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        rel_position: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        if self.config.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask
            if rel_position is not None:
                # rel_position [b, nodes, nodes]
                rel_position = rel_position[:, None, :, :]
                attn_weights[:, :, :rel_position.shape[-1], :rel_position.shape[-1]] += rel_position

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, -1)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


LLAMA_ATTENTION_CLASSES = {
    "eager": GLlamaAttention,
    "flash_attention_2": LlamaFlashAttention2,
    "sdpa": LlamaSdpaAttention,
}

class GLlamaDecoderLayer(LlamaDecoderLayer):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super(LlamaDecoderLayer, self).__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
        # self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        rel_position: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            rel_position=rel_position,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs
    

class GLlamaModel(LlamaModel):

    def __init__(self, config):
        super(LlamaModel, self).__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.memory_token_nums = config.memory_token_nums

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [GLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = GLlamaRotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        graph_embeds: Optional[torch.FloatTensor] = None,
        graph_attention_mask: Optional[torch.Tensor] = None,
        rel_position: Optional[torch.Tensor] = None,
        text_length: torch.LongTensor = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # kept for BC (non `Cache` `past_key_values` inputs)
        return_legacy_cache = False
        if use_cache and not isinstance(past_key_values, Cache):
            return_legacy_cache = True
            if past_key_values is None:
                past_key_values = DynamicCache()
            else:
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
                logger.warning_once(
                    "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
                    "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
                    "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
                )


        # graph info
        graph_node_nums = graph_embeds.shape[1] if graph_embeds is not None else 0
        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] + graph_node_nums, device=inputs_embeds.device
            )

        if graph_embeds is not None:
            if position_ids is None:
                raise ValueError("You must specify position_ids if use graph_embeds")
            else:
                assert position_ids.shape[-1] == (graph_embeds.shape[1] + inputs_embeds.shape[1])
        else:
            if position_ids is None:
                position_ids = cache_position.unsqueeze(0) # [1, seq]

        if graph_embeds is None:
            causal_mask = self._update_causal_mask(
                attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
            )
            hidden_states = inputs_embeds
        else:
            causal_mask = self._update_graph_causal_mask(
                attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, graph_attention_mask, text_length
            )
            assert graph_embeds.dim() == inputs_embeds.dim()
            assert graph_embeds.dim() == 3
            hidden_states = torch.cat([graph_embeds, inputs_embeds], dim=1)

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                    rel_position,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    rel_position=rel_position,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    def _update_graph_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
        graph_attention_mask: torch.Tensor,
        text_length: torch.LongTensor,
    ):  
        # TODO: modify graph attention mask for flash_attention
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and 0.0 in attention_mask:
                return attention_mask
            return None

        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
        # to infer the attention mask.
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        using_static_cache = isinstance(past_key_values, StaticCache)

        # TODO: modify graph attention mask for sdpa
        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
        if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=self.training,
            ):
                return None

        dtype, device = input_tensor.dtype, input_tensor.device
        sequence_length = input_tensor.shape[1]
        if using_static_cache:
            target_length = past_key_values.get_max_cache_shape()
        else:
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
        causal_mask = self._prepare_4d_graph_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            device=device,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
            graph_attention_mask=graph_attention_mask,
            memory_token_nums=self.memory_token_nums,
            text_length=text_length,
        )

        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type == "cuda"
            and not output_attentions
        ):
            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
            # Details: https://github.com/pytorch/pytorch/issues/110213
            min_dtype = torch.finfo(dtype).min
            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

        return causal_mask

    @staticmethod
    def _prepare_4d_graph_causal_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        device: torch.device,
        cache_position: torch.Tensor,
        batch_size: int,
        graph_attention_mask: torch.Tensor,
        memory_token_nums: int,
        text_length: torch.LongTensor,
        **kwargs,
    ):
        """
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            device (`torch.device`):
                The device to plcae the 4D attention mask on.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        """
        if attention_mask is not None and attention_mask.dim() == 4:
            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
            causal_mask = attention_mask
        else:
            # if memory_token_nums != 0:
            #     target_length += memory_token_nums

            if graph_attention_mask is not None:
                graph_node_nums = graph_attention_mask.shape[-1]
                sequence_length += graph_node_nums
                target_length += graph_node_nums
                assert sequence_length == target_length, f"seq: {sequence_length}, tar: {target_length}"

            min_dtype = torch.finfo(dtype).min
            causal_mask = torch.full(
                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            )
            if sequence_length != 1:
                causal_mask = torch.triu(causal_mask, diagonal=1)
            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)

            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
            mask = torch.arange(sequence_length, device=device).expand(batch_size, -1) < (text_length.view(-1, 1) - memory_token_nums + graph_node_nums).view(-1, 1)
            col_mask = mask[:, None, None, :].expand(-1, -1, graph_node_nums, -1) # [5, 1, 3, 10]
            row_mask = col_mask.transpose(-1, -2)
            causal_mask[:, :, :, :graph_node_nums] = causal_mask[:, :, :, :graph_node_nums].masked_fill(row_mask, min_dtype)
            causal_mask[:, :, :graph_node_nums, :] = causal_mask[:, :, :graph_node_nums, :].masked_fill(col_mask, 0)

            # text padding mask (if padding_side=right, can be ignored)
            if attention_mask is not None:
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[:, :, :, graph_node_nums:graph_node_nums+mask_length] + attention_mask[:, None, None, :]
                padding_mask = padding_mask == 0
                causal_mask[:, :, :, graph_node_nums:graph_node_nums+mask_length] = causal_mask[:, :, :, graph_node_nums:graph_node_nums+mask_length].masked_fill(
                    padding_mask, min_dtype
                )
            
            # node padding mask
            if graph_attention_mask is not None:
                graph_padding_mask = causal_mask[:, :, :, :graph_node_nums] + graph_attention_mask[:, None, None, :]
                graph_padding_mask = graph_padding_mask == 0
                causal_mask[:, :, :, :graph_node_nums] = causal_mask[:, :, :, :graph_node_nums].masked_fill(
                    graph_padding_mask, min_dtype
                )

        return causal_mask
    

class GLlamaForCausalLM(LlamaForCausalLM):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = GLlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        graph_embeds: Optional[torch.FloatTensor] = None,
        graph_attention_mask: Optional[torch.Tensor] = None,
        rel_position: Optional[torch.Tensor] = None,
        text_length: torch.LongTensor = None,
        **loss_kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            graph_embeds=graph_embeds,
            graph_attention_mask=graph_attention_mask,
            rel_position=rel_position,
            text_length=text_length,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
            logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    

class Projector(nn.Module):
    def __init__(self, model_args, edge=False):
        super().__init__()

        if edge:
            self.linear_1 = nn.Linear(model_args.node_dim, 256, bias=True)
            self.linear_2 = nn.Linear(256, 1, bias=True)
        else:
            self.linear_1 = nn.Linear(model_args.node_dim, model_args.llm_dim, bias=True)
            self.linear_2 = nn.Linear(model_args.llm_dim, model_args.llm_dim, bias=True)
        self.act = ACT2FN[model_args.projector_hidden_act]

    def forward(self, image_features):
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


class UniGTE(nn.Module):
    def __init__(self, model_args, training_args):
        super().__init__()
        self.model_args = model_args
        self.training_args = training_args

        self.projector = Projector(model_args).to(dtype=torch.float16 if training_args.bf16 is False else torch.bfloat16)
        self.edge_projector = Projector(model_args, edge=True).to(dtype=torch.float16 if training_args.bf16 is False else torch.bfloat16)
        self.gllama = GLlamaForCausalLM.from_pretrained(
            self.model_args.gt_enocder_path,
            torch_dtype=torch.float16 if training_args.bf16 is False else torch.bfloat16,
        )
        print(self.gllama.config._attn_implementation)
        # tunable
        self.vocab_size = self.gllama.config.vocab_size
        self.mem_size = self.model_args.memory_token_nums
        assert self.mem_size == self.gllama.config.memory_token_nums
        self.dim = self.gllama.config.hidden_size
        self.vocab_size_with_mem = self.vocab_size + self.mem_size

        if training_args.use_peft and not training_args.inference and not training_args.fix_encoder:
            lora_config = training_args.peft_config
            self.gllama = get_peft_model(self.gllama, lora_config)

        self.decoder = AutoModelForCausalLM.from_pretrained(
            self.model_args.llm_path,
            torch_dtype=torch.float16 if training_args.bf16 is False else torch.bfloat16,
        )
        if not self.training_args.inference:
            self.prepare_decoder()

        # if not training_args.use_peft and not training_args.fix_encoder:
        llama_embeds = self.decoder.get_input_embeddings().weight.data
        mem_token = torch.zeros(30, llama_embeds.shape[1]).to(dtype=llama_embeds.dtype)
        self.raw_embeds = nn.Embedding.from_pretrained(torch.cat([llama_embeds, mem_token] , dim=0), freeze=True)

        # special tokens about task type
        # self.ae_token_id = self.vocab_size_with_mem + 0
        # self.lm_token_id = self.vocab_size_with_mem + 1
        # self.ft_token_id = self.vocab_size_with_mem + 1      

        self.gllama.resize_token_embeddings(self.vocab_size_with_mem + 1) 
        self.memory_token_embed = nn.Embedding(self.mem_size + 1, self.dim, padding_idx=None)

        self.graph_position_theta = nn.Parameter(torch.tensor([2.0]), requires_grad=True)
        self.rel_position_encoder = nn.Embedding(model_args.spatial_pos_max, 1, padding_idx=0)

    def prepare_decoder(self):
        # print_0("Freezing the decoder...")
        # print("Freezing the decoder...")
        freeze(self.decoder)
        self.decoder.eval()
        # print_trainable_parameters(self)
        # print_0("Enabling gradient checkpointing...")
        # print("Enabling gradient checkpointing...")
        self.decoder.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

    def prepare_encoder_input(
        self,
        graph_embeds: torch.FloatTensor,
        graph_attention_mask: torch.Tensor,
        rel_position: torch.Tensor,
        edge_attr: torch.Tensor,
        edge_type: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        text_length: torch.LongTensor,
        output_attentions: bool = False,
        output_hidden_states: bool = True,
    ):  
        batch_size = input_ids.shape[0]
            
        memory_mask = input_ids >= self.vocab_size
        autoencoder_input_embedding = self.gllama.get_input_embeddings()(input_ids)
        autoencoder_input_embedding[memory_mask] = self.memory_token_embed(input_ids[memory_mask] - self.vocab_size).to(autoencoder_input_embedding)

        # position id [1, text_length]
        position_ids = torch.arange(0, input_ids.shape[1], device=autoencoder_input_embedding.device).unsqueeze(0).expand(batch_size, -1)
        if not self.training_args.pure_icae:
            graph_position_ids = torch.zeros([batch_size, graph_embeds.shape[1]], device=position_ids.device) + (text_length - self.mem_size).reshape(-1, 1) # [b, node_length]
            position_ids = torch.cat([graph_position_ids, position_ids], dim=-1)
            position_ids[:, :graph_embeds.shape[1]] += self.graph_position_theta
            rel_position = self.rel_position_encoder(rel_position).squeeze(dim=-1)
            graph_embeds = self.projector(graph_embeds)

            edge_scores = self.edge_projector(edge_attr).squeeze(-1) # [b, edge_num + 1]
            edge_bias = UniGTE.compute_average_scores_batch_with_prepended_special_score(edge_type, edge_scores)
        else:
            graph_embeds = None
            graph_attention_mask = None
            rel_position = None
            position_ids = None

        model_inputs = {
            'graph_embeds': graph_embeds, # [b, node_length, hidden_dim]
            'graph_attention_mask': graph_attention_mask, # [b, node_length]
            'rel_position': rel_position + edge_bias, # [b, node_length, node_length]
            'position_ids': position_ids, # [b, node_length+text_length]
            'inputs_embeds': autoencoder_input_embedding, # [b, text_length]
            'attention_mask': attention_mask, # [b, text_length]
            'text_length': text_length, # [b]
            'output_attentions': output_attentions,
            'output_hidden_states': output_hidden_states,
            'return_dict': True,
        }

        return model_inputs, memory_mask
    
    def forward(
        self,
        graph_embeds: torch.FloatTensor,
        graph_attention_mask: torch.Tensor,
        rel_position: torch.Tensor,
        edge_attr: torch.Tensor,
        edge_type: torch.Tensor,
        input_ids: torch.Tensor, #  text_tokens + mem_tokens + padding_tokens
        prompt_answer_ids: torch.LongTensor = None, # mem_tokens + prompt_tokens + answer_tokens + padding_tokens
        labels: Optional[torch.LongTensor] = None,
        attention_mask: torch.Tensor = None,
        text_length: torch.LongTensor = None,
    ):  
        model_inputs, memory_mask = self.prepare_encoder_input(
            graph_embeds=graph_embeds,
            graph_attention_mask=graph_attention_mask,
            rel_position=rel_position,
            edge_attr=edge_attr,
            edge_type=edge_type,
            input_ids=input_ids,
            attention_mask=attention_mask,
            text_length=text_length
        )
        graph_node_num = model_inputs['graph_embeds'].shape[1] if model_inputs['graph_embeds'] is not None else 0
        assert torch.all(memory_mask.sum(dim=1) == self.mem_size), f"Each row in the mask must sum to {self.mem_size}"
        compress_outputs = self.gllama(**model_inputs).hidden_states[-1]
        memory_embedding = compress_outputs[:, graph_node_num:, :][memory_mask]

        # decoder part
        # if self.training_args.use_peft or self.training_args.fix_encoder:
        #     prompt_answer_embs = self.gllama.get_input_embeddings()(prompt_answer_ids)
        # else:
        prompt_answer_embs = self.raw_embeds(prompt_answer_ids)
        decoder_mem_flag = (prompt_answer_ids >= self.vocab_size) & (prompt_answer_ids < self.vocab_size + self.mem_size)   # only mem tokens
        assert torch.all(decoder_mem_flag.sum(dim=1) == self.mem_size), f"Each row in the mask must sum to {self.mem_size}"
        prompt_answer_embs[decoder_mem_flag] = memory_embedding  # replace memory slots
        # replace special token's embedding from self.memory_token_embed
        special_prompt = prompt_answer_ids >= self.vocab_size_with_mem
        prompt_answer_embs[special_prompt] = self.memory_token_embed(prompt_answer_ids[special_prompt] - self.vocab_size).to(prompt_answer_embs)
        assert torch.all(special_prompt.sum(dim=1) == 1) or torch.all(special_prompt.sum(dim=1) == 2), f"Each row in the mask must sum to 1"

        # if self.training_args.use_peft:
        #     # TODO: ckeck disable_adapter gradient
        #     with self.gllama.disable_adapter():   # no independent decoder;
        #         decoder_outputs = self.gllama(inputs_embeds=prompt_answer_embs, output_hidden_states=True)

        decoder_outputs = self.decoder(
            inputs_embeds=prompt_answer_embs,
            labels=labels,
            output_hidden_states=True
        )

        loss = decoder_outputs.loss
        logits = decoder_outputs.logits

        return {"loss": loss, "logits": logits}
    
    @torch.no_grad()
    def generate(
        self,
        graph_embeds: torch.FloatTensor,
        graph_attention_mask: torch.Tensor,
        rel_position: torch.Tensor,
        edge_attr: torch.Tensor,
        edge_type: torch.Tensor,
        input_ids: torch.Tensor, #  text_tokens + mem_tokens + padding_tokens
        prompt_answer_ids: torch.LongTensor = None, # mem_tokens + prompt_tokens + answer_tokens + padding_tokens
        prompt_attention_mask: torch.Tensor = None,
        labels: Optional[torch.LongTensor] = None,
        attention_mask: torch.Tensor = None,
        text_length: torch.LongTensor = None,
    ):
        model_inputs, memory_mask = self.prepare_encoder_input(
            graph_embeds=graph_embeds,
            graph_attention_mask=graph_attention_mask,
            rel_position=rel_position,
            edge_attr=edge_attr,
            edge_type=edge_type,
            input_ids=input_ids,
            attention_mask=attention_mask,
            text_length=text_length
        )
        graph_node_num = model_inputs['graph_embeds'].shape[1] if model_inputs['graph_embeds'] is not None else 0
        assert torch.all(memory_mask.sum(dim=1) == self.mem_size), f"Each row in the mask must sum to {self.mem_size}"
        # self.gllama.enable_adapter_layers()
        compress_outputs = self.gllama(**model_inputs).hidden_states[-1]
        memory_embedding = compress_outputs[:, graph_node_num:, :][memory_mask]

        # decoder part
        if self.training_args.use_peft or self.training_args.fix_encoder:
            prompt_answer_embs = self.gllama.get_input_embeddings()(prompt_answer_ids)
        else:
            prompt_answer_embs = self.raw_embeds(prompt_answer_ids)
        decoder_mem_flag = (prompt_answer_ids >= self.vocab_size) & (prompt_answer_ids < self.vocab_size + self.mem_size)   # only mem tokens
        assert torch.all(decoder_mem_flag.sum(dim=1) == self.mem_size), f"Each row in the mask must sum to {self.mem_size}"
        prompt_answer_embs[decoder_mem_flag] = memory_embedding  # replace memory slots
        special_prompt = prompt_answer_ids >= self.vocab_size_with_mem
        prompt_answer_embs[special_prompt] = self.memory_token_embed(prompt_answer_ids[special_prompt] - self.vocab_size).to(prompt_answer_embs)
        assert torch.all(special_prompt.sum(dim=1) == 1) or torch.all(special_prompt.sum(dim=1) == 2), f"Each row in the mask must sum to 1"

        # self.gllama.disable_adapter_layers()
        decoder_outputs = self.decoder.generate(
            inputs_embeds=prompt_answer_embs,
            attention_mask=prompt_attention_mask,
            # do_sample=True,
            temperature=0.7,
            max_new_tokens=80,
        )

        return decoder_outputs

    def save_model(
        self,
        save_directory,
        state_dict,
    ):
        # if self.gllama is not None:
            # self.gllama.save_pretrained(
            #     f'{save_directory}/unigte',
            #     state_dict={k.replace('gllama.', ''):v for k, v in state_dict.items() if 'gllama' in k}
            # )
        if not os.path.exists(save_directory):
            os.makedirs(save_directory, exist_ok=True)

        if self.training_args.use_peft:
            self.gllama.save_pretrained(f'{save_directory}/unigte')
        # if self.projector is not None:
        #     if hasattr(self, 'projector') and self.projector is not None:
        torch.save({k.replace('projector.', ''):v for k, v in state_dict.items() if k.startswith('projector')}, f'{save_directory}/projector.pt')

        torch.save({k.replace('edge_projector.', ''):v for k, v in state_dict.items() if k.startswith('edge_projector')}, f'{save_directory}/edge_projector.pt')
        
        # if self.memory_token_embed is not None:
        torch.save({k.replace('memory_token_embed.', ''):v for k, v in state_dict.items() if k.startswith('memory_token_embed')}, f'{save_directory}/memory_token_embed.pt')

        # if self.rel_position_encoder is not None:
        torch.save({k.replace('rel_position_encoder.', ''):v for k, v in state_dict.items() if k.startswith('rel_position_encoder')}, f'{save_directory}/rel_position_encoder.pt')

        # if self.graph_position_theta is not None:
        torch.save({k :v for k, v in state_dict.items() if k.startswith('graph_position_theta')}, f'{save_directory}/graph_position_theta.pt')

    @staticmethod
    def load_model(
            model_args,
            training_args,
            saved_model_path,
    ):
        model = UniGTE(model_args, training_args)

        if training_args.use_peft:
            model.gllama = PeftModel.from_pretrained(model.gllama, os.path.join(saved_model_path, 'unigte'))
            model.gllama = model.gllama.merge_and_unload() # fast inference

        model.projector.load_state_dict(torch.load(f'{saved_model_path}/projector.pt'))
        model.edge_projector.load_state_dict(torch.load(f'{saved_model_path}/edge_projector.pt'))
        model.memory_token_embed.load_state_dict(torch.load(f'{saved_model_path}/memory_token_embed.pt'))
        model.rel_position_encoder.load_state_dict(torch.load(f'{saved_model_path}/rel_position_encoder.pt'))
        model.graph_position_theta.data = torch.load(f'{saved_model_path}/graph_position_theta.pt')['graph_position_theta']

        return model
    
    @staticmethod
    def compute_average_scores_batch_with_prepended_special_score(path_tensor, feature_scores, epsilon=1e-8):
        """
        计算每个批次中每对节点之间的路径边类型分数均值，
        适用于 feature_scores 形状为 [batch_size, num_feature_types + 1]，
        其中 feature_scores[:, 0] 是 special_score。

        参数:
        - path_tensor (torch.LongTensor): 形状为 [batch_size, n_node, n_node, max_dist] 的张量，
        表示每个批次中每对节点之间的最短路径的边类型。
        值为 -1 表示对角线上的节点或不存在最短路径的节点对，其他值为边类型的整数表示。
        - feature_scores (torch.FloatTensor): 形状为 [batch_size, num_feature_types + 1] 的张量，
        表示每个批次中每种边类型对应的分数。
        feature_scores[:, 0] 对应 special_score，feature_scores[:, 1:] 对应实际边类型的分数。

        返回:
        - average_scores (torch.FloatTensor): 形状为 [batch_size, n_node, n_node] 的张量，
        表示每个批次中每对节点之间的路径边类型分数均值。
        对角线上的分数被设置为 0。
        """
        batch_size, n_node, _, max_dist = path_tensor.shape
        device = path_tensor.device  # 获取设备（CPU 或 GPU）

        # 将 -1 映射到 0，其余边类型加 1
        # 这样，-1 -> 0 (special_score)，0 -> 1, 1 -> 2, ..., num_feature_types-1 -> num_feature_types
        shifted_path = path_tensor + 1  # [batch_size, n_node, n_node, max_dist]

        # 确保 shifted_path 中的所有值都在 [0, num_feature_types] 范围内
        # 如果有超出范围的值，可能会导致索引错误
        # 这里假设输入数据已经满足这一条件

        # 创建批次索引，用于高级索引
        # batch_indices 的形状为 [batch_size, 1, 1, 1]，然后扩展到 [batch_size, n_node, n_node, max_dist]
        batch_indices = torch.arange(batch_size, device=device).view(batch_size, 1, 1, 1).expand(batch_size, n_node, n_node, max_dist)

        # 使用高级索引从 feature_scores 中提取对应的分数
        # feature_scores: [batch_size, num_feature_types + 1]
        # shifted_path: [batch_size, n_node, n_node, max_dist]
        # scores: [batch_size, n_node, n_node, max_dist]
        scores = feature_scores[batch_indices, shifted_path]

        # 创建一个有效边的掩码（mask），-1 对应的边无效
        mask_valid = (path_tensor != -1).float()  # [batch_size, n_node, n_node, max_dist]

        # 计算每对节点的有效边的分数总和
        sum_scores = (scores * mask_valid).sum(dim=-1)  # [batch_size, n_node, n_node]

        # 计算每对节点的有效边的数量
        counts = mask_valid.sum(dim=-1)  # [batch_size, n_node, n_node]

        # 获取 special_score，shape: [batch_size, 1, 1]
        special_scores = feature_scores[:, 0].view(batch_size, 1, 1)

        # 计算均值，对于有有效边的节点对，计算 sum_scores / counts
        # 对于没有有效边的节点对，使用 special_scores
        epsilon
        average = torch.where(
            counts > 0,
            sum_scores / (counts + epsilon),
            special_scores
        )  # [batch_size, n_node, n_node]

        # 将对角线上的分数设置为 0
        # 创建一个形状为 [n_node, n_node] 的对角线掩码
        diagonal_mask = torch.eye(n_node, device=device).bool()  # [n_node, n_node]

        # 扩展对角线掩码以匹配批次大小
        # [n_node, n_node] -> [batch_size, n_node, n_node]
        diagonal_mask = diagonal_mask.unsqueeze(0).expand(batch_size, -1, -1)

        # 使用掩码将对角线上的分数设置为 0
        average = average.masked_fill(diagonal_mask, 0.0)
        average = average.to(dtype=torch.bfloat16)

        return average

    
