\documentclass{midl} % Include author names

% The following packages will be automatically loaded:
% jmlr, amsmath, amssymb, natbib, graphicx, url, algorithm2e
% ifoddpage, relsize and probably more
% make sure they are installed with your latex distribution

\usepackage{mwe} % to get dummy images
\usepackage{bbm}
\usepackage{tikz}
\usepackage{standalone}
\usepackage{placeins}

% \usepackage{amsmath,amssymb,amsfonts}
% \usepackage{algorithmic}
% \usepackage{graphicx}
\usepackage[outline]{contour}
\usepackage{booktabs}
\usepackage{multirow}
\usepackage{microtype}
\usepackage{adjustbox}
\usepackage{colortbl}
\usepackage{hyperref}
\usepackage{array}
\newcolumntype{C}[1]{>{\centering\arraybackslash}p{#1}}
\newcommand{\bolduparrow}{\contourlength{0.03em}\contour{black}{\ensuremath{\uparrow}}}
\newcommand{\bolddownarrow}{\contourlength{0.03em}\contour{black}{\ensuremath{\downarrow}}}


\def\eg{\emph{e.g.,\ }}

% Usage: \cbstart\lmnote{REVIEWER #, POINT #} TEXT \cbend
\usepackage{marginnote}
\usepackage{mcb2}

% --- Toggle for review margin notes ---
\newif\ifshownotes
\shownotestrue            % default: show notes
% \shownotesfalse         % uncomment to disable notes

\newcommand{\nonotes}{\shownotesfalse}
\newcommand{\usenotes}{\shownotestrue}

\newcommand{\rmnote}[1]{%
  \ifshownotes
    \marginnote{\scriptsize\raggedright\textbf{#1}}%
  \fi
}
\newcommand{\lmnote}[1]{%
  \ifshownotes
    \marginnote{\scriptsize\raggedleft\textbf{#1}}%
  \fi
}

\nochangebars % disable changebar
\nochangecolor % disable changecolor
\nonotes
% \usechangecolor
\cbsetchangecolor{magenta} 


\input{tikzstyles}

\jmlrvolume{-- nnn}
\jmlryear{2026}
\jmlrworkshop{Full Paper -- MIDL 2026}
\editors{Accepted for publication at MIDL 2026}

\title[MetaVoxel]{MetaVoxel: Joint Diffusion Modeling of Imaging and Clinical Metadata}

 % Use \Name{Author Name} to specify the name.
 % If the surname contains spaces, enclose the surname
 % in braces, e.g. \Name{John {Smith Jones}} similarly
 % if the name has a "von" part, e.g \Name{Jane {de Winter}}.
 % If the first letter in the forenames is a diacritic
 % enclose the diacritic in braces, e.g. \Name{{\'E}louise Smith}

 % Two authors with the same address
 % \midlauthor{\Name{Author Name1} \Email{abc@sample.edu}\and
 %  \Name{Author Name2} \Email{xyz@sample.edu}\\
 %  \addr Address}

 % Three or more authors with the same address:
 % \midlauthor{\Name{Author Name1} \Email{an1@sample.edu}\\
 %  \Name{Author Name2} \Email{an2@sample.edu}\\
 %  \Name{Author Name3} \Email{an3@sample.edu}\\
 %  \addr Address}


% Authors with different addresses:
% \midlauthor{\Name{Author Name1} \Email{abc@sample.edu}\\
% \addr Address 1
% \AND
% \Name{Author Name2} \Email{xyz@sample.edu}\\
% \addr Address 2
% }

%\footnotetext[1]{Contributed equally}

% More complicate cases, e.g. with dual affiliations and joint authorship
\midlauthor{\Name{Yihao Liu\nametag{$^{1}$}} \orcid{0000-0003-3187-9903} \Email{yihao.liu@vanderbilt.edu}
\AND
\Name{Chenyu Gao\nametag{$^{1}$}}
\orcid{0000-0003-2098-3035}
\Email{chenyu.gao@vanderbilt.edu}
\AND
\Name{Lianrui Zuo\nametag{$^{1}$}}
\orcid{0000-0002-5923-9097}
\Email{lianrui.zuo@vanderbilt.edu}
\AND
\Name{Michael E. Kim\nametag{$^{2}$}}
\orcid{0009-0006-3562-2688}
\Email{michael.kim@vanderbilt.ddu}
\AND
\Name{Brian D. Boyd\nametag{$^{3}$}}
\Email{brian.d.boyd@vumc.org}
\AND
\Name{Lisa L. Barnes\nametag{$^{4}$}}
\orcid{0000-0002-0072-9817}
\Email{Lbarnes1@rush.edu}
\AND
\Name{Walter A. Kukull\nametag{$^{5}$}}
\orcid{0000-0001-8761-9014}
\Email{kukull@wustl.edu}
\AND
\Name{Lori L. Beason-Held\nametag{$^{6}$}}
\orcid{0000-0001-9057-6270}
\Email{heldlo@grc.nia.nih.gov}
\AND
\Name{Susan M. Resnick\nametag{$^{6}$}}
\orcid{0000-0003-1115-7145}
\Email{resnicks@grc.nia.nih.gov}
\AND
\Name{Timothy J. Hohman\nametag{$^{7,8}$}}
\orcid{0000-0002-3377-7014}
\Email{timothy.j.hohman@vumc.org}
\AND
\Name{Warren D. Taylor\nametag{$^{3,9}$}}
\orcid{0000-0002-9975-3082}
\Email{warren.d.taylor@vumc.org}
\AND
\Name{Bennett A. Landman\nametag{$^{1,2,10,11}$}}
\orcid{0000-0001-5733-2127}
\Email{bennett.landman@vanderbilt.edu}
\AND
\Name{for the Alzheimer’s Disease Neuroimaging Initiative\footnotemark[1]\ and the BIOCARD Study Team\footnotemark[2]\,\mbox{} }
\AND
\addr $^{1}$Department of Electrical and Computer Engineering, Vanderbilt University, Nashville, TN, US.\\
\addr $^{2}$Department of Computer Science, Vanderbilt University, Nashville, TN, US.\\
\addr $^{3}$Center for Cognitive Medicine, Department of Psychiatry and Behavioral Science, Vanderbilt University Medical Center, Nashville, TN, US.\\
\addr $^{4}$Department of Neurological Sciences and Rush Alzheimer's Disease Center, Rush University Medical Center, Chicago, IL.\\
\addr $^{5}$Washington University in St. Louis, St Louis, MO, US.\\
\addr $^{6}$Laboratory of Behavioral Neuroscience, National Institute on Aging, National Institutes of Health, Baltimore, MD.\\
\addr $^{7}$Vanderbilt Memory and Alzheimer’s Center, Vanderbilt University Medical Center, Nashville, TN.\\
\addr $^{8}$Department of Neurology, Vanderbilt University Medical Center, Nashville, Tennessee 37240, USA
\addr $^{9}$Geriatric Research, Education, and Clinical Center, Veterans Affairs Tennessee Valley Health System, Nashville, TN, US.\\
\addr $^{10}$Department of Biomedical Engineering, Vanderbilt University, Nashville, TN, US.\\
\addr $^{11}$Department of Radiology, Vanderbilt University Medical Center, Nashville, TN.
}

\begin{document}

\maketitle

\begin{abstract}
Modern deep learning methods have achieved impressive results across tasks from disease classification, estimating continuous biomarkers, to generating realistic medical images.
Most of these approaches are trained to model conditional distributions defined by a specific predictive direction with a specific set of input variables.
We introduce MetaVoxel, a generative joint diffusion modeling framework that models the joint distribution over imaging data and clinical metadata by learning a single diffusion process spanning all variables. By capturing the joint distribution, MetaVoxel unifies tasks that traditionally require separate conditional models and supports flexible zero-shot inference using arbitrary subsets of inputs without task-specific retraining.
Using more than $10,000$ T1-weighted MRI scans paired with clinical metadata from nine datasets, we show that a single MetaVoxel model can perform image generation, age estimation, and sex prediction, achieving performance comparable to established task-specific baselines. Additional experiments highlight its capabilities for flexible inference.
Together, these findings demonstrate that joint multimodal diffusion offers a promising direction for unifying medical AI models and enabling broader clinical applicability.

\end{abstract}

\begin{keywords}
Diffusion Model, Joint distribution, Multimodal.
\end{keywords}

\renewcommand{\thefootnote}{\fnsymbol{footnote}}
\footnotetext[1]{Data used in preparation of this article were obtained from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) database (adni.loni.usc.edu). As such, the investigators within the ADNI contributed to the design and implementation of ADNI and/or provided data but did not participate in analysis or writing of this report. A complete listing of ADNI investigators can be found at: \texttt{\url{http://adni.loni.usc.edu/wp-content/uploads/how_to_apply/ADNI_Acknowledgement_List.pdf}}}
\footnotetext[2]{Data used in preparation of this article were derived from BIOCARD study data, supported by grant U19 –AG033655 from the National Institute on Aging. The BIOCARD study team did not participate in the analysis or writing of this report, however, they contributed to the design and implementation of the study. A listing of BIOCARD investigators may be accessed at: \texttt{\url{https://www.biocard-se.org/public/Core\%20Groups.html}}}

\section{Introduction}
Clinicians are routinely faced with a diverse question set when evaluating patients: Does this individual show signs of a particular disease? What is the patient's risk factor given their age? How might this patient’s imaging look in two years? How would the imaging appear if disease status were different? These questions highlight the multidimensional nature of clinical assessment, where imaging and clinical metadata together form a complex and interdependent portrait of health.

Deep learning have provided the tools to help address these questions. Modern classifiers can accurately predict disease status from imaging~\cite{li2015robust,paul2017deep}. Regression models can estimate continuous attributes such as age or risk scores with impressive precision~\cite{venkadesh2021deep, gao2025brain}. Generative models can synthesize realistic medical images for data augmentation~\cite{guo2025maisi,zhang2025msrepaint}, and predict disease trajectories~\cite{puglisi2025brain}. While each of these task-specific methods has advanced considerably, every task requires a separate model with a predefined set of inputs.
This contrasts with clinical reasoning, where the boundaries between questions are fluid, and reasoning often shifts from diagnosis to prognosis to hypothetical scenarios within the course of a single patient encounter.

In this work, we present MetaVoxel, a generative diffusion framework designed to capture the multidimensional nature of patients. Unlike existing approaches that learn a conditional distribution of a target variable given a predefined set of inputs variables, we model the joint distribution over all variables, including imaging and clinical metadata.
Although this is substantially more challenging than focusing on a single predictive direction, joint modeling provides broader opportunities.
Because the joint distribution encompasses \textbf{all possible conditionals}, a single trained diffusion model can perform diverse tasks beyond image generation. Moreover, by capturing the complete set of dependencies among variables, our model enables flexible conditioning on arbitrary subsets of information at test time. Our contributions are summarized as follows:
\begin{itemize}
    \item We introduce MetaVoxel, a multimodal diffusion model that learns the joint distribution over imaging data and clinical metadata;
    \item We develop a zero-shot inference procedure that allows MetaVoxel to access conditional distributions and perform diverse tasks using arbitrary subsets of input variables.
    \item Using more than 10,000 paired MRI–metadata samples, we demonstrate that MetaVoxel achieves competitive performance on image generation, age estimation, and sex prediction, while supporting flexible conditioning scenarios.
\end{itemize}

\section{Backgrounds and Related Works}
\label{s:related_works}
\textbf{Denoising Diffusion Probabilistic Models (DDPMs)}~\cite{sohl2015deep, ho2020denoising} are generative models that aim to learn the underlying data distribution \( p(x) \) by defining a forward diffusion process that gradually corrupts data, and a reverse process that learns to invert this corruption process.
The forward process is a fixed Markov chain that incrementally adds noise to a data sample \( x_0 \) over a sequence of discrete time steps \( t = 1, \dots, T \). In practice, Gaussian noise is the most common choice for image generation tasks, where the forward process provides a smooth trajectory from data \( x_0 \) to a latent variable \( x_T \) that follows a simple prior distribution. Each successive latent variable \( x_t \) along this trajectory follows
\begin{equation}
    q(x_t|x_{t-1}) = \mathcal{N}\!\left(x_t \mid \sqrt{1 - \beta_t^{\text{DDPM}}}\, x_{t-1},\, \beta_t^{\text{DDPM}} I \right),
    \label{e:gaussian_diffusion}
\end{equation}
where \(\beta_t^{\text{DDPM}}\in(0,1]\) controls the noise magnitude at each step.
The reverse process, parameterized by a time-conditioned neural network, learns to denoise step-by-step by predicting the mean of the conditional distribution \( p_\theta(x_{t-1}|x_t) \). The model is trained by maximizing a variational lower bound on the data likelihood. In practice, this objective simplifies to mean squared error terms between the true and predicted total noise added to an image \( x_t \) across all time steps \( t \). Once trained, new samples can be generated by drawing an initial latent variable \( x_T \) from a standard Gaussian distribution and iteratively applying the learned reverse process to obtain a realistic data sample.

Several extensions have been proposed to improve the efficiency and scalability of DDPMs. Denoising Diffusion Implicit Models~(DDIMs)~\cite{songdenoising} introduce an implicit probabilistic formulation that shares the same training objective as DDPMs but enables deterministic sampling that requires significantly fewer denoising steps.
Latent Diffusion Models~(LDMs)~\cite{rombach2022high} enhance the framework by performing diffusion in a latent space learned via a variational autoencoder, achieving substantial computational efficiency for high-resolution image generation.

\textbf{Discrete diffusion models}, such as Discrete Denoising Diffusion Probabilistic Models~(D3PMs)~\cite{austin2021structured} and~\cite{hoogeboom2021argmax}, generalize the diffusion process to categorical variables. In this setting, the forward process is defined as
\begin{equation}
    q(\mathbf{x}_t|\mathbf{x}_{t-1}) = \mathrm{Cat}(\mathbf{x}_t; \mathbf{p} = \mathbf{x}_{t-1} \mathbf{Q}_t),
    \label{e:d3pm}
\end{equation}
where \(\mathrm{Cat}(\mathbf{x}; \mathbf{p})\) denotes a categorical distribution over the one-hot row vector \(\mathbf{x}\) with class probabilities given by the row vector \(\mathbf{p}\), and the transition matrices \(\mathbf{Q}_t\) control the corruption process. The reverse process typically adopts an \(\mathbf{x}_0\)-parameterization, where a time-conditioned neural network models \(p_\theta(\mathbf{x}_0|\mathbf{x}_t)\). From here, one can derive \(p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)\), allowing sampling to proceed from an initial random one-hot vector through iterative denoising, analogous to the procedure used in DDPMs.

\textbf{Conditional DDPMs} extend the diffusion framework to learn conditional distributions \( p(x|y) \), enabling sampling guided by auxiliary information \(y\). This formulation has become foundational across many tasks, such as text-to-image and structure-guided image synthesis. Prominent examples include Stable Diffusion~\cite{rombach2022high} and ControlNet~\cite{zhang2023adding}, which enable flexible and controllable image generation from text prompts or structural inputs.
In medical research, conditional diffusion models have been explored to address data scarcity problem in AI model development~\cite{guo2025maisi} and disease progression modeling~\cite{puglisi2025brain, mcmaster2025technical}.
% For instance, MAISI~\cite{guo2025maisi} generates realistic medical images conditioned on metadata or masks, enabling the creation of paired segmentation and images for downstream tasks such as image segmentation. Generating realistic synthetic images also facilitates disease progression modeling. For instance, Brain Latent Progression~\cite{puglisi2025brain, mcmaster2025technical} conditions on baseline Magnetic Resonance~(MR) scans and clinical metadata to model longitudinal brain changes.

\textbf{Unconditioned DDPMs} can also be adapted for \textbf{conditional generation} during the sampling process, as demonstrated by methods such as RePaint~\cite{lugmayr2022repaint}.
RePaint enables free-form image inpainting, where the goal is to fill in unknown regions of an image specified by an arbitrary binary mask, while maintaining fidelity in the known regions and ensuring global coherence. At each time step \(t\) during sampling, the updated sample \(x_{t-1}\) is composed by combining the re-noised known region ($x_{t-1}^{\text{known}}$), obtained through the forward process, with the denoised unknown region ($\tilde{x}_{t-1}^{\text{unknown}}$) generated by the reverse process. This process can be written as:
\begin{equation}
    \tilde{x}_{t-1} = m \odot x_{t-1}^{\text{known}} + (1 - m) \odot \tilde{x}_{t-1}^{\text{unknown}},
\end{equation}
where \(m\) is a binary mask indicating the known regions. This composition step, applied at every time step $t$ during sampling, preserves consistency in the known regions while generating coherent and realistic content in the unknown regions.

\section{Method}
Clinical data inherently involve multiple interdependent variables that extend beyond imaging alone. While existing diffusion-based models primarily focus on images, we treat all variables as equally important components. 
Without loss of generality, we consider three specific variables in this work: T1-weighted~(T1w) magnetic resonance~(MR) image $I$, age $A$, and sex $S$. These variables were selected because they represent the types of information most commonly available in existing datasets, spanning both high-dimensional imaging data and scalar metadata, encompassing both continuous and categorical data types. MetaVoxel naturally extends to additional variables with minimal modification.
The remainder of this section is organized as follows. Section~\ref{ss:training} introduces the MetaVoxel framework and its formulation for learning the joint distribution over $(I, A, S)$. A diagram is provided in Figure~\ref{f:diagram}.
Section~\ref{ss:inference} describes how MetaVoxel enables flexible inference using arbitrary subsets of inputs to
perform image generation, regression, and classification. 

\begin{figure}[!t]
    \centering
    \includestandalone[width=0.9\textwidth]{diagram}
    \caption{Schematic of the MetaVoxel diffusion framework. Continuous variables such as the image $I$ and age $A$ undergo Gaussian diffusion, while discrete variables such as sex $S$ follow a discrete diffusion process. A single denoising network $f_\theta$ models the shared reverse process.\label{f:diagram}}
\end{figure}

\subsection{Learning the Joint Distribution}
\label{ss:training}

\textbf{Encoding.} To model the joint distribution, MetaVoxel must encode each variable in a form suitable for inclusion in a single joint diffusion process, while allowing each variable to retain an appropriate representation.
For the imaging variable $I$, which consists of high-resolution 3D MR volumes, dimensionality reduction is essential to make diffusion modeling computationally feasible. We adopt a KL-regularized variational autoencoder~(VAE) in which an encoder network $\mathcal{E}$ maps an image $i\in\mathbb{R}^{H\times W\times D}$ to a compact latent representation
$z^I = \mathcal{E}(i)$,  $z^I\in\mathbb{R}^{\frac{H}{8}\times \frac{W}{8}\times\frac{D}{8}}$,
and a decoder $\mathcal{D}$ reconstructs the image from this latent space.
A light KL penalty is applied to keep the distribution of $z^I$ close to a standard normal, which prevents arbitrarily high-variance latent spaces.
We implement this VAE by extending the 2D architecture used in LDM~\cite{rombach2022high} to full 3D, and we train it using the same loss formulation employed in MAISI~\cite{guo2025maisi}.

Scalar metadata, in contrast, do not require dimensionality reduction for computational reasons. We apply a lightweight encoding process to place them in numerically stable ranges and representations compatible with MetaVoxel’s diffusion process.
Continuous variables such as age are linearly scaled to lie approximately within the interval $[-1, 1]$, ensuring numerical stability similar to that imposed on the image latent space. 
Categorical variables such as sex are converted to one-hot vectors.
% , which aligns with the representation in Equation~\ref{e:d3pm}.
We denote these encoded scalar variables as $z^A$ and $\mathbf{z}^S$, respectively.

\textbf{Joint Diffusion.} 
Given the encoded representations, MetaVoxel applies a forward diffusion process to each variable using a corruption mechanism suited to its representation.
For image latents $z^I$ and continuous scalar variables $z^A$, the forward process follows the standard Gaussian diffusion defined in Equation~\ref{e:gaussian_diffusion}.
Categorical variables $\mathbf{z}^S$ follow the discrete corruption process defined in Equation~\ref{e:d3pm}. We adopt transition matrix
\begin{equation}
    \mathbf{Q}_t = (1-\beta_t^{\text{D3PM}})\mathbf{I}+\beta_t^{\text{D3PM}}/K 
    \mathbbm{1}\mathbbm{1}^T,
\end{equation}
where $K$ is the number of categories, $\beta_t^{\text{D3PM}}\in[0,1]$ follows a cosine schedule.
All variables share the same time index $t$, so each diffusion step produces a single noisy tuple $z_t = (z_t^I, z_t^A, \mathbf{z}_t^S)$.

MetaVoxel learns a unified reverse process that jointly denoises all variables.
A single time-conditioned denoising network $f_\theta(z_t, t)$ takes the full noisy tuple $z_t$ as input and predicts the quantities required to reverse their respective corruption processes: Gaussian noise for continuous components and logits for categorical components.
During training, the model is optimized using a combined objective that mirrors the simplified evidence lower bound~(ELBO) formulations for both Gaussian and discrete diffusion.
For the image latent and continuous scalar variables, ELBO collapses into mean-squared error losses between the true and predicted noise.
For categorical variables, the ELBO under the $x_0$-parameterization reduces to a cross-entropy loss between the true class label and the logits predicted by $f_\theta(z_t, t)$.
The overall training objective is therefore
\begin{align}
\mathcal{L}
&=
\mathbb{E}_{z_t,\, t,\, \epsilon^I \sim \mathcal{N}(0, \mathbf{I}), \epsilon^A \sim \mathcal{N}(0, \mathbf{I})}
\left[
    \frac{1}{HWD}
    \left\|
        \epsilon^I
        -
        f_\theta(z_t, t)^{I}
    \right\|_2^2
    +
    \left\|
        \epsilon^A
        -
        f_\theta(z_t, t)^{A}
    \right\|_2^2
\right]
\nonumber\\
&+
\mathbb{E}_{z_0,\, t,\, z_t \sim q(\cdot|z_0)}
\left[
    \mathrm{CE}\left(
        \mathbf{z}_0^{S},
        \mathrm{softmax}(f_\theta(z_t, t)^{S})
    \right)
\right],
\end{align}
where the superscripts on $f_\theta(z_t, t)$ indicate the component of the network output corresponding to each variable and $\mathrm{CE}(\cdot,\cdot)$ refers to the cross-entropy loss.

The denoising network $f_\theta$ must process all variables jointly and produce outputs aligned with the structure of each variable.
To accommodate this multimodal input–output structure, we adapt the U-Net architecture used in LDM into a 3D model.
The scalar variables are integrated into the model through a combination of input-channel concatenation and cross-attention~(CA). At the input, discrete scalar variables are represented as single scalars rescaled to the interval $[-1, 1]$, consistent with continuous variables.
On the output side, the standard U-Net decoder produces the prediction for image variables, while additional lightweight decoding heads are attached to the bottleneck features for the scalar variables~(see Figure~\ref{f:diagram}). 
Each scalar head consists of two (GroupNorm $\rightarrow$ SiLU $\rightarrow$ Conv) blocks with a skip connection, followed by global average pooling to produce a single output value.

\begin{figure}[!t]
    \centering
    \includestandalone[width=0.9\textwidth]{unconditional_sampling}
    \caption{Unconditional sampling in MetaVoxel. At time step $T=1000$, Gaussian noise is sampled for image latents and continuous variables, and categorical noise for categorical variables.
    The variables are then jointly denoised from $T=1000$ to $T=0$ to generate coherent samples from the learned joint distribution. All images, ages, and sexes shown are decoded from the latent space before visualization. Additional examples can be found in Appendix~(Figure~\ref{f:image}).\label{f:unconditional}}
\end{figure}


\subsection{Zero-shot Inference with Arbitrary Conditioning}
\label{ss:inference}

A trained MetaVoxel model can be \textit{unconditionally sampled} by drawing Gaussian or categorical noise for each variable at timestep $T$, and iteratively applying the denoising network $f_\theta$ until reaching $t=0$. Unlike conventional diffusion models that focus solely on image generation, MetaVoxel can generate coherent synthetic patient profiles from the joint distribution $p(I, A, S)$, as shown in Figure~\ref{f:unconditional}.

A distinctive strength of MetaVoxel is its ability to perform flexible zero-shot inference for a broad spectrum of tasks with arbitrary subsets of inputs.
To enable this, we reinterpret the RePaint strategy described in Section~\ref{s:related_works}: just as RePaint performs conditional image generation by specifying a binary mask that marks which pixels are fixed, we can further designate any subset of MetaVoxel’s variables as ``known'' conditions.
At each denoising step, these known variables are overwritten with their re-noised values, ensuring that they remain \textbf{fixed} throughout the sampling trajectory, while the remaining variables evolve according to the learned reverse process.
This keeps the synthesized variables consistent with the conditioning, analogous to how RePaint produces inpainted regions consistent with the known region.
Although MetaVoxel learns the joint distribution $p(I, A, S)$, different downstream tasks can be realized simply by choosing which variables to fix during sampling. Image inpainting is achieved by fixing pixel regions of the image variable\footnote{see examples in Appendix (Figure~\ref{f:inpaint})}; conditional image generation by fixing variables such as age or sex; regression tasks (\eg age estimation) by fixing any combination of image and sex; and classification tasks (\eg sex prediction) by fixing any combination of image and age. 
As more variables are included in the joint model, the range of possible zero-shot tasks naturally expands. For example, when multiple imaging modalities are present, fixing one modality and sampling the other enables image-to-image translation.




\section{Experiments}
\label{s:exp}
MetaVoxel introduces several architectural and functional components that are rare, and in combination, not seen in existing work. 
% Architecturally, it implements a unified diffusion process that jointly models continuous variables through Gaussian diffusion and discrete variables through discrete diffusion.
% Functionally, a single trained MetaVoxel model can perform image generation, regression, and classification without task-specific retraining. In addition, it supports flexible zero-shot conditioning at inference, enabling the model to respond to arbitrary subsets of known variables.
While prior efforts have addressed isolated components, none to our knowledge provide a unified framework with the combined capabilities of MetaVoxel. As a result, direct head-to-head comparisons with a single baseline are not feasible. Given this context, our experiments were structured around three key questions:
\begin{enumerate}
    \item How does introducing discrete diffusion influence the quality of image generation compared to image-only diffusion baselines?
    \item How does MetaVoxel perform on regression versus specialized regression models?
    \item How does MetaVoxel perform on classification versus specialized classification models?
\end{enumerate}

\textbf{Datasets}: We compiled a cohort of 10,154 T1-weighted~(T1w) brain MR scans with age and sex information. These scans were drawn from nine datasets: ADNI~\cite{jack2008alzheimer}, BIOCARD~\cite{sacktor2017biocard}, BLSA~\cite{shock1984normal}, HCPA~\cite{bookheimer2019lifespan,harms2018extending}, ICBM~\cite{mazziotta2001probabilistic}, NACC~\cite{beekly2007national}, OASIS-3~\cite{lamontagne2019oasis}, ROS/MAP/MARS~\cite{bennett2018religious,l2012minority}, and WRAP~\cite{johnson2018wisconsin}.
Additional information on ADNI and NACC can be found in Appendix~\ref{app:dataset}. 
\cbstart
The age and sex distribution of the dataset are shown in Appendix~(Figure~\ref{app:data_demo}). 
\cbend
Only unimpaired individuals were included. All images underwent standardized preprocessing consisting of N4 inhomogeneity-field correction~\cite{tustison2010n4itk} and skull stripping using HD-BET~\cite{isensee2019hbm}; for subjects with multiple imaging sessions, the baseline scan was first rigidly registered using ANTs~\cite{avants2008symmetric} to an MNI template~\cite{fonov-unbiased-2009}, and all follow-up scans were subsequently rigidly registered to the registered baseline. Final preprocessed images underwent manual quality assurance by the authors using a custom Python-based application interface~\cite{kim2025scalable}.
We then performed a \textbf{subject-level} data split, allocating approximately 89\% of subjects for training, 1\% for validation, and 10\% for testing, resulting in 9,078 cases for training, 108 for validation, and 968 for testing.

We trained a single MetaVoxel model to jointly model the T1w image, age, and sex using the training set. \textbf{Unless otherwise stated, all MetaVoxel results were generated from this single model.}
The validation set was used to select the optimal training epoch.
To accelerate the sampling process, we used DDIM with 50 steps for continuous variables and k-step sampling ($k=20$) for discrete variables.
All baseline methods used the same data split and the same protocol for model selection and evaluation.
\cbstart
For the VAE, we trained on the training set and then froze the weights during MetaVoxel training.
\cbend



\subsection{Impact of Discrete Diffusion on Image Generation}

Discrete diffusion is essential in  incorporating categorical variables, such as sex.
However, introducing a discrete diffusion pathway alongside Gaussian diffusion may influence image generation quality.
Since joint diffusion models have not been explored in prior work, the effect of combining these two diffusion processes remains unknown.
To evaluate how discrete diffusion affects image generation, we compared MetaVoxel to two baselines: (a) LDM: a latent diffusion model trained solely on T1w images, and (b) Continuous-Sex MetaVoxel: a variant of MetaVoxel in which the sex variable is treated as an extra continuous scalar and modeled using Gaussian diffusion.

We drew 100 samples from each of the synthetic and real distributions and evaluated sample quality using the Fr\'echet Inception Distance~(FID). FID measures the distance between the distribution of generated images and the distribution of real images from the held-out test set by comparing their activations in a pretrained network.
To account for the 3D nature of the data, we computed FID using the approach implemented in MAISI~\cite{guo2025maisi}: each 3D T1w scan was sliced along the axial, coronal, and sagittal planes, FID was computed separately for each plane using the corresponding slice-level feature distributions, and the three values were averaged to obtain the final FID.
\cbstart
Using the same strategy, we additionally computed the Maximum Mean Discrepancy (MMD) between generated and real images. Specifically, MMD was evaluated on slice-level feature distributions for each anatomical plane using a Gaussian radial basis function kernel, and the resulting axial, coronal, and sagittal MMD values were averaged to obtain a single 3D MMD score.

Although LDM converges in roughly one-third the training time required by MetaVoxel and Continuous-Sex MetaVoxel, the three models achieve comparable FID and MMD~(Table~\ref{t:results}), with no degradation observed when adding either continuous or discrete diffusion pathways.
\cbend
Given that the continuous-sex variant cannot naturally extend beyond binary categories, discrete diffusion offers a practical and scalable way to represent categorical data within the joint model. These results indicate that discrete diffusion can be integrated into a joint generative framework without compromising image sample quality.


\begin{table}[!t]
    \centering
    \caption{Quantitative results across image generation, regression, and classification tasks. An em-dash (--) indicates cases where a model, under its specific training setup, is not applicable or cannot perform the corresponding task.\label{t:results}}
    \adjustbox{max width = 0.8 \textwidth}
    {{\renewcommand{\arraystretch}{1.5}
        \begin{tabular}{r C{0.15\textwidth} C{0.15\textwidth} r C{0.2\textwidth} r C{0.2\textwidth}
        }
            \toprule
            & \multicolumn{2}{c}{\textbf{Image Generation}} & \quad & \textbf{Age Estimation} & \quad & \textbf{Sex Prediction}
            \\
            \cmidrule(lr){2-3} \cmidrule(lr){5-5} \cmidrule(lr){7-7}
            \textbf{Method} & FID $\bolddownarrow$  & MMD $\bolddownarrow$ &  & MAE $\bolddownarrow$ & & ACC $\bolduparrow$
            \\
            \cmidrule(lr){1-1} \cmidrule(lr){2-3} \cmidrule(lr){5-5} \cmidrule(lr){7-7}
            LDM~\cite{rombach2022high} & 10.94 & \textbf{0.15} && -- && --
            \\
            \cmidrule(lr){1-1} \cmidrule(lr){2-3} \cmidrule(lr){5-5} \cmidrule(lr){7-7}
            \rowcolor{gray!10}
            3D-DenseNet(MSE)~\cite{huang2017densely}& \multicolumn{2}{c}{--} && $\mathbf{3.96\pm2.92}$ && --
            \\
            3D-ViT(MSE)~\cite{dosovitskiy2020image} & \multicolumn{2}{c}{--} && $7.99\pm6.72$ && --
            \\
            \rowcolor{gray!10}
            BRAID-T1w~\cite{gao2025brain} & \multicolumn{2}{c}{--} && $4.01\pm3.31$ && --
            \\
            \cmidrule(lr){1-1} \cmidrule(lr){2-3} \cmidrule(lr){5-5} \cmidrule(lr){6-6}
            3D-Dense(CE)~\cite{huang2017densely}& \multicolumn{2}{c}{--} && -- && \textbf{0.884}
            \\
            \rowcolor{gray!10}
            3D-ViT(CE)~\cite{dosovitskiy2020image} & \multicolumn{2}{c}{--} && -- && 0.788
            \\
            \cmidrule(lr){1-1} \cmidrule(lr){2-3} \cmidrule(lr){5-5} \cmidrule(lr){7-7}
            MetaVoxel & \textbf{10.84} & 0.16 && $4.50\pm3.46$ && $0.815$
            \\
            \rowcolor{gray!10}
            Continuous-Sex MetaVoxel & 11.18 & 0.16 && $4.95\pm4.16$ && $0.855$
            \\
        \bottomrule
        \end{tabular}
    }}
\end{table}


\subsection{Regression and Classification}
Although diffusion models have recently matured as powerful image generative methods, there has been little incentive to develop dedicated conditional diffusion models for regression or classification. These tasks already have well-established solutions, and in most settings the diffusion sampling process offers no clear advantage.
MetaVoxel sidesteps the need of dedicated conditional models for accessing
conditional distributions such as $p(A|I,S)$ or $p(S|I,A)$.
To compare MetaVoxel with established approaches, we examine its performance on age estimation~(regression) and sex prediction~(classification).

\textbf{Age estimation}: we sampled MetaVoxel three times with $I$ and $S$ treated as known variables during the sampling process. The predicted age was obtained by averaging the three sampled values. We then reported the mean absolute error between predicted and true age for MetaVoxel and the following comparison methods:
(a) 3D-DenseNet(MSE): A 3D DenseNet that receives the T1w image with sex concatenated at the input channel level and is trained using mean squared error~(MSE) loss; (b) ViT(MSE): A Vision Transformer that receives the T1w image with sex concatenated at the input channel level and is trained with MSE loss;
(c) BRAID-T1w: A ResNet-based architecture that incorporates T1w image and sex information at the feature level and is trained with MSE loss.

\textbf{Sex prediction}: we sampled MetaVoxel three times with $I$ and $A$ as known variables in the sampling process. We then used majority voting to determine the prediction and
we reported classification accuracy on MetaVoxel and the following methods: (a) 3D-DenseNet(CE): A 3D DenseNet that receives the T1w image with age concatenated at the input channel level and is trained with cross-entropy loss; (b) ViT(CE): A Vision Transformer that receives the T1w image with age concatenated at the input channel level and is trained with CE loss.

Table~\ref{t:results} shows that MetaVoxel’s performance on both age estimation and sex prediction is well within the range of established discriminative models. Its primary drawback is computational cost: generating a single prediction requires roughly $30$ seconds on a Nvidia A6000 GPU, whereas the baseline models produce outputs in under a second. Despite this disadvantage, the results demonstrate that a trained MetaVoxel model can be used for regression and classification without any task-specific retraining. Moreover, because MetaVoxel models the entire joint distribution rather than only conditional means~(in regression) or decision boundaries~(in classification), it naturally supports flexible conditioning scenarios and sample-based uncertainty that are not directly available to conventional discriminative models.
To illustrate this, we use age estimation as an example. With the same MetaVoxel model, we generated age samples under any subset of observed variables~(\eg T1w scans, sex, both, or none) thereby accessing $p(A|I,S)$, $p(A|I)$, $p(A|S)$, and $p(A)$ without retraining. For each conditioning choice, we used the average of three samples of $A$ to obtain the predicted age, and compared it with the true age. We summarize the results in Table~\ref{t:condition}. The MAE depends almost entirely on whether the image is provided. Adding sex as an additional variable offers no measurable benefit, and using only sex or no information yields MAEs comparable to using population mean of the training data as prediction. We also find that, when the image is absent, the mean sample variance increases substantially, reflecting greater uncertainty that closely tracks the rise in MAE.


\begin{table}[!t]
    \centering
    \caption{Flexible conditioning of MetaVoxel. MetaVoxel is evaluated by sampling age under four conditioning settings: (I, S): conditioned on both the T1w MR scan and sex; (I): conditioned only on the T1w MR scan; (S): conditioned only on sex; and ($\varnothing$): the unconditional setting with no fixed variables.\label{t:condition}}
    \adjustbox{max width = 0.9 \textwidth}
    {{\renewcommand{\arraystretch}{1.5}
        \begin{tabular}{r C{0.2\textwidth} C{0.2\textwidth}C{0.2\textwidth} C{0.2\textwidth}
        C{0.25\textwidth}
        }
            \toprule
            &
            MetaVoxel$(I, S)$ & MetaVoxel$(I)$  & MetaVoxel$(S)$ & MetaVoxel$(\varnothing)$
            & Population Mean
            \\
             \cmidrule(lr){2-2} \cmidrule(lr){3-3} \cmidrule(lr){4-4}
            \cmidrule(lr){5-5}
            \cmidrule(lr){6-6}
            MAE & $4.50\pm3.46$ & $4.57\pm3.54$ & $11.34\pm9.15$ & $11.27\pm8.96$ & $10.76\pm7.76$
            \\
            \rowcolor{gray!10}
            Mean Sample Variance & $9.76$ & $9.88$ & $62.89$ & $64.61$ & --
            \\
        \bottomrule
        \end{tabular}
    }}
\end{table}

\section{Conclusion}

We introduced MetaVoxel, a multimodal diffusion framework that learns a single joint generative model over imaging and clinical metadata. By modeling the full joint distribution, MetaVoxel unifies tasks that traditionally require separate conditional architectures. Experiments on more than $10,000$ T1-weighted magnetic resonance scans demonstrate that this single generative model can perform image generation, age estimation, and sex prediction with performance comparable to established baselines, despite relying solely on diffusion-based sampling at inference.
Beyond matching task-specific models, MetaVoxel enables zero-shot inference from arbitrary subsets of imaging and metadata.
These results highlight the potential of joint multimodal diffusion modeling as a foundation for general-purpose medical AI systems.


\clearpage  % Acknowledgements, references, and appendix do not count toward the page limit (if any)
% Acknowledgments---Will not appear in anonymized version
\input{acknowledgement}


\bibliography{midl26_41}

\input{appendix}
% \appendix
% \section{test}
% \lipsum{10}

\end{document}
