

# Emergence of Spatial Representation in an Actor-Critic Agent with Prewired Hippocampal Sequences

This repository is the official implementation of [Emergence of Spatial Representation in an Actor-Critic Agent with Prewired Hippocampal Sequences](no_url_for_anonymity). 

<!-- >📋  Optional: include a graphic explaining your approach/main result, bibtex entry, link to demos, blog posts and tutorials -->

## Requirements

### install python 3.8



### sample factory

This folder contains a modified version of sample factory

Install custom sample factory:

```setup
pip install -e .
```
the sample factory authors provided a [pre-built wheel](https://drive.google.com/file/d/1hAKAkl85HE8JsHXfXbdkF0CrLdiGyuoL/view) for 
deepmind lab python package. 
To install:
```
pip install deepmind_lab-1.0-py3-none-any.whl
```

then copy the content from `to_deepmindlab/` to the deepmind_lab folder in your environment, e.g. `​miniforge3/​envs/<env_name>/​lib/​python3.8/​site-packages/​deepmind_lab/​`
- this step adds the custom maps ("openfield_*") and custom model for transparent goal object.
```
cd to_deepmindlab/assets
zip -ur <home>/​miniforge3/​envs/<env_name>/​lib/​python3.8/​site-packages/​deepmind_lab/baselab/assets.pk3 models/
cd ..
cp -r game_scripts/ <home>/​miniforge3/​envs/<env_name>/​lib/​python3.8/​site-packages/​deepmind_lab/baselab/
```


However, building deepmind lab python package depends on the specific environment and settings on your machine. If it doesn't work, follow the instruction [below](#deepmind-lab)




### Deepmind Lab

clone deepmind lab from the the [official repo](https://github.com/google-deepmind/lab)

copy the content from `./to_deepmindlab` into deepmind lab folder `./lab`
- it contains custom .lua map files and custom transparent model for goal object


build and install pip wheels following the [official build doc](https://github.com/google-deepmind/lab/blob/master/python/pip_package/README.md)

**Note for bazel build:** The build rules are using a few compiler settings that are specific to GCC. If some flags are not recognized by your compiler (typically those would be specific warning suppressions), you may have to edit those flags. The warnings should be noisy but harmless.




### Install RL bindings for deepmind lab:
```
pip install dm_env
```

<!-- >📋  Describe how to set up the environment, e.g. pip/conda/docker commands, download datasets, etc... -->

## model architecture

The actor-critic architecture follows the original implementation in sample factory for dmlab-30, custom modules are defined here:

- dentate gyrus projection to CA3: [link code](sf_examples/dmlab/hippo2025_model.py#L142)
- CA3 sequences:  [link code](sf_examples/dmlab/hippo2025_model.py#L1021)
- misc. including bypass, depth sensor and one-hot encoding of reward location [link code](sf_examples/dmlab/hippo2025_model.py#L209)


## Pre-trained Models

The models (including milestones), full config files and results of evaluation that are used for the analyses in the paper can be downloaded from [this link](https://doi.org/10.5281/zenodo.15496416)
- [main exp](train_dir/Epi_Ins9_Depth_Fix3_Tr_AR_SS_DGBN243_Hippo16_L64_skip8_3pbt_64_rollout_Gamma99_epoch2/)
- [new reward location](train_dir/ResumeLoc2_Epi_Ins9_Depth_Fix3_Tr_AR_SS_DGBN243_Hippo16_L64_skip8_3pbt_64_rollout_Gamma99_epoch2/)
- [LSTM with same DG input](train_dir/Epi_Ins9_Depth_Fix3_Tr_AR_LSTM_DGBN243_Hippo16_skip8_3pbt_64_rollout_Gamma99_epoch2/)
- [LSTM with dense input](train_dir/Epi_Ins9_Depth_Fix3_Tr_AR_LSTM_DG_linear_Hippo16_skip8_3pbt_64_rollout_Gamma99_epoch2/)

put the downloaded files in `./train_dir/` to use the analyses notebook

## Training

### main exp

To train the models in the paper, run this command:
<details>
<summary>run command line </summary>

```train
python -m sf_examples.dmlab.train_hippo2025 --env=openfield_map2_fixed_loc3 --experiment=Replicating_main --async_rl=True --train_for_env_steps=1000000000 --gamma=0.99 --use_rnn=True --num_workers=12 --num_envs_per_worker=2 --num_epochs=2 --rollout=64 --recurrence=64 --batch_size=1536 --benchmark=False --max_grad_norm=0.0 --dmlab_renderer=hardware --decorrelate_experience_max_seconds=120 --encoder_conv_architecture=pretrained_resnet --encoder_conv_mlp_layers=256 --nonlinearity=relu --dmlab_extended_action_set=False --dmlab_one_task_per_worker=True --set_workers_cpu_affinity=True  --dmlab_use_level_cache=True  --num_policies=3 --pbt_replace_reward_gap=0.05 --pbt_replace_reward_gap_absolute=0.2 --pbt_period_env_steps=1000000 --pbt_start_mutation=5000000 --with_pbt=True --dmlab_one_task_per_worker=True --max_policy_lag=35 --use_record_episode_statistics=True --keep_checkpoints=10 --save_every_sec=120 --save_milestones_sec=4000 --decoder_mlp_layers 128 128 --Hippo_L=64 --env_frameskip=8 --with_pbt=True --dmlab_reduced_action_set=True --core_name=BypassSS --rnn_size=1149 --rnn_type=gru --DG_name=batchnorm_relu --learning_rate=0.0001 --fix_encoder_when_load=True --encoder_load_path=./models/best_000025288_203030528_reward_94.185.pth  --pbt_mix_policies_in_one_env=False --wandb_project=hippo2025 --worker_num_splits=2 --pbt_target_objective=lenweighted_score --with_number_instruction=True --save_best_metric=avg_z_00_openfield_map2_fixed_loc3_lenweighted_score --device=gpu --Hippo_n_feature=16 --number_instruction_coef=9 --DG_BN_intercept=2.43 --depth_sensor=True
```

</details>
optionally, add wandb logging to CLI:

```
--with_wandb=True --wandb_user=<your_id>
```


### LSTM exp

<details>

```
--env=openfield_map2_fixed_loc3 --experiment=Epi_Ins9_Depth_Fix3_Tr_AR_LSTM_DGBN243_Hippo16_skip8_3pbt_64_rollout_Gamma99_epoch2 --async_rl=True --train_for_env_steps=1000000000 --gamma=0.99 --use_rnn=True --num_workers=12 --num_envs_per_worker=2 --num_epochs=2 --rollout=64 --recurrence=64 --batch_size=1536 --benchmark=False --max_grad_norm=0.0 --dmlab_renderer=hardware --decorrelate_experience_max_seconds=120 --encoder_conv_architecture=pretrained_resnet --encoder_conv_mlp_layers=256 --nonlinearity=relu --dmlab_extended_action_set=False --dmlab_one_task_per_worker=True --set_workers_cpu_affinity=True  --dmlab_use_level_cache=True  --num_policies=3 --pbt_replace_reward_gap=0.05 --pbt_replace_reward_gap_absolute=0.2 --pbt_period_env_steps=1000000 --pbt_start_mutation=5000000 --with_pbt=True --dmlab_one_task_per_worker=True --max_policy_lag=35 --use_record_episode_statistics=True --keep_checkpoints=10 --save_every_sec=120 --save_milestones_sec=4000 --decoder_mlp_layers 128 128 --env_frameskip=8 --with_pbt=True --dmlab_reduced_action_set=True --core_name=BypassLSTM --rnn_size=347 --rnn_type=gru --DG_name=batchnorm_relu --learning_rate=0.0001 --fix_encoder_when_load=True --encoder_load_path=./models/best_000025288_203030528_reward_94.185.pth  --pbt_mix_policies_in_one_env=False --wandb_project=SF_dmlab --worker_num_splits=2 --pbt_target_objective=lenweighted_score --with_number_instruction=True --save_best_metric=lenweighted_score --device=gpu --Hippo_n_feature=16 --number_instruction_coef=9 --DG_BN_intercept=2.43 --depth_sensor=True --seed=42
```

</details>

### LSTM dense exp


<details>

```
python -m sf_examples.dmlab.train_hippo2025
--env=openfield_map2_fixed_loc3 --experiment=LSTM_dense_in --async_rl=True --train_for_env_steps=1000000000 --gamma=0.99 --use_rnn=True --num_workers=12 --num_envs_per_worker=2 --num_epochs=2 --rollout=64 --recurrence=64 --batch_size=1536 --benchmark=False --max_grad_norm=0.0 --dmlab_renderer=hardware --decorrelate_experience_max_seconds=120 --encoder_conv_architecture=pretrained_resnet --encoder_conv_mlp_layers=256 --nonlinearity=relu --dmlab_extended_action_set=False --dmlab_one_task_per_worker=True --set_workers_cpu_affinity=True  --dmlab_use_level_cache=True  --num_policies=3 --pbt_replace_reward_gap=0.05 --pbt_replace_reward_gap_absolute=0.2 --pbt_period_env_steps=1000000 --pbt_start_mutation=5000000 --with_pbt=True --dmlab_one_task_per_worker=True --max_policy_lag=35 --use_record_episode_statistics=True --keep_checkpoints=10 --save_every_sec=120 --save_milestones_sec=4000 --decoder_mlp_layers 128 128 --env_frameskip=8 --with_pbt=True --dmlab_reduced_action_set=True --core_name=BypassLSTM --rnn_size=347 --rnn_type=gru --DG_name=linear_relu --learning_rate=0.0001 --fix_encoder_when_load=True --encoder_load_path=./models/best_000025288_203030528_reward_94.185.pth  --pbt_mix_policies_in_one_env=False --wandb_project=SF_dmlab --worker_num_splits=2 --pbt_target_objective=lenweighted_score --with_number_instruction=True --save_best_metric=lenweighted_score --device=gpu --Hippo_n_feature=16 --number_instruction_coef=9 --DG_BN_intercept=2.43 --depth_sensor=True --seed=42
```

</details>

### changing reward location

duplicate the main exp folder and rename it to the one used below, e.g. `ResumeLoc2_Epi_Ins9_Depth_Fix3_Tr_AR_SS_DGBN243_Hippo16_L64_skip8_3pbt_64_rollout_Gamma99_epoch2`. Change also the experiment names in the config.json file under that folder.


<details>

```
python -m sf_examples.dmlab.train_hippo2025
--env=openfield_map2_fixed_loc2 --experiment=ResumeLoc2_Epi_Ins9_Depth_Fix3_Tr_AR_SS_DGBN243_Hippo16_L64_skip8_3pbt_64_rollout_Gamma99_epoch2 --async_rl=True --train_for_env_steps=1000000000 --gamma=0.99 --use_rnn=True --num_workers=12 --num_envs_per_worker=2 --num_epochs=2 --rollout=64 --recurrence=64 --batch_size=1536 --benchmark=False --max_grad_norm=0.0 --dmlab_renderer=hardware --decorrelate_experience_max_seconds=120 --encoder_conv_architecture=pretrained_resnet --encoder_conv_mlp_layers=256 --nonlinearity=relu --dmlab_extended_action_set=False --dmlab_one_task_per_worker=True --set_workers_cpu_affinity=True  --dmlab_use_level_cache=True  --num_policies=3 --pbt_replace_reward_gap=0.05 --pbt_replace_reward_gap_absolute=0.2 --pbt_period_env_steps=1000000 --pbt_start_mutation=5000000 --with_pbt=True --dmlab_one_task_per_worker=True --max_policy_lag=35 --use_record_episode_statistics=True --keep_checkpoints=10 --save_every_sec=120 --save_milestones_sec=4000 --decoder_mlp_layers 128 128 --Hippo_L=64 --env_frameskip=8 --with_pbt=True --dmlab_reduced_action_set=True --core_name=BypassSS --rnn_size=1149 --rnn_type=gru --DG_name=batchnorm_relu --learning_rate=0.0001 --fix_encoder_when_load=True --encoder_load_path=./models/best_000025288_203030528_reward_94.185.pth  --pbt_mix_policies_in_one_env=False --wandb_project=SF_dmlab --worker_num_splits=2 --pbt_target_objective=lenweighted_score --with_number_instruction=True --save_best_metric=avg_z_00_openfield_map2_fixed_loc2_lenweighted_score --device=gpu --Hippo_n_feature=16 --number_instruction_coef=9 --DG_BN_intercept=2.43 --depth_sensor=True
```

</details>

### training in batch

This shows the code to run experiments in batch using slurm on computing clusters.

The parameters are the ones used in Fig. 1A

you need to modify the [training_template.sh](training_template.sh) according to your path and environment setting.

```
python -m sample_factory.launcher.run --backend=slurm --slurm_workdir=./slurm_grid --slurm_gpus_per_job=0 --slurm_cpus_per_gpu=48 --slurm_sbatch_template=./training_template.sh --pause_between=1 --slurm_print_only=False --run=sf_examples.dmlab.experiments.hippo2025_batch_run_control --slurm_partition=genoa --slurm_timeout=30:05:00
```


## Evaluation





### run evaluation scripts

To evaluate the checkpoints throughout training, run:

```eval
python enjoy_multi_thread.py
```
modify experiment name in the `enjoy_multi_thread.py` file according to the one you would like to evaluate.

There are additional packages that need to be installed to save data, e.g. pandas, h5py, etc.

<!-- >📋  Describe how to evaluate the trained models on benchmarks reported in the paper, give commands that produce the results (section below). -->


## analyses

in [analyses notebook](sf_examples/analyses.ipynb)



