\section{Predictive State Distribution}
\label{sec:predict}

% Problem we address in this section.
We now focus on the following problem: given a (trained) discrete-time model, an initial state distribution $\bm{\pi}_0$ and a time horizon $T$, predict the marginal state distribution after $T$ steps, $\bm{\pi}^\star_T$.
This distribution is the multistate equivalent of the survival distribution in survival analysis, and it is of central importance in many applications.
For example, in a disease progression model, it can be used to predict the number of patients that have recovered after a given time, irrespective of the patients' particular trajectories.
% Markov processes.
In the case of Markov chains there is an efficient algorithm to compute the state distribution exactly, with running time $O(N^2 T)$.
Given a transition matrix $\bm{\Theta}$ and starting from $\bm{\pi}^\star_0 = \bm{\pi}_0$, we can use the identity $\bm{\pi}^\star_t = \bm{\pi}^{\star \Tr}_{t-1} \bm{\Theta}$ iteratively $T$ times to obtain $\bm{\pi}^\star_T$.
The identity is a consequence of the Markov property.
To the best of our understanding there is no similar iterative procedure applicable to our model in the general case.

A special case of practical importance is when the directed graph of admissible transitions has no cycles of length greater than one
(in other words, the graph is acyclic except for self-loops).
In this case, we can derive a simple iterative procedure with running time quadratic in $T$.
\begin{proposition}
\label{thm:dagpredict}
Let $(\bm{A}, \bm{B})$ be any generalized Dirichlet mixture of Markov chains on a graph $\mathcal{G} = ([N], \mathcal{E})$, and let $\bm{\pi}_0$ be an initial state distribution.
If $\mathcal{G}$ has no cycle of length greater than one, then $\bm{\pi}^\star_T$ can be computed exactly in time $O(T^2 N^2)$.
\end{proposition}
We provide an explicit algorithm as well as complete proofs of the results presented in this section in Appendix~\ref{app:predict}.
By way of example, the transition graph underpinning the \textsc{ebmt} dataset, depicted in Figure~\ref{fig:chains}, satisfies the condition of the proposition.

In the general case where arbitrary cycles are allowed, it is still possible to compute $\bm{\pi}^\star_T$ exactly with running time polynomial in $T$, but the algorithm is impractical for all but the simplest cases (see Appendix~\ref{app:predict}).
However, the specific structure of our mixture model suggests an effective sampling-based approach.
The key observation is that, conditioned on a transition matrix $\bm{\Theta}$ sampled from the mixture distribution, we can compute the predictive state probability after $T$ steps exactly by using the efficient recursive algorithm for Markov chains.
Therefore, we propose to estimate the state distribution by averaging samples obtained as follows.
First, sample $\bm{\Theta}$ from the mixture distribution, and then compute the exact state distribution conditional on $\bm{\Theta}$.
We formalize this procedure in Algorithm~\ref{alg:predict}.

\begin{algorithm}[t]
  \caption{Predictive state distribution.}
  \label{alg:predict}
  \begin{algorithmic}[1]
    \Require $\bm{A}, \bm{B}$, horiz. $T$, init. dist. $\bm{\pi}_0$, \# samples $L$
    \For{$\ell = 1, \ldots, L$}
      \State $\bm{\Theta} \gets$ sample from $\prod_i \mathrm{GDir}(\bm{\theta}_i \mid \bm{\alpha}_i, \bm{\beta}_i)$
      \State $\bm{\pi}_{\ell,0} \gets \bm{\pi}_0$ \label{line:start}
      \For{$t = 1, \ldots, T$}
        %\Comment{Update items.}
        \State $\bm{\pi}_{\ell,t} \gets \bm{\pi}_{\ell,t-1}^\Tr \bm{\Theta}$ \label{line:end}
      \EndFor
    \EndFor
    \State \Return $\hat{\bm{\pi}}_T = \frac{1}{L} \sum_\ell \bm{\pi}_{\ell,T}$
  \end{algorithmic}
\end{algorithm}

A natural question to ask is:
How many samples are necessary to achieve a desired level of accuracy?
We answer this question with a proposition that provides an upper bound on the sample complexity.
\begin{proposition}
\label{thm:predict}
For any $\bm{A}, \bm{B}$, horizon $T$, and initial distribution $\bm{\pi}_0$, let $\hat{\bm{\pi}}_T$ be the output of Algorithm~\ref{alg:predict}.
Then, for any $\epsilon, \delta > 0$, we have
$\mathbf{P}[\lVert \hat{\bm{\pi}}_T - \bm{\pi}^\star_T \rVert < \varepsilon] > 1 - \delta$,
as long as $L > \frac{11}{\varepsilon^2} \log \frac{N+1}{\delta}$.
\end{proposition}

A useful viewpoint is to think of our sampling scheme as approximating an infinite mixture model with a finite mixture of $L$ components.
In contrast to \emph{learning} a finite mixture model, where large values of $L$ can lead to overfitting, increasing $L$ in Algorithm~\ref{alg:predict} can only increase the accuracy of the resulting estimate.
In Appendix~\ref{app:predict}, we run experiments comparing Algorithm~\ref{alg:predict} to a naive scheme that directly samples trajectories, instead of sampling $\bm{\Theta}$ and averaging over all possible sequences as we do.
We find that our algorithm is significantly more efficient.

\paragraph{Continuous-Time Model.}
We can adapt the same idea to the continuous-time setting as follows.
The exact state distribution of a CTMC at time $T$ is given by $\bm{\pi}^\star_T = \bm{\pi}_0^\Tr e^{T \bm{\Lambda}}$, where the matrix exponential $e^{\bm{X}}$ can be well-approximated by a simple iterative procedure.
We adapt lines~\ref{line:start}--\ref{line:end} of Algorithm~\ref{alg:predict} accordingly.
%The full algorithm and a discussion are provided in Appendix~\ref{app:predict}.
