\begin{figure}
  \begin{center}
    \includegraphics[width=0.45\textwidth]{figs/airl_model.png}
  \end{center}
  \caption{Overall architecture of \texttt{AIRL} and the visualization of its learning process.}
  \label{fig:2}
\end{figure}

\section{Proposed Algorithm}\label{sec:method}

\paragraph{Overview.}
Based on Remark \ref{thm-5}, we propose \texttt{AIRL}, a novel model that learns adaptive invariant representations from a sequence of $T$ source domains. \texttt{AIRL} includes two components: (i) \textit{representation network} which are instantiation of mechanisms $F^{\ast}$ and $G^{\ast}$ that generates representation mapping sequences $[f^*_1, \cdots, f^*_{T+K}]$ and $[g^*_1, \cdots, g^*_{T+K}]$ from input space to representation space, (ii) \textit{classification network} that learns the sequence of classifiers $H^{\ast} = [{h}^*_1, \cdots, {h}^*_{T+K}]$ from representation to the output spaces. Figure \ref{fig:2} shows the overall architecture of \texttt{AIRL}; the technical details of each component are presented in Appendix \ref{app:model}. The learning and inference processes of \texttt{AIRL} are formally stated as follows.

\subsection{Learning} Non-stationary mechanisms $F^{\ast}$ and $G^{\ast}$, and classifiers $H^{\ast}$ in \texttt{AIRL} can be learned by solving an optimization problem over $T$ source domains $\{D_t\}_{t=1}^T$:
\begin{align}\label{eq:obj}
  F^{\ast}, G^{\ast}, H^{\ast} = \underset{F , G , H }{\argmin} \sum_{t=1}^{T} \mathcal{L}^t_{cls} + \alpha \mathcal{L}^t_{inv}
\end{align} 
where $\mathcal{L}^t_{cls}$ is the prediction loss on source domains $D_t$ and  $D_{t+1}$; $\mathcal{L}^t_{inv}$ enforces the representations are invariant across a pair of consecutive domains $D_t,D_{t+1}$; hyper-parameter $\alpha$  controls the trade-off between two objectives. We note that enforcing pairwise invariance as in objective (\ref{eq:obj}) does not imply global invariance (i.e., representations that are invariant across all domains). It is because we use distinct mappings for different pairs of domains. In particular, $D_{t-1}$ and $D_{t}$ are aligned by two mappings $f_{t-1}$ and $g_{t-1}$ while $D_{t}$ and $D_{t+1}$ are aligned by two mappings $f_{t}$ and $g_{t}$.

Next, we present the detailed architecture of the \textit{representation network} and the \textit{classification network}. In practice, the representation mappings are often complex (e.g., ResNet~\citep{he2016deep} for image data, Transformer~\citep{vaswani2017attention} for text data), then explicitly capturing the evolving of these mappings is challenging. We surpass this bottleneck by capturing the evolving of representation space induced by these mappings instead. Formally, our \textit{representation network} consists of an encoder $\enc$ which maps from input to representation spaces, and Transformer layer $\trans$ which learns the non-stationary pattern from a sequence of source domains using attention mechanism. Given the batch sample $\mathcal{B} := \{x_t, y_t \}_{t \leq T}$ from $T$ source domains where $\{x_t, y_t \} = \{x^j_t, y^j_t\}_{j=1}^n$ are samples for domain $D_t$, the encoder first maps each input $x_t^j$ to a representation $z_t^j = \enc \left( x_t^j \right)$, $\forall t \leq T, j \leq n$. Then, Transformer layer $\trans$ is used to generate representation $\widehat{z}_t^j$ from the sequence $z_{\leq t}^j = \left[z_1^j, z_2^j, \cdots, z_t^j  \right]$. Specifically, $\forall j, t$, $\trans$ leverages four feed-forward networks $Q, K, V, U$ to compute $\widehat{z}_t^j$ as follow:    
\begin{align}  \widehat{z}_t^j &= \left(a^j_{\leq t}\right)^\top V\left(z_{\leq t}^j\right) + U\left(z_t^j\right) \nonumber \\ & ~~~\text{with}~~~    a^j_{\leq t} = \frac{K\left(z_{\leq t}^j\right)^\top Q\left(z_{t}^j\right)}{\sqrt{d}}
\end{align}
% In our design, the computational path from $x_{t+1}^j$ to $z_{t+1}^j$ is considered as {$g_t$} ($z_{t+1}^j = g_t( x_{t+1}^j )$) and the computational path from $x_t^j$ to $\widehat{z}_t^j$ for all $j$ is considered as $f_{t}$. 

It is worth pointing out that we do not assume data are aligned across domain sequence. In particular, due to randomness in data loading, there is no alignment between the $j^{th}$ sample in domain $D_{t-1}$ and the $j^{th}$ sample in domain $D_t$. In our design, the computational paths from $x_{t+1}^j$ to $z_{t+1}^j$ and from $x_t^j$ to $\widehat{z}_t^j$ are considered as $g_t$ and $f_{t}$, respectively. In particular, $z_{t+1}^j = g_t( x_{t+1}^j )$ and $\widehat{z}_{t}^j = f_t( x_{t}^j )$. The main goal of this design is as follows: By incorporating historical data into computation, representation space constructed by $f_{t}$ can capture evolving pattern across domain sequence. However, this design requires access to the historical data during inference which might not be feasible in practice. To avoid it, we enforce $g_t$, which obviates the need to access historical data, to mimic representation space constructed by $f_{t}$.

% To this end, our model constructs a sequence of $\{f_{t}\}_{t < T}$ with each $f_{t}$ a function learned from the domain sequence $\{ D_{i} \}_{i \leq t}$ (i.e., $f_{t}$ evolves over $t$), while it constructs $g_{t}= g, \forall t$ (i.e., $g_{t}$ is fixed across domains). This design serves two purposes in \textit{representation network}: (i) a sequence of $\{f_{t}\}_{t < T}$ learned from source domains can sufficiently capture the non-stationary patterns; (ii) the fixed $g_{t}= g$ enables the inference for any target domains $D_{{t}'}, \forall t'> T$ without the need to access to the whole domain sequence $\{ D_{i} \}_{i \leq {t}'}$. 

As shown in Remark~\ref{thm-5}, our goal is to enforce invariant representation constraint (i.e., $\mathcal{L}_{inv}^t$ in objective (\ref{eq:obj})) for every pair of two consecutive domains $D_t,D_{t+1}$ constructed by $f_{t}$ and $g_t$ instead of learning a network that achieves invariant representations for all source domains together. Thus, the representations constructed by $f_{{t}}$ and $g_{{t}}$ might not be aligned with the ones constructed by $f_{{t}'}$ and $g_{{t}'}$. After mapping data to the representation spaces, the \textit{classification network} are used to generate the classifier sequence $H$. Due to the simplicity of the classifier in practice (i.e., 1 or 2-layer network), we leverage long short-term memory~\citep{hochreiter1997long} $\lstm$ to explicitly capture the evolving of the classifier over domain sequence. Specifically, the weights of previous classifiers $h_{< t} = \left[h_1, h_2, \cdots, h_{t-1} \right]$ are vectorized and put into $\lstm$ to generate the weights of $h_{t}$.
% To achieve it, we  leverage long short-term memory layer~\citep{hochreiter1997long} $\lstm$ to generate $h_t$ from the sequence of previous classifier $h_{< t} = \left[h_1, h_2, \cdots, h_{t-1} \right]$. In particular, the weights of $h_{<t}$ are vectorized and put into $\lstm$ to generate the weights of $h_{t}$.

Because $f_t$, $g_t$, $h_t$ $\forall t < T$ are functions of the \textit{representation network} and \textit{the classification network}, the weights of these two networks are updated using the backpropagated gradients for objective (\ref{eq:obj}). The pseudo-code of the complete learning process for \texttt{AIRL} is shown in Algorithm \ref{alg:train}. Next, we present the details of each loss term used in optimization.

\textbf{Prediction loss $\mathcal{L}_{cls}^t$:} We adopt cross-entropy loss for classification tasks. Specifically, $\mathcal{L}_{cls}^t$  for the optimization over domains $D_{t}$,  $D_{t+1}$ is defined as follows.
\begin{align}\label{eq:cls}
    \mathcal{L}_{cls}^t &= \mathbb{E}_{{D^W_t}} \left [ -\log \left (\frac{h_t(f_t(X))_Y}{\sum_{{y}' \in \mathcal{Y}} h_t(f_t(X))_{{y}'}} \right ) \right ] \nonumber \\ 
    &+ \mathbb{E}_{{D_{t+1}}} \left [ -\log \left( \frac{h_t\left(g_t(X)\right)_Y}{\sum_{{y}' \in \mathcal{Y}} h_t\left(g_t(X)\right)_{{y}'}} \right) \right ]
\end{align}

\textbf{Invariant representation constraint $\mathcal{L}^{t}_{inv}$:} It aims to minimize the distance between $f_{t} \sharp P_{D^W_{t}}^{Z|Y=y}$ and $g_{t} \sharp P_{D_{t+1}}^{Z|Y=y}$, $\forall y \in \mathcal{Y}$, two conditional distributions induced from domains $D^W_t$ and $D_{t+1}$ using representation mappings $f_t,g_t$, respectively. In other words, for any inputs $X$ from domains $D^W_t$ and ${X}'$ from $D_{t+1}$  whose labels are the same, we need to find representation mappings $f_t, g_t$ such that the representations $f_t(X),g_t(X')$ have similar distributions. Inspired by correlation alignment loss \citep{sun2016deep}, we enforce this constraint by using the following as loss $\mathcal{L}^{t}_{inv}$:
\begin{align}\label{eq:inv}
\mathcal{L}^{t}_{inv} = \sum_{y \in \mathcal{Y}} \frac{1}{4d^2} \left\|C^y_t - C^y_{t+1} \right\|^2_{F}
\end{align}
where $d$ is the dimension of representation space $\mathcal{Z}$, $\|\cdot\|^2_F$ is the squared matrix Frobenius norm, and $C^y_t$ and $C^y_{t+1}$ are covariance matrices defined as follows:
\begin{align}
C^y_t &= \frac{1}{n^t_y - 1} \left (  f_t \left( \mathbf{X}^y_{t} \right)^\top f_t \left( \mathbf{X}^y_{t} \right) \right. \nonumber \\
&\left. - \frac{1}{n^t_y} \left ( \mathbf{1}^\top f_t \left( \mathbf{X}^y_{t} \right)  \right )^\top \left ( \mathbf{1}^\top f_t \left( \mathbf{X}^y_{t} \right)  \right ) \right) \\
C^y_{t+1} &= \frac{1}{n^{t+1}_y - 1} \left(  g_t \left( \mathbf{X}^y_{t+1} \right)^\top g_t \left( \mathbf{X}^y_{t+1} \right) \right. \nonumber \\
&\left. - \frac{1}{n^{t+1}_y} \left ( \mathbf{1}^\top g_t \left( \mathbf{X}^y_{t+1} \right)  \right )^\top \left ( \mathbf{1}^\top g_t \left( \mathbf{X}^y_{t+1} \right)  \right ) \right)
\end{align}
where $\mathbf{1}$ is the column vector with all elements equal to $1$, $\mathbf{X}^y_t = \{ x_i: x_i \in D^W_{t},  y_i=y\}$ is the matrix whose columns are $\{x_i\}$, $f_t$ and $g_t$ are column-wise operations applied to $\mathbf{X}^y_t$ and $\mathbf{X}^y_{t+1}$, respectively, and $n^t_y$ is cardinality of $\mathbf{X}^y_t$.

\begin{algorithm2e}[t]
\SetAlgoLined
\KwIn{Training datasets from $T$ source domains $\{D_t\}_{t=1}^T$, \textit{representation network} = $\{\enc$, $\trans\}$, \textit{classification network} = $\{ \lstm, h_1 \}$, $\alpha$, $n$}
\KwOut{Trained $\enc, \trans, \lstm, h^{\ast}_1$}

\setcounter{AlgoLine}{0}
$L_{inv} = 0, L_{cls} = 0$ \\
\tcc{Estimate $\{w^t_y\}_{y \in \mathcal{Y}, t < T}$  for important weighting}
\For{$t=1:T-1$}{
\For{$y \in \mathcal{Y}$}{
$w^t_y = P^{Y=y}_{D_{t+1}} / P^{Y=y}_{D_{t}}$
}
}
\tcc{Learn weights for $\enc, \trans, \lstm$}
\While{learning is not end}{
Sample batch $\mathcal{B} = \left\{ x_t, y_t \right\}_{t=1}^T \sim \{D_t\}_{t=1}^T$ where $\left\{ x_t, y_t \right\} = \left\{ x^j_t, y^j_t \right\}_{j=1}^n$\\
$z_{1} = \enc\left( x_{1} \right)$ \\
\For{$t=1:T-1$}{
    $z_{t+1} = \enc\left( x_{t+1} \right)$ \\
    $\widehat{z}_t = \trans\left( z_{\leq t} \right) $ \\
    $\{\widehat{z}_t(w), y_t(w)\} =$ Reweight $\{\widehat{z}_t, y_t\}$ with $w^t = \{w^t_y\}_{y \in \mathcal{Y}}$ \\
    % \For{$y=1:|\mathcal{Y}|$}{
    % $\widehat{z}_t(y) = \{\widehat{z}^j_t: y^j_t = y  \}$ \\
    % Reweight $\widehat{z}_t(y)$ with $w_y$ \\
    % }
    Calculate $L_{inv}^t$ from $\widehat{z}_t(w), z_{t+1}$ by Eq. (\ref{eq:inv}) \\
    $L_{inv} = L_{inv} + L_{inv}^t$ \\
    \If{$t>1$}{
    $h_t = \lstm \left( h_{<t} \right)$ \\
    }
    Calculate $L_{cls}^t$ from $y_t(w), y_{t+1}, h_t\left( \widehat{z}_t(w) \right), h_t\left( z_{t+1} \right)$ by Eq. (\ref{eq:cls}) \\
    $L_{cls} = L_{cls} + L_{cls}^t$ \\
}
Update $\enc, \trans, \lstm, \widehat{h}_1$ by optimizing $L_{inv} + \alpha L_{cls}$ \\
}
\caption{Learning process for \texttt{AIRL}}
\label{alg:train}
\end{algorithm2e}

\subsection{Inference}

At the inference stage, the well-trained \textit{representation network} and \textit{classification network} can be used to make predictions about input $x$ from target domain sequence $\left\{D_{t}\right\}_{t=T+1}^{T+K}$. In particular, we first map input $x$ in domain $D_t$ to representation $z$ using the encoder $\enc$ in the \textit{representation network} (i.e., $g^{\ast}_{t-1}$). Then the \textit{classification network} (i.e., LSTM) is used sequentially to generate $h^{\ast}_{t-1}$ from the sequence of classifiers $\left [h^{\ast}_{1}, \cdots, h^{\ast}_{t-2} \right]$, and the prediction about $z$ can be made by $h^{\ast}_{t-1}$. Note that at the learning stage, both $g^{\ast}_{t-1}$ and $f^{\ast}_{t}$ are used to map input $x$ from domain $D_t$ to the representation space while at the inference stage, only $g^{\ast}_{t-1}$ is needed for target domain $D_{t}$ (we do not use $f^{\ast}_{t}$ because it requires access to data from all domains $\{D_{t'}\}_{t' \leq t}$ which generally are not available during inference). The complete inference process is shown in Algorithm \ref{alg:test}.

\begin{algorithm2e}
\label{alg:fatdm}
\SetAlgoLined
\KwIn{Testing dataset from target domain $D_t (t \in \{ T+1, \cdots, T+K \})$, trained $\enc, \lstm, h^{\ast}_1$}
\KwOut{Predictions for testing dataset}

\setcounter{AlgoLine}{0}[t]

\For{$t'=2:(t-1)$}{
$h^{\ast}_{t'} = \lstm \left ( h^{\ast}_{< t'} \right )$ 
}
\While{inference is not end}{
Sample batch $\mathcal{B} = x_{t} \sim D_{t}$ \\
$z_{t} = \enc \left( x_{t} \right)$ \\
Generate predictions $h^{\ast}_{t-1} \left ( z_{t} \right )$ \\
}

\caption{Inference process for \texttt{AIRL}}
\label{alg:test}
\end{algorithm2e}