\documentclass{midl} % Include author names

% The following packages will be automatically loaded:
% jmlr, amsmath, amssymb, natbib, graphicx, url, algorithm2e
% ifoddpage, relsize and probably more
% make sure they are installed with your latex distribution

\usepackage{booktabs}
\usepackage{amsmath}

\jmlrvolume{-- 208}
\jmlryear{2024}
\jmlrworkshop{Full Paper -- MIDL 2024}
\editors{Accepted for publication at MIDL 2024}

%\title[Predicting the risk of AMD Progression]{Predicting the individual risk of AMD progression from retinal OCT with intra-subject temporal consistency}
\title[Predicting the Risk of AMD Progression]{Predicting Age-related Macular Degeneration Progression from Retinal Optical Coherence Tomography with Intra-Subject Temporal Consistency}

%\title[Predicting the risk of progression to late AMD]{Predicting the individual risk of AMD progression from retinal OCT with  intra-subject temporal consistency}

 % Use \Name{Author Name} to specify the name.
 % If the surname contains spaces, enclose the surname
 % in braces, e.g. \Name{John {Smith Jones}} similarly
 % if the name has a "von" part, e.g \Name{Jane {de Winter}}.
 % If the first letter in the forenames is a diacritic
 % enclose the diacritic in braces, e.g. \Name{{\'E}louise Smith}

 % Two authors with the same address
 % \midlauthor{\Name{Author Name1} \Email{abc@sample.edu}\and
 %  \Name{Author Name2} \Email{xyz@sample.edu}\\
 %  \addr Address}

 % Three or more authors with the same address:
 % \midlauthor{\Name{Author Name1} \Email{an1@sample.edu}\\
 %  \Name{Author Name2} \Email{an2@sample.edu}\\
 %  \Name{Author Name3} \Email{an3@sample.edu}\\
 %  \addr Address}


% Authors with different addresses:
% \midlauthor{\Name{Author Name1} \Email{abc@sample.edu}\\
% \addr Address 1
% \AND
% \Name{Author Name2} \Email{xyz@sample.edu}\\
% \addr Address 2
% }

%\footnotetext[1]{Contributed equally}

% More complicate cases, e.g. with dual affiliations and joint authorship
\midlauthor{\Name{Arunava Chakravarty\nametag{$^{1}$}} \Email{arunava.chakravarty@meduniwien.ac.at}\\
\Name{Taha Emre\nametag{$^{1}$}} \Email{taha.emre@meduniwien.ac.at}\\
\Name{Dmitrii Lachinov\nametag{$^{1,2}$}} \Email{dmitrii.lachinov@meduniwien.ac.at}\\
\Name{Antoine Rivail\nametag{$^{1,2}$}} \Email{antoine.rivail@meduniwien.ac.at}\\
\Name{Ursula Schmidt-Erfurth\nametag{$^{1}$}} \Email{ursula.schmidt-erfurth@meduniwien.ac.at}\\
\Name{Hrvoje Bogunović\nametag{$^{1,2}$}} \Email{hrvoje.bogunovic@meduniwien.ac.at}\\
\addr $^{1}$ OPTIMA Lab, Department of Ophthalmology, Medical University of Vienna, Austria \\
\addr $^{2}$ Christian Doppler Lab for Artificial Intelligence in Retina, Medical University of Vienna, Austria \\
}


\usepackage{amsmath,amssymb,amsfonts}


\def\HB#1{\textcolor{orange}{#1}} % inserted text
\def\AR#1{\textcolor{cyan}{#1}} % inserted text
\def\AC#1{\textcolor{red}{#1}} % inserted text
\def\TE#1{\textcolor{magenta}{#1}} % inserted text
\def\DLc#1{\textcolor{violet}{\footnotesize\texttt{\textbf{[DL: #1]}}}} % comment
\def\HBc#1{\textcolor{orange}{\footnotesize\texttt{\textbf{[HB: #1]}}}} % comment
\def\ARc#1{\textcolor{cyan}{\footnotesize\texttt{\textbf{[RIVAIL: #1]}}}} % comment
\def\ACc#1{\textcolor{red}{\footnotesize\texttt{\textbf{[AC: #1]}}}} % comment
\def\TEc#1{\textcolor{magenta}{\footnotesize\texttt{\textbf{[TE: #1]}}}} % comment
\def\HBd#1{\textcolor{orange}{{\sout{#1}}}} % deleted
\def\ACd#1{\textcolor{gray}{{\sout{#1}}}} % deleted
\def\ARd#1{\textcolor{cyan}{{\sout{#1}}}} % deleted
\def\TEd#1{\textcolor{magenta}{{\sout{#1}}}} % deleted text

%\begin{itemize}
%\item Eprints such as arXiv papers can of course be cited \cite{Hinton:arXiv:2015:Distilling}. We recomend using a \verb|@misc| bibtex entry 
%\item Note that the JMLR template provides many handy functionalities
%such as \verb|\figureref| to refer to a figure,
%e.g. \figureref{fig:example},  \verb|\tableref| to refer to a table,
%e.g. \tableref{tab:example} and \verb|\equationref| to refer to an equation,
%e.g. \equationref{eq:example}.
%\end{itemize}

%\begin{table}[htbp]
 % The first argument is the label.
 % The caption goes in the second argument, and the table contents  go in the third argument.



\begin{document}

\maketitle

\begin{abstract}
%The prediction of an individual’s risk for progression from intermediate to the late dry stage of AMD poses a significant challenge due to the wide variability in disease progression rates and the absence of well-established clinical biomarkers. In this work, we present an automated approach to address this challenge. An AMD stage classifier to discriminate between the intermediate and late dry stage of AMD is jointly trained with a neural-ODE which models the future trajectory of the disease progression in the learned feature embedding. A temporal ordering is imposed such that the distance of an OCT scan from the decision hyperplane of the AMD stage classifier is inversely related to its time-to-conversion. Additionally, an intra-subject temporal consistency in the predicted conversion risk scores is ensured by incorporating a pair of longitudinal scans from the same eye during training. We evaluated our proposed method on a dataset comprising 240 eyes (200 non-converters and 40 converters). The results demonstrate the effectiveness of our approach, achieving an average dynamic area under the receiver operating characteristic curve (AUC) of 0.81 for predicting conversion within the next 2 years. Additionally, the Concordance Index of 0.765 surpasses the performance of several popular methods for survival analysis. Our automated method holds promise in enhancing patient-specific disease management and the recruitment in clinical trial populations by identifying patients who are at a higher risk of conversion to late dry AMD.

%The prediction of an individual’s risk for progression from intermediate (iAMD) to the late dry stage of AMD (dAMD) poses a significant challenge due to the wide variability in disease progression rates and the absence of well-established clinical biomarkers. 
%The prediction of an individual’s risk for progression of Age-Related Macular Degeneration(AMD) from intermediate (iAMD) to the late dry stage of AMD (dAMD) poses a significant challenge due to the wide variability in disease progression rates and the absence of well-established clinical biomarkers.
The wide variability in the progression rates of Age-Related Macular Degeneration (AMD) and the absence of well-established clinical biomarkers make it difficult to predict an individual's risk of AMD progression from intermediate stage (iAMD) to late dry stage (dAMD) using Optical Coherence Tomography (OCT) scans.
To address this challenge, we propose to jointly train an AMD stage classifier to discriminate between iAMD and dAMD with a N-ODE that models the future trajectory of the disease progression in the learned embedding space. A temporal ordering is imposed such that the distance of a scan from the decision hyperplane of the AMD stage classifier is inversely related to its time-to-conversion. In addition, an intra-subject temporal consistency in the predicted conversion risk scores is ensured by incorporating a pair of longitudinal scans from the same eye during training. We evaluated our proposed method on a longitudinal dataset comprising 235 eyes (3,534 OCT scans) with 40 converters. The results demonstrate the effectiveness of our approach, achieving an average area under the ROC of 0.84 for predicting conversion within the next 6, 12, 18 and 24 months. Additionally, the Concordance Index of 0.78 surpasses the performance of several popular methods for survival analysis. %Our automated method holds promise in enhancing patient-specific disease management and the recruitment in clinical trial populations by identifying patients who are at a higher risk of conversion to late dry AMD.
\end{abstract}

\begin{keywords}
Survival Analysis, AMD, OCT, Longitudinal disease progression, Retina
\end{keywords}

\section{Introduction}
Age-related macular degeneration (AMD) is a leading cause of blindness among the elderly population \cite{wong2014global}. It is asymptomatic in its early and intermediate stages (iAMD), characterized by the presence of drusen. AMD gradually advances to the late stage leading to irreversible vision loss which could be categorized as  either neovascular (nAMD) or dry (dAMD). nAMD is caused by abnormal blood vessel growth in the choroid that leaks fluid into the retina. dAMD is more prevalent than nAMD and characterized by Geographic Atrophy (GA) due to the loss of Retinal Pigment Epithelium (RPE). Recently, for the first time, drugs for dAMD~\cite{Khanani2023,Heier2023} were approved by FDA. Patients in the iAMD stage are regularly monitored with longitudinal Optical Coherence Tomography (OCT) imaging across multiple visits to initiate treatment at the earliest onset of late AMD to minimize vision loss. Identifying iAMD patients at a high risk of dAMD conversion enables ophthalmologists to prioritize these cases for enhanced monitoring, facilitating early detection of dAMD onset. However, this is a challenging task due to the absence of well-established clinical biomarkers and significant inter-subject variations in the rate of AMD progression. Deep learning (DL) methods to predict the future risk of conversion of an eye from iAMD to dAMD can play a critical clinical role in supporting personalized treatments and clinical research by categorizing iAMD patients into distinct risk levels for biomarker identification and recruitment in clinical trials.

\textbf{Related Work:} Existing methods for predicting the risk of conversion from iAMD to nAMD or dAMD fall into two main categories: biomarker and image-based approaches. \emph{Biomarker-based} methods \cite{sleiman2017optical, schmidt2018prediction, banerjee2020prediction, de2014quantitative, lad2022machine} involve segmenting retinal tissues and pathologies to extract features, subsequently combined with clinical and demographic data for risk prediction. Notably, \citet{banerjee2020prediction} utilizes multiple past visit biomarkers in an LSTM network for future risk assessment. \emph{Image-based} methods, however, directly utilize DL models on raw OCT scans, bypassing manual segmentation. A hybrid approach using both biomarker and image features for predicting nAMD conversion is presented in \citet{yim2020predicting}, employing an ensemble DL model. Unlabeled longitudinal OCT datasets have been used in \citet{emre2022tinc, rivail2019modeling} for feature learning via temporal self-supervised learning. These methods typically employ a binary classifier for predicting conversion within specific timeframes, (e.g., 2 years \cite{russakoff2019deep}, 6 months \cite{yim2020predicting, emre2022tinc}), or multi-label classification for various discrete time-intervals (e.g., 6, 12, and 18 months \cite{rivail2019modeling}).

The binary classification based approaches are limited by discretization of the conversion time and their inability to manage censoring, which occurs when an eye's actual conversion time is unknown due to missing follow-ups or non-conversion within a limited study duration. Survival analysis addresses these challenges. Discrete survival models are similar to multi-label classification but modify training loss to incorporate censoring and have recently been applied to predict dAMD conversion \cite{rivail2023deep}. A transformer model has also been used for discrete-time modeling of the hazard function from tabular clinical and demographic data \cite{hu2021transformer}. Traditional non-DL continuous models of survival analysis have also been explored to capture AMD progression with handcrafted biomarkers using the linear Cox Proportional Hazard model (CoxPH) \cite{schmidt2018prediction}. Although CoxPH has been extended with DL using images \cite{katzman2018deepsurv}, they have not yet been explored to model AMD progression so far. Moreover, these models are inflexible as each patient's hazard function is constrained to be a scaled version of the same baseline hazard across the entire population. SODEN \cite{tang2022soden} overcomes this issue by employing a N-ODE to model the cumulative hazard function for survival on tabular data. The GRU-ODE-Bayes \cite{de2019gru} proposed a N-ODE to extend the GRU based Recurrent Neural Network in continuous time, used in predicting disability progression in Multiple Sclerosis patients from tabular data of past history. %Predicting the individual risk of the progression to late AMD with intra-subject temporal consistency \cite{de2021longitudinal}. 
Recently, N-ODEs have also been used to model the spatial evolution of GA segmentation in OCT \cite{lachinov2023learning} and Diabetic retinopathy in fundus images \cite{zeghlache2023lmt}. 

\textbf{Contributions:} Our key contributions are: 
(i)  The time-to-conversion from iAMD to nAMD is modeled in continuous time, rather than discrete time-intervals as used in most existing methods. Our model can therefore use actual continuous conversion times as ground-truths during training and also predict conversion probabilities within arbitrary continuous times. %The time-to-conversion from iAMD to nAMD is modeled in continuous time, rather than discrete time-intervals as used in most existing methods for OCT-based retinal disease progression. Our model can therefore use actual continuous conversion times as labels during training and also predict conversion probabilities within arbitrary continuous times.
(ii) Our novel N-ODE based modeling directly models the Cumulative Distribution Function(CDF) of the future conversion time instead of the cumulative hazard function used in existing methods like SODEN. Our SMGRU-ODE architecture also extends ODE-GRU by stacking multiple layers with multiple parallel heads. %multiple layers, with multiple parallel heads in each layer.
(iii) We incorporate intra-subject consistency by requiring the N-ODE estimates of the feature and risk at future time-points to be consistent with the values obtained using the actual OCT scan of the future visit. 
(iv) We jointly train a linear AMD stage classifier and employ a rank loss on its logits which is sensitive to censoring, to regularize the feature embedding. This facilitates patient stratification into risk groups based on a scalar risk score derived from the decision hyperplane distance for clinical studies or personalized treatment. %Based on the distance from its decision hyperplane, a scalar risk score can be calculated to stratify patients into low, medium, and high-risk groups for clinical studies or personalized treatment.












\section{Method}

\begin{figure}[!htb]
 \centering
  \includegraphics[width=.8\textwidth]{method.pdf}
%\caption{Training: The same ConvNeXt-Tiny encoder and linear AMD stage classifier models are used in both branches of the Siamese architecture with shared weights. Only the top branch (highlighted in green) is used for inference.}
\caption{Siamese architecture uses shared weights for the  ConvNeXt-Tiny encoder and linear AMD stage classifier in both branches. Top branch predicts current AMD stage at time-point $t=j$ and evolves features with GRU-ODE for future $t=k$. Bottom branch computes features and stage predictions for $t=k$ directly from the future scan $\vec{I}_k$. Losses $\mathcal{L}_{cns-ftr}$ and $\mathcal{L}_{cns-rnk}$ ensure consistency in the feature and risk predictions between branches. $\mathcal{L}_{rnk}$ loss ranks the logit $r$ in inverse order of conversion time. Only the top branch (shaded in green) is used for inference.}
\label{fig:pipeline} 
 \end{figure}





%%%%%%%%%%%%%%% Problem Setting %%%%%%%%%%%%%%%%%%%%%
Given an input OCT scan of an iAMD patient, the proposed method projects it to a feature embedding where the current feature is evolved over time with a N-ODE to forecast the future trajectory of the disease progression. The features estimated for any (continuous) future time-point are fed to a  linear AMD stage classifier to predict the probability of the eye to have already converted within that time, thereby modeling the CDF of the future conversion time. 
In Survival analysis, the Ground Truth (GT) label for an OCT image $\vec{I}_j$ is defined by the tuple $(E_j, T_j)$. %at acquisition time $j$. 
The event indicator $E_j=1$ signifies that the eye associated with scan $\vec{I}_j$ will progress from iAMD to dAMD, while $E_j=0$ denotes no conversion within the monitoring period. $T_j$ denotes the time of conversion from the current visit if  $E_j=1$ or the censoring time until when the patient was last monitored. Our approach is trained on batches comprising random image pairs $\vec{I}_j$, $\vec{I}_k$ (\figureref{fig:pipeline}) of the same eye captured at different time points $t_j, t_k \in \mathbb{R}_{\ge 0}$ from two visits, such that $\vec{I}_j$ precedes $\vec{I}_{k}$ with $t_j<t_k$. %\TEc{if j and k time stamps, what are $t_j t_k$?}

%%%%%%%%%%%%%%% Stage Classifier %%%%%%%%%%%%%%%%%%%%%
\textbf{AMD Stage Classifier:} Both $\vec{I}_j$, $\vec{I}_k$ are input to the same  ConvNeXt-Tiny Encoder to obtain the features $\vec{f}_j$ and $\vec{f}_k$ respectively. ConvNext-Tiny with 29M parameters and 4.5GFLOPs is comparable to ResNet-50 and outperforms similar-sized Vision Transformer (ViT) architectures \cite{liu2022convnet}, making it a suitable choice for our task.
The stage classifier's GT $y^{cls}_j=1$ if both $T_j \le 0$ and $E_j=1$, otherwise $y^{cls}_j=0$ for the scan $\vec{I}_j$. The stage classifier predicts the logit $r_j$. The probability for the current AMD stage for $\vec{I}_j$ is $\vec{p}_j=\sigma \left( r_j \right)$, where $\sigma(.)$ denotes the sigmoid activation. Notably, $r_j$ is proportional to the distance of $\vec{f}_j$ from the decision hyperplane of the AMD stage classifier and would be used below to define a risk score for future conversion. The stage classifier treats each scan independently without considering any correlations between two scans of the same eye from different time-points. While it enables the learned feature to capture pathologies to distinguish dAMD from iAMD, it may fail to capture more subtle retinal changes indicative of how AMD will progress in the future (Appendix \figureref{fig:motivation}(a)). 


% While training the encoder and classifier alone(without the GRU-ODE in \figureref{fig:method} d.) on the stage classification task can help predict the current state of the disease in $\vec{I}_j$, $\vec{I}_k$, it cannot predict the future probability of conversion without accessing the future visit scans. Each scan is treated as an independent sample (see \figureref{fig:method a.}) and the correlations between the scans from different time-points of the same eye is not modeled. While the learned feature embedding could be sensitive to the pathologies such as GA that can distinguish dAMD from iAMD, they may fail to capture the more subtle retinal changes in the iAMD stage that can indicate how the disease will progress in the future.


%%%%%%%%%%%%%% Time-Series Prediction %%%%%%%%%%%%%%%%%
\textbf{Time-Series Prediction:} To address these issues, we incorporate a N-ODE based continuous time-series predictor called Stacked Multihead GRU-ODE (SMGRU-ODE) to model the future trajectory of AMD progression in the feature embedding using the current scan. SMGRU-ODE evolves the current feature $\vec{f}_j$ over a ($t_k-t_j$) time-interval to independently predict the future feature $\hat{\vec{f}}_k$ for time $t_k$ directly from the prior visit $\vec{I}_j$, while the actual feature $\vec{f}_k$ is also obtained from $\vec{I}_k$. The encoder, SMGRU-ODE and stage classifier can now be jointly trained with the AMD stage classification task:
\begin{equation}
%\mathcal{L}_{tot}=\mathcal{L}_{cls}+ \gamma_1 ||a||_2^2 + \gamma_2 || \mtrx{M}\odot ( 1-\mtrx{\widehat{R}} )||_1,
\mathcal{L}_{cls}=L_{bce}\left( y^{cls}_j, p_j \right) + L_{bce}\left( y^{cls}_k, p_k \right) + L_{bce}\left( y^{cls}_k, \hat{p}_k \right),
\label{eqn:cls}
\end{equation}
where $L_{bce}\left( y,p \right)$ %=-{(y\log(p) + (1 - y)\log(1 - p))}$ 
is the binary cross-entropy loss. $p_j$, $p_k$ and $\hat{p}_{k}$ are predictions from the stage classifier for the features $\vec{f}_j$, $\vec{f}_k$ and $\hat{\vec{f}_k}$ respectively. The  SMGRU-ODE architecture is detailed in Section \ref{sec:smgru}% The details of the SMGRU-ODE architecture is discussed below in Section \ref{sec:smgru}

%%%%%%%%%%%%% Intra-eye consistency %%%%%%%%%%%%%%%%%%%
\textbf{Intra-eye Consistency:} For the disease progression trajectory predicted by the SMGRU-ODE to be consistent, the features $\hat{\vec{f}}_k$ and its stage prediction  $\hat{p}_k$ should match the corresponding $\vec{f}_k$ and $p_k$, obtained directly from $\vec{I}_{k}$.  This consistency loss between the features ($\mathcal{L}_{cns-ftr}$) and the stage predictions ($\mathcal{L}_{cns-rsk}$) are defined as:
\begin{equation}
\mathcal{L}_{cns-ftr}=|| \vec{f}_k - \hat{\vec{f}}_k ||^2_2, \qquad\qquad\qquad \mathcal{L}_{cns-rsk}=L_{bce}\left( p_k, \hat{p}_k \right).
\label{eqn_consistency}
\end{equation}
These losses combined with $\mathcal{L}_{cls}$ ensure that the learned feature embedding (Appendix \figureref{fig:motivation} (a) vs (b)) not only characterizes the current disease stage but is also sensitive to the subtle retinal changes that capture the trajectory of the disease progression in the future. Since the future visit scans are unavailable at test time, only the top branch in \figureref{fig:pipeline} highlighted in green is employed to obtain future predictions with the N-ODE.


%%%%%%%%%%%%%%%%%  Rank Loss %%%%%%%%%%%%%%%
\textbf{Risk Score Ranking: }
%\TEc{this section can be summarized shortly, and the all notation can be moved to appendix, learn-to-rank loss is commonly known. Also j,k becomes m,n here it is kind of confusing}
$\mathcal{L}_{cls}$ ensures that the iAMD and nAMD samples lie on opposite sides of its decision hyperplane without imposing any ordering among iAMD cases. %, it does not enforce any additional ordering within the iAMD samples. 
We envision a regularized feature manifold (Appendix \figureref{fig:motivation}(b) vs (c)) which correlates the risk of disease progression of a feature point to be inversely related to its distance from decision hyperplane, i.e., the closer an iAMD sample is to the decision hyperplane, the smaller its time to conversion, until it crosses over the hyperplane to the dAMD class. In this case, the logits $r$ from the stage classifier acts as a risk score for AMD progression as it is proportional to the distance of the sample from the decision hyperplane.  While predicting the probability of conversion within specified time-points (CDF) requires the N-ODE during inference, a scalar risk score is directly obtained from the logits of the stage classifier from the current scan.
We consider a loss $\mathcal{L}_{rank}$ which is defined using pairs of samples to encourage such ordering. Given a training batch comprising $B$ pairs of images (i.e., a total of 2$\times$B scans), we form all possible pairs $(\vec{I}_{m}, \vec{I}_n)$ which may or may not come from the same eye. $\mathcal{L}_{rank}$ defines an auxiliary classification task where the difference of their scalar logits from the AMD stage classifier is fed through a neuron (with a single input and output) to obtain the probability of ranking $r_m>r_n$ as $P_{m>n}=\sigma \left(w\cdot \left( r_m - r_n \right)+b\right)$, where $w$ and $b$ are scalar parameters of the neuron and the loss is defined as $\mathcal{L}_{rank}=L_{bce}\left( y^{m>n}_{rank}, P_{m>n} \right)$. 

The GT $y^{m>n}_{rank}=1$, if $T_m<T_n$ and $E_m=1$ (indicating the $\vec{I}_m$ converts before $\vec{I}_{n}$) or when $T_m<T_n$ and both $\vec{I}_m, \vec{I}_n$ are scans of the same eye (the risk increases in the future visits as damage to the retinal tissue is irreversible). Similarly,  $y^{m>n}_{rank}=0$ if $T_m>T_n$ and, either $E_n=1$ or $\vec{I}_m,\vec{I}_n$ come from the same eye, which signify cases where $\vec{I}_n$ converts before $\vec{I}_m$. The image pairs that do not fall into either one of these two categories cannot be ranked due to censoring and are considered to have missing labels that are masked out during the loss computation. %Since, the loss to correctly rank the pairs $(r_m, r_n)$ and $(r_n, r_m)$ will be same and only one among them is randomly selected for each $m$ and $n$. 
%%%%%%%%%%%%%%%%%%%TODO to be finalized tomorrow
%\TEc{From here the summary version of the paragraphs, feel free to replace it with the 3 paragprahs above, and move the all defeinitions to the appendix}While $\mathcal{L}_{cls}$ ensures that the iAMD and nAMD samples lie on opposite sides of its decision hyperplane without any ordering among iAMD cases. We envision a regularized feature manifold (Appendix \figureref{fig:motivation}) which correlates the risk of disease progression of a feature point to be inversely related to its distance from decision hyperplane, i.e., the closer an iAMD sample is to the decision hyperplane, the smaller its time to conversion, until it crosses over the hyperplane to the dAMD class. In this case, the logits $r$ from the stage classifier acts as a risk score for AMD progression as it is proportional to the distance of the sample from the decision hyperplane.  While predicting the probability of conversion within specified time-points (CDF) requires the Neural-ODE during inference, a scalar risk score is directly obtained from the logits of the stage classifier. In a batch, we created all possible scan pairs regardless of the source eye. Inspired from \citet{burges2005rank}, we define cross-entropy based $\mathcal{L}_{rank}$ as the probability of correctly ranking the risks for a given pair to encourage ordering among iAMD.
Finally, the total loss to train the proposed model is: 

\begin{equation}
\mathcal{L}_{tot}= \lambda_1 \mathcal{L}_{cls} + \lambda_2 \mathcal{L}_{cns-rsk} + \lambda_3 \mathcal{L}_{cns-ftr} + \lambda_4 \mathcal{L}_{rank},
\label{eqn_totloss}
\end{equation}

 where the loss weights $\lambda_1$, $\lambda_2$, $\lambda_3$ and $\lambda_4$ are not handcrafted but dynamically adapted during training using MTAdam \cite{malkiel2021mtadam} (see Appendix E for more details). 


\subsection{The N-ODE architecture}
\label{sec:smgru}


\begin{figure}[!htb]
 \centering
  \includegraphics[width=0.90\textwidth]{ODE3.pdf}
\caption{Our SMGRU-ODE extends GRU-ODE (in eq. 4) by stacking multiple layers (a), with multiple parallel heads in each layer (b).}
\label{fig:neural_ode} 
 \end{figure}


%The instantaneous velocity vector is defined as time derivative $\frac{\text{d}f}{\text{d}t}$. The function $f(t)$ itself is unknown and expensive to sample from. Instead of learning $f(t)$ directly, we approximate $\frac{\text{d}f}{\text{d}t}$ using function $\vec{v}_{D}(f)$ which is a function of feature $f$ only and independent of the time $t$ elapsed so far. Provided the initial conditions $\vec{f}(0)=\vec{f}_j$, we solve the initial value problem and recover $f(t)$. The function $\vec{v}_{D}(f)$ is modeled with a DL network and learned with the Neural-ODE framework.

Upon projecting the initial scan $\vec{I}_{j}$ to the feature $\vec{f}_j$, the N-ODE predicts its future trajectory in the feature embedding to model disease progression. Let $\vec{f}(t)$ represent the feature after a time $t$ has elapsed since $\vec{I}_{j}$ was imaged. As time progresses from $t$ to $t+dt$ by an infinitesimal amount, $\vec{f}(t)$ is displaced by $\vec{v}_{D} \cdot dt$ where $\vec{v}_{D}$ denotes the instantaneous velocity vector.
%the instantaneous velocity vector $\vec{v}_{D} \cdot dt$. 
This can be modeled in continuous time using the N-ODE $\frac{d\vec{f}(t)}{dt}=\vec{v}_D(\vec{f}(t))$ with the initial value $\vec{f}(0)=\vec{f}_j$, where $\vec{v}_D$ is modeled with a DL network. We assume a time-invariant system, i.e., $\vec{v}_D$ is solely dependent on the current feature $f(t)$ and not on the time $t$ elapsed so far.

We propose the SMGRU-ODE network to model $\vec{v}_{D}(\vec{f})$ which extends GRU-ODE \cite{de2019gru} by stacking $D=3$ layers (\figureref{fig:neural_ode}(a)), and  modifying each layer to have $H=12$ parallel pathways (\figureref{fig:neural_ode}(b)) called heads, based on the efficacy of such design in  Vision Transformers \cite{dosovitskiy2020image} and recent CNN architectures \cite{liu2022convnet}, \cite{xie2017aggregated}. In~\figureref{fig:neural_ode}(a), each layer employs $\vec{f}(t)$ as the hidden state and except for the first layer, also accepts an external input $\vec{v}_{d-1}(t)$ from its previous layer. Additive skip residual connections are applied between the inputs and outputs of each layer ($\vec{v}_{d}(t)=\hat{\vec{v}}_{d}(t)+ \vec{v}_{d-1}(t)$). As depicted in ~\figureref{fig:neural_ode}(b), each head $1 \le h \le H$ independently projects the two, $d$-dimensional (768 for ConvNeXT-Tiny \cite{liu2022convnet}) inputs to a  $d/H$ dimensional sub-space using the fully connected (FC) layers, $\varphi_{h,d}(\vec{f}(t))$ and  $\phi_{h,d}(\vec{v}_{d-1}(t))$ which project $\vec{f}(t)$ and $\vec{v}_{d-1}(t)$ to $\vec{f}^{(h)}_{d}(t)$ and $\vec{v}^{(h)}_{d}(t)$ respectively. Next, the $h^{th}$ head computes the output $\vec{o}^{(h)}_d(t)$ similar to GRU-ODE as
\vspace{-1pt}
\begin{subequations}  \label{eq1}
\begin{align}
\vec{r}^{(h)}_{d}(t) &= \vec{\Psi}^{(rst)}_{h,d} \left( \left[ \vec{v}^{(h)}_{d}(t),  \vec{f}^{(h)}_{d}(t) \right] \right), &
\vec{u}^{(h)}_{d}(t) &=  \vec{\Psi}^{(updt)}_{h,d}\left( \left[ \vec{v}^{(h)}_{d}(t),  \vec{f}^{(h)}_{d}(t) \right]\right),\\
\vec{g}^{(h)}_{d}(t) &= \vec{\Psi}^{(act)}_{h,d} \left(  \left(\vec{r}^{(h)}_{d}(t) \odot  \vec{f}^{(h)}_{d}(t) \right)  \right), \!\! &  
\vec{o}^{(h)}_d(t) &= \left( 1 - \Vec{u}^{(h)}_{d}(t)\right) \odot \left( \vec{g}^{(h)}_{d}(t) -  \vec{f}^{(h)}_{d}(t) \right)
\end{align}
\label{eqn:all-lines}
\end{subequations}
where $\vec{\Psi}^{(rst)}_{h,d}(.)$ and  $\vec{\Psi}^{(updt)}_{h,d}(.)$ comprise a FC layer followed by Layer Normalization (LN) and sigmoid activation to compute the  \textit{update} and \textit{reset gates}, $\vec{u}^{(h)}_{d}(t)$ and $\vec{r}^{(h)}_{d}(t)$ in eq. \eqref{eq1} respectively.  The $\vec{\Psi}^{(act)}_{h,d}(.)$ used for the \textit{candidate activation vector} $\vec{g}^{(h)}_d(t)$ employs a Softplus activation after the LN and FC layers. Finally, each head's output $\vec{o}^{(h)}_d(t)$ is concatenated and input to $\Psi_d(.)$, which represents a FC layer \textit{without} LN and activation, to obtain $\hat{\vec{v}}_d(t)$. The LN and activations are instead applied at the beginning of the next layer in $\varphi_{h,d+1}(.)$ and  $\phi_{h,d+1}(.)$. This is to ensure that (i) the output of the final layer $\vec{v}_{D}$ can take arbitrary (including negative) values and (ii) the normalization and activation is applied after the additive residual connections in each layer (\figureref{fig:neural_ode}(a)) so that backpropagation gradients can be improved through pre-activation \cite{he2016identity}. 
During forward pass of the N-ODE, the feature for a future time-point $k$ is given by $\hat{\vec{f}}_{k}=\vec{f}_{j}+ \int _0 ^{k-j} \vec{v}_{D}(f_{t+j}) dt$ which can be numerically estimated using any black-box ODE-solver. During the backward pass, the computational graph related to each iteration of the ODE-solver is not saved but can be estimated on the run by solving another augmented ODE introduced by the adjoint sensitivity analysis in \cite{chen2018neural}. As a result, the training requires a constant amount of memory independent of the solver's step size and integration time allowing us to evolve the trajectory over long time-intervals, even with limited GPU memory. 




\section{Experiments and Results}
\textbf{Dataset:} It consists of 3,534 OCT scans from 235 eyes (40 converters and 195 censored) from 123 patients, collected at the Department of Ophthalmology, Medical University of Vienna \cite{schlanitz2017drusen} and acquired using a Spectralis scanner at a resolution of 49 B-scans (slices), each with a $512-1024 \times 496$ pixels.
Each eye was imaged every 3-6 months, with total follow-up periods spanning 2-7 years. For converter eyes, labels for each scan were computed by measuring the time interval between its acquisition and the first conversion visit. Our Pytorch code is available at \url{https://github.com/arunava555/Multihead_GRU_ODE_based_Survival_Analysis}.


%\noindent 
\textbf{Experimental Setup:} A stratified five-fold cross-validation was performed by randomly dividing the scans at an eye-level to reduce the bias of a specific train-test data split. Each fold had 47 eyes with 8 converters and the number of scans varied between 667-707 across the folds. The model was trained five times, treating each fold as the test set, while the remaining dataset was randomly divided into $80\%$ for training and $20\%$ for validation. While the already converted dAMD scans were used during training, they were removed from the test set during evaluation. The performance was evaluated for predicting the conversion to dAMD within $6,12, 18$ and $24$ months using the Area under the receiver operating characteristic curve (AUROC). Balanced Accuracy was used to assess binary predictions obtained by setting a threshold on the conversion probabilities at an optimal operating point determined from the validation set in each fold.
%and the Balanced Accuracy. 
Additionally, the Concordance Index (C-index) was used to evaluate the proposed risk score. It quantifies a model's ability to provide a reliable (inverse) ranking of the conversion time, taking censoring into account.

%The performance was evaluated for predicting the conversion to dAMD within $6,12, 18$ and $24$ months. Area under the receiver operating characteristic curve (AUROC) was used to evaluate the prediction scores.\TEc{remove followıng sentence} Balanced accuracy was employed for the binary predictions obtained by thresholding the prediction scores at an operating point that maximized the Youden's J statistic on the validation set. Concordance Index(C-index) was used to evaluate the proposed risk score which 

%\TEc{this paragraph can be removed completely}The performance was measured both at the Scan-level (provided in the Appendix, Table xx) and Eye-level. The scan-level performance was evaluated by treating each image as an independent sample to report the average performance across the five-folds. 



%\noindent 
\textbf{Ablation Results:} In Table 1, we analyzed the effect of the depth and the number of heads in SMGRU-ODE. Either reducing the depth D from 3 to 1 (in row 1) while keeping H=12, or reducing H from 12 to 1, while keeping D=3 (in row 2) had an adverse impact on the performance at all time-points, both in terms of AUROC and Balanced Accuracy. The C-index also reduced from 0.777 to 0.744 in both cases. This justifies our incorporation of multiple layers and heads in GRU-ODE. From rows 3-6, we perform ablation on the loss terms. In row 3, we train the model with $\mathcal{L}_{cls}$ loss (see eq. \eqref{eqn:cls} ), which is the minimal loss required to predict future conversion without incorporating any other losses to regularize the feature embedding. Introducing $\mathcal{L}_{rank}$ loss to it leads to a significant improvement in C-index (from 0.715 to 0.769) which is expected as $\mathcal{L}_{rank}$ is geared towards improving the rank ordering. Moreover, it improves the conversion prediction performance for all time-points both in terms of AUROC and Balanced Accuracy (except for AUROC - 6 month). Next, introducing the $\mathcal{L}_{cns-rsk}$ loss (in row 5) leads to further improvement in C-index, AUROC also improves for all except the 24-month time-point. However, the Balanced Accuracy shows mixed results with minor improvements for predicting conversion within 12 and 18 months but a slight drop in performance for the 6 and 24-month time-points. Finally, introducing the $\mathcal{L}_{cns-ftr}$ loss leads to our proposed method in row 6. It consistently improves the AUROC, Balanced Accuracy and C-index metrics with the exception of the 6 month time-point. Overall, the results demonstrate the value of using all loss terms.

\setlength{\tabcolsep}{8pt}
\renewcommand{\arraystretch}{1.2}
\begin{table}[htbp]
\centering
\label{Tab:Ablation}
\caption{Ablation experiments of different loss terms and the SMGRU-ODE architecture (mean $\pm$ std. dev.). Best values in each column are highlighted in bold. 
}
\resizebox{1.0 \textwidth}{!}{
\begin{tabular}{@{}l|llll|llll|l@{}}
\toprule
  & \multicolumn{4}{c|}{AUROC}   & \multicolumn{4}{c|}{Balanced Accuracy} &  \\ 
                                 & \multicolumn{1}{c}{6} & \multicolumn{1}{c}{12} & \multicolumn{1}{c}{18} & \multicolumn{1}{c|}{24} & \multicolumn{1}{c}{6} & \multicolumn{1}{c}{12} & \multicolumn{1}{c}{18} & \multicolumn{1}{c|}{24} & \multicolumn{1}{c}{C-index} \\ \midrule

SMGRU-ODE(D=1) & $0.823\pm0.05$& $0.789\pm0.05$& $0.771\pm0.05$& $0.779\pm0.06$& $0.814\pm0.05$& $0.769\pm0.05$& $0.751\pm0.05$& $0.746\pm0.06$ & $0.744\pm0.06$     \\  

SMGRU-ODE(H=1) & $0.817\pm0.08$& $0.791\pm0.08$& $0.766\pm0.07$& $0.766\pm0.08$& $0.813\pm0.05$& $0.777\pm0.05$& $0.749\pm0.05$& $0.743\pm0.05$ & $0.744\pm0.07$  \\ \midrule

$\mathcal{L}_{cls}$ & $0.854\pm0.06$& $0.827\pm0.06$& $0.795\pm0.04$& $0.799\pm0.04$& $0.832\pm0.06$& $0.784\pm0.05$& $0.756\pm0.04$& $0.764\pm0.05$ & $0.715\pm0.05$\\

$\mathcal{L}_{cls} + \mathcal{L}_{rank}$ & $0.852\pm0.05$& $0.828\pm0.04$& $0.803\pm0.01$& $0.812\pm0.02$& $\mathbf{0.846\pm0.04}$& $0.788\pm0.05$& $0.773\pm0.02$& $0.781\pm0.02$ & $0.769\pm0.04$\\

$\mathcal{L}_{cls} + \mathcal{L}_{rank} + \mathcal{L}_{cns-rsk}$ & $\mathbf{0.857\pm0.05}$& $0.832\pm0.04$& $0.807\pm0.04$& $0.810\pm0.03$& $0.834\pm0.04$& $0.793\pm0.03$& $0.774\pm0.03$& $0.776\pm0.03$ & $0.773\pm0.05$    \\ 

%$\mathcal{L}_{cls} + \mathcal{L}_{rank} + \mathcal{L}_{cns-rsk}$ & $0.857\pm0.05$& $0.832\pm0.04$& $0.807\pm0.04$& $0.810\pm0.03$& $0.834\pm0.04$& $0.793\pm0.03$& $0.774\pm0.03$& $0.776\pm0.03$ & $0.773\pm0.05$    \\ 
 

%H=12
Proposed  & $0.856\pm0.05$& $\mathbf{0.844\pm0.04}$& $\mathbf{0.819\pm0.02}$& $\mathbf{0.822\pm0.03}$& $0.840\pm0.05$& $\mathbf{0.818\pm0.04}$& $\mathbf{0.800\pm0.04}$& $\mathbf{0.803\pm0.04}$ & $\mathbf{0.777\pm0.04}$\\ 
\bottomrule
\end{tabular}
}
\end{table}








%\noindent 
\textbf{Comparison with the State of the Art:} In Table 2, we compare our method against common survival analysis methods. 6-month time windows are considered for the discrete-time survival models based on the censored cross-entropy loss \cite{wulczyn2020deep} and the logistic hazard model \cite{rivail2023deep}.
 DeepSurv \cite{katzman2018deepsurv} extends CoxPH with DL,  while SODEN \cite{tang2022soden} is a N-ODE based method, previously used on tabular data. These methods were also trained with ConvNeXt-Tiny encoder but with modified classification layers and losses. Notably, all of these methods do not employ intra-subject regularization, hence require training a single branch network. The results in Table 2 indicate the superiority of our proposed method which outperforms the existing methods at all time-points. SODEN, another N-ODE-based method showed signs of overfitting with good performance on the validation set (for selecting the best-performing models in each fold) but led to a drastic drop in performance on the test sets across all folds.
 

\setlength{\tabcolsep}{8pt}
\renewcommand{\arraystretch}{1.2}
\begin{table}[htbp]
\centering
\label{Tab:sota1}
\caption{Comparison with State-of-the-Art. Best performance is highlighted in bold.}%Comparison with State-of-the-Art (mean $\pm$ std. dev.). Best performance in each column is highlighted in bold.}
\resizebox{1.0 \textwidth}{!}{
\begin{tabular}{@{}l|llll|llll|l@{}}
\toprule
  & \multicolumn{4}{c|}{AUROC}   & \multicolumn{4}{c|}{Balanced Accuracy} &  \\ 
                                 & \multicolumn{1}{c}{6} & \multicolumn{1}{c}{12} & \multicolumn{1}{c}{18} & \multicolumn{1}{c|}{24} & \multicolumn{1}{c}{6} & \multicolumn{1}{c}{12} & \multicolumn{1}{c}{18} & \multicolumn{1}{c|}{24} & \multicolumn{1}{c}{C-index} \\ \midrule

%Proposed  & $0.851\pm0.05$& $\mathbf{0.840\pm0.05}$& $\mathbf{0.820\pm0.04}$& $\mathbf{0.827\pm0.05}$& $0.834\pm0.05$& $\mathbf{0.808\pm0.06}$& $\mathbf{0.793\pm0.04}$& $\mathbf{0.794\pm0.04}$ & $\mathbf{0.786\pm0.05}$  \\  
Cens. Cross-Entropy  & $0.787\pm0.06$& $0.779\pm0.06$& $0.776\pm0.05$& $0.789\pm0.04$& $0.764\pm0.05$& $0.739\pm0.04$& $0.731\pm0.03$& $0.741\pm0.02$ & $0.767\pm0.04$   \\ 
%Cens. Cross-Entropy 3 yr  & $0.795\pm0.04$& $0.795\pm0.06$& $0.789\pm0.05$& $0.807\pm0.04$& $0.773\pm0.05$& $0.76\pm0.05$& $0.747\pm0.03$& $0.756\pm0.02$ & $0.785\pm0.04$   \\ 
Logistic Hazard    & $0.787\pm0.06$& $0.787\pm0.04$& $0.779\pm0.04$& $0.797\pm0.03$& $0.780\pm0.06$& $0.766\pm0.03$& $0.745\pm0.04$& $0.755\pm0.04$ & $0.769\pm0.04$ \\ 
%Logistic Hazard 3 yr  & $0.801\pm0.07$& $0.801\pm0.04$& $0.784\pm0.05$& $0.801\pm0.04$& $0.785\pm0.06$& $0.771\pm0.05$& $0.752\pm0.02$& $0.757\pm0.02$ & $0.766\pm0.07$  \\ 
DeepSurv & $0.755\pm0.13$& $0.735\pm0.12$& $0.720\pm0.11$& $0.728\pm0.12$& $0.734\pm0.12$& $0.702\pm0.10$& $0.681\pm0.09$& $0.679\pm0.09$ & $0.768\pm0.04$  \\ 
%DeepRSA   \\ 
SODEN  & $0.673\pm0.09$& $0.707\pm0.05$& $0.703\pm0.04$& $0.721\pm0.05$& $0.676\pm0.05$& $0.691\pm0.03$& $0.685\pm0.04$& $0.698\pm0.04$ & $0.710\pm0.05$ \\ % CI reverse

%H=12
Proposed & $\mathbf{0.856\pm0.05}$& $\mathbf{0.844\pm0.04}$& $\mathbf{0.819\pm0.02}$& $\mathbf{0.822\pm0.03}$& $\mathbf{0.840\pm0.05}$& $\mathbf{0.818\pm0.04}$& $\mathbf{0.800\pm0.04}$& $\mathbf{0.803\pm0.04}$ & $\mathbf{0.777\pm0.04}$\\ 


\bottomrule
\end{tabular}
}
\end{table}


\section{Conclusion}

A wide variability in progression speed and the lack of well-established biomarkers make predicting the progression of AMD challenging. We proposed a novel framework that combines an AMD stage classifier with a N-ODE to forecast dAMD onset at continuous future times. To learn meaningful features from scarce labels, we enforce (i) intra-subject consistency to ensure that the feature embedding is sensitive to temporal changes in the retina to predict the future; 
(ii) temporal ordering, where a scan's proximity to the AMD classifier’s decision hyperplane is inversely related to its time-to-conversion.
These constraints enabled our model to outperform several existing deep survival analysis methods. Additionally, temporal ranking allowed us to derive a scalar risk score to stratify eyes into low and high risk groups. While training uses longitudinal OCT scans, only a single scan at test time is needed for future conversion prediction. Our method for predicting dAMD onset can facilitate patient-specific disease management and enrich clinical trial populations with high-risk patients. Currently, the proposed method has been evaluated on a single-center dataset. Further evaluation of our method on multi-center data and adaptation to other survival analysis tasks in the medical domain, such as progression-free survival in cancer patients, are potential directions for future work. Use of a ViT based encoder and incorporating segmentations of relevant retinal layers and lesions as additional inputs may also be considered in the future to further improve performance. 



%for this task that jointly trains an AMD stage classifier with a N-ODE to predict the onset of nAMD at continuous future time-points. In order to learn a meaningful feature embedding from limited labeled data, we regularize the learned feature embedding with (i) intra-subject consistency to ensure that the feature embedding is sensitive to the temporal changes in the retina to predict the future; (ii) a temporal ordering is imposed such that the distance of a scan from the decision hyperplane of the AMD stage classifier is inversely related to its time-to-conversion. These constraints on the feature embedding enable us to outperform several existing deep Survival analysis methods. Additionally, temporal ranking allows us to derive a scalar risk score which is effective in stratifying eyes coming from low and high risk groups. Although requiring longitudinal OCT scans during training, our method only requires a single OCT scan at test time to predict future conversion. Our method for predicting dAMD onset can facilitate patient-specific disease management and enhance clinical trial populations with high risk patients.
% Future work: competing risks for dAMD and nAMD?
% Demi-supervised : intra subject consistency does not require labels

\clearpage
\midlacknowledgments{This research was funded in whole, or in part, by the Austrian Science Fund (FWF) [10.55776/FG9], and Wellcome Trust Collaborative Award Ref. 210572/Z/18/Z.}

\bibliography{midl24_208}
\appendix





\section{Intuitive Explanation of the Methodology}
\begin{figure}[!htb]
 \centering
  \includegraphics[width=1.0\textwidth]{motivation.pdf}
\caption{(a) The AMD stage classifier learns a decision hyper-plane separating iAMD from dAMD. Each scan is considered to be an independent sample. The learned feature should capture pathologies that distinguish iAMD from dAMD. (b) A N-ODE is introduced along with the stage classifier.  The N-ODE traces the trajectory of disease progression (shown as dotted lines connecting the points of the same color, representing scans coming from the same eye at different time-points). Now the feature also needs to capture the subtle retinal changes indicative of the future disease state. (iii) A notion of direction is incorporated in the feature embedding. The closer a point is to the decision hyperplane, the smaller its time-to-conversion. AMD being an irreversible disease can only progress in time, so scans from a later visit of an eye (shown by numbered indices) have to successively get closer to the decision hyperplane.}
\label{fig:motivation} 
 \end{figure}



\section{Implementation Details}
All experiments were performed in Python 3.8.16 with Pytorch 2.0.0. 
The proposed method was trained with batches comprising 16 image pairs for 200 epochs (300 batch updates per epoch), using the MTAdam \cite{malkiel2021mtadam} optimizer for dynamic loss tuning. A cyclic learning rate scheduler was employed with a minimum and maximum learning rate of $10^{-6}$ and $10^{-4}$ respectively. The performance on the validation set was monitored at the end of each epoch for early stopping with a patience of 50 epochs. The N-ODE was implemented with the \textit{torchdiffeq}  library \cite{torchdiffeq}. The Euler method was used as the ODE-solver due to its computational efficiency with a step size of 0.06, where the time between 0-3 years was mapped to [0,1]. During training, each batch was constructed with random training image-pairs $(\vec{I}_{j}, \vec{I}_k)$ with a time-interval of 0-3 years between them. The training batches were constructed to ensure that all $\vec{I}_{j}$ were in the iAMD stage while half of the $\vec{I}_k$ in each batch were in the dAMD stage (through oversampling) to enable the training of the AMD stage classifier.

The proposed method required around 6 GB of GPU memory to train using a training batch size of 16 image pairs. The ConvNeXt-Tiny \cite{liu2022convnet} encoder was initialized with the standard  Image-Net pre-trained weights for end-to-end fine-tuning. The proposed SMGRU-ODE model with D=3, H=12 has 4,798,848 learnable network parameters. %At inference time, the proposed method needed an average of xx secs/image



\section{Eye-level Performance Comparison with Bootstrapping: }



\setlength{\tabcolsep}{7pt}
\renewcommand{\arraystretch}{1.2}
\begin{table}[!htbp]
\centering
\label{Tab:bootstrap}
\caption{Eye-level Bootstrap Performance (mean $\pm$ std. dev.). Best values in each column is highlighted in bold. }
\resizebox{1.0 \textwidth}{!}{
\begin{tabular}{@{}l|llll|llll|l@{}}
\toprule
  & \multicolumn{4}{c|}{AUROC}   & \multicolumn{4}{c|}{Balanced Accuracy} &  \\ 
                                 & \multicolumn{1}{c}{6} & \multicolumn{1}{c}{12} & \multicolumn{1}{c}{18} & \multicolumn{1}{c|}{24} & \multicolumn{1}{c}{6} & \multicolumn{1}{c}{12} & \multicolumn{1}{c}{18} & \multicolumn{1}{c|}{24} & \multicolumn{1}{c}{C-index} \\ \midrule
%H=12
Proposed & $\mathbf{0.863\pm0.10}$ & $\mathbf{0.827\pm0.10}$ & $\mathbf{0.808\pm0.07}$ & $\mathbf{0.816\pm0.07}$ & $\mathbf{0.871\pm0.11}$ & $\mathbf{0.811\pm0.09}$ & $\mathbf{0.789\pm0.07}$ & $\mathbf{0.801\pm0.06}$ & $\mathbf{0.769\pm0.06}$ \\

Cens. Cross-Entropy  & $0.775\pm0.14$& $0.772\pm0.103$& $0.773\pm0.10$& $0.790\pm0.08$& $0.804\pm0.11$& $0.756\pm0.11$& $0.742\pm0.07$& $0.746\pm0.06$ & $0.762\pm0.06$   \\ 

Logistic Hazard   & $0.769\pm0.19$& $0.768\pm0.12$& $0.763\pm0.09$& $0.786\pm0.08$& $0.792\pm0.14$& $0.760\pm0.11$& $0.749\pm0.08$& $0.766\pm0.08$ & $0.749\pm0.08$ \\ 

DeepSurv  & $0.769\pm0.18$& $0.710\pm0.16$& $0.712\pm0.14$& $0.723\pm0.14$& $0.749\pm0.17$& $0.689\pm0.12$& $0.682\pm0.12$& $0.686\pm0.12$ & $0.752\pm0.07$ \\ 

SODEN   & $0.675\pm0.24$& $0.674\pm0.17$& $0.673\pm0.13$& $0.698\pm0.11$& $0.711\pm0.19$& $0.671\pm0.14$& $0.665\pm0.11$& $0.693\pm0.10$ & $0.673\pm0.09$\\

\bottomrule
\end{tabular}
}
\end{table}

Eye-level bootstrapping involves multiple re-samplings of the test set in each fold. In each re-sampling, one OCT scan is selected from each eye (by randomly selecting any one of the patient visits). This re-sampling process is repeated $1000$ times for each of the five folds to report the average performance across the $5 \times 1000=5000$ sample estimates across all folds (see Table 3). 

%\setlength{\tabcolsep}{8pt}
%\renewcommand{\arraystretch}{1.2}
%\begin{table}[htbp]
%\centering
%\label{Tab:Ablation}
%\caption{Ablation experiments for loss terms and SMGRU-ODE architecture. Eye-level Bootstrap performance (mean $\pm$ std. dev.). Best values highlighted in bold.}
%\resizebox{1.0 \textwidth}{!}{
%\begin{tabular}{@{}l|llll|llll|l@{}}
%\toprule
%  & \multicolumn{4}{c|}{AUROC}   & \multicolumn{4}{c|}{Balanced Accuracy} &  \\ 
%                                 & \multicolumn{1}{c}{6} & \multicolumn{1}{c}{12} & \multicolumn{1}{c}{18} & \multicolumn{1}{c|}{24} & \multicolumn{1}{c}{6} & \multicolumn{1}{c}{12} & \multicolumn{1}{c}{18} & \multicolumn{1}{c|}{24} & \multicolumn{1}{c}{CI} \\ \midrule

%SMGRU-ODE(D=1)  & $0.803\pm0.14$& $0.753\pm0.13$& $0.735\pm0.1$& $0.749\pm0.1$& $0.82\pm0.09$& $0.769\pm0.1$& $0.744\pm0.09$& $0.741\pm0.09$ & $0.717\pm0.07$     \\

%SMGRU-ODE(H=1) & $0.847\pm0.12$& $0.805\pm0.11$& $0.769\pm0.09$& $0.76\pm0.09$& $0.849\pm0.12$& $0.796\pm0.1$& $0.756\pm0.08$& $0.745\pm0.08$ & $0.753\pm0.06$  \\

%$\mathcal{L}_{cls}$ & $0.874\pm0.12$& $\mathbf{0.828\pm0.10}$& $0.794\pm0.08$& $0.8\pm0.08$& $0.875\pm0.11$& $0.804\pm0.1$& $0.766\pm0.08$& $0.776\pm0.07$ & $0.706\pm0.05$ \\

%$\mathcal{L}_{cls} + \mathcal{L}_{rank}$ & $\mathbf{0.879\pm0.11}$& $0.815\pm0.12$& $0.79\pm0.08$& $0.802\pm0.07$& $0.873\pm0.12$& $0.784\pm0.11$& $0.763\pm0.08$& $0.776\pm0.07$ & $0.758\pm0.06$\\


%$\mathcal{L}_{cls} + \mathcal{L}_{rank}  + \mathcal{L}_{cns-rsk}$ & $0.868\pm0.12$& $0.817\pm0.12$& $0.795\pm0.09$& $0.805\pm0.08$& $\mathbf{0.874\pm0.11}$& $0.804\pm0.11$& $\mathbf{0.786\pm0.09}$& $0.787\pm0.07$ & $\mathbf{0.767\pm0.07}$    \\ 


%H=12
%Proposed  & $0.863\pm0.10$ & $0.827\pm0.1$ & $0.808\pm0.07$ & $0.816\pm0.07$ & $0.871\pm0.11$ & $0.811\pm0.09$ & $0.789\pm0.07$ & $0.801\pm0.06$ & $0.769\pm0.06$\\
%\bottomrule
%\end{tabular}
%}
%\end{table}













%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


\section{Preprocessing and Data Augmentation}

The top and bottom boundaries delineating the retinal tissue called the Inner Limiting Membrane (ILM) and the Bruch's Membrane (BM) were extracted using the automated method in \cite{fazekas2022segmentation}. Thereafter, the curvature of the retinal surface was flattened by shifting each A-scan by an offset such that the BM lies on a straight plane similar to \cite{emre2022tinc}. The five central B-scans centered around the fovea spanning 5 mm across the A-scans (image columns) were extracted and the region containing the retinal tissue between the ILM and BM was cropped with a margin of $280$ micron in the bottom to include the choroid region and resized to $248 \times 248$. The intensity was linearly scaled to [-1,1]. 

During training, 3 consecutive B-scans (slices) out of the 5 central B-scans extracted during preprocessing were randomly selected from each scan and provided as input to the ConvNeXt-Tiny model in place of the three RGB color channels. The data augmentations during training involved random translations, horizontal flip, random crop-resize, Gaussian noise, random in-painting and random intensity transformations.

During inference, no data augmentation was employed. Of the 5 central B-scans extracted, 3 sets of images were constructed, each using 3 consecutive B-scans as channels similar to RGB in natural images (and the average predictions from these 3 images was used). The same approach was also employed for evaluating the other state-of-the-art methods for comparison.


\section{Dynamic Loss Tuning}
\label{dynamic_loss}

Determining the value of the tunable loss weights $\lambda_1$, $\lambda_2$, $\lambda_3$ and $\lambda_4$ in Eq.~\ref{eqn_totloss} through a systematic grid search is computationally expensive as it requires training multiple model configurations. Instead, we used Multi-Term Adam (MTAdam) \cite{malkiel2021mtadam} to dynamically adapt the loss weights during training. MTAdam extends the ADAM optimizer by tracking derivatives and the first and second order moments of each loss term separately and continuously balances their gradient magnitudes across all layers during training batch updates. To evaluate the impact of this design choice, we retrained the model with different alternatives presented below in Table 4.
In case of \textit{Equal Weighting} we fixed all weights to $\lambda_1=\lambda_2=\lambda_3=\lambda_4=1.0$. In case of \textit{Handcrafted weights}, we fixed $\lambda_1=1.0$, $\lambda_2=0.1$, $\lambda_3=1.0$ and $\lambda_4=10.0$ by observing the scale and the perceived relative importance of the different loss terms.  The uncertainty weighting based method in \cite{kendall2018multi} is another alternative automatic method for dynamic loss tuning which was used along with the modifications proposed in \cite{liebel2018auxiliary} to avoid the loss becoming negative during training.

\setlength{\tabcolsep}{8pt}
\renewcommand{\arraystretch}{1.2}
\begin{table}[htbp]
\centering
\label{Tab:dynamic_loss}
\caption{Comparison of different loss weighting strategies (mean $\pm$ std. dev.). Best values in each column is highlighted in bold.}
\resizebox{1.0 \textwidth}{!}{
\begin{tabular}{@{}l|llll|llll|l@{}}
\toprule
  & \multicolumn{4}{c|}{AUROC}   & \multicolumn{4}{c|}{Balanced Accuracy} &  \\ 
                                 & \multicolumn{1}{c}{6} & \multicolumn{1}{c}{12} & \multicolumn{1}{c}{18} & \multicolumn{1}{c|}{24} & \multicolumn{1}{c}{6} & \multicolumn{1}{c}{12} & \multicolumn{1}{c}{18} & \multicolumn{1}{c|}{24} & \multicolumn{1}{c}{C-index} \\ \midrule
%\multicolumn{10}{c}{Scan-level performance.} \\ \midrule

Equal weighting   & $\mathbf{0.862\pm0.07}$& $0.828\pm0.04$& $0.802\pm0.03$& $0.807\pm0.04$& $\mathbf{0.855\pm0.05}$& $0.798\pm0.03$& $0.777\pm0.02$& $0.775\pm0.02$ & $0.772\pm0.03$    \\ 
Handcrafted weights & $0.843\pm0.06$& $0.835\pm0.04$& $0.808\pm0.03$& $0.820\pm0.02$& $0.831\pm0.05$& $0.796\pm0.04$& $0.780\pm0.04$& $0.793\pm0.04$ & $0.772\pm0.02$      \\ 
%H=12
MT-ADAM  & $0.856\pm0.05$& $\mathbf{0.844\pm0.04}$& $\mathbf{0.819\pm0.02}$& $\mathbf{0.822\pm0.03}$& $0.840\pm0.05$& $\mathbf{0.818\pm0.04}$& $\mathbf{0.800\pm0.04}$& $\mathbf{0.803\pm0.04}$ & $\mathbf{0.777\pm0.04}$\\
Uncertainty weighting  & $0.843\pm0.06$& $0.819\pm0.04$& $0.790\pm0.03$& $0.797\pm0.03$& $0.818\pm0.07$& $0.773\pm0.04$& $0.749\pm0.04$& $0.752\pm0.04$ & $0.763\pm0.04$ \\ 

%\midrule
%\multicolumn{10}{c}{Eye-level Bootstrap performance.} \\ \midrule

%Equal weighting   & $\mathbf{0.866\pm0.10}$& $0.808\pm0.11$& $0.787\pm0.08$& $0.797\pm0.07$& $\mathbf{0.890\pm0.10}$& $\mathbf{0.796\pm0.11}$& $0.777\pm0.08$& $0.778\pm0.08$ & $0.753\pm0.06$   \\ 
%Handcrafted weights  & $0.845\pm0.12$& $0.807\pm0.12$& $0.789\pm0.09$& $0.805\pm0.07$& $0.849\pm0.08$& $0.787\pm0.12$& $0.769\pm0.08$& $0.784\pm0.07$ & $0.756\pm0.06$  \\ 
%MT-ADAM  & $0.839\pm0.13$& $0.809\pm0.13$& $\mathbf{0.794\pm0.09}$& $\mathbf{0.807\pm0.09}$& $0.856\pm0.09$& $\mathbf{0.796\pm0.13}$& $\mathbf{0.780\pm0.09}$& $\mathbf{0.788\pm0.07}$ & $\mathbf{0.765\pm0.08}$ \\ 
%Uncertainty weighting  & $0.843\pm0.12$& $\mathbf{0.811\pm0.11}$& $0.783\pm0.09$& $0.791\pm0.08$& $0.874\pm0.12$& $0.795\pm0.11$& $0.759\pm0.08$& $0.763\pm0.07$ & $0.754\pm0.06$ \\ 

\bottomrule
\end{tabular}
}
\end{table}


\section{Identification of Risk Groups}
%%%%%%%%%%OLD
We calibrated the risk scores in each fold to lie in the $[0,1]$. This was performed with bicubic interpolation to map the $x^{th}$ percentile of the risk scores in the validation set to $\frac{x}{100}$ (e,g., the $10^{th}$ percentile of the risk scores is mapped 0.1 and so on). The test set predictions of the calibrated risk scores were combined from the five folds to obtain a risk score for each OCT scan.  The scans were then stratified into 3 groups with low risk ($0 \le r \le 0.33$), moderate risk ($0.33 < r \le 0.67$) and high risk ($0.67<r \le 1$). A population-level survival function for these groups is plotted in Fig. \ref{fig:KM} (a) using the Kaplan–Meier estimator on the GT conversion time. It depicts the mean and standard deviation of the survival probability for each population group, computed across 1000 re-samplings using bootstrapping. The survival curves for the three risk groups show a clear separation, thereby demonstrating the effectiveness of the proposed risk score.
%%%%%%%%%%Shortened
The learned feature embedding (Fig. \ref{fig:KM} (b)) exhibit a smooth transition from fast(red) to slow converters(blue) along the feature manifold where the gray dots represent the censored scans.

The Saliency maps obtained for the risk scores in Fig. \ref{fig:qual}  show the network to be sensitive to the structural changes around the RPE (e.g. Fig.\ref{fig:qual} (a),(b)) and Hyperreflective Foci (HRF) (e.g. Fig.\ref{fig:qual} (d),(e) ) which have been clinically linked to dAMD progression.  


\begin{figure}[!htb]
 \centering
  \includegraphics[width=1\textwidth]{KM_plots2.pdf}
\caption{(a) Kaplan-Meier curves for different risk groups; (b) UMAP plot of feature embedding for one of the five folds. The censored scans are depicted with gray dots and the converters colored by their time to conversion (red indicates fast conversion)}
\label{fig:KM} 
 \end{figure}


 
\begin{figure}[!htb]
 \centering
  \includegraphics[width=1\textwidth]{gradcam.pdf}
\caption{ Grad-CAM Saliency maps for the risk score.}
\label{fig:qual} 
 \end{figure}

\end{document}
