
\input{figure/fig_method}
In this section, we describe each component of our method in detail. 
Overall framework of our method is illustrated in \Cref{fig:method}.
We first describe a state-conditioned action abstraction, a set of sub-actions relevant to the transition on the current state, and an auxiliary network that infers such relationships (\Cref{sec:method-csi}). We then describe the training of the latent dynamics model with the auxiliary network (\Cref{sec:method-training}), as depicted in \Cref{fig:method_training}. Finally, we combine MCTS with state-conditioned action abstraction (\Cref{sec:method-inference}), as depicted in \Cref{fig:method_mcts}.


\subsection{State-conditioned action abstraction}
\label{sec:method-csi}
As described earlier, our method learns compositional structure between the state and action variables so as to reduce the search space of MCTS. For this, we devise a \csi{} that infers action variables irrelevant of the transition from the current state. Importantly, it operates in the latent space to deal with high-dimensional observations.

First, an encoder $f$ maps the observation (i.e., image) to the latent state representation, i.e., $z = f(s)$. The \csi{} $h$ then infers from $z$ as:
\begin{equation}\label{eq:h}
h(z) = [p_z^{1}, \cdots, p_z^{n}] \in [0, 1]^n,
\end{equation}
where each entry $p_z^{i}$ is the parameter of the Bernoulli distribution. The mask is then sampled from $h(z)$ as:
\begin{equation}
M(z) = [m_z^{1}, \cdots, m_z^{n}]\in \{0, 1\}^n,
\end{equation}
where $m_z^{i} \sim \text{Bernoulli}(p_z^{i})$ for all $i \in [n]$.
Here, the action variable $A^i$ is relevant for the state transition if $m_z^i = 1$; otherwise, it is irrelevant. Based on this, we construct a state-conditioned action abstraction which is defined as:
\begin{equation}\label{eq:action-abstraction-train}
    \phi_{z}(A) = \{A^i \mid m_z^i = 1\} \subseteq A.
\end{equation}
It is worth keeping in mind that the abstraction depends on the current state. This represents the CSI relationship as:
\begin{equation}
\label{eq:maskcsi}
S' \Perp \phi^c_z(A)\mid S=s, \phi_z(A), 
\end{equation}
where $\phi^c_z(A) \coloneq A\setminus \phi_z(A)$. For example, in the case of 3 action variables $A=[A^1, A^2, A^3]$ with the inferred mask $M(z)=[1, 1, 0]$, the inferred CSI is $S' \Perp A^3 \mid S = s, \{A^1, A^2\}$ and the abstract action is $\phi_z(A) = [A^1, A^2]$. We denote $\phi_z(\gA)$ as the abstract action space, e.g., $\phi_z(\gA) = \gA^1 \times \gA^2$, and denote $\phi_z(a)$ as the value of $\phi_z(A)$, e.g., if $a=[a^1, a^2, a^3]$, then $\phi_z(a) = [a^1, a^2]\in \phi_z(\gA)$.

Such auxiliary network is utilized to uncover CSI relations when true variables are fully observable in low-dimension \citep{hwang2023on}. However, it is unclear how to capture them from high-dimensional observation. We proceed to describe how to learn the \csi{} $h$ that induces the state-conditioned action abstraction $\phi_z(A)$ which adheres to \Cref{eq:maskcsi}.


\subsection{Training Latent Dynamics Model}
\label{sec:method-training}
A CSI relationship in \Cref{eq:maskcsi} implies that the abstract action $\phi_z(a)$ is sufficient for predicting the future state:
\begin{equation}
\label{eq:maskcsi_eq}
p(s'\mid s, a) = p(s'\mid s, \phi_{z}(a)).
\end{equation}
Thus, we train the latent dynamics model $g$ to use the abstraction action for prediction, i.e., $\hat{z}_{t+1} = g(z_t, \phi_{z_t}(a_t))$. We employ $K$-step reconstruction loss to jointly train the latent dynamics model and \csi{} as \Cref{fig:method_training}:
\begin{align}\label{eq:loss_recon}
\mathcal{L}_{recon}(s_t) = \frac{1}{K}\sum_{k=1}^K \bigg[ &\|s_{t+k} - \texttt{Dec}(\hat{z}_{t+k})\|_2^2 \notag \\
&+ \lambda \| M(\hat{z}_{t+k-1})\|_1 \bigg], 
\end{align}
where $\hat{z}_{t+k} = g(\hat{z}_{t+k-1}, \phi_{\hat{z}_{t+k-1}}(a_{t+k-1}))$ and $\hat{z}_t = z_t = f(s_t)$. $\lambda$ is a sparsity coefficient, which is a hyperparameter. Intuitively, the regularized reconstruction loss encourages the models to accurately predict the future state by using only necessary action variables, i.e., $\phi_z(a)$. This allows us to learn the compositional relationships between the current state and action variables from high-dimensional observations without knowing the true environment model. 


Since $M(z) = [m_z^1, \cdots, m_z^n]$ is not differentiable with respect to $z$ due to the sampling $m_z^i \sim \text{Bernoulli} (p_z^i)$, we use Straight-Through Gumbel-Softmax estimator \citep{maddison2016concrete,jang2016categorical}:
\begin{equation*}
\sigma\left(\frac{1}{\beta}(\log p_z^i-\log (1-p_z^i)+\log u-\log (1-u))\right),
\end{equation*}
where $\sigma$ is the sigmoid function, $u\sim \text{Unif}(0, 1)$, and $\beta$ is a temperature. Intuitively, $h(z)=[p_z^1, \cdots, p_z^n]$ is trained to assign a high probability to the sub-action that is necessary for predicting the future state. This allows us to update the \csi{} $h$ with the reconstruction loss with regularization in \Cref{eq:loss_recon}.


\subsection{Complete method: MCTS with State-Conditioned Action Abstraction}
\label{sec:method-inference}

We propose MCTS using abstract action $\phi_z(a)$ for each node $z$, instead of $a$, reducing the search space exponentially with respect to the number of sub-actions masked out. An overall framework is illustrated in \Cref{fig:method_mcts}.

\paragrapht{Deterministic abstraction.}
State-conditioned action abstraction (\Cref{eq:action-abstraction-train}) involves the sampling from a Bernoulli distribution. For the inference, we use a deterministic abstraction with the threshold $\tau$, which is a hyperparameter:
\begin{equation}\label{eq:action-abstraction-inference}
    \phi_z(A) = \{A^i \mid p_z^i > \tau\} \subseteq A.
\end{equation}

\paragrapht{Selection.} At each node $z$, an abstract action is selected as:
\begin{align}
\widehat{\phi}_z(a) = \argmax_{\phi_z{(a)}} &  \bigg[ Q(z, {\phi_z{(a)}}) + \\ 
c & \cdot \pi_\theta(z, {\phi_z{(a)}}) \frac{\sqrt{\sum_b N(z, b)}}{1+N(z, {\phi_z{(a)}})} \bigg], \notag
\end{align}
where the abstract action $\phi_z(a)$ is the key difference compared to the vanilla action selection of MuZero in \Cref{eq:puct}. Here, the policy prior $\pi_\theta(z, a)$ is marginalized over the actions $a'$ that are projected to the same abstract action $\phi_z(a)$:
\begin{equation}
\pi_\theta(z, {\phi_z{(a)}}) 
= \sum_{\{b\in \mathcal{A}\mid \phi_z(b)=\phi_z(a)\} } \pi_\theta(z, b).
\end{equation}
For example, if $A=[A^1, A^2, A^3]$ where the action variables are binary, $\phi_z(A)=[A^1, A^2]$, and $\phi_z(a)=(0, 0)$, then we are marginalizing over the third dimension: $\pi_\theta(z, \phi_z(a))=\pi_\theta(z, (0,0,0)) + \pi_\theta(z, (0,0,1))$. Note that the modeling of the policy prior as $\pi_\theta(z, a)$ instead of $\pi_\theta(z, \phi_z(a))$ is the design choice for simplicity since the abstraction depends on the current state, making the dimension of $\phi_z(a)$ varies across different states.

\paragrapht{Expansion and backup.} 
If there is no child node corresponding to the selected action $\widehat{\phi}_z(a)$, the latent dynamics model predicts the subsequent latent state $z^{\prime} = g(z, {\widehat{\phi}_z(a)})$ and adds it to the search tree as a child node of the current node $z$. The rest of the procedures are identical to MuZero.

\paragrapht{Final action selection at the root node.} After the simulations, the final action is selected based on the visit distribution $\hat{\pi}(z, \phi_z(a))$, which is the normalized visit count for each (abstract) action from the root node $z$. We unfold the visit distribution to the original action space $A$ as:
\begin{align}
\hat{\pi}(z, a) 
&=\hat{\pi}(z, \phi_z(a)) \times u(\phi_z^c(a)),
\end{align}
where $u(\phi_z^c(a))$ represents the uniform distribution over action variables $\phi_z^c(a)$. This provides diverse state-action samples for robust training of the auxiliary network.

\paragrapht{Training.}
All components of our method are jointly trained in an end-to-end fashion. The \csi{} is trained only with the reconstruction loss to faithfully represent the dynamics transition in \Cref{eq:maskcsi_eq}. The remaining components are trained with the combination of policy, value, reward, and reconstruction losses, similar to MuZero as described in \Cref{sec:preliminary_mcts}.