import torch
import torch.nn as nn

class QFormer(nn.Module):
    def __init__(
        self,
        gnn_hidden_dim: int = 768,  # GNN中节点向量的维度(自注意力时的内部维度)
        llm_hidden_dim: int = 4096, # LLM的embedding维度 
        num_heads: int = 8,
        num_layers: int = 2,
        num_queries: int = 8
    ):
        """
        :param gnn_hidden_dim:   GNN 输出的节点向量维度
        :param llm_hidden_dim:   要映射到的 LLM embedding 维度
        :param num_heads:        Multi-head attention 的头数
        :param num_layers:       TransformerEncoder 的层数
        :param num_queries:      学习到的 query 向量数量
        """
        super().__init__()
        # 8 个固定长度的 learnable queries，每个维度为 gnn_hidden_dim
        self.query_embeddings = nn.Parameter(
            torch.randn(num_queries, gnn_hidden_dim)
        )

        # ---------------------------
        # 1) 使用 gnn_hidden_dim 做自注意力
        # ---------------------------
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=gnn_hidden_dim, 
            nhead=num_heads, 
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

        # ---------------------------
        # 2) 在 QFormer 最终输出时，用线性层映射到 LLM 需要的维度
        # ---------------------------
        self.proj = nn.Linear(gnn_hidden_dim, llm_hidden_dim)

    def forward(self, node_embeds: torch.Tensor, batch_index: torch.Tensor):
        """
        :param node_embeds:  形状 [num_nodes, gnn_hidden_dim]，
                             所有图的节点向量拼接到一起
        :param batch_index:  形状 [num_nodes]，表明每个节点属于哪个图 (sample)
        :return: 形状 [batch_size, num_queries, llm_hidden_dim]，
                 每个图 8 个 query 的输出（已映射到与 LLM 维度一致）。
        """
        device = node_embeds.device
        batch_size = batch_index.max().item() + 1
        outputs = []

        for b in range(batch_size):
            # 取出第 b 个图的所有节点向量: [n_b, gnn_hidden_dim]
            single_graph_nodes = node_embeds[batch_index == b]

            # queries: [1, num_queries, gnn_hidden_dim]
            queries = self.query_embeddings.unsqueeze(0).to(device)
            # 节点向量: [1, n_b, gnn_hidden_dim]
            single_graph_nodes = single_graph_nodes.unsqueeze(0)

            # 拼起来: [1, (num_queries + n_b), gnn_hidden_dim]
            combined = torch.cat([queries, single_graph_nodes], dim=1)

            # 经过 TransformerEncoder
            encoded = self.encoder(combined)
            # 只取出前 num_queries 部分 (query part)
            query_outputs = encoded[:, :queries.shape[1], :]  # [1, num_queries, gnn_hidden_dim]

            # 最后用线性层投影到 llm_hidden_dim
            query_outputs = self.proj(query_outputs)          # [1, num_queries, llm_hidden_dim]

            outputs.append(query_outputs)

        # 拼成 [batch_size, num_queries, llm_hidden_dim]
        return torch.cat(outputs, dim=0)