\section{Methodology}

\subsection{Problem Formulation}
\label{sec:problem}
Let $G_x = (\mathcal{V}_x,\mathcal{E}_x,\mathcal{R}_x)$ be the 1-hop star subgraph centered on query entity $c$, extracted from Wikidata~\cite{wikidata}.
Given query $x$ and corresponding $G_x$, the model must generate answer $y$ grounded in the factual triples contained in $G_x$.  
The task can be seen as a \textit{knowledge-conditioned} language-modelling objective, where the conditioning signal is a compact representation of $G_x$ that fits within the LLM's context.



\subsection{Model Architecture}
\label{sec:arch}

\begin{figure}
    \centering
    \includegraphics[width=1\linewidth]{images/KoRe.pdf}
    \caption{KoRe architecture: the query entity is used to extract a star-graph from Wikidata (1), which gets encoded by a TransformerConv GNN with GraphNorm (2), compressed into $Q$ discrete tokens via directional residual vector quantization (3), aligned to the LLM embedding space (4), and injected at the \texttt{<KG\_EMBEDDING>} placeholder before autoregressive generation (5).}
    \label{fig:arch}
\end{figure}

\textsc{KoRe} is composed by four trainable modules grafted onto a largely frozen Qwen3-8B backbone with LoRA adaptation (Figure~\ref{fig:arch}):
\begin{enumerate}
  \item \textbf{Graph extraction:} 
    This module uses the relevant entities in the query to extract from WikiData the relevant sub-graph;
  \item \textbf{Graph Encoding:} 
    First, entities and relations are mapped to dense vectors using a frozen sentence encoder (Qwen-Embeddings); then, a TransformerConv network, with edge-type embeddings and GraphNorm, produces a single $d_{\text{gnn}}$-dimensional graph summary;
  \item \textbf{Vector Quantization:} 
    The summary is compressed into $Q$ discrete codebook indices $(i_1,\dots,i_Q)$;
  \item \textbf{Alignment:} 
    The quantised discrete embeddings are projected into the LLM token space, mean-standardisation matching is applied, and a \texttt{<KG\_EMBEDDING>} placeholder is replaced in the tool-response template.
\end{enumerate}


\paragraph{Graph Extraction \& Encoding}
\label{sec:GNN}
For each textual instance, we extract relevant knowledge in the form of star graphs centered around key entities mentioned in the text. 
In our datasets, central entities are pre-annotated. In production scenarios, entity linking tools or LLM-based entity recognition would identify and disambiguate entities before graph extraction.
To manage computational costs and focus on the most relevant information, we implement a neighbor selection strategy that ranks entities by their global PageRank scores.

To be fed into a Graph Neural Network, we embed the graph nodes and edges using their labels passed through a frozen sentence encoder $\phi:\text{text}\rightarrow\mathbb{R}^{d_\phi}$, where the hidden size $d_\phi$ determines both the node and edge feature dimensions fed to the GNN ($d_\phi = d_{\text{node}} = d_{\text{edge}}$). 
Using this approach, the initial node and edge representations are already aligned with the text domain, making the mapping to the LLM embedding space easier.

The computed embeddings are fed into a \texttt{TransformerConv} GNN layer from \cite{shi2020masked}, followed by \texttt{GraphNorm} from \cite{cai2021graphnormprincipledapproachaccelerating} and residual connections.
This process aggregates the graph information into the central node. 
To obtain the final graph embedding $\mathbf{g} \in \mathbb{R}^d$, we perform central-node pooling by selecting only the representation of the central node.
This summary $\mathbf{g}$ is then passed to the residual vector quantization (RVQ) layer to be converted into $Q$ knowledge tokens.

\paragraph{Directional Residual Vector Quantisation}
\label{sec:rvq}
For compressing the graph representation into multiple discrete vector indices, we follow Meta's GQT \cite{wang2025learninggraphquantizedtokenizers} and use a residual vector quantization strategy.
RVQ iteratively subtracts the chosen code vector from the residual: $r_{t+1}=r_t - e_t$.
However, preliminary experiments consistently lead to codebook collapse. 
To mitigate this issue, we modified the formulation to make use of a $\ell_2$-normalized codebook and a cosine-similarity selection mechanism, as proposed in \cite{yu2022vectorquantizedimagemodelingimproved}, with a \textbf{directional} update:
\[
r_{t+1}=r_t - \langle r_t, c_t\rangle\,c_t,
\]
which removes only the \textit{component} of the residual along the chosen code direction. 
This incentivises codes to span orthogonal subspaces, resulting in higher codebook capacity and diversity.

The training loss combines a directional commitment term and a final residual norm:
\[
\mathcal{L}_{\text{RVQ}}=
\beta\sum_{t=0}^{Q-1}\bigl(1-\cos(r_t,c_t)\bigr)
+\|r_Q\|^2.
\]
This, using the Straight-Through Estimator (STE), propagates the gradients to the encoder, ensuring the generated representation uses the codebooks properly.
The codes are updated using an exponential moving average (EMA) of the input residuals with dead-code reset after $N_{dead}*\texttt{codebook\_size}$ training graphs of unuse (i.e., we require the codes to have a probability of being samples of $1/N_{dead}$).

\paragraph{Alignment and Injection}
\label{sec:alignment}
The $Q$ discrete quantized tokens $\mathbf{Z} = [\hat{k}_1; \dots; \hat{k}_Q]$ are then processed through a residual MLP $f_{\text{out}}$ and a linear skip connection to project these tokens into the LLM's embedding dimension $d_{\text{llm}}$.
To ensure the injected tokens are numerically stable, we normalize each graph representation by normalizing its sequence of tokens to a mean of 0 and a variance of 1 across the $Q$ and $d_\text{llm}$ dimensions. 
This normalized tensor is then scaled by the text embedding standard deviation and adjusted by a learned mean shift $\mu_{\text{llm}}$:
\[
\tilde{\mathbf{Z}} = \text{StdMatch}(\text{LN}(f_{\text{out}}(\mathbf{Z}) + \text{skip}(\mathbf{Z}))) + \mu_{\text{llm}}
\]
The aligned tokens $\tilde{\mathbf{Z}}$ replace a special placeholder token \texttt{<KG\_EMBEDDING>} within the prompt, using a prefix mechanism similar to \cite{barmettler2025conceptformerefficientuseknowledgegraph}.


\paragraph{Training Objective}
The final composite loss used is
\[
\mathcal{L}=
\mathcal{L}_{\text{LM}_{(\text{answer tokens only})}}
+\mathcal{L}_{\text{RVQ}},
\]
where $\mathcal{L}_{\text{LM}}$ is the standard causal language-modelling cross-entropy computed solely on the target answer span.
This focus prevents the model from wasting capacity on reconstructing the prompt structure.



\section{Experimental Design}
\label{sec:experiments}
To train and evaluate our model, we use both synthetic corpora and QA benchmarks.

\subsection{Datasets}
\label{sec:datasets}


\paragraph{Tri-REx\cite{barmettler_2025_15166163}}
The dedicated dataset Tri-REx Star \cite{barmettler_2025_15165974} provides the star graphs extracted from Wikidata for each sentence in Tri-REx.
Following the ConceptFormer \cite{barmettler2025conceptformerefficientuseknowledgegraph} approach, we use Tri-REx for training our knowledge graph encoder to convey factual information to the LLM backbone. 
For graph extraction, we reuse the data from Tri-REx star \cite{barmettler_2025_15165974} with a maximum of 100 edges.
This dataset is well-suited for testing the factual recall ability gained by the model, given 
the absence of overlap between training, validation, and testing entities, which penalizes models that rely on memorization rather than exploiting the available knowledge.

\paragraph{SimpleQuestions\cite{bordes2015largescalesimplequestionanswering}}
Dataset used to evaluate our model on simple one-hop question answering, and experimenting on continual finetuning. 
In particular, we use the answerable split provided by \cite{wikidata-benchmark}, which maps the dataset to the Wikidata KG and keeps only the questions for which they were able to find answers in Wikidata by mapping the properties.
This leaves the dataset with $14894$ train samples, $2210$ validation samples, and $4295$ test samples.
However, due to retrieval limitations from the public SPARQL endpoint\footnote{\url{https://query.wikidata.org/bigdata/namespace/wdq/sparql}}, we excluded cases where the central or target entity was absent from the retrieved subgraph. This filtering resulted in a final dataset comprising $9294$ training samples, $1377$ validation samples, and $2677$ test samples.
For this dataset, in the graph extraction step, we keep up to 10,000 edges, allowing us to evaluate the models under much noisier input conditions.

\paragraph{WebQSP\cite{yih-etal-2016-value}} 
We use this dataset to test the zero-shot capabilities of our model. We never trained on nor used this dataset for validation and held it out only for testing. 
Similarly to SimpleQuestions, we mapped the dataset to the Wikidata KG using the preprocessed files from \cite{C18-1280} 
and retained up to 10,000 edges during graph extraction.

\subsection{Baselines}
We compare our model against the following baselines: 
\begin{itemize}
    \item \textbf{Vanilla Qwen3-8B} (parametric only) to isolate the baseline performance coming from the memorization the LLM underwent during its pretraining. 
    \item \textbf{Textualization} (graph triples $\to$ natural-language prompt) as the most straightforward injection methodology.
    \item \textbf{LoRA-only} (no KG) to isolate the contribution of our injection mechanism for integrating new knowledge from simply matching the distribution of the training data or memorizing the answers.
    \item \textbf{ConceptFormer} (GPT-2 baseline from literature~\cite{barmettler2025conceptformerefficientuseknowledgegraph}) to evaluate the scalability of the approach.
\end{itemize}

\subsection{Metrics}
\label{sec:eval-met}
Our evaluation focuses on the model's ability to correctly predict object entities mentioned in ground truth answers, which are the most critical factual components.
For this reason, the primary evaluation metric is \textbf{Hit@k}, which measures the proportion of test instances where the correct answer token appears among the top $k$ predictions. 
For each test instance containing a query $x$ and target answer $y$, we:
\begin{enumerate}
    \item \textbf{Token-level ranking}: For each position $t$ in the target answer sequence, we compute the rank of the true token $y_t$ among all vocabulary tokens based on the model's output logits. The rank is defined as:
    \[
    \text{rank}_t = 1 + \sum_{v \in \mathcal{V}} \mathbf{1}[\text{logit}(v) > \text{logit}(y_t)]
    \]
    where $\mathcal{V}$ is the vocabulary and $\mathbf{1}[\cdot]$ is the indicator function.
    
    \item \textbf{Object boundary identification}: We identify the span of tokens corresponding to the target object entity in the answer sequence, focusing evaluation on factual content rather than auxiliary tokens like articles or prepositions.
    
    \item \textbf{Sequence-level rank}: The sequence-level rank is the maximum rank across all object token positions:
    \[
    \text{rank}_{\text{seq}} = \max_{t \in \text{object positions}} \text{rank}_t
    \]
    This conservative approach ensures that a sequence is considered correct only if \emph{all} object tokens are predicted with high confidence.
\end{enumerate}
We compute Hit@k for $k \in \{1, 3, 5, 10\}$ to capture both strict accuracy (Hit@1) and more lenient retrieval performance (Hit@10).
For the WebQSP dataset, as one question can have multiple correct answers, 
we consider a prediction correct if it matches any of the ground truth answers.

\subsection{Training Protocol}
We evaluate two model configuration checkpoints at different training stages:
\begin{enumerate}
  \item \textbf{KoRe-base} the first model gets trained using only the synthetic sentences from Tri-REx dataset. This step is used as foundation for the GNN and adaptation layers to learn the mapping between knowledge graphs and the LLM token embedding space.
  \item \textbf{KoRe-QA} then the model is used as base for finetuning, specializing the model on question answering using the SimpleQuestions dataset training split.
\end{enumerate}

\paragraph{Implementation Details}
We performed multiple ablation studies to determine the optimal configuration for our KG-LM architecture.
We analyzed the impact of codebook size, EMA aggressiveness, and quantizer depth (Appendix \ref{app:ablations}). Our results indicate that a moderate codebook size of $128$ codes prevents under-utilization while maintaining representational richness. We found that a less aggressive EMA replacement strategy ($N_{dead}=4$) stabilizes training and improves cross-dataset generalization. Furthermore, increasing the number of quantizers to $20$ consistently improved performance on the larger SimpleQuestions graphs compared to Tri-REx, justifying our final choice of Q=20.

Based on these analyses, we fix our final hyperparameters to: 
codebook size $128$, $Q=20$, EMA dead-code threshold $N_{dead}=4$, for the RVQ loss $\beta=0.25$, LoRA is applied to the query, key, value, and output projection matrices with rank $r=4$ and alpha $\alpha=8$ with dropout of $0.2$.
As text encoder for nodes and edge features, we use the Qwen3-Embedding-8B model \cite{zhang2025qwen3}.
While as LLM backbone, we use Qwen3-8B \cite{yang2025qwen3}.
For training our system, we utilize the AdamW optimizer with a batch size of $8$ per GPU ($32$ total), gradient accumulation steps of $2$, gradient clipping with a max norm of $1.0$, and validation every $8196$ batches for Tri-REx and the entire training set for SimpleQuestions.
Given the heterogeneous nature of our model components, we apply distinct learning rates optimized for each parameter group:
\begin{itemize}
    \item \textbf{LoRA parameters}: $1\times10^{-5}$ to ensure stable adaptation of the pretrained language model
    \item \textbf{Knowledge Graph Encoder parameters}: $5\times10^{-4}$ to enable efficient learning of graph representations
\end{itemize}
We also employ weight decay at $1\times 10^{-2}$, the reduce-LR-on-plateau scheduler, using a reduction factor of $0.5$ and patience $1$, monitoring the Hit@10 on the validation set, and early stopping patience on the same metric of $2$.

Each experimental run utilizes:
\begin{itemize}
    \item \textbf{GPU setup}: 4 NVIDIA A100 GPUs with 64GB memory each
    \item \textbf{Training budget}: 8 hours limit per experiment to ensure efficient resource usage.
    \item \textbf{Distributed framework}: We make use of the Accelerate library \cite{accelerate} integrated with DeepSpeed \cite{deepspeed} for coordinated multi-GPU training with ZeRO~\cite{rajbhandari2020zeromemoryoptimizationstraining} stage 2, which shards the optimizer states across all four GPUs.
    \item \textbf{Memory Optimization}: To reduce the memory footprint and accelerate training, we use the $bf16$ precision option in DeepSpeed.
\end{itemize}
The distributed training setup enables efficient parallelization of the training for both the language model and knowledge graph encoder components.
