\documentclass{midl} % Include author names

% The following packages will be automatically loaded:
% jmlr, amsmath, amssymb, natbib, graphicx, url, algorithm2e
% ifoddpage, relsize and probably more
% make sure they are installed with your latex distribution

\usepackage{mwe} % to get dummy images
\jmlryear{2025}
\jmlrworkshop{Full Paper -- MIDL 2025}
\jmlrvolume{--029}
\editors{Accepted for publication at MIDL 2025}

\title[Spatio-temporal STNAGNN]{STNAGNN: Data-driven Spatio-temporal\\ Brain Connectivity beyond FC}

 % Use \Name{Author Name} to specify the name.
 % If the surname contains spaces, enclose the surname
 % in braces, e.g. \Name{John {Smith Jones}} similarly
 % if the name has a "von" part, e.g \Name{Jane {de Winter}}.
 % If the first letter in the forenames is a diacritic
 % enclose the diacritic in braces, e.g. \Name{{\'E}louise Smith}

 % Two authors with the same address
 % \midlauthor{\Name{Author Name1} \Email{abc@sample.edu}\and
 %  \Name{Author Name2} \Email{xyz@sample.edu}\\
 %  \addr Address}

 % Three or more authors with the same address:
 % \midlauthor{\Name{Author Name1} \Email{an1@sample.edu}\\
 %  \Name{Author Name2} \Email{an2@sample.edu}\\
 %  \Name{Author Name3} \Email{an3@sample.edu}\\
 %  \addr Address}


% Authors with different addresses:
% \midlauthor{\Name{Author Name1} \Email{abc@sample.edu}\\
% \addr Address 1
% \AND
% \Name{Author Name2} \Email{xyz@sample.edu}\\
% \addr Address 2
% }

%\footnotetext[1]{Contributed equally}

% More complicate cases, e.g. with dual affiliations and joint authorship
\midlauthor{\Name{Jiyao Wang\nametag{$^{1}$}} \Email{jiyao.wang@yale.edu}\\
\Name{Nicha C. Dvornek \nametag{$^{1,2}$}}\\
\Name{Peiyu Duan \nametag{$^{1}$}}\\
\Name{Lawrence H. Staib \nametag{$^{1,2}$}}\\
\Name{Pamela Ventola \nametag{$^{3}$}}\\
\Name{James S. Duncan \nametag{$^{1,2,4}$}}\\
\addr $^{1}$ Department of Biomedical Engineering, Yale University, USA \\
\addr $^{2}$ Radiology \& Biomedical Imaging, Yale School of Medicine, USA \\
\addr $^{3}$ Child Study Center, Yale School of Medicine, USA \\ 
\addr $^{4}$ Electrical Engineering, Yale University, USA\\ 
}

\begin{document}
\maketitle

\begin{abstract}
In recent years, graph neural networks (GNNs) have been widely applied in the analysis of brain fMRI, yet defining the connectivity between ROIs remains a challenge in noisy fMRI data. Among all approaches, Functional Connectome (FC) is the most popular method. Computed by the correlation coefficients between ROI time series, FC is a powerful and computationally efficient way to estimate ROI connectivity. However, it is well known for neglecting structural connections and causality in ROI interactions. Also, FC becomes much more noisy in the short spatio-temporal sliding-window subsequences of fMRI. Effective Connectome (EC) is proposed as a directional alternative, but is difficult to accurately estimate. Furthermore, for optimal GNN performance, usually only a small percentage of the strongest connections are selected as sparse edges, resulting in oversimplification of complex brain connections. To tackle these challenges, we propose the Spatio-Temporal Node Attention Graph Neural Network (STNAGNN) as a data-driven alternative that combines sparse predefined FC with dense data-driven spatio-temporal connections, allowing for flexible and spatio-temporal learning of ROI interaction patterns. Our implementation is publicly available at \url{https://github.com/Jiyao96/STNAGNN-fMRI/}.
\end{abstract}

\begin{keywords}
Graph Neural Network, Functional MRI, Spatio-temporal learning
\end{keywords}

\section{Introduction}
% task/resting fmri 
Functional magnetic resonance imaging (fMRI) is a non-invasive imaging technique that measures brain activity by detecting changes in blood-oxygen-level-dependent~(BOLD) signals. Through the use of fMRI, significant progress has been made in understanding the functional organization of the brain. 
Among the resting-state and task-based alternatives of fMRI, task-based fMRI presents more significant fluctuations in BOLD signal. It has been shown to be superior to resting-state data for applications such as predicting behavioral traits~\cite{task-rest} and detecting individual differences~\cite{vince-task}. Although the diversity of task designs causes considerable difficulties in constructing large task-based fMRI datasets, increasing evidence indicates promising potential for task-based fMRI. The task-induced fMRI signal may offer a strong inductive bias to learn an informative model, especially in studies where tasks are designed to enhance disease-specific brain activities. 

% fMRI & machine learning on
In recent years, a wide range of machine learning methods including recurrent neural networks (RNNs)~\cite{nicha-lstm,dakka-lstm}, convolutional neural networks (CNNs)~\cite{brainetcnn}, and graph neural networks (GNNs)~\cite{BrainGNN,zhao-stgnn,lg-gnn} are applied to fMRI analysis. Among these approaches, GNN has its unique advantage in interpreting ROI-based brain interactions, an important field of research for understanding general brain functions and mechanisms of neurological disorders such as Autism Spectrum Disorder (ASD). However, efficient message passing and model interpretation in GNN rely on a high-quality definition of edges, posing considerable data processing challenges in the application of functional brain networks. 

% GNN
To utilize task-based fMRI data with both temporal task context and spatio-temporal ROI interactions, we formulate our goal as a spatio-temporal graph analysis problem. Specifically, we focus on discrete spatio-temporal graph formation where the spatio-temporal fMRI input is a temporal sequence of sliding window subsequences of the fMRI that we denote as graph snapshots. Although temporal GNNs have been a frequently studied subject in recent years~\cite{tgnn,gcgrulstm,gclstm,lrgcn,evolve}, we identify two key challenges unique to spatio-temporal brain graph applications: 

\begin{itemize}
    \item From the temporal dimension, the limited temporal resolution of fMRI acquisition and the sliding window truncation of the sequence data leads to a short sequence of graph snapshots, minimizing the advantages of the typical RNN to capture long-term dependencies in temporal information. 
    \item For each graph snapshot, FC is more susceptible to noise when applied on short temporal sequences inside each sliding window. The noisy pre-defined edges are less likely to be accurate or sufficient in describing the brain's functional dynamics.
\end{itemize}

In the field of brain network analysis, methods incorporating multiple brain atlases \cite{multiatlas} or multiple connectivity measurements~\cite{festgnn} are proposed to mitigate the challenges by introducing extra features in the input. In the field of graph theory, there are also attempts to detect and remove low-quality edges before training using geometric constraints~\cite{geometric-edge}. We propose to address these challenges from a data-driven perspective using STNAGNN, a spatio-temporal GNN model that incorporates a node-level attention algorithm for information aggregation on ROI-based brain graphs. To our knowledge, our approach is the first to implement direct spatio-temporal ROI connections at the node level, enabling more flexible information aggregation and model explainability. Meanwhile, it can also be applied as a complement to existing designs~\cite{festgnn,multiatlas,geometric-edge}.

\section{Notation and Problem Definition}
We truncate spatio-temporal fMRI data temporally into $T$ subsequences and construct each sliding window subsequence into a graph snapshot. For each instance, the input is a sequence of undirected weighted graph snapshots $\{{G}_1,{G}_2,\ldots,{G}_{T}\}$ where any ${G}_{i}=({V}_i,{E}_i)$ is a graph in the vertex set ${V}_i$ and the edge set ${E}_i$. For any edge $(v_{i,j},v_{i,k})\in {E}_i$ connecting vertices $v_{i,j}$ and $v_{i,k}$, we define its edge weight $e_{i,j,k}\in R^+$. For a vertex set of $N$ vertices, $d$-dimensional input node features are denoted as
$x_{i,j} \in R^d$ where $j\in\{0,1,\ldots,N-1\}$. 

Based on the above definitions, our goal of performing K class instance classification is equivalent to learning a mapping function $f$ that maps a sequence of graph snapshots to a class prediction label output $Z$:
$$ f:\{{G}_{i}|i\in\{1,2,\ldots,T\}\}\mapsto Z\in\{0,1,\ldots,K-1\}$$


\section{Data and Preprocessing}
\subsection{Biopoint Dataset}
We include a 118-subject task-based fMRI dataset to experiment with an autism spectrum disorder (ASD) classification task. The dataset contains fMRI scans of 75 children with ASD and 43 healthy controls matched in age and IQ. The scans are acquired under the biopoint~\cite{biopoint} task that contains 12 videos of biological or scrambled motions of point light displays. Videos of these two categories are given to subjects in an alternating sequence during the fMRI scan with the intention of highlighting deficits in the perception of biological motion in autistic children.

The scan for each subject has 146 frames with a frame rate of 2 seconds and an original resolution of 3.2mm. It is collected in the anonymous institution and approved by Yale Institutional Review Board. The acquired fMRI data are preprocessed using a pipeline described in~\cite{preprocess}, including the preprocessing steps of motion correction, interleaved slice timing correction, BET brain extraction, grand mean intensity normalization, spatial smoothing, and high-pass temporal filtering. The preprocessed data have a 2mm resolution in the MNI space.

\subsection{Human Connectome Project (HCP) Dataset}
We also evaluated our method in a 7-class brain state classification task using the HCP dataset~\cite{hcp}. We take 1,025 subjects in the WU-Minn HCP 1200 subject data release who have RL task-based fMRI scans in all 7 fMRI task sessions: emotion, gambling, language, motor, relational, social, and working memory. We use preprocessed fMRI in MNI space with 2mm resolution. Models are trained to classify spatio-temporal graph inputs into their corresponding tasks during data acquisition.

\subsection{Data Processing and Graph Construction}
% augmentation
For our biopoint dataset, we first parcellate the brain fMRI data into 84 ROIs based on the Desikan-Killiany atlas~\cite{desikan}. For network training, we performed class-stratified sampling on subjects in five roughly equal length subsets for five-fold cross-validation. Then, the mean time series of each ROI is extracted using $1/3$ of all voxels by bootstrap random sampling~\cite{dvornek2018}. We sample each ROI 30 times as data augmentation method, resulting in a total of $3540=118\times30$ instances. 

For graph construction, we truncate the mean time series into 12 non-overlapping subsequences aligned with each video stimuli. Each subsequence has 12 or 13 frames.
% node edge def
Using a similar approach as described in~\cite{BrainGNN}, for each local subsequence time series, we calculate the Pearson correlation between ROIs and use it as node features. Meanwhile, we extract and concatenate all the time series acquired in biological motion videos. Using the concatenated sequence, we calculate a global biological partial correlation and use its top 5\% values to define the edges and their weights for graph sparsity. Edges are shared across all 12 graph snapshots for each instance. As the biological motion viewing task is expected to elicit stronger correlated activity than the more random scrambled motion viewing task, we chose to use the global biological partial correlation for all edges. A graphical illustration of the preprocessing pipeline is shown in \figureref{fig_flowchart}.

For HCP data, we parcellate brain fMRI into 268 ROIs using Shen atlas~\cite{shen268}. The 1025 subjects are also divided into five subsets of 205 subjects for cross-validation. Each subject has 7 scans for all tasks, which leads to a total of $7175=1025\times7$ instances. Data augmentation is not performed on the HCP dataset. For graph construction in HCP, we follow the same approach as biopoint data except truncating each HCP fMRI scan into 4 equal-length sliding window subsequences and calculating edges using each entire fMRI sequence.

\begin{figure}[t]
\centering
\includegraphics[width=0.9\textwidth]{st_GNN.drawio.png}
\caption{STNAGNN architecture} \label{fig_STNAGNN}
\end{figure}

% proposed model
\section{Models}
\subsection{STNAGNN Model}
Our proposed STNAGNN model utilizes GNN convolution operation and the attention algorithm for sparse-connection and dense-connection graph information aggregation, respectively. It maintains node identities in aggregating spatio-temporal information from different graph snapshots. As shown in \figureref{fig_STNAGNN}, after performing two layers of graph convolution to extract localized graph information, we add positional encoding to each node and compute nodewise self-attention as a global spatio-temporal information aggregation operation using the dot product attention algorithm~\cite{att}. Essentially, in this operation, we neglect the spatial edges defined by FC and impose a fully connected spatio-temporal graph containing nodes from all node sets $\{{V}_1,{V}_2,\ldots,{V}_{T}\}$.
%The two graph convolution layers serve as feature extractor for the attention operation.

Positional encoding is a crucial part for the attention algorithm to capture the order information of the data. In both Transformer~\cite{att} and Vision Transformer~\cite{ViT}, the additive positional encoding for the attention algorithm is an absolute 1D raster sequence sinusoidal function. Various other designs of positional encoding have also been proposed for transformer architectures, including relative~\cite{ViT} and learnable~\cite{learnable_pe} alternatives. There are also attempts to apply 2D positional encoding, but mainly in the application for x and y dimensions of 2D images~\cite{2d_pe}.

For graph-structured data, although adding positional encoding to nodes can potentially further empower GNNs with positional knowledge, additive absolute positional embedding is generally considered not applicable since it breaks the permutation invariance of graph message passing~\cite{gnn_pe}. However, for the ROI-based brain graph application, the brain graph nodes always follow the same node sequence. Permutation invariance is not a required attribute. In our proposed STNAGNN architecture, to encode both spatial and temporal information of a node and preserve computation simplicity, we propose an absolute multiplicative 2D positional encoding defined as follows:   
\begin{equation}
PE(j,i,2f)=sin(j/C_1^{E})sin((C_2+i)/C_1^{E})
\end{equation}
\begin{equation}
PE(j,i,2f+1)=cos(j/C_1^{E})cos((C_2+i)/C_1^{E})
\end{equation}
\begin{equation}
E=2f/d
\end{equation}
where $j$ denotes spatial position and $i$ denotes temporal position. $f$ represents the individual feature channel in the node features of dimension $d$. $C_1=10000$ is a constant to scale the frequency of encoding.
$C_2$ is a constant offset to avoid duplicated embedding in different nodes. In our experiments, we set $C_2=10000$. When the spatial position $j$ is fixed, any positional encoding $PE_{j,i}$ can be represented as a linear function of $PE_{j,i+k}$ with $k$ being a constant temporal offset. The same applies to temporal position $i$ being fixed. As illustrated in \figureref{fig_STNAGNN}, the 2D positional encoding is added to the node features of each node before computing the self-attention operation. Meanwhile, since the proposed positional encoding can be pre-computed, the extra computation during training is negligible.

Our method is to our knowledge the first to use 2D positional encoding in a spatio-temporal context. There are several advantages in applying the spatio-temporal self-attention operation in the STNAGNN architecture.
\begin{itemize}
    \item By imposing a fully connected self-attention operation, we mitigate the bias from inaccurate edge definition in the functional brain graph application. Data-driven information aggregation based on the similarity between node features and the spatio-temporal adjacency of nodes brings more flexibility to learning ROI interactions. 
    \item A fully connected graph using attention allows for the direct participation of information from one node to any other nodes (\figureref{fig_connection}). It alleviates the problem of limited receptive fields in graph convolution operations~\cite{gcn,gat}. In our experiment, we argue that adding this operation allows multi-scale information aggregation from both local and global neighborhoods of nodes. 
    %\item Similar to the scenario in the Transformer~\cite{att}, all graph convolutions can now be computed in parallel. It is computationally more efficient than some of the RNN-based temporal GNN architectures~\cite{gcgrulstm,gclstm,lrgcn,evolve}.
\end{itemize}

\begin{figure}[t]
\centering
\includegraphics[width=0.9\textwidth]{connection-abc.png}
\caption{Illustration of connectivity types: a) spatial connectivity between neighboring nodes in one graph, usually computed by GNN convolution; b) temporal connectivity between different time points of one node; c) spatio-temporal connectivity (magenta), allowing information flow between nodes in different graphs. Existing architectures usually consider only spatial connectivity~\cite{BrainGNN} or temporal connectivity~\cite{nicha-lstm}. Some spatio-temporal designs consider both spatial and temporal perspectives~\cite{gcgrulstm,gclstm,lrgcn,evolve} but use a two-step spatial-then-temporal approach. Our proposed STNAGNN jointly considers all spatio-temporal connectivity.}
\label{fig_connection}
\end{figure}

\subsection{Baseline models}
There are various existing spatio-temporal GNN models that use discrete graph snapshot structures to incorporate temporal information into GNN. For comparison with the proposed STNAGNN approach, we experiment with SVM and four different spatio-temporal GNN designs: GConvLSTM~\cite{gcgrulstm}, GCLSTM~\cite{gclstm}, LRGCN~\cite{lrgcn}, and EvolveGCN~\cite{evolve}. These compared architectures include approaches of GNN-embedded RNN~\cite{gcgrulstm,gclstm}, stateful GNN~\cite{lrgcn}, and weight-evolving GNN~\cite{evolve}. In Section~\ref{sec_gnnrnn}, we also show the performance of another implementation using LSTM instead of the proposed attention operation.

\section{Evaluation and Interpretation}
\subsection{Classification Task Performance}
For the STNAGNN model, we experiment with alternative two-layered graph convolution backbones including GCN~\cite{gcn}, GAT~\cite{gat}, GraphSAGE~\cite{sage}, and Graph Transformer (GT)~\cite{graph-transformer}. In the feed-forward modules following graph convolution and temporal aggregation methods, we apply SiLU~\cite{silu} activation and a dropout rate of 0.2 in each layer. All models are trained with cross-entropy loss on a single RTX A5000 GPU. We perform a five-fold cross-validation experiment on both biopoint and HCP datasets. During training, we tune hyperparameters for each dataset respectively. For biopoint data, we use a learning rate of $2\times10^{-5}$ and a large weight decay factor of $0.015$. For HCP, we use
$4\times10^{-6}$ as learning rate and $0.0001$ as weight decay. For both datasets, we use a batch size of $10$. The results for both datasets measured by classification accuracy and Area Under the ROC Curve (AUC) are summarized in \tableref{tab_acc}. 
% TABLE vertical
\begin{table}[t]
\centering
\setlength{\tabcolsep}{3mm}
\begin{tabular}{c | c c | c c }
{} & \multicolumn{2}{c|}{Biopoint}& \multicolumn{2}{c}{HCP}\\ 
{} & Acc(\%) & AUC & Acc(\%) & AUC \\
\hline
\hline
SVM & 68.7(4.97) & 0.608(0.046) & 91.2(0.991)  & 0.993(0.001)\\
\hline
GConvLSTM  & 63.8(3.65) & 0.642(0.031) & 96.7(0.768) & 0.998(0.001)\\
GCLSTM & 72.5(3.67) & 0.675(0.067) & 97.7(0.224) & 0.998(0.000) \\
LRGCN & 72.5(2.83) & 0.741(0.059) & 96.8(0.637) & 0.998(0.000)\\
EvolveGCN & 72.0(7.58) &  0.734(0.123) & 95.7(0.535) & 0.997(0.001) \\
\hline
STNAGNN-GCN & 75.2(4.40) &  0.670(0.130) & 97.3(0.424) & 0.998(0.000) \\
STNAGNN-GAT & \textbf{79.2(3.49)}  & \textbf{0.755(0.099)} & 96.9(0.747) & 0.998(0.000) \\
STNAGNN-SAGE & 74.0(4.26)  & 0.619(0.112) & \textbf{98.1(0.179)} & \textbf{0.999(0.000)}\\
STNAGNN-GT & 74.7(3.36) & 0.664(0.105) &\textbf{98.1(0.407)} & \textbf{0.999(0.000)}\\
\hline
\end{tabular}
\caption{Comparison of classification performance with SVM, temporal GNN baselines, and STNAGNN architectures using different graph convolution backbones. Results in biopoint and HCP dataset are reported in mean(standard deviation). Best mean performance in each column are bolded.}
\label{tab_acc}
\end{table}

\subsection{Ablation Study}
% TABLE
\begin{table}[t]
\centering
\setlength{\tabcolsep}{1.5mm}
\begin{tabular}{ c | c c | c c | c c | c c | c c}
\hline
Edge & \multicolumn{2}{c|}{ALL} & \multicolumn{2}{c|}{ALL}  & \multicolumn{2}{c|}{ALL}& \multicolumn{2}{c|}{\textbf{BIOL}}  & \multicolumn{2}{c}{SCRAM}  \\
\hline
\# Windows & \multicolumn{2}{c|}{10} & \multicolumn{2}{c|}{12} & \multicolumn{2}{c|}{14} & \multicolumn{2}{c|}{\textbf{12}}  & \multicolumn{2}{c}{12} \\
\hline
 GNN-backbone & GCN & GAT & GCN & GAT & GCN & GAT & \textbf{GCN} & \textbf{GAT}  & GCN & GAT \\
\hline
Acc (\%) & 73.4 & 73.5 & 73.4 & 73.5 & 69.7 & 71.6 & \textbf{75.2} & \textbf{79.2} & 70.0 & 73.0 \\
% f1 & 0.814 & 0.806 & \textbf{0.822} & 0.814 & 0.793 & 0.788 & 0.820 & \textbf{0.848} & 0.794 & 0.806 \\
AUC & \textbf{0.680} & 0.666 & 0.621 & 0.665 & 0.658 & 0.710 & 0.670 & \textbf{0.755}  & 0.676 & 0.666 \\
\hline
\end{tabular}
\caption{Ablation study on graph construction. BIOL, SCRAM, and ALL denote edge computed using fMRI under biological motion, scrambled motion, and all fMRI frames. The combination used above and the best performances are bolded.}
\label{tab_ablation_graph}
\end{table}

% TABLE
\begin{table}
\centering
\setlength{\tabcolsep}{1mm}
\begin{tabular}{l|c c | c c }
\hline
{} &  \multicolumn{2}{c|}{1D Raster Sequence} &  \multicolumn{2}{c}{2D Spatio-temporal}\\
\cline{2-5}
{} & Acc(\%) & AUC & Acc(\%) & AUC\\
\hline
GCN & 71.6 & 0.653 & 75.2 & 0.670\\
GAT & 74.6 & 0.680 & 79.2 & 0.755\\
\hline
\end{tabular}
\caption{Ablation study on positional encoding. Mean of cross-validation is reported.}
\label{tab_ablation_pe}
\end{table}

\subsubsection{Graph Construction}
We perform ablation studies of graph construction methods on biopoint ASD classification tasks for STNAGNN architecture using GCN and GAT. 
For data truncation, we validate the task-aligned choice of using 12 sliding windows by comparing it to using 10 and 14 windows using the whole sequence in edge construction for comparison. In addition, under the 12-sliding-window construction, we compare the performance of spatial graph edges constructed using fMRI data acquired under biological motion videos, scrambled motion videos, and the whole sequence. A visualization of the brain connectome using different graph construction methods is shown in \figureref{fig_connectome}.The results are summarized in \tableref{tab_ablation_graph}. 


\subsubsection{Positional Encoding}
We validate the proposed 2D spatio-temporal positional encoding by comparing it to the 1D raster sequence positional encoding. The results are summarized in \tableref{tab_ablation_pe}. Using the proposed 2D spatio-temporal encoding outperforms using the 1D raster sequence option in all metrics. In Appendix~\ref{sec_pe}, we show a visualization of positional encoding methods to facilitate understanding of its interaction with the attention algorithm.

\subsubsection{Attention versus LSTM}
\label{sec_gnnrnn}
To compare the performance of the proposed attention-based spatio-temporal aggregation method with RNN structures such as LSTM~\cite{lstm}, we perform an ablation study on biopoint and HCP dataset using an architecture similar to STNAGNN except that the global attention operation is replaced with LSTM. Meanwhile, this modification also removes the global node-level connection and relies on LSTM to aggregate graph-level information in the temporal space. The architecture of this GNN-LSTM ablation is described in \figureref{fig_baseline}. We experiment with the same set of GNN backbones as in STNAGNN architecture. The results in both classification tasks are shown in \tableref{tab_lstm}. The best performance in each column is not better than the best STNAGNN performance shown in \tableref{tab_acc}.

% TABLE vertical
\begin{table}[t]
\centering
\setlength{\tabcolsep}{3mm}
\begin{tabular}{c | c c | c c }
{} & \multicolumn{2}{c|}{Biopoint}& \multicolumn{2}{c}{HCP}\\ 
{} & Acc (\%) &  AUC & Acc(\%) & AUC \\
\hline
\hline
GCN-LSTM &73.4(6.67)&0.700(0.089)&97.4(0.387)& 0.998(0.000)\\
GAT-LSTM & \textbf{76.0(5.31)}&0.705(0.090)&97.0(0.546)&0.998(0.001)\\
SAGE-LSTM & 72.1(2.68)&\textbf{0.708(0.076)}&97.8(0.471)&\textbf{0.999(0.000)}\\
GT-LSTM & 73.6(3.63)&0.690(0.078)&\textbf{97.9(0.276)}&\textbf{0.999(0.000)}\\
\hline
\end{tabular}
\caption{Performance on GNN-LSTM ablation with various GNN backbones. Results in biopoint and HCP dataset are reported in mean(standard deviation). Best mean performance in each column are bolded.}
\label{tab_lstm}
\end{table}

\subsection{ASD Biomarker Interpretation}
An important advantage of graph-based methods in the analysis of brain fMRI is the capability to identify ROI biomarkers by interpreting the decision-making process of trained GNN models. To interpret our trained STNAGNN model, we apply the GNNExplainer~\cite{gnnexplainer}, which is a module designed as a post-hoc interpretability method for GNN architectures. We consider all the $1008=84\times12$ spatio-temporal nodes and derive an importance score for each node empirically by optimizing a mask function towards the highest mutual information between the outputs generated using masked and unmasked inputs. Using the trained best-performance STNAGNN with GAT backbone, we plot 12 ROI-importance heatmaps, each for a graph snapshot. 6 heat maps sampled among 12 snapshots of graphs are shown in \figureref{fig_biomarker}. 

\begin{figure}[t]
\centering
\includegraphics[width=0.9\columnwidth]{biomarker_midl.png}
\caption{Interpreted ROI importance from $T=3, 4, 7, 8, 11, 12$. Temporal indices of graph snapshots are marked on the top of each plot. Time under biological motion stimuli are marked by the blue dashed square. Darker regions indicate higher importance. Blue and green circles mark left parietal lobe and right thalamus. The complete plots of 12 sliding window snapshots are shown in \figureref{fig_all_heatmaps}.} 
\label{fig_biomarker}
\end{figure}

By comparing heat maps over different time frames, the brain ROIs that are important for making ASD classification appear to be dynamic across different graph snapshots. For example, from the heatmaps sampled, the signals from the left parietal lobe at T=8 contribute significantly more to the classification task than in the other frames. Meanwhile, we also see some recurring ROIs being prominent, such as the right thalamus in T=3, 4, 8 and the left parietal lobe in T=8, 12. While the thalamus is usually considered highly associated with ASD~\cite{asd-thalamus-2016,asd-thalamus-2017}, the left parietal lobe is also found to be indicative of language development in ASD~\cite{parietal}. %The biomarker ROIs interpreted by the model agree with existing research findings. 

%We also observe that snapshots generated using fMRI segments acquired under biological motion video stimuli (T=4, 8, 12) generally have ROIs showing higher importance than those acquired under scrambled motion video stimuli (T=3, 7, 11), which is especially obvious in T=3, 11 where the overall heatmaps of scrambled motion video data show a significantly lighter color. It reassures that the designed biological motion video stimulates brain functions related to ASD-related deficiencies. 

Similar to other applications for interpreting ROI importance using GNN, analyzing recurring salient ROIs can help us identify potential ASD biomarkers useful for ASD diagnosis and subtype classification. Additionally, in a spatio-temporal GNN model such as STNAGNN, spatio-temporal importance aligned with task schemes can guide us in finding the appropriate stimuli to trigger the study-related functional response in the brain, which can potentially help design fMRI sessions. 

\section{Conclusion}
In this paper, we propose the STNAGNN architecture as a spatio-temporal framework for analyzing task-based fMRI data. In addition to enabling spatio-temporal explainability, it also outperforms existing designs in both ASD classification and brain state classification tasks from two datasets. In future research, we intend to further develop this method, including experimenting with additional datasets and extending the tasks to predicting cognitive scores or performances in psychiatric assessments.


\clearpage  % Acknowledgements, references, and appendix do not count toward the page limit (if any)
% Acknowledgments---Will not appear in anonymized version
\midlacknowledgments{This paper is supported by NIH grant R01NS035193.}

\bibliography{midl25_029}

\appendix

\section{Related Work}
\subsection{Graph Neural Network}
Graph Neural Network (GNN) is a class of machine learning models applied to graph-structured data. It aims to learn an aggregation function for a neighborhood of nodes and propagate the function over the entire graph. In each layer, a GNN update of the node feature can usually be described as follows:
\begin{equation}
x_i^k=f(\{x_j^{k-1},\forall j\in N(i)\},x_i^{k-1})
\end{equation}
where $x$ denotes node feature, $k$ denotes layer number, and $N(i)$ is the neighborhood of node $i$. From a graph spectrum perspective, a GNN layer usually functions similarly to a low-pass filter on an input graph to suppress high-frequency noise and extract low-frequency information. 

\subsection{Scaled Dot-Product Attention Algorithm}
The scaled dot-product attention algorithm is first introduced in~\cite{att} as a fundamental mechanism for the Transformer architecture. It quantifies the similarity between elements using regularized dot-product and updates each element as a weighted sum using the similarity-based attention score. A self-attention operation using this algorithm is theoretically similar to a fully connected graph, which allows for information to flow from one node to any other nodes.   
\begin{equation}
Attention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V
\label{eq_att}
\end{equation}

\begin{figure}[t]
\centering
\includegraphics[width=\textwidth]{graph_construction.png}
\caption{Preprocessing and graph construction pipeline on the biopoint data} \label{fig_flowchart}
\end{figure}

\begin{figure}[t]
\centering
\includegraphics[width=0.4\textwidth]{visualize_connectome.png}
\caption{Visualization of calculated brain connectome using different approaches} \label{fig_connectome}
\end{figure}

\section{Visualization of positional encoding on simulated data}
\label{sec_pe}
We used a simulated sample to visualize the interaction between different positional encoding methods and self-attention operation. We first generate a simulated setting with 84 nodes in each of the 12 temporal graphs so that it has the same shape as our Biopoint data. For simplicity, each node feature is generated as a one-dimensional value following a normal distribution $\mathcal{N}(0,0.1)$. We apply self-attention operation using three different approaches: 1) W/O positional encoding, 2) 1D raster sequence positional encoding, and 3) 2D spatio-temporal positional encoding. We extract the attention scores calculated during the self-attention operations, which directly explains how the node feature of each node participated in the attention operation output of each other node. We visualize the attention scores from one node as a heat map and compare its pattern under different positional encoding in Figure~\ref{fig_pe}.

\begin{figure}[t]
\centering
\includegraphics[width=0.9\textwidth]{visualize_pe.png}
\caption{Visualization of the attention score calculated on simulated sample using different positional encoding method.} \label{fig_pe}
\end{figure}

Without positional encoding, the attention score is entirely calculated from the randomly generated node features. Therefore, there is no noticeable connection between the attention weights and the underlying spatiotemporal structure. When 1D positional encoding is applied, it guides the attention weights towards an oscillation with pattern similar to sinusoidal function on the flattened raster sequence direction. Although it explains the affinity of nodes with similar node index in the spatial domain, it also shows obvious misalignment in the temporal space. The attention weights calculated from the proposed 2D spatiotemporal encoding shows a 2D pattern aligned with both spatial and temporal axis. 

\begin{figure}[t]
\centering
\includegraphics[width=0.9\textwidth]{RNN_GNN.png}
\caption{GNN-LSTM Architecture} \label{fig_baseline}
\end{figure}

\begin{figure}[t]
\centering
\includegraphics[width=0.9\textwidth]{biomarker_all.png}
\caption{Interpreted ROI importance from $T=1, 2, \ldots, 12$} \label{fig_all_heatmaps}
\end{figure}
\end{document}
