from transformers.modeling_outputs import CausalLMOutputWithPast
import torch
from typing import Any, Optional

class CausalLMOutputWithPastExtended(CausalLMOutputWithPast):
    """
    Extension of `CausalLMOutputWithPast`
    """
    A_logits: torch.Tensor = None  # [B, d_max, d_max]

    def __init__(
        self,
        loss: Optional[torch.FloatTensor] = None,
        logits: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Any] = None,
        hidden_states: Optional[Any] = None,
        attentions: Optional[Any] = None,
        A_logits: Optional[torch.Tensor] = None,
        q: Optional[torch.Tensor] = None, # (B, d_max, d_model)
        node_tokens: Optional[torch.Tensor] = None, # (B, d_max, d_model)
        global_tokens: Optional[torch.Tensor] = None, # (B, num_graph_tokens - d_max, d_model)
    ):
        super().__init__(
            loss=loss,
            logits=logits,
            past_key_values=past_key_values,
            hidden_states=hidden_states,
            attentions=attentions,
        )
        self.A_logits = A_logits
        self.q = q
        self.node_tokens = node_tokens
        self.global_tokens = global_tokens