Metadata-Version: 2.1
Name: learned-planners
Version: 0.1.0
Summary: Investigate learned planners
Requires-Python: >=3.10
Description-Content-Type: text/markdown
Provides-Extra: torch
Provides-Extra: dev
Provides-Extra: dev-local

# Learned planners

## Installation

The repository can be installed with pip:

```bash
pip install -e .
```

## Model, Probes, and SAEs

The model is available in the `drc_model` directory. The probes are available in the `probes` directory. The trained SAEs are available in the `sae` directory.

## Loading the model

You can load the model using the following code:

```python
from cleanba.environments import BoxobanConfig
from cleanba import cleanba_impala
from learned_planner.interp.utils import jax_to_th, load_jax_model_to_torch

MODEL_PATH = MODEL_BASE_PATH / MODEL_PATH_IN_REPO

env = BoxobanConfig(
    cache_path=BOXOBAN_CACHE,
    num_envs=1,
    max_episode_steps=120,
    asynchronous=False,
    tinyworld_obs=True,
).make()
jax_policy, carry_t, jax_args, train_state, _ = cleanba_impala.load_train_state(MODEL_PATH, env)
```

This repository provides the PyTorch implementation of the DRC network compatible with MambaLens for doing interpretability research. You can load the model using the following code:

```python
from learned_planner.interp.utils import load_jax_model_to_torch

cfg_th, policy_th = load_jax_model_to_torch(MODEL_PATH, env)
```

## Reproducing paper results

The behaviroal results from the paper can be reproduced using the `behavior_analysis.py` script:

```bash
python plot/behavior_analysis.py
```

This script will generate the plots in the `plots` directory.
