# Distributional Distance Classifiers for Goal-Conditioned Reinforcement Learning


This repository contains the code for the Distributional Distance Classifiers for Goal-Conditioned Reinforcement Learning paper.

This codebase was forked from the official [contrastive RL codebase](https://github.com/google-research/google-research/tree/master/contrastive\_rl) and modified to run the Distributional NCE algorithm.


### Installation

1. Create an Anaconda environment: `conda create -n contrastive_rl python=3.9 -y`
2. Activate the environment: `conda activate contrastive_rl`
3. Install the preliminary dependencies for contrastive RL codebase: `pip install -r requirements.txt --no-deps`
4. Install the final dependencies to accomodate our modifications: `conda env update --file environment.yml`

### Running the experiments

```
conda activate contrastive_rl

# Distributional NCE
python lp_contrastive.py --env_name sawyer_push --num_classifier_bins 21

# Distributional NCE with 1-step consistency
python lp_contrastive.py --env_name sawyer_push --num_classifier_bins 21 --selfsup_flag

# Distributional NCE with Multi-step consistency
python lp_contrastive.py --env_name sawyer_push --num_classifier_bins 21 --selfsup_flag --selfsup_multi_step_flag
```

Some important flags:
- `--env_name`: Select the goal-reaching environment (fetch_reach, fetch_push, sawyer_push, sawyer_bin, fetch_reach_image, fetch_push_image, sawyer_push_image).
- `--num_classifier_bins`: Control the number of classifier bins in the distributional critic (default: 21).
- `--selfsup_flag`: Enable consistency loss, defaults to 1-step consistency regularization.
- `--selfsup_multi_step_flag`: Enable multi-step consistency loss (if this flag is set, `--selfsup_flag` also needs to be set).


To run the Contrastive NCE and C-Learning baselines, we used the original [contrastive RL codebase](https://github.com/google-research/google-research/tree/master/contrastive\_rl).

> We will add a lot more details for running our experiments in the final submission.