# Interpreting Learned Search

## Installation

The repository can be installed with pip in a python3.11 virtual environment:

```bash
python3.11 -m venv .venv
source .venv/bin/activate
pip install -e MambaLens -e farconf -e trained-learned-planners -e gym-sokoban -e stable-baselines3 -e 'learned-planners[torch,dev-local]'
```

## Model, and dataset

The DRC(3, 3) model analyzed in the paper is available in the `drc_model` directory.

You can load the model using the following code:

```python
from cleanba.environments import BoxobanConfig
from cleanba import cleanba_impala
from learned_planner.interp.utils import jax_to_th, load_jax_model_to_torch

MODEL_PATH = MODEL_BASE_PATH / MODEL_PATH_IN_REPO

env = BoxobanConfig(
    cache_path=BOXOBAN_CACHE,
    num_envs=1,
    max_episode_steps=120,
    asynchronous=False,
    tinyworld_obs=True,
).make()
jax_policy, carry_t, jax_args, train_state, _ = cleanba_impala.load_train_state(MODEL_PATH, env)
```

This repository provides the PyTorch implementation of the DRC network compatible with MambaLens for doing interpretability research. You can load the pytorch model using the following code:

```python
from learned_planner.interp.utils import load_jax_model_to_torch

cfg_th, policy_th = load_jax_model_to_torch(MODEL_PATH, env)
```

## Reproducing paper results

### Channel groups

We interpret and group all the channels in DRC(3, 3) into five categories based on their functionality: `box`, `agent`, `Misc plan`, `entity`, and `no-label`. These groups are defined in the `learned_planners/interp/channel_group.py` file, that contains channel description, whether they store long- or short-term information, and their sign of activation. Each channel in the model is spatially-offset and we provide the offset values in the `learned_planners/interp/offset_fns.py` file.

### Causal intervention on channel groups

We verify our interpretation of the channel groups by performing causal intervention on the channel groups. The script `ci_score_direction_channel.py` performs the causal intervention on the channel groups. `--channel_type` can be `box`, `agent`, `box_agent`, `nfa`, or `mpa`. See its arguments for more details.

```bash
python learned-planners/plot/interp/channels/ci_score_direction_channel.py --channel_type box_agent
```

### Short and Long-term channel AUC

We verify the short- and long-term channel storing future actions at different horizon length by checking their AUC score for predicting actions `N` steps in the future. The script `future_accuracy_channels.py` performs this analysis.

```bash
python learned-planners/plot/interp/channels/future_accuracy_channels.py
```

### Kernel visualization

The kernel visualization is done using the `kernel_visualization.py` script. The script will generate the plots for the linear, turn plan extension kernels and the winner takes all kernel figure in the `new_plots` directory.

```bash
python learned-planners/learned_planners/notebooks/kernel_visualization.py
```

### Turn stabilization ablation

The headline figure with the two paths and the figure showing the ablation of turn kernels is generated using the `turn_stabiliization_ablation.py` script.

```bash
python learned-planners/learned_planners/notebooks/turn_stabiliization_ablation.py
```

### Backtracking mechanism

The backtracking mechanism is visualized using the `backtracking_mechanism.py` script.
The quantitative results are generated using the `backtrack_quant.py` script.

```bash
python learned-planners/learned_planners/notebooks/backtracking_mechanism.py
python learned-planners/learned_planners/notebooks/backtrack_quant.py
```

### Plan stopping

The plan stopping signals are visualized using the `plan_stopping.py` script.

```bash
python learned-planners/learned_planners/notebooks/plan_stopping.py
```