# Learned planners

## Installation

The repository can be installed with pip in a python3.11 virtual environment:

```bash
python3.11 -m venv .venv
source .venv/bin/activate
pip install -e MambaLens -e farconf -e gym-sokoban -e stable-baselines3 -e 'learned-planners[torch,dev-local]'
```

## Generalization to larger level video

This command will play the bigger level X-sokoban-31 shown in the paper using DRC(3,3) with the actions predicted using probes. This should run in under a minute and will create a `level_22_031.mp4`.

```bash
python learned-planners/plot/play_bigger_levels.py
```

## Model, Probes, and SAEs

The model is available in the `drc_model` directory. The probes are available in the `probes` directory. The trained SAEs are available in the `sae` directory.

## Loading the model

You can load the model using the following code:

```python
from cleanba.environments import BoxobanConfig
from cleanba import cleanba_impala
from learned_planner.interp.utils import jax_to_th, load_jax_model_to_torch

MODEL_PATH = MODEL_BASE_PATH / MODEL_PATH_IN_REPO

env = BoxobanConfig(
    cache_path=BOXOBAN_CACHE,
    num_envs=1,
    max_episode_steps=120,
    asynchronous=False,
    tinyworld_obs=True,
).make()
jax_policy, carry_t, jax_args, train_state, _ = cleanba_impala.load_train_state(MODEL_PATH, env)
```

This repository provides the PyTorch implementation of the DRC network compatible with MambaLens for doing interpretability research. You can load the model using the following code:

```python
from learned_planner.interp.utils import load_jax_model_to_torch

cfg_th, policy_th = load_jax_model_to_torch(MODEL_PATH, env)
```

## Reproducing paper results

The behaviroal results from the paper can be reproduced using the `behavior_analysis.py` script:

```bash
python plot/behavior_analysis.py
```

This script will generate the plots in the `plots` directory.
