%\documentclass{midl} % Include author names
\documentclass{midl} % Anonymized submission

% The following packages will be automatically loaded:
% jmlr, amsmath, amssymb, natbib, graphicx, url, algorithm2e
% ifoddpage, relsize and probably more
% make sure they are installed with your latex distribution

\usepackage{mwe}% to get dummy images
\usepackage{bbding}
\usepackage{amsfonts}
\usepackage{mathtools}
\usepackage{float}
\usepackage[normalem]{ulem}
\usepackage{multirow}
\graphicspath{{figures/}}
\usepackage{color}
%\usepackage[svgnames]{xcolor}

\jmlryear{2024}
\jmlrworkshop{Full Paper -- MIDL 2024}
\jmlrvolume{-- nnn}
\editors{Accepted for publication at MIDL 2024}

\title[Conditional Generation of 3D Brain Tumor ROIs]{Conditional Generation of 3D Brain Tumor ROIs via VQGAN and Temporal-Agnostic Masked Transformer}


\DeclarePairedDelimiter{\ceil}{\lceil}{\rceil}

 \midlauthor{\Name{Meng Zhou\nametag{$^{1,2}$}} \Email{simonzhou@cs.toronto.edu}\\
  \Name{Farzad Khalvati\nametag{$^{1,2,3}$}} \Email{farzad.khalvati@utoronto.ca}\\
  \addr $^{1}$ Neurosciences \& Mental Health Research Program, SickKids Research Institute, Toronto, ON, Canada \\
  \addr $^{2}$ Department of Computer Science, University of Toronto, Toronto, ON, Canada \\
  \addr $^{3}$ Department of Medical Imaging, University of Toronto, Toronto, ON, Canada}

\begin{document}

\maketitle

\begin{abstract}
Neuroradiology studies often suffer from lack of sufficient data to properly train deep learning models. Generative Adversarial Networks (GANs) can mitigate this problem by generating synthetic images to augment training datasets. However, GANs sometimes are unstable and struggle to produce high-fidelity images. An alternative solution is Diffusion Probabilistic Models, but these models require extensive computational resources. Additionally, most of the existing generation models are designed to generate the entire image volumes, rather than the regions of interest (ROIs) such as the tumor region. Research shows that it is easier to classify tumor types based on ROIs than the entire image volumes. To this end, we present a class-conditioned ROI generation framework that combines a vector-quantization GAN and a class-conditioned masked Transformer to generate high-resolution and diverse 3D brain tumor ROIs. We further propose a temporal-agnostic masking strategy to effectively learn relationships between semantic tokens in the latent space. Our experiments demonstrate that the proposed method can generate high-quality 3D MRIs of brain tumor regions for both low- and high-grade glioma (LGG/HGG) in the BraTS 2019 dataset. Using the generated data, our approach demonstrates superior performance compared to several baselines in a downstream task of brain tumor type classification. Our proposed method has the potential to facilitate accurate diagnosis of rare brain tumors using MRI-based machine learning models.
\end{abstract}

\begin{keywords}
Generative Adversarial Networks, Transformer, Image Generation, 3D MRI, Data Augmentation
\end{keywords}

\section{Introduction}

Gliomas are the most frequent primary adult brain tumor types within the central nervous system \cite{menze2014multimodal,bakas2017advancing}. Among all variations of gliomas, high-grade glioma (HGG) accounts for the majority of cases, and low-grade glioma (LGG) accounts for less common cases. For both variations, a commonly used technique for diagnosing is the multi-parametric Magnetic Resonance Imaging (MRI) equipped with different sequences such as T1-, T2-weighted, and Fluid Attenuated Inversion Recovery (FLAIR) \cite{menze2014multimodal}. Each modality provides distinct biological information about the tumor, aiding radiologists in determining the tumor type. However, distinguishing between HGG and LGG remains challenging, and misdiagnosis may lead to suboptimal prognoses \cite{mzoughi2020deep}.

In recent years, deep learning-based methods have proven to be one of the effective ways for adult brain tumor classification tasks using brain MR images \cite{ge2020deep,hao2021transfer,namdar2022tumor,tandel2020multiclass}. However, the requirement for large training datasets poses challenges in medical imaging, especially for rare diseases such as LGG, leading to potential overfitting and poor generalization to unseen datasets. Several works aim to mitigate the imbalanced data problem. One line of work is to use the transfer learning approach by pre-training models on huge datasets (i.e., ImageNet), and then fine-tuning them on domain-specific datasets \cite{ghazal2022alzheimer,tak2023noninvasive,ullah2022effective}. Another line of work is to synthesize MRIs using Generative Adversarial Network (GAN) \cite{volokitin2020modelling,kwon2019generation,sun2020adversarial,xia2020pseudo} or Diffusion-based methods \cite{khader2022medical,peng2022generating,dorjsembe2023conditional,sanchez2022healthy} to alleviate the need for extensive datasets. However, GANs for image generation can be unstable, produce blurry images, and encounter mode collapse problems \cite{kwon2019generation}. As an alternative approach, Diffusion Probabilistic models have been proposed and demonstrated superior performance over GANs \cite{muller2022diffusion}, but these methods are extremely computationally expensive when synthesizing full-resolution MRIs, thus posing challenges in both the training and inference stage. More recently, autoregressive transformer models have attracted increasing attention in image generation tasks \cite{esser2021imagebart,esser2021taming,huang2023not}. The key idea behind such models is to obtain discretized feature maps from a Vector-Quantization GAN (VQGAN) model and then use the transformer model to learn the compositions. Autoregressive transformers have been extended to medical images \cite{pinaya2023generative,tudosiu2022morphology,zhou2023generating}, to unconditionally generate brain MR images. However, one limitation of these models is the lack of conditioning; separate models need to be trained for different tasks, which becomes time-consuming and resource-intensive. Therefore, a condition-based image generation model is important for practical applications in real clinical settings. Moreover, a group of works \cite{sajjad2019multi,mzoughi2020deep,srinivasan2023grade} have demonstrated the effectiveness of using the tumor region of interest (ROI) for classifying tumor types because ROIs contain less information than the whole-image that may negatively affect the results. Hence, this work aims to explore generating different brain MRI tumor ROIs based on their pathology labels.
To this end, we introduce the first class-conditional generation framework for synthesizing 3D brain tumor MRI ROIs. Our model is built upon the previous work \cite{zhou2023generating} and extended to a conditional generation paradigm. Our framework has three modules: a 3D-VQGAN image encoder to extract high-level feature maps while concurrently learning the importance score for each region in the feature maps; an Exponential Moving Averages(EMA) codebook with $l_2$-norm lookup for converting feature maps into discrete semantic tokens \cite{peng2022beit}; and a temporal-agnostic masked transformer to learn the relationships between discrete tokens. We evaluated our proposed method in the BraTS 2019 dataset and demonstrated superior performance over several baselines on both image generation quality and the downstream HGG vs. LGG classification task. \textbf{Our contributions are as follows:} \textbf{(1). }We propose the first image generation framework for different tumor types based on the given class label. \textbf{(2). }We use a \textit{classifier-guidance approach} to learn the importance score for each region in the encoded feature maps. \textbf{(3). }We propose a novel \textit{temporal-agnostic hybrid masking strategy} which uses the importance score to mask tokens to prevent any information leakage. \textbf{(4). }Experiments show our proposed method outperforms several baselines in both image generation quality and the downstream classification task. 

\section{Materials and Methods}
\subsection{Model Architecture} \label{model_arc}
We adopted and extended the VQGAN \cite{esser2021taming} and recently proposed 3D-VQGAN \cite{zhou2023generating} with some modifications detailed below for class-conditional generation of brain tumor ROIs. 

\begin{figure}[htbp]
	\floatconts
    {fig:model}
	{\vspace{-0.8cm}\caption{Detailed overview of the proposed method. Our method contains two modules, \textit{Top: }a 3D-VQGAN model to encode 3D inputs, generate importance score for each region, and further quantize to discrete tokens. \textit{Bottom left: }a class-conditional masked transformer to capture the long-term dependency via mask token modeling based on the importance score and class label information. \textit{Bottom right: }difference between random and our proposed masking strategy.}}
	{\includegraphics[width=0.78\linewidth]{figures/main_all.png}}
\end{figure}

\noindent \textbf{Stage 1. 3D-VQGAN:} The first stage is shown at the top of \figureref{fig:model}. We train all modules presented to learn efficient data representation through a reconstruction task in this stage. Our encoder, decoder, and discriminator follow the same design as in \cite{zhou2023generating}, except we replace batch normalization with group normalization to stabilize the training process for small batch sizes \cite{wu2018group}.

\noindent \textbf{Importance Score Map:} We use a lightweight scoring network $f$ before quantization to assign an importance score for each region in encoded feature maps. Let the encoded map with size $z_e \in \mathbb{R}^{H\times W \times Dp \times n_z}$ where $H, W, Dp, n_z$ denote the height, width, depth, and the number of feature maps, respectively. Then, for each region $r_i \in \mathbb{R}^{n_z}$, its score is defined by $s_i = f(r_i)$, where $i = 1,...,H \times W \times Dp$. The larger the score $s_i$ is, the more important the feature region $r_i$ is. To learn $f$, we use a \textit{classifier-guidance} approach by using an auxiliary MLP-based classifier $f_{cls}$ after $f$ to classify tumor types (e.g., HGG vs. LGG). We hypothesize that regions with the higher scores (i.e., important regions) are the key to differentiating between two tumor types, and thus by optimizing both $f$ and $f_{cls}$, the model can identify important feature regions that are specific to either LGG or HGG tumors.

\noindent \textbf{Quantization:} In the quantization step, the latent feature maps are quantized by replacing each one with its closest corresponding vector in codebook $C$. Formally, we train a learnable codebook $C = \{c_i\}_{i=1}^{K}$ that transforms feature vectors $z_e$ to $K:=H\times W \times Dp$ discrete tokens $c_q, q \in [1,K]$ by the nearest neighbor search in $C$, and each token $c_q$ includes an embedding vector $c_z \in \mathbb{R}^{n_z}$. We use the $l_2$ normalization for codebook lookup, as done in \cite{yu2021vector}. Finally, we stack $K$ quantized feature vectors back to the original latent shape and feed it into the decoder $D$ to produce reconstructed images.

\noindent\textbf{Stage 2. Class-conditional Masked Transformer:} In this stage, we propose a novel \textbf{temporal-agnostic hybrid masking} strategy based on importance scores computed in Stage 1, which is inspired by DropBlock \cite{ghiasi2018dropblock} and BERT \cite{devlin2018bert}. The 3D images are initially represented in the latent space, and the encoder $E$, decoder $D$, and codebook $C$ are fixed, with only the transformer being trained. The encoded feature map $z_e$ of size $H \times W \times Dp \times n_z$ is quantized into a set of $L$ discrete tokens, where $L = H \times W \times Dp$. We first set the masking ratio $\alpha$ and randomly sample $N = L \times \alpha$ tokens to be masked. These tokens are then divided into two equal subsets: $N_1$ and $N_2$, which denote the number of important tokens and unimportant tokens to be masked based on their importance score. Let $\mathbf{Y} = \{y_i\}_{i=1}^{L}$ be the raster-scan linearized discrete tokens and each of $y_i$ associates with their importance score $s_i$. The importance score and the corresponding tokens are sorted in descending order, denoted as $\mathbf{Y}^{'}$. For $N_1$ important tokens, we randomly sample $\ceil{\frac{N_{1}}{2}}$ tokens from the \textbf{\textit{top-k = 25\%}} of $\mathbf{Y}^{'}$. For each selected token, we also mask along with their spatial or temporal neighborhood tokens. Unlike random masking, this \textit{blockwise masking} around each selected token prevents information leakage from neighbors, enhancing the model's learning ability on important tokens and preventing short-cut learning. The special \texttt{[MASK]} token is used to mask out these important tokens and their associated blocks. For $N_2$ unimportant tokens, we randomly sample them from the remaining $(1-k) \times \mathbf{Y}^{'}$ tokens and replace them with $N_2$ randomly selected tokens from the codebook $C$. Importantly, we ensure that $N_1$ and $N_2$ are non-overlapping ($N_1 \cap N_2 = \emptyset$). We denote $\mathbf{M} = \{m_i\}_{i=1}^{L}$ be the mask for each of the discrete tokens, where $m_i = 1$ if the token $i$ is \textit{unmasked} and $m_i = 0$ if the token $i$ is \textit{masked out}. Finally, we prepend a class label indicating HGG or LGG sample at the start of each indices sequence. During training, the objective is to reconstruct the masked tokens using unmasked ones. The introduced noise from masked-out tokens is hypothesized to enhance the transformer model's ability to learn relationships between semantic tokens, improving overall robustness. The proposed masking strategy and its difference between random masking is depicted in the bottom right of \figureref{fig:model}.

\noindent \textbf{Classification: }For the downstream classification task between LGG and HGG tumor types, we use a standard 3D ResNet-50 model \cite{hara2017learning} that takes 3D tumor ROIs as inputs and outputs two class probabilities for two tumor types.

% \vspace{-0.3cm}
\subsection{Loss Function}

We employ the same loss function as done in \cite{zhou2023generating}. We use a combination of the pixel differences loss ($\mathcal{L}_{pixel}$), perceptual loss ($\mathcal{L}_{perp}$) \cite{johnson2016perceptual}, GAN-based feature matching loss ($\mathcal{L}_{match}$) \cite{ge2022long}, 3D image gradient loss ($\mathcal{L}_{grad}$), codebook loss ($\mathcal{L}_{codebook}$) \cite{esser2021taming}, and the discriminator loss ($\mathcal{L}_{Dis}$) in the first stage. See \equationref{l1perpmatchgrad} and \equationref{discodebook} for details.

\begin{equation} \label{l1perpmatchgrad}
    \begin{gathered}
    \mathcal{L}_{pixel} = \|x - \hat{x}\|_1, \
    \mathcal{L}_{perp} = \sum_{j=1}^{6}\|f^{i}(x_{j}) - f^{i}(\hat{x_j})\|_{2}^{2}, \ \mathcal{L}_{match} = \|f_{Dis}^{i}(x) - f_{Dis}^{i}(\hat{x})\|_{1}, \\
    \mathcal{L}_{grad} = \|\nabla(A(x))-\nabla(A(\hat{x}))\|_2^2
        + \|\nabla(R(x))-\nabla(R(\hat{x}))\|_2^2
        + \|\nabla(S(x))-\nabla(S(\hat{x}))\|_2^2
    \end{gathered}
\end{equation}

\begin{equation} \label{discodebook}
    \begin{gathered}
        \mathcal{L}_{Dis} = \mathbb{E}_{x\sim p_{d}}[max(0,1-D(x))] + \mathbb{E}_{\hat{x}\sim p_{\hat{d}}}[max(0,1+D(\hat{x})], \\
        \mathcal{L}_{codebook} = \|sg[E(x)]-c_z\|_2^2 + \beta\|sg[c_z]-E(x)\|_2^2
    \end{gathered}
\end{equation}

Where $x$ is the original image and $\hat{x}$ is the reconstructed image, $\nabla(\cdot)$ computes the $x$- and $y$-direction gradients of the image, $A(x), R(x), S(x)$ represents slicing over Axial, Coronal and Sagittal plane, respectively. Additionally, we use the standard cross-entropy loss $\mathcal{L}_{ce}$ between class logits and class labels for our auxiliary classifier $f_{cls}$. Aggregating all the loss terms together yields the loss objective in Equation \eqref{loss_obj} for the first stage of the framework:

\begin{equation}
    \begin{gathered}
    \min_{E,D,C}(\max_{Dis}(\mathcal{L}_{Dis})) \\
    \min_{E,D,C} C_1*(\lambda_1\mathcal{L}_{pixel} + \lambda_2\mathcal{L}_{perp} + \lambda_3\mathcal{L}_{match} + \lambda_4\mathcal{L}_{grad} + \lambda_5\mathcal{L}_{codebook}) + C_2*\mathcal{L}_{ce} \label{loss_obj}
    \end{gathered}
\end{equation}

Where $\lambda_i, i\in[1,5]$ is the weighting factor between different loss terms. We follow previous publications \cite{ge2022long,khader2022medical} to set $\lambda_1 = \lambda_3 = 4$ and $\lambda_2 = \lambda_5 = 1$. We also set $\lambda_4 = 4$ and $\beta$ in $\mathcal{L}_{codebook}$ to be 1. $C_1$ and $C_2$ are balancing factors between the main task and auxiliary task, we empirically set $C_1 = 0.8$ and $C_2 = 0.2$.
% Since the image gradient loss is as important as the pixel $\mathcal{L}_1$ loss ($\mathcal{L}_{pixel}$), we set $\lambda_4 = 4$ as well. We also set $\beta$ in $\mathcal{L}_{codebook}$ to be 1 instead of 0.25.

For the transformer model, we use the cross entropy loss between the reconstructed token sequence and the ground truth token sequence as shown in Equation \eqref{masked_loss} to optimize the transformer.

\begin{align} \label{masked_loss}
    %\mathcal{L}_{transformer} = \mathbb{E}_{z\sim p(z_{data})}[-\log p(z)]
    \mathcal{L}_{transformer} = -\mathbb{E}_{\mathbf{Y}\in \mathcal{D}}(\sum_{\forall i, m_i = 0}log p(y_i|\mathbf{Y}_{M}))
\end{align}

Where $\mathcal{D}$ is the training dataset, $\mathbf{Y}_{M}$ denotes the \textit{unmasked} tokens, thus the masked tokens can conditioned on these unmasked tokens during training.

\subsection{Data and Preprocessing}

We used the FLAIR sequence data from the BraTS 2019 dataset \cite{bakas2017advancing,bakas2018identifying,menze2014multimodal}. The data contains 259 HGG patients and 76 LGG patients. We used this dataset since it is the latest version of the dataset which provides labels for the brain tumor pathology classification, i.e., HGG vs. LGG. We reshaped the data from $240 \times 240 \times 155$ to $128 \times 128 \times 128$ and normalized all pixel intensities in $[-1, 1]$. To achieve this, we first remove all zero-valued slices in both the brain images and the segmentations, since we are interested in the slices with the brain tumor present. Then, we obtain the ROIs by multiplying the images with masks. Finally, we center-crop the region based on the segmentation mask to a target size of $128 \times 128 \times 128$. 

\raggedbottom
\section{Experiments} \label{exp_det_brats}
For the first stage of the proposed 3D-VQGAN-cond model, we train for 10k epochs with an initial learning rate of 0.0001 and cosine decay to 0 for all sub-modules, a mini-batch size of 3, and with the Adam optimizer \cite{kingma2014adam}. We set the codebook size $K = 1024$. For the second stage, we train the transformer for 5k epochs using a learning rate of $4.5e-06$, a mini-batch size of 3, and the AdamW optimizer \cite{loshchilov2017decoupled}. We set the mask ratio $\alpha = 0.5$. We randomly held out 25 patients from both HGG and LGG as a standalone test set. The rest of the data is used to train our model.

To assess the usability of our generated data, we conducted two sets of classification experiments: \textbf{Experiment (1). }We aimed to determine if models trained on synthetic data are better than those trained without or only using a portion of synthetic data. Following the approach in \cite{zhou2023generating}, we compared our classification model pre-trained with \textit{both synthetic LGG and HGG} to one pre-trained with \textit{real HGG and synthetic/real LGG}. We use the same amount of data for pre-training and the \textbf{same data} for fine-tuning. \textbf{Experiment (2). }Our goal was to investigate whether increasing the number of synthetic data for pre-training enhances classification performance. We generated 250 synthetic HGG and LGG from baseline models and our proposed model to pre-train the classification model. The data for fine-tuning remained consistent. More details can be found in Appendix \ref{app:train_det}. The ablations on the \textit{top-k} ratio can be found in Appendix \ref{app:abl} due to page limit.

\noindent \textbf{Baseline Model \& Comparison. }For comparison of image generation results, we consider five state-of-the-art methods, 3D-WGAN-GP \cite{gulrajani2017improved}, 3D-$\alpha$WGAN \cite{kwon2019generation}, 3D-Med-DDPM \cite{dorjsembe2023conditional}, Medical Diffusion \cite{khader2022medical} and 3D-VQGAN \cite{zhou2023generating}. We re-implemented the Medical Diffusion to make it class-conditioned (Medical Diffusion-C) which can be jointly trained on LGG and HGG data, as our class-conditioned baseline. Other baselines were rerun \textit{separately} for two data classes. For classification, we establish a baseline where we only use traditional augmentations, ensuring a fair comparison with other methods. We evaluate the quality of generated MRI ROIs using three commonly used metrics in previous publications \cite{kwon2019generation,peng2022generating,dorjsembe2023conditional,zhou2023generating}: maximum mean discrepancy (MMD) \cite{gretton2012kernel}, multi-slice structure similarity (MS-SSIM) \cite{rosca2017variational}, and the Fr\'echet Inception Distance (FID) \cite{heusel2017gans}. FID score is computed in three views (Axial, Coronal, Sagittal) to reflect the nature of medical images. The classification performance is evaluated using AUC, F1-Score, and Accuracy.
\noindent{\textbf{Generating Synthetic MRI Data. }}To generate 3D tumor ROIs, we start with a class token, either 0 for LGG or 1 for HGG, and then have the transformer model predict and complete the rest of the indices. Next, we obtain the embedding vectors of each index from the codebook $C$ and then feed them into decoder $D$ to produce the final images. For other baselines, we follow the exact procedure as stated in their paper.

\section{Results and Discussions} \label{qual_analysis}
\subsection{Results for Generated Images} \label{qua_ana_brats}

In \figureref{fig:qual_res}, we compare LGG and HGG ROIs generated by baseline models and our proposed method. The center three slices in the Axial plane are shown for better visual quality. We observed that both 3D-WGAN-GP and 3D-$\alpha$WGAN produce images that lack detail and exhibit major artifacts. 3D-Med-DDPM and Medical Diffusion/-C introduce noise and checkerboard artifacts. Generated images from both 3D-VQGAN and our proposed method contain the detailed attributes of the tumor and exhibit high image fidelity. Quantitative metrics are computed over 250 generated HGG and LGG samples as shown in \tableref{tab:quan_res}. It can be seen that our proposed method performs best in terms of preserving diversity based on the MS-SSIM score. For the MMD score on LGG data, our method outperforms all methods except 3D-$\alpha$WGAN, which we argue that this is not a fair comparison because it exhibits a severe mode collapse problem (99.4 in MS-SSIM). For the MMD score on HGG data, our method is slightly worse than Medical Diffusion/-C, but it still outperforms other baselines. For FID, our method consistently outperforms on FID-A score, and the other two FID-S and FID-C scores are very close to the best performance. It is also worth noting that all baselines except Medical Diffusion-C are trained \textit{separately}, whereas ours and Medical Diffusion-C are trained \textit{jointly} on both LGG and HGG data. Our performance 1) outperforms Medical Diffusion-C in most of the quantitative metrics, and 2) exceeds or is on par with other baselines indicating our method can effectively learn and distinguish between two tumor types and significantly reduce the time needed to train separate models. More visualizations can be found in Appendix \ref{app:gen_img}.

\begin{figure}[htbp]
\floatconts
    {fig:qual_res}
    {\vspace{-1cm}\caption{Qualitative comparison between generated and real LGG and HGG ROIs. We show the center three consecutive slices in the Axial plane for each ROI sample. Zoom in for a better view.}}
    {\includegraphics[width=1\linewidth]{figures/qual_res_cond_new4.png}}
\end{figure}

\begin{table}
\floatconts
{tab:quan_res}
{\caption{Quantitative results of class-conditioned generation. Values in `$()$' are the absolute difference to the real MS-SSIM score (85.3 for LGG, 88.6 for HGG).}}
{\vspace{-0.5cm}
    \scalebox{0.75}{
        {
        \begin{tabular}{cccccc}
        \hline
        \multicolumn{1}{c|}{Method}                          & \multicolumn{1}{c|}{MMD (10$^4$) $\downarrow$}      & \multicolumn{1}{c|}{MS-SSIM (\%)}                         & \multicolumn{1}{c|}{FID-A $\downarrow$}             & \multicolumn{1}{c|}{FID-C $\downarrow$}             & FID-S $\downarrow$             \\ \hline
        \colorbox{lightgray}{LGG Results} & \multicolumn{1}{l}{}                                & \multicolumn{1}{l}{}                                      & \multicolumn{1}{l}{}                                & \multicolumn{1}{l}{}                                & \multicolumn{1}{l}{}           \\ \hline
        \multicolumn{1}{c|}{3D-WGAN-GP}                      & \multicolumn{1}{c|}{1.98}                           & \multicolumn{1}{c|}{93.4 (8.1)}                           & \multicolumn{1}{c|}{65.9}                           & \multicolumn{1}{c|}{55.4}                           & 45.5                           \\
        \multicolumn{1}{c|}{3D-$\alpha$WGAN}                 & \multicolumn{1}{c|}{1.61}                           & \multicolumn{1}{c|}{99.4 (14.1)}                          & \multicolumn{1}{c|}{79.3}                           & \multicolumn{1}{c|}{69.1}                           & 73.6                           \\
        \multicolumn{1}{c|}{3D-Med-DDPM}                     & \multicolumn{1}{c|}{1.83}                              & \multicolumn{1}{c|}{93.3 (8.0)}                                    & \multicolumn{1}{c|}{62.6}                              & \multicolumn{1}{c|}{46.5}                              & 43.3                              \\
        \multicolumn{1}{c|}{Medical Diffusion}               & \multicolumn{1}{c|}{1.78}                           & \multicolumn{1}{c|}{92.9 (7.6)}                           & \multicolumn{1}{c|}{31.6}                           & \multicolumn{1}{c|}{\textbf{30.5}} & 37.1                           \\
        \multicolumn{1}{c|}{Medical Diffusion-C}               & \multicolumn{1}{c|}{\textbf{1.72}}                           & \multicolumn{1}{c|}{89.8 (4.5)}                           & \multicolumn{1}{c|}{26.2}                           & \multicolumn{1}{c|}{32.6} & 37.2                           \\
        \multicolumn{1}{c|}{3D-VQGAN}                        & \multicolumn{1}{c|}{1.79}                           & \multicolumn{1}{c|}{92.7 (7.4)}                           & \multicolumn{1}{c|}{24.1}                           & \multicolumn{1}{c|}{31.4}                           & 36.1                           \\
        \multicolumn{1}{c|}{3D-VQGAN-cond}                   & \multicolumn{1}{c|}{\textbf{1.72}} & \multicolumn{1}{c|}{\textbf{87.9 (2.6)}} & \multicolumn{1}{c|}{\textbf{23.7}} & \multicolumn{1}{c|}{32.9}                           & \textbf{35.2} \\ \hline
        \colorbox{lightgray}{HGG Results} &                                                     &                                                           &                                                     &                                                     &                                \\ \hline
        \multicolumn{1}{c|}{3D-WGAN-GP}                      & \multicolumn{1}{c|}{2.44}                           & \multicolumn{1}{c|}{97.5 (8.9)}                                 & \multicolumn{1}{c|}{53.5}                           & \multicolumn{1}{c|}{50.2}                           & 50.6                           \\
        \multicolumn{1}{c|}{3D-$\alpha$WGAN}                 & \multicolumn{1}{c|}{2.34}                           & \multicolumn{1}{c|}{98.9 (10.3)}                                 & \multicolumn{1}{c|}{122.3}                          & \multicolumn{1}{c|}{145.7}                          & 153.1                          \\
        \multicolumn{1}{c|}{3D-Med-DDPM}                     & \multicolumn{1}{c|}{2.50}                              & \multicolumn{1}{c|}{95.9 (7.3)}                                    & \multicolumn{1}{c|}{84.6}                              & \multicolumn{1}{c|}{61.4}                              & 58.6                              \\
        \multicolumn{1}{c|}{Medical Diffusion}               & \multicolumn{1}{c|}{\textbf{1.41}}                           & \multicolumn{1}{c|}{89.4 (0.8)}                           & \multicolumn{1}{c|}{30.0}                           & \multicolumn{1}{c|}{26.3}                           & 23.2                          \\
        \multicolumn{1}{c|}{Medical Diffusion-C}               & \multicolumn{1}{c|}{1.44}                           & \multicolumn{1}{c|}{90.7 (2.1)}                           & \multicolumn{1}{c|}{32.3}                           & \multicolumn{1}{c|}{27.2}                           & \textbf{20.5}                           \\
        \multicolumn{1}{c|}{3D-VQGAN}                        & \multicolumn{1}{c|}{1.63}                              & \multicolumn{1}{c|}{90.6 (2.0)}                                    & \multicolumn{1}{c|}{32.1}                              & \multicolumn{1}{c|}{29.3}                              & 31.4                              \\
        \multicolumn{1}{c|}{3D-VQGAN-cond}                   & \multicolumn{1}{c|}{1.57}                           & \multicolumn{1}{c|}{\textbf{88.5 (0.1)}}                           & \multicolumn{1}{c|}{\textbf{29.1}}                           & \multicolumn{1}{c|}{\textbf{24.4}}                           & 26.3                           \\ \hline
        \end{tabular}
        }
    }
}
\end{table}

\vspace{-0.5cm}
\subsection{Classification Results} \label{class_res_brats}
% Our proposed method addresses the imbalanced training data problem by generating data for the minority class (LGG).
We trained a classification model to validate the efficacy of the proposed 3D-VQGAN-cond model in distinguishing between HGG and LGG brain tumor types. \tableref{tab:cls_res} shows the classification results for Experiments (1) and (2) from Section \ref{exp_det_brats}. For Experiment (1), we showed that the model pre-trained with both synthetic HGG and LGG data outperforms all baselines, including the model trained with traditional augmentations and trained only on synthetic LGG data. This highlights the substantial improvement in classification performance achieved through pre-training with synthetic data, alleviating the demand for extensive real data to train effective classification models. For Experiment (2), we demonstrated that the classification performance improves when the number of synthetic data used for pre-training increases. These results collectively indicate that synthetic ROI data can be effectively used to pre-train deep models and only a small amount is needed for fine-tuning. In addition, the improvements from both experiments are statistically significant (two-sided t-test $p < 0.05$) compared to the previous SOTA (3D-VQGAN), which further validate that both LGG and HGG samples generated by our proposed 3D-VQGAN-cond model have better image quality and fidelity compared to other baselines. More results can be found in Appendix \ref{app:cls_res}.

\begin{table}[htbp]
\floatconts
{tab:cls_res}
{\caption{Results for all experiments as described in Section \ref{exp_det_brats}. We run all for three trials and report as mean$\pm$standard deviation. Trad. Aug. is the short for traditional augmentations.}}
{\vspace{-0.5cm}
    \subtable[Experiment (1)][c]{
        \parbox{.45\linewidth}{
            \scalebox{0.72}{
                \begin{tabular}{c|c|c|c}
                \hline
                Method              & AUC                                          & F1-Score                                     & Accuracy                                     \\ \hline
                Trad. Aug. & 0.66$\pm$0.03                              & 0.63$\pm$0.03                              & 0.59$\pm$0.03                              \\ \hline
                3D-WGAN-GP               & 0.64$\pm$0.08                              & 0.62$\pm$0.05                              & 0.56$\pm$0.01                              \\
                3D-$\alpha$WGAN           & 0.70$\pm$0.09                              & 0.59$\pm$0.09                              & 0.58$\pm$0.06                              \\ \hline
                3D-Med-DDPM               & 0.69$\pm$0.03                              & 0.64$\pm$0.01                              & 0.61$\pm$0.06                              \\
                Medical Diffusion        & 0.71$\pm$0.09                              & 0.62$\pm$0.07                              & 0.59$\pm$0.02                              \\
                Medical Diffusion-C & 0.66$\pm$0.03 & 0.65$\pm$0.02 & 0.58$\pm$0.08 \\
                3D-VQGAN             & 0.72$\pm$0.03 & 0.67$\pm$0.02 & 0.65$\pm$0.04 \\ \hline
                Ours & \begin{tabular}[c]{@{}c@{}}\textbf{0.77$\pm$0.03}$^*$ \\ (\textbf{p=0.02})\end{tabular} & \begin{tabular}[c]{@{}c@{}}\textbf{0.71$\pm$0.02}$^*$ \\ (\textbf{p=0.03})\end{tabular} & \textbf{0.67$\pm$0.04}   \\ \hline
                \end{tabular}
            }
        }
    } \quad
    \subtable[Experiment (2)][c]{
        \parbox{.45\linewidth}{
            \scalebox{0.72}{
                \begin{tabular}{c|c|c|c}
                \hline
                Method                 & AUC                                          & F1-Score                                     & Accuracy                                     \\ \hline
                % Trad. Aug. & 0.657$\pm$0.028                              & 0.634$\pm$0.031                              & 0.593$\pm$0.034                              \\ \hline
                % 3D-WGAN-GP               & 0.643$\pm$0.081                              & 0.623$\pm$0.052                              & 0.562$\pm$0.014                              \\
                3D-$\alpha$WGAN           &  0.71$\pm$0.04        & 0.65$\pm$0.02                             & 0.64$\pm$0.06                             \\ \hline
                3D-Med-DDPM  & 0.71$\pm$0.09 & 0.69$\pm$0.02 & 0.65$\pm$0.03     \\
                Medical Diffusion         & 0.73$\pm$0.04                 & 0.67$\pm$0.02           
                & 0.61$\pm$0.02                             \\
                Medical Diffusion-C & 0.75$\pm$0.03 & 0.68$\pm$0.02 & 0.66$\pm$0.03 \\
                3D-VQGAN           & 0.78$\pm$0.04 & 0.71$\pm$0.06 & 0.70$\pm$0.06 \\ \hline
                Ours & \begin{tabular}[c]{@{}c@{}}\textbf{0.80$\pm$0.02}$^*$ \\ (\textbf{p=0.04})\end{tabular} & \begin{tabular}[c]{@{}c@{}}\textbf{0.74$\pm$0.02}$^*$ \\ (\textbf{p=0.03})\end{tabular} & \textbf{0.70$\pm$0.01}   \\ \hline
            \end{tabular}
            }
        }
    }
}
\end{table}

\section{Conclusions}

We propose the first class-conditional generation framework for LGG and HGG brain tumor types based on VQGAN and masked Transformer. The conditional scheme enables generating different types of tumors in a unified framework, rather than in separate models that require a large amount of time and resources to train. Our proposed method performs better or on par with several baseline models in image quality metrics such as MS-SSIM, slice-wise FID, and MMD. Using the generated data, our method yields the best classification performance compared to all other baselines.  

\midlacknowledgments{This research has been made possible with the financial 
support of the Canadian Institutes of Health Research (CIHR) 
(Funding Reference Number: 184015). The authors would like to thank Landy Xu for her suggestions on the figures.}


\bibliography{midl24_116}


\appendix

\section{Training Details} \label{app:train_det}

\subsection{Implementation Details of the Scoring Network}
The scoring network $f$ is learned using an auxiliary MLP, $f_{cls}$, through the classifier-guidance approach we described in Section \ref{model_arc}. To learn $f$, we first use the encoder $E$ to get the latent feature map $z_e \in \mathbb{R}^{H\times W \times Dp \times n_z}$, then for each region $r_i \in \mathbb{R}^{n_z}$, its score is defined by $s_i = f(r_i)$, where $i = 1,...,H \times W \times Dp$. Next, we normalize the encoded feature map, $z_{e}^{norm} = LayerNorm(z_e)$ and further multiply with the predicted importance score as modulating factors, $z_{e}^{'} = z_{e}^{norm} * \textbf{S}$ where $\textbf{S}$ is the set of importance scores ($\{s_i\}, i = 1,...,H \times W \times Dp $) obtained from $f$, then $f_{cls}$ takes $z_{e}^{'}$ as the input and output class probabilities.

\subsection{Model Training}
All programs were implemented in Pytorch, and all models were trained on a single TESLA V100 GPU. Additionally, we applied the automatic mixed precision in the PyTorch library during the training process \cite{subramaniam2022generating}. The overall training time for our 3D-VQGAN-cond model takes about 7 GPU days to complete. For our transformer architecture, we used the same one in \cite{esser2021taming}. Our code is available at \href{https://github.com/IMICSLab/Brain_VQGAN_TATrans}{here}.

For classification, we pre-trained for 50 epochs with a batch size of 8 and a learning rate of 0.001 with Adam optimizer for all models using synthetic data. These models are then finetuned using a batch size of 10, and a learning rate of 0.01. For models that do not involve the use of synthetic data (i.e., traditional augmentation), we trained those with a batch size of 8 and a learning rate of 0.01 with Adam optimizer. All classification models are optimized by focal loss \cite{lin2017focal}, as we noticed that it is better than the standard cross entropy loss.

\textbf{Data Split. }As discussed in the main text, we randomly hold out 25 HGG and LGG patients (50 in total) as the standalone test data. For the rest of the 234 HGG patients and 51 LGG patients, we design two sets of training data combinations for traditional and non-traditional augmentation methods. For the traditional augmentation baseline, we augment the LGG data by rotating 30 degrees, scaling by 1.5 times larger, left-right flipping, and elastic deformation to form a balanced dataset of 234 cases for both HGG and LGG. For all other non-traditional augmentation models, we pre-trained on 183 real HGGs and 183 synthetic LGGs, and then fine-tuned with 51 real HGGs and 51 real LGGs for \textbf{Experiment (1)} in Section \ref{exp_det_brats} follows the setup in \cite{zhou2023generating}; we pre-trained on 250 synthetic HGGs and LGGs, and then fine-tuned with 51 real HGGs and 51 real LGGs for \textbf{Experiment (2)}. We ensure that the data for fine-tuning is the same across experiments for a fair comparison. Furthermore, during the fine-tuning stage in both experiments, we used 85\% of the data for optimizing the model and the remaining 15\% of the data for validation, and we repeated this process three times to ensure the robustness of our model. We also ensure that there is no overlap between the validation data in the three runs.

% \noindent{\textbf{Generating Synthetic MRI Data. }}To generate 3D tumor ROIs, we start with a class token, either 0 for LGG or 1 for HGG, and then have the transformer model predict and complete the rest of the indices. Next, we obtain the embedding vectors of each index from the codebook $C$ and then feed them into decoder $D$ to produce the final images. For other baselines, we follow the exact procedure as stated in their paper.

% {\color{red} pretrain with 183 fake hgg and lgg: f1: 0.706 pm 0.021, acc: 0.673 pm 0.037, auc: 0.767 pm 0.035, precision: 0.649 pm 0.051, recall: 0.786 pm 0.075}

\section{Ablation Study} \label{app:abl}

\subsection{Ablation on top-k ratio}
Recall that the top-$k$ ratio acts as the ratio of important tokens of all discrete semantic tokens obtained from encoded feature maps (e.g., 512 tokens in this work). We argue that the ratio controls the \textbf{trade-off between the diversity of generated images and their quality}, i.e., how close is the synthetic to real distribution. If the top-$k$ ratio is small, i.e., $k = 0$, then we treat every token as the unimportant token, and the proposed masking strategy will \textit{degrade to random masking}. Random masking poses an issue of information leakage, impeding the transformer model's ability to effectively learn relationships between regions, especially those pertinent to tumors. Therefore, generated images may lack diversity in critical regions, but the overall distribution may be close to the real one. Hence, the ablation study is conducted on altering the value of $k$ for our proposed 3D-VQGAN-cond model, we compared $k = 0\%$, $k = 25\%$, and $k = 50\%$ and computed the MS-SSIM score to evaluate the diversity and MMD score to evaluate the distance between distributions, as shown in \tableref{apptab:abl}. We fix the overall masking ratio $\alpha = 0.5$ for all ablations. When $k = 0\%$ (random masking), the MMD score is low but it has a slightly higher MS-SSIM score. For $k = 25\%$, our method exhibits a slightly higher MMD score compared with $k = 0\%$ but has a significantly lower MS-SSIM score (note that the MS-SSIM score is computed over 1000 randomly selected pairs). When $k = 50\%$, we observed that there is a dramatic performance degradation in the MMD score and the MS-SSIM score seems to have minimal change compared to $k = 25\%$. To balance this trade-off, we select $k = 25\%$ in our study. We believe block-wise masking on important tokens helps the transformer model better learn the relationship between other tokens, increasing the possibility of generating diverse images based on important regions.

\begin{table}[htbp]
\floatconts
{apptab:abl}
{\caption{Ablation study on the \textit{top-k} ratio used in our masked transformer model. Values in `$()$' are the absolute difference to the real MS-SSIM score (85.3 for LGG, 88.6 for HGG).}}
{
    \begin{tabular}{c|cc|cc}
             & \multicolumn{2}{c|}{LGG}        & \multicolumn{2}{c}{HGG}         \\ \hline
    \textit{top-k} ratio         & MMD ($10^4$) $\downarrow$ & MS-SSIM & MMD($10^4$) $\downarrow$ & MS-SSIM \\ \hline
    0\%    & 1.67                  & 89.4 (4.1)   & 1.46                  & 89.7 (1.1)       \\ \hline
    25\% & 1.72                  & 87.9 (2.6)   & 1.57                  & 88.5 (0.1)    \\ \hline 
    50\% & 2.14                  & 87.3 (2.0)  & 1.80           & 89.2 (0.6) \\
    \hline
    \end{tabular}
}
\end{table}
%x        & x                     & x       & x                     & x       \\ \hline

\subsection{Ablation on number of neighbor tokens}

We have also conducted ablations on the number of neighbor tokens to be masked around important tokens, as described in Section \ref{model_arc}. In our work, we selected only one neighborhood token (we denote as single-side) in either spatial or temporal dimension to be masked, given that our latent feature maps are relatively compact ($8 \times 8 \times 8$), and masking more tokens around the important one may lose too much feature information for the model to be learned and reconstructed effectively during training. In the ablation, we fix the mask ratio $\alpha = 0.5$, top-k ratio $k = 25\%$, and mask \textit{two neighborhood tokens}, i.e., for a given important token, we mask two more tokens to its left and right in the spatial dimension, or successor and predecessor tokens in the temporal dimension (we denote as double-sided). The quantitative results are reported in \tableref{apptab:abl_nei}, we observed that masking more tokens led to performance degradation, which validates our claim above.

\begin{table}[htbp]
\floatconts
{apptab:abl_nei}
{\caption{Ablation study on the number of neighbor tokens in our masked transformer model. Values in `$()$' are the absolute difference to the real MS-SSIM score (85.3 for LGG, 88.6 for HGG).}}
{
    \begin{tabular}{c|cc|cc}
             & \multicolumn{2}{c|}{LGG}        & \multicolumn{2}{c}{HGG}         \\ \hline
             & MMD ($10^4$) $\downarrow$ & MS-SSIM & MMD($10^4$) $\downarrow$ & MS-SSIM \\ \hline
    single-side    & 1.72                  & 87.9 (2.6)   & 1.57                  & 88.5 (0.1)       \\ \hline
    double-sided & 2.07                  & 88.9 (3.6)   & 1.91                  & 89.9 (1.3)    \\ \hline 
    \end{tabular}
}
\end{table}


\section{More on Generated Images} \label{app:gen_img}

In this section, we provide more visualizations of the generated LGG and HGG from our proposed method and other baseline methods. The additional visualization of generated LGG samples from all methods is shown in \figureref{appfig:lgg}; additional HGG samples visualization is shown in \figureref{appfig:hgg}.

We observe that all GAN-based baselines produce images with noise and blurry edges, and the image quality is low. For the 3D-Med-DDPM, the intensity range in the generated samples seems to mismatch the real samples, and it looks unreal compared with the real ROIs. For Medical Diffusion/-C, the generated images suffer from minor checkerboard artifacts (visible when zoomed in). 3D-VQGAN sometimes produces blurry images (sample 2 in HGG), but overall, the generated images are smooth and do not have any checkerboard artifacts or noises. Finally, for our method, the images exhibit high-resolution with no noise, no blurry edges, and no checkerboard artifacts, the contrast inside the generated ROIs looks very similar to the real ROIs.

\begin{figure}[htbp]
	% \centering
\floatconts
    {appfig:lgg}
    {\caption{Additional generated samples of LGG data. We show the center three consecutive slices in the Axial plane for each ROI sample. Zoom in for a better view.}}
    {\includegraphics[width=1\linewidth]{figures/app_lgg_sample_large2.png}}
\end{figure}


\begin{figure}[htbp]
	% \centering
\floatconts
    {appfig:hgg}
    {\caption{Additional generated samples of HGG data. We show the center three consecutive slices in the Axial plane for each ROI sample. Zoom in for a better view.}}
    {\includegraphics[width=1\linewidth]{figures/app_hgg_sample_large2.png}}
\end{figure}

We also provide a visualization of the importance score map learned by the proposed scoring network $f$ in Section \ref{model_arc}. Visualizing the importance map provides insight into how the model has learned each region in the latent feature map, as depicted in \figureref{appfig:is}. Lighter regions indicate higher importance. The score map is obtained by interpolating the original importance score in the latent space ($8 \times 8 \times 8$) to the original image size and overlay with the original image. We can observe that our model can effectively identify important regions within the tumor, and provide a reliable reference for the masked transformer model. 

\begin{figure}[htbp]
	% \centering
\floatconts
    {appfig:is}
    {\caption{Visualization of importance score map. We show the center three consecutive slices in the Axial plane and its corresponding importance map for two randomly selected LGG and HGG samples.}}
    {\includegraphics[width=1\linewidth]{figures/is_vis.png}}
\end{figure}

\section{More on Classification Results} \label{app:cls_res}

\subsection{Full Results}

In this section, we provide more details on classification performance, this includes the additional report of precision and recall scores for all experiments we performed. Results for \textbf{Experiment (1)} is shown in \tableref{apptab:cls_res} and for \textbf{Experiment (2)} is shown in \tableref{apptab:cls_res_exp2}.

\begin{table}[htbp]
\floatconts
{apptab:cls_res}
{\caption{Detailed classification results for Experiment (1). We run all for three trials and report as mean$\pm$standard deviation. \textbf{Bold} values represent the best results, and \underline{Underline} values represent the second-best results.}}
{\begin{tabular}{c|c|c|c|c|c}
           & AUC         & F1-Score & Accuracy & Precision & Recall \\ \hline
% Reference  & $(a)$ & \XSolidBrush       & \XSolidBrush        & 0.641$\pm$0.110 & 0.529$\pm$0.223 & 0.566$\pm$0.010 & 0.533$\pm$0.094 & 0.603$\pm$0.202   \\
Trad. Aug.    & 0.66$\pm$0.03 & 0.63$\pm$0.03 & 0.59$\pm$0.03 & 0.59$\pm$0.07 & 0.72$\pm$0.14 \\ \hline
3D-WGAN-GP  & 0.64$\pm$0.08            & 0.62$\pm$0.05 & 0.56$\pm$0.01 & 0.61$\pm$0.05 & 0.65$\pm$0.24     \\
%3D-$\alpha$GAN & Yes & $(c)$ & Yes & 0.648$\pm$0.064 & x \\
3D-$\alpha$WGAN & 0.70$\pm$0.09 & 0.59$\pm$0.09 & 0.58$\pm$0.06 & 0.63$\pm$0.12 & 0.67$\pm$0.28 \\ \hline
3D-Med-DDPM & 0.70$\pm$0.03 & 0.64$\pm$0.01 & 0.61$\pm$0.06 & 0.62$\pm$0.09 & 0.71$\pm$0.13   \\
Medical Diffusion & 0.71$\pm$0.09 & 0.62$\pm$0.08 & 0.59$\pm$0.02 & 0.59$\pm$0.04 & 0.71$\pm$0.20 \\
Medical Diffusion-C & 0.66$\pm$0.03 & 0.65$\pm$0.02 & 0.58$\pm$0.08 & 0.52$\pm$0.02 & \underline{0.76$\pm$0.11} \\ \hline
% 3D-VQGAN-lat4  & 0.683$\pm$0.066 & 0.602$\pm$0.065 & 0.600$\pm$0.038 & 0.610$\pm$0.087 & 0.680$\pm$0.246    \\
3D-VQGAN & \underline{0.72$\pm$0.03}            & \underline{0.67$\pm$0.02} & \underline{0.65$\pm$0.04} & \underline{0.64$\pm$0.08} & 0.73$\pm$0.14 \\
% 3D-VQGAN-cond & \textbf{0.772$\pm$0.066}            & \textbf{0.691$\pm$0.027} & \textbf{0.693$\pm$0.009} & \textbf{0.705$\pm$0.054} & 0.693$\pm$0.098
3D-VQGAN-cond (Ours) & \textbf{0.77$\pm$0.03}   & \textbf{0.71$\pm$0.02} & \textbf{0.67$\pm$0.04} & \textbf{0.65$\pm$0.05} & \textbf{0.79$\pm$0.07}
\\ \hline
\end{tabular}}
\end{table}


\begin{table}[htbp]
\floatconts
{apptab:cls_res_exp2}
{\caption{Detailed classification results for Experiment (2). We run all for three trials and report as mean$\pm$standard deviation. \textbf{Bold} values represent the best results, and \underline{Underline} values represent the second-best results.}}
{\begin{tabular}{c|c|c|c|c|c}
           & AUC         & F1-Score & Accuracy & Precision & Recall \\ \hline
% Reference  & $(a)$ & \XSolidBrush       & \XSolidBrush        & 0.641$\pm$0.110 & 0.529$\pm$0.223 & 0.566$\pm$0.010 & 0.533$\pm$0.094 & 0.603$\pm$0.202   \\
%3D-$\alpha$GAN & Yes & $(c)$ & Yes & 0.648$\pm$0.064 & x \\
3D-$\alpha$WGAN & 0.71$\pm$0.04 & 0.65$\pm$0.02 & 0.64$\pm$0.06 & 0.65$\pm$0.08 & 0.68$\pm$0.10 \\ \hline
3D-Med-DDPM & 0.71$\pm$0.09 & 0.69$\pm$0.02 & 0.65$\pm$0.03 & 0.62$\pm$0.04 & 0.77$\pm$0.04   \\
Medical Diffusion & 0.73$\pm$0.04 & 0.67$\pm$0.02 & 0.61$\pm$0.02 & 0.55$\pm$0.04 & \underline{0.84$\pm$0.11} \\
Medical Diffusion-C & 0.75$\pm$0.03 & 0.68$\pm$0.02 & 0.66$\pm$0.03 & 0.64$\pm$0.04 & 0.73$\pm$0.04 \\ \hline
% 3D-VQGAN-lat4  & 0.683$\pm$0.066 & 0.602$\pm$0.065 & 0.600$\pm$0.038 & 0.610$\pm$0.087 & 0.680$\pm$0.246    \\
3D-VQGAN & \underline{0.78$\pm$0.04} & \underline{0.71$\pm$0.06} & \underline{0.70$\pm$0.06} & \textbf{0.69$\pm$0.06} & 0.72$\pm$0.07 \\
% 3D-VQGAN-cond & \textbf{0.772$\pm$0.066}            & \textbf{0.691$\pm$0.027} & \textbf{0.693$\pm$0.009} & \textbf{0.705$\pm$0.054} & 0.693$\pm$0.098
3D-VQGAN-cond (Ours) & \textbf{0.80$\pm$0.02}   & \textbf{0.74$\pm$0.02}    & \textbf{0.70$\pm$0.01} & \underline{0.66$\pm$0.05} & \textbf{0.87$\pm$0.13}
\\ \hline
\end{tabular}}
\end{table}

\subsection{Comparison with Transfer Learning}

We also compare our proposed method with the traditional transfer learning approach. We used the MedicalNet \cite{chen2019med3d} pre-trained weights in this work. The model was originally designed for 3D medical image segmentation. Therefore, we adapted the model for our classification task by replacing its segmentation head with a classification head. This new head, implemented as a two-layer MLP, takes the latent representations as input and produces class logits as output for the two tumor types. Due to our computational limitations, we opted for the ResNet-34 backbone instead. During transfer learning, we froze the feature extractor and only trained the new classification head. The results are included in \tableref{apptab:cls_res_tl}. Our method outperforms MedicalNet by a significant margin in all metrics, demonstrating the effectiveness of our proposed approach. Although there is a difference with the model backbone we used, we hypothesize that this change would not dramatically alter the results. This hypothesis is based on the datasets used to pre-train the model were whole 3D volumes of various organs, including but not limited to the brain. Moreover, the substantial difference between whole 3D volumes and Regions of Interest (ROIs) can also affect the results.

\begin{table}[htbp]
\floatconts
{apptab:cls_res_tl}
{\caption{Comparison between our method's best performance and transfer learning approach. \textbf{Bold} values represent the best results. $^*$: results computed using ResNet-34 backbone instead of ResNet-50.}}
{\begin{tabular}{c|c|c|c|c|c}
           & AUC         & F1-Score & Accuracy & Precision & Recall \\ \hline
% Reference  & $(a)$ & \XSolidBrush       & \XSolidBrush        & 0.641$\pm$0.110 & 0.529$\pm$0.223 & 0.566$\pm$0.010 & 0.533$\pm$0.094 & 0.603$\pm$0.202   \\
%3D-$\alpha$GAN & Yes & $(c)$ & Yes & 0.648$\pm$0.064 & x \\
MedicalNet$^*$ & 0.61$\pm$0.07 & 0.55$\pm$0.03 & 0.55$\pm$0.05 & 0.58$\pm$0.04 & 0.63$\pm$0.05 \\ \hline
3D-VQGAN-cond (Ours) & \textbf{0.80$\pm$0.02}   & \textbf{0.74$\pm$0.02}    & \textbf{0.70$\pm$0.01} & \textbf{0.66$\pm$0.05} & \textbf{0.87$\pm$0.13}
\\ \hline
\end{tabular}}
\end{table}

\end{document}
