import torch
import torch.nn as nn
import torch.nn.functional as F
from third_party.phi_4.modeling_phi4mm import Phi4MMForCausalLM, Phi4MMImageEmbedding, _IMAGE_SPECIAL_TOKEN_ID
from transformers.modeling_outputs import CausalLMOutputWithPast
from processor_wrapper import _GRAPH_SPECIAL_TOKEN_ID, _GRAPH_SPECIAL_TOKEN
from typing import *
import transformers
import math
from models.utils import CausalLMOutputWithPastExtended

_IMAGE_SPECIAL_TOKEN_ID = _IMAGE_SPECIAL_TOKEN_ID
_USER_SPECIAL_TOKEN_ID = 200021
_GRAPH_SPECIAL_TOKEN_ID = _GRAPH_SPECIAL_TOKEN_ID

class MLP(nn.Module):
    """
    Simple feed-forward projection head for tabular features.

    Args:
        input_dim (int): Number of input features.
        hidden_dim (int): Size of the intermediate hidden representation.
        output_dim (int): Dimensionality of the output logits or embeddings.
        dropout (float, optional): Drop probability applied after the activation. Defaults to 0.0.
    """
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.0):
        super(MLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)
    
class AHeadSetAttention(nn.Module):
    """
    Predicts a low-rank approximation of the causal adjacency matrix from multiple image embeddings using set attention.

    Args:
        d_model (int): Dimension of the input image embeddings.
        d_max (int): Maximum number of nodes in the target causal graph.
        rank_r (int): Rank used for the low-rank factorization of the adjacency matrix.
        num_heads (int, optional): Number of attention heads. Defaults to 8.
        num_layers (int, optional): Number of attention layers. Defaults to 2.
        dropout (float, optional): Drop probability within attention layers. Defaults to 0.1.
        use_positional_encoding (bool, optional): Whether to use positional encodings for image embeddings. Defaults to False.
        max_images (int, optional): Maximum number of images for positional encoding. Defaults to 8.
    """

    def __init__(
        self,
        d_model: int,
        d_max: int,
        rank_r: int,
        num_heads: int = 8,
        num_layers: int = 2,
        dropout: float = 0.1,
        use_positional_encoding: bool = False,
        max_images: int = 8, # only used if use_positional_encoding is True
    ):
        super(AHeadSetAttention, self).__init__()
        self.d_max = d_max
        self.rank_r = rank_r
        self.use_positional_encoding = use_positional_encoding

        self.node_query = nn.Parameter(torch.randn(d_max, d_model) / math.sqrt(d_model))
        self.null_kv = nn.Parameter(torch.zeros(1, 1, d_model))
        
        self.pos_emb = nn.Embedding(max_images, d_model) if use_positional_encoding else None

        self.blocks = nn.ModuleList()
        for _ in range(num_layers):
            attn = nn.MultiheadAttention(
                embed_dim=d_model,
                num_heads=num_heads,
                dropout=dropout,
                batch_first=True,
            )
            ln1 = nn.LayerNorm(d_model)
            ff = MLP(d_model, 4 * d_model, d_model, dropout=dropout)
            ln2 = nn.LayerNorm(d_model)
            self.blocks.append(nn.ModuleDict({
                'attn': attn,
                'ln1': ln1,
                'ff': ff,
                'ln2': ln2,
            }))

        self.to_L = nn.Linear(d_model, rank_r, bias=False)
        self.to_R = nn.Linear(d_model, rank_r, bias=False)
        self.bias = nn.Parameter(torch.tensor(-0.693))

    def forward(self,
                img_embeds: torch.Tensor, # (B, N_images, d_model), N_images can be 0
                img_mask: Optional[torch.Tensor] = None, # (B, N_images)
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        B = img_embeds.size(0)
        device = img_embeds.device
        dtype = img_embeds.dtype

        if img_mask is None:
            img_mask = torch.ones(B, img_embeds.size(1), device=device, dtype=torch.bool)
        else:
            img_mask = img_mask.bool()

        if self.pos_emb is not None and img_embeds.size(1) > 0:
            pos = torch.arange(img_embeds.size(1), device=device)
            img_embeds = img_embeds + self.pos_emb(pos).unsqueeze(0).to(device=device, dtype=dtype)

        # give the null key/value to handle missing images
        null = self.null_kv.expand(B, 1, -1).to(device=device, dtype=dtype)  # (B, 1, d_model)
        img_embeds = torch.cat([null, img_embeds], dim=1)  # (B, 1 + N_images, d_model)

        img_mask = torch.cat([torch.ones(B, 1, device=device, dtype=torch.bool), img_mask], dim=1)  # (B, 1 + N_images)

        key_padding_mask = ~img_mask  # (B, 1 + N_images)

        q = self.node_query.unsqueeze(0).expand(B, -1, -1).to(device=device, dtype=dtype)  # (B, d_max, d_model)

        for block in self.blocks:
            attn_out, _ = block['attn'](
                query=q,
                key=img_embeds,
                value=img_embeds,
                key_padding_mask=key_padding_mask,
                need_weights=False,
            )  # (B, d_max, d_model)
            q = block['ln1'](q + attn_out)
            ff_out = block['ff'](q)  # (B, d_max, d_model)
            q = block['ln2'](q + ff_out)

        L = self.to_L(q) # (B, d_max, r)
        R = self.to_R(q) # (B, d_max, r)
        A_logits = torch.bmm(
            L,
            R.transpose(1, 2)
        ) + self.bias

        return L, R, A_logits, q


class CausalGraphMessageGateLayers(nn.Module):
    '''
    m^out = P W_out h, m^in = P^T W_in h
    h <- LN(h + m^out + m^in)
    h <- LN(h + FFN(h))
    '''

    def __init__(self, d_model: int, dropout: float = 0.1):
        super(CausalGraphMessageGateLayers, self).__init__()
        self.out_lin = nn.Linear(d_model, d_model, bias=False)
        self.in_lin = nn.Linear(d_model, d_model, bias=False)
        self.ln1 = nn.LayerNorm(d_model)
        self.ff = MLP(d_model, 4 * d_model, d_model, dropout=dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        h: torch.Tensor, # (B, d_max, d_model)
        P_out: torch.Tensor, # (B, d_max, d_max)
        P_in: torch.Tensor,  # (B, d_max, d_max)
    ) -> torch.Tensor:
        v_out = self.out_lin(h)  # (B, d_max, d_model)
        v_in = self.in_lin(h)    # (B, d_max, d_model)

        m_out = torch.bmm(P_out, v_out)  # (B, d_max, d_model)
        m_in = torch.bmm(P_in, v_in)     # (B, d_max, d_model)

        h = self.ln1(h + self.dropout(m_out + m_in))
        h = self.ln2(h + self.dropout(self.ff(h)))
        return h # (B, d_max, d_model)
    
class GraphEmbEncoder(nn.Module):
    """
    P = sigmoid(A) \odot mask
    h <- Gate(h, P, P^T) repeated
    g_tokens = CrossAttn(Q=Q, K=h, V=h) where Q are learnable graph tokens, get global info
    n_tokens = CrossAttn(Q=h, K=g, V=g) to update node tokens with global info

    
    Input:
        A_logits: (B, d_max, d_max) - logits of adjacency matrix for directional graph
        q: (B, d_max, d_model) - initial node embeddings from set attention
        node_mask: (B, d_max) - mask for valid nodes

    Output:
        node_tokens: (B, d_max, d_model) - updated node embeddings after message passing
        global_token: (B, num_graph_tokens - d_max, d_model) - rest will be global graph tokens
    """
    def __init__(
        self,
        d_model: int,
        d_max: int,
        num_graph_tokens: int,
        mp_layers: int = 2,
        attn_heads: int = 8,
        dropout: float = 0.1,
        add_self_loops: bool = True, # whether to add self-loops to adjacency matrix
    ):
        super(GraphEmbEncoder, self).__init__()
        assert num_graph_tokens >= d_max, f"num_graph_tokens ({num_graph_tokens}) must be >= d_max ({d_max})"
        self.d_model = d_model
        self.d_max = d_max
        self.num_global_tokens = num_graph_tokens - d_max
        self.add_self_loops = add_self_loops

        self.mp_blocks = nn.ModuleList([
            CausalGraphMessageGateLayers(d_model=d_model, dropout=dropout)
            for _ in range(mp_layers)
        ])

        # global <- nodes
        self.global_queries = nn.Parameter(torch.randn(self.num_global_tokens, d_model) / math.sqrt(d_model))
        self.g_cross_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=attn_heads,
            dropout=dropout,
            batch_first=True,
        )
        self.g_ln1 = nn.LayerNorm(d_model)
        self.g_ff = MLP(d_model, 4 * d_model, d_model, dropout=dropout)
        self.g_ln2 = nn.LayerNorm(d_model)

        # nodes <- global
        self.n_cross_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=attn_heads,
            dropout=dropout,
            batch_first=True,
        )
        self.n_ln1 = nn.LayerNorm(d_model)
        self.n_ff = MLP(d_model, 4 * d_model, d_model, dropout=dropout)
        self.n_ln2 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

        self.node_type = nn.Parameter(torch.zeros(1, 1, d_model)) # This token is node type
        self.global_type = nn.Parameter(torch.zeros(1, 1, d_model)) # This token is global graph type

        self.null_node = nn.Parameter(torch.zeros(1, 1, d_model)) # This is the null node embedding
        

    def _build_transition(
        self, 
        A_logits: torch.Tensor, # (B, d_max, d_max)
        node_mask: torch.Tensor, # (B, d_max)
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Builds the transition matrices P_out and P_in from adjacency logits and node mask.
        """

        B, D, _ = A_logits.shape
        device = A_logits.device

        A = A_logits.float()
        P = torch.sigmoid(A)  # (B, d_max, d_max)

        m = node_mask.float()  # (B, d_max)
        P = P * (m.unsqueeze(-1) * m.unsqueeze(-2))  # apply node mask

        if self.add_self_loops:
            eye = torch.eye(D, device=device).unsqueeze(0)  # (1,D,D)
            P = P * (1.0 - eye) + eye

        eps = 1e-6
        P_out = P / (P.sum(dim=-1, keepdim=True) + eps)  # row-normalize

        P_in = P.transpose(1, 2)
        P_in = P_in / (P_in.sum(dim=-1, keepdim=True) + eps)  # row-normalize
        
        P_out = P_out.to(dtype=A_logits.dtype, device=device) # (B, d_max, d_max)
        P_in = P_in.to(dtype=A_logits.dtype, device=device) # (B, d_max, d_max)

        return P_out, P_in
    
    def forward(
        self,
        A_logits: torch.Tensor, # (B, d_max, d_max)
        q: torch.Tensor, # (B, d_max, d_model)
        node_mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        B, D, H = q.shape
        assert D == self.d_max, f"Expected d_max={self.d_max}, got {D}"
        assert H == self.d_model, f"Expected d_model={self.d_model}, got {H}"
       
        if node_mask is None:
            node_mask = torch.ones(B, D, device=q.device, dtype=torch.bool)
        else:
            node_mask = node_mask.bool()

        # invalid node -> null node
        null = self.null_node.expand(B, D, H).to(device=q.device, dtype=q.dtype)  # (B, 1, d_model)
        h = torch.where(node_mask.unsqueeze(-1), q, null) 

        # Msg Passing
        P_out, P_in = self._build_transition(A_logits, node_mask)  # (B, d_max, d_max)
        for mp in self.mp_blocks:
            h = mp(h, P_out, P_in)  # (B, d_max, d_model)
            h = torch.where(node_mask.unsqueeze(-1), h, null)

        # node tokens
        node_tokens = h + self.node_type.to(device=h.device, dtype=h.dtype)  # (B, d_max, d_model) 

        # global <- nodes
        g = self.global_queries.unsqueeze(0).expand(B, -1, -1).to(device=h.device, dtype=h.dtype)  # (B, num_global_tokens, d_model)
        global_tokens = g + self.global_type.to(device=h.device, dtype=h.dtype)

        key_padding_mask = ~node_mask  # (B, d_max)
        g_attn_out, _ = self.g_cross_attn(
            query=g,
            key=node_tokens, value=node_tokens, key_padding_mask=key_padding_mask,
            need_weights=False,
        )  # (B, num_global_tokens, d_model)
        global_tokens = self.g_ln1(global_tokens + g_attn_out)
        global_tokens = self.g_ln2(global_tokens + self.dropout(self.g_ff(global_tokens)))

        # nodes <- global, update node tokens with global info
        n_attn_out, _ = self.n_cross_attn(
            query=node_tokens,
            key=global_tokens, value=global_tokens, key_padding_mask=None,
            need_weights=False,
        )  # (B, d_max, d_model)
        node_tokens = self.n_ln1(node_tokens + n_attn_out)
        node_tokens = self.n_ln2(node_tokens + self.dropout(self.n_ff(node_tokens)))

        # Reset invalid nodes to null node
        null_node_tokens = (self.null_node + self.node_type).expand(B, D, H).to(device=h.device, dtype=h.dtype)
        node_tokens = torch.where(node_mask.unsqueeze(-1), node_tokens, null_node_tokens)

        return node_tokens, global_tokens


class Phi4CausalQA(Phi4MMForCausalLM):
    """
     Phi-4 model wrapper for Causal Question Answering with integrated causal graph reasoning.

    Args:
        config: Model configuration object.
        d_max (int, optional): Maximum number of nodes in the causal graph. Defaults to 10.
        rank_r (int, optional): Rank for low-rank factorization of the adjacency matrix. Defaults to 8.
        num_graph_tokens (int, optional): Number of graph embedding tokens to insert. Defaults to 16.
        use_positional_encoding (bool, optional): Whether to use positional encodings for image embeddings. Defaults to False. (For sequential images)

    """

    def __init__(
        self,
        config,
        d_max: int = 10,
        rank_r: int = 8,
        num_graph_tokens: int = 16,
        use_positional_encoding: bool = True,
    ):
        super().__init__(config)

        self.signature = ['input_ids', 'input_image_embeds', 'image_sizes', 'image_attention_mask', 'input_audio_embeds', 'audio_embed_sizes', 'audio_attention_mask', 'attention_mask', 'input_mode', 'num_images_per_sample', "node_mask"]

        self.d_model = int(config.hidden_size)
        self.d_max = int(d_max)
        self.rank_r = int(rank_r)
        self.num_graph_tokens = int(num_graph_tokens)
        self.graph_token_id = int(_GRAPH_SPECIAL_TOKEN_ID) if _GRAPH_SPECIAL_TOKEN_ID is not None else None

        self.use_positional_encoding = bool(use_positional_encoding)

        # AHead + GraphEmb
        self.a_head = AHeadSetAttention(
            d_model=self.d_model,
            d_max=self.d_max,
            rank_r=self.rank_r,
            use_positional_encoding=self.use_positional_encoding,
        )
        self.graph_emb = GraphEmbEncoder(
            d_model=self.d_model,
            d_max=self.d_max,
            num_graph_tokens=self.num_graph_tokens,
        )

        # Image embedding summary projection
        img_embed: Phi4MMImageEmbedding = self.model.embed_tokens_extend.image_embed
        image_dim_out = int(getattr(img_embed, "image_dim_out", self.d_model))
        self.img_summary_proj = nn.Linear(image_dim_out, self.d_model, bias=False)

        # cache for A_logits during generation
        self._cached_A_logits: Optional[torch.Tensor] = None

    def _get_past_len(self, past_key_values: Optional[transformers.cache_utils.DynamicCache]) -> int:
        if past_key_values is None:
            return 0
        get_len = getattr(past_key_values, "get_seq_length", None)
        if callable(get_len):
            return int(get_len())
        if isinstance(past_key_values, (tuple, list)) and len(past_key_values) > 0:
            k0 = past_key_values[0][0]  # (B, heads, T, dim)
            return int(k0.shape[-2])
        return 0

    def _is_prefill(self, past_key_values: Optional[transformers.cache_utils.DynamicCache]) -> bool:
        return self._get_past_len(past_key_values) == 0

    def _mask_graph_labels(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        if labels is None or self.graph_token_id is None:
            return labels
        labels = labels.clone()
        labels[input_ids == self.graph_token_id] = -100
        return labels

    def _replace_graph_tokens(
        self,
        hidden_states: torch.Tensor,  # (B,S,D)
        input_ids: torch.Tensor,      # (B,S)
        graph_tokens: torch.Tensor,   # (B,M,D)
    ) -> torch.Tensor:
        if self.graph_token_id is None or self.num_graph_tokens <= 0:
            raise ValueError("Graph token ID is not defined or num_graph_tokens <= 0.")
            return hidden_states

        gtid = self.graph_token_id
        if not (input_ids == gtid).any():
            raise ValueError("No graph placeholder tokens found in input_ids.")
            return hidden_states

        with torch.no_grad():
            positions = torch.nonzero(input_ids == gtid, as_tuple=True)
        num_pos = positions[0].numel()

        B, M, D = graph_tokens.shape
        expected = B * M
        if num_pos != expected:
            raise ValueError(
                f"Graph placeholder token count mismatch: found {num_pos}, expected {expected} "
                f"(B={B}, M={M})"
            )

        merged = graph_tokens.reshape(B * M, D).to(device=hidden_states.device, dtype=hidden_states.dtype)

        # Same as phi4mm image/audio embedding replacement, but for graph tokens
        # Avoid bf16 index_put issues by disabling autocast
        with torch.autocast(device_type=hidden_states.device.type, enabled=False):
            out = hidden_states.index_put(indices=positions, values=merged, accumulate=False)
        return out

    def _extract_single_image_embedding(
        self,
        input_image_embeds: torch.Tensor, # (N_images, Max_Crops, C, H, W) or (N_images, C, H, W)
        image_attention_mask: Optional[torch.Tensor], # (N_images, Max_Crops, Ph, Pw)
    ) -> torch.Tensor:
        """
        Extracts embedding for each single image.

        Args:
            input_image_embeds (torch.Tensor): Image embeddings of shape (N_images, Max_Crops, C, H, W).
            image_attention_mask (Optional[torch.Tensor]): Attention mask for images of shape [N_images, Max_crops, Ph, Pw].

        Returns:
            torch.Tensor: Embeddings g of shape (N_images, D_model).
        """

        if input_image_embeds is None:
            p = next(self.parameters())
            return torch.zeros(0, self.d_model, device=p.device, dtype=p.dtype)

        if input_image_embeds.dim() == 4:
            input_image_embeds = input_image_embeds.unsqueeze(1)  # (N,1,3,H,W)

        N_img, max_crops, C, H, W = input_image_embeds.shape
        img_embed : Phi4MMImageEmbedding = self.model.embed_tokens_extend.image_embed
        img_param = next(img_embed.parameters())
        v_device = img_param.device
        v_dtype = img_param.dtype

        flat = input_image_embeds.flatten(0, 1).to(device=v_device, dtype=v_dtype)  # (N_images*Max_Crops, C, H, W)

        if image_attention_mask is not None and image_attention_mask.numel() > 0:
            attn = image_attention_mask.flatten(0, 1).to(device=v_device)
            feats = img_embed.get_img_features(flat, attention_mask=attn.bool())  # (N_images*Max_Crops, Ph, Pw)
        else:
            feats = img_embed.get_img_features(flat) # (N_images*Max_Crops, P, D_img)

        pooled = feats.mean(dim=1) # (N_images*Max_Crops, D_img), D_img: 1152
        restored = pooled.view(N_img, max_crops, -1)  # (N_images, Max_Crops, D_img)
        restored = restored.mean(dim=1)  # (N_images, D_img)

        # Use summary projection to get final image embedding
        proj = self.img_summary_proj(restored.to(self.img_summary_proj.weight.dtype)) # (N_images, D_model), D_model: 3072
        proj = proj.to(device=next(self.parameters()).device, dtype=next(self.parameters()).dtype) 
        return proj  

    def _extract_padded_image_embeddings(
        self,
        num_images_per_sample: List[int],
        input_image_embeds: Optional[torch.Tensor],
        image_attention_mask: Optional[torch.Tensor],
        device: torch.device,
        dtype: torch.dtype,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Extracts and pads image embeddings for a batch of samples.
        Args:
            num_images_per_sample (List[int]): Number of images for each sample in the batch.
            input_image_embeds (Optional[torch.Tensor]): Image embeddings of shape (N_total, Max_Crops, C, H, W).
            image_attention_mask (Optional[torch.Tensor]): Attention mask for images of shape [N_total, Max_crops, Ph, Pw].
            device (torch.device): Target device for the output tensors.
            dtype (torch.dtype): Target data type for the output tensors.
        """
        B0 = len(num_images_per_sample)
        if input_image_embeds is None or (sum(num_images_per_sample) == 0):
            imgs = torch.zeros(B0, 0, self.d_model, device=device, dtype=dtype)
            mask = torch.zeros(B0, 0, device=device, dtype=torch.bool)
            return imgs, mask

        N_total = int(input_image_embeds.size(0))
        if N_total != sum(num_images_per_sample):
            raise ValueError(
                f"input_image_embeds.size(0) ({N_total}) does not match sum(num_images_per_sample) ({sum(num_images_per_sample)})"
            )

        single_img_emb = self._extract_single_image_embedding(
            input_image_embeds=input_image_embeds,
            image_attention_mask=image_attention_mask,
        ) # (N_images, D_model)
        
        Nmax = max(num_images_per_sample) if B0 > 0 else 0
        imgs = torch.zeros(B0, Nmax, self.d_model, device=single_img_emb.device, dtype=single_img_emb.dtype)
        mask = torch.zeros(B0, Nmax, device=single_img_emb.device, dtype=torch.bool)

        idx = 0
        for b, n in enumerate(num_images_per_sample):
            if n > 0:
                imgs[b, :n] = single_img_emb[idx:idx + n]
                mask[b, :n] = True
            idx += n

        # move to target device/dtype
        imgs = imgs.to(device=device, dtype=dtype)
        mask = mask.to(device=device)
        return imgs, mask

    # ---------- forward ----------
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None, # Ignore external position_ids
        past_key_values: Optional[Any] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None, # Do not use inputs_embeds directly
        input_image_embeds: Optional[torch.FloatTensor] = None,
        num_images_per_sample: Optional[List[int]] = None,
        image_sizes: Optional[torch.LongTensor] = None,
        image_attention_mask: Optional[torch.Tensor] = None,
        input_audio_embeds: Optional[torch.FloatTensor] = None,
        audio_embed_sizes=None,
        audio_attention_mask=None,
        input_mode=None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        node_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> CausalLMOutputWithPastExtended:
        if inputs_embeds is not None:
            raise ValueError("Phi4CausalQA wrapper does not support passing `inputs_embeds` directly.")
        if input_ids is None:
            raise ValueError("input_ids cannot be None.")


        B_full = input_ids.size(0)
        device = input_ids.device
        dtype = self.get_input_embeddings().weight.dtype

        prefill = self._is_prefill(past_key_values)


        # == generate() beam search handling ==
        # if num_images_per_sample is the original batch size, and input_ids is expanded to B_full=B0 * num_beams,
        # We need to use the first B0 to extract image embeddings and then expand them to beams
        if num_images_per_sample is None:
            # default: no images
            num_images_per_sample = [0] * B_full

        B0 = len(num_images_per_sample)
        expand = (B_full // B0) if (B0 > 0 and B_full % B0 == 0) else 1

        def take_base(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
            if x is None or expand == 1:
                return x
            # B_full = B0 * expand, take copy 0
            return x.view(B0, expand, *x.shape[1:])[:,0].contiguous()

        base_input_ids = take_base(input_ids) if (prefill and expand > 1) else input_ids

        # While pre-filling, calculate A_logits + graph_tokens, and replace graph placeholders
        if prefill:
            # Clear cached values during prefill
            self._cached_A_logits = None

            #Token sequence without beams
            imgs, img_mask = self._extract_padded_image_embeddings(
                num_images_per_sample=num_images_per_sample,
                input_image_embeds=input_image_embeds,
                image_attention_mask=image_attention_mask,
                device=device,
                dtype=dtype,
            ) # (B0, Nmax, D_model), (B0, Nmax)

            _, _, A_logits_base, node_q_base = self.a_head(
                img_embeds=imgs,
                img_mask=img_mask,
            )  # (B0, d_max, r), (B0, d_max, r), (B0, d_max, d_max)

            node_mask_base = node_mask
            if node_mask_base is None:
                node_mask_base = torch.ones(node_q_base.size(0), self.d_max, device=node_q_base.device, dtype=torch.bool)
            else:
                # expand
                node_mask_base = take_base(node_mask_base) if (prefill and expand > 1) else node_mask_base
                node_mask_base = node_mask_base.to(device=node_q_base.device, dtype=torch.bool)
                if node_mask_base.dim() == 1:
                    node_mask_base = node_mask_base.unsqueeze(0)
                if node_mask_base.size(1) != self.d_max:
                    cur_d = node_mask_base.size(1)
                    if cur_d < self.d_max:
                        padded = torch.zeros(node_mask_base.size(0), self.d_max, device=node_mask_base.device, dtype=torch.bool)
                        padded[:, :cur_d] = node_mask_base
                        node_mask_base = padded
                    else:
                        node_mask_base = node_mask_base[:, : self.d_max]


            node_tokens_base, global_tokens_base = self.graph_emb(A_logits=A_logits_base, q=node_q_base, node_mask=node_mask_base) # (B0, d_max, d_model), (B0, num_global_tokens, d_model)
            graph_tokens_base = torch.cat([node_tokens_base, global_tokens_base], dim=1)  # (B0, num_global_tokens + d_max = num_graph_tokens, d_model)


            # Expand to beams if needed
            if expand > 1:
                A_logits_base : torch.Tensor
                graph_tokens_base : torch.Tensor
                node_q_base : torch.Tensor
                node_tokens_base : torch.Tensor
                global_tokens_base : torch.Tensor
                A_logits = A_logits_base.repeat_interleave(expand, dim=0)
                graph_tokens = graph_tokens_base.repeat_interleave(expand, dim=0)
                node_q = node_q_base.repeat_interleave(expand, dim=0)
                node_tokens = node_tokens_base.repeat_interleave(expand, dim=0)
                global_tokens = global_tokens_base.repeat_interleave(expand, dim=0)
                
            else:
                A_logits = A_logits_base
                graph_tokens = graph_tokens_base
                node_q = node_q_base
                node_tokens = node_tokens_base
                global_tokens = global_tokens_base

            # Solving bugs of original Phi-4 implementation - It mush has at least one image during training
            # By giving a dummy zero image when there is no image
            
            has_img_token = bool((base_input_ids == _IMAGE_SPECIAL_TOKEN_ID).any().item())
            if has_img_token and input_image_embeds is None:
                raise ValueError("input_ids contains image special tokens but input_image_embeds is None.")
            
            dummy_image = None
            if self.training and (input_image_embeds is None) and (input_audio_embeds is None):
                crop = getattr(self.model.embed_tokens_extend.image_embed, 'crop_size', 336)
                dummy_image = torch.zeros(1, 3, crop, crop, device=device, dtype=torch.float32)

            base_hidden_states = self.model.embed_tokens_extend(
                input_ids=base_input_ids,
                input_embeds=None,
                input_image_embeds=input_image_embeds if input_image_embeds is not None else dummy_image,
                input_audio_embeds=input_audio_embeds,
                image_sizes=image_sizes,
                image_attention_mask=image_attention_mask,
                audio_embed_sizes=audio_embed_sizes,
                audio_attention_mask=audio_attention_mask,
                audio_projection_mode='vision',
                wte=self.model.embed_tokens,
            )  # (B0, seq_len, d_model) seq_len: Text + Image Tokens + Graph Tokens

            if expand > 1:
                base_hidden_states : torch.Tensor
                hidden_states = base_hidden_states.repeat_interleave(expand, dim=0) # (B_full, seq_len, d_model)
            else:
                hidden_states = base_hidden_states

            # Replace graph placeholder tokens with computed graph tokens
            hidden_states = self._replace_graph_tokens(
                hidden_states=hidden_states,
                input_ids=input_ids, # (B_full, seq_len)
                graph_tokens=graph_tokens, # (B_full, graph_M, d_model)
            )

            # Prepare labels, mask graph token labels
            labels_full = labels
            if labels_full is not None and expand > 1 and labels_full.size(0) == B0:
                labels_full = labels_full.repeat_interleave(expand, dim=0)
            labels_full = self._mask_graph_labels(input_ids, labels_full)

            if not self.training:
                self._cached_A_logits = A_logits.detach()

            outputs = super().forward(
                input_ids=None,
                inputs_embeds=hidden_states,
                attention_mask=attention_mask,
                position_ids=None,
                past_key_values=past_key_values,
                input_image_embeds=None, 
                image_sizes=None,
                image_attention_mask=None,
                input_audio_embeds=None,
                audio_embed_sizes=None,
                audio_attention_mask=None,
                input_mode=input_mode,
                labels=labels_full,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                cache_position=cache_position,
                num_logits_to_keep=num_logits_to_keep,
            )

            return CausalLMOutputWithPastExtended(
                loss=getattr(outputs, 'loss', None),
                logits=outputs.logits,
                past_key_values=outputs.past_key_values,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
                A_logits=A_logits,
                q=node_q,
                node_tokens=node_tokens,
                global_tokens=global_tokens,
            )

        # == Decode step: No graph token insertion, just use cached positions ==
        outputs = super().forward(
            input_ids=input_ids,
            inputs_embeds=None,
            attention_mask=attention_mask,
            position_ids=None,
            past_key_values=past_key_values,
            input_image_embeds=None, 
            image_sizes=None,
            image_attention_mask=None,
            input_audio_embeds=None,
            audio_embed_sizes=None,
            audio_attention_mask=None,
            input_mode=input_mode,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            num_logits_to_keep=num_logits_to_keep,
        )

        return CausalLMOutputWithPastExtended(
            loss=getattr(outputs, "loss", None),
            logits=outputs.logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            A_logits=self._cached_A_logits,
        )
