# Categorical DQN (C51)

## Overview

C51 introduces a distributional perspective for DQN: instead of learning a single value for an action, C51 learns to predict a distribution of values for the action. Empirically, C51 demonstrates impressive performance in ALE.


Original papers: 

* [A Distributional Perspective on Reinforcement Learning](https://arxiv.org/abs/1707.06887)

## Implemented Variants


| Variants Implemented      | Description |
| ----------- | ----------- |
| :material-github: [`c51_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py), :material-file-document: [docs](/rl-algorithms/c51/#c51_ataripy) |  For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques. |
| :material-github: [`c51.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51.py), :material-file-document: [docs](/rl-algorithms/c51/#c51py) | For classic control tasks like `CartPole-v1`. |
| :material-github: [`c51_atari_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari_jax.py), :material-file-document: [docs](/rl-algorithms/c51/#c51_atari_jaxpy) |  For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques. |
| :material-github: [`c51_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_jax.py), :material-file-document: [docs](/rl-algorithms/c51/#c51_jaxpy) | For classic control tasks like `CartPole-v1`. |

Below are our single-file implementations of C51:


## `c51_atari.py`

The [c51_atari.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py) has the following features:

* For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques.
* Works with the Atari's pixel `Box` observation space of shape `(210, 160, 3)`
* Works with the `Discrete` action space

### Usage


=== "poetry"

    ```bash
    poetry install -E atari
    poetry run python cleanrl/c51_atari.py --env-id BreakoutNoFrameskip-v4
    poetry run python cleanrl/c51_atari.py --env-id PongNoFrameskip-v4
    ```

=== "pip"

    ```bash
    pip install -r requirements/requirements-atari.txt
    python cleanrl/c51_atari.py --env-id BreakoutNoFrameskip-v4
    python cleanrl/c51_atari.py --env-id PongNoFrameskip-v4
    ```


### Explanation of the logged metrics

Running `python cleanrl/c51_atari.py` will automatically record various metrics such as actor or value losses in Tensorboard. Below is the documentation for these metrics:

* `charts/episodic_return`: episodic return of the game
* `charts/SPS`: number of steps per second
* `losses/loss`: the cross entropy loss between the $t$ step state value distribution and the projected $t+1$ step state value distribution
* `losses/q_values`: implemented as `(old_pmfs * q_network.atoms).sum(1)`, which is the sum of the probability of getting returns $x$ (`old_pmfs`) multiplied by $x$ (`q_network.atoms`), averaged over the sample obtained from the replay buffer; useful when gauging if under or over estimation happens


### Implementation details

[c51_atari.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py) is based on (Bellemare et al., 2017)[^1] but presents a few implementation differences:

1. (Bellemare et al., 2017)[^1] injects stochaticity by doing "on each frame the environment rejects the agent’s selected action with probability $p = 0.25$", but `c51_atari.py` does not do this
1. `c51_atari.py` use a self-contained evaluation scheme: `c51_atari.py` reports the episodic returns obtained throughout training, whereas (Bellemare et al., 2017)[^1] is trained with `--end-e=0.01` but reported episodic returns using a separate evaluation process with `--end-e=0.001` (See "5.2. State-of-the-Art Results" on page 7).


### Experiment results

To run benchmark experiments, see :material-github: [benchmark/c51.sh](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/c51.sh). Specifically, execute the following command:

<script src="https://emgithub.com/embed.js?target=https%3A%2F%2Fgithub.com%2Fvwxyzjn%2Fcleanrl%2Fblob%2Fmaster%2Fbenchmark%2Fc51.sh%23L8-L13&style=github&showBorder=on&showLineNumbers=on&showFileMeta=on&showCopy=on"></script>

Below are the average episodic returns for `c51_atari.py`. 


| Environment      | `c51_atari.py` 10M steps | (Bellemare et al., 2017, Figure 14)[^1] 50M steps | (Hessel et al., 2017, Figure 5)[^3] 
| ----------- | ----------- | ----------- | ---- |
| BreakoutNoFrameskip-v4      | 461.86 ± 69.65      | 748  | ~500 at 10M steps, ~600 at 50M steps
| PongNoFrameskip-v4  | 19.46 ± 0.70    |  20.9 |  ~20 10M steps, ~20 at 50M steps 
| BeamRiderNoFrameskip-v4   | 9592.90 ± 2270.15        | 14,074 | ~12000 10M steps, ~14000 at 50M steps 


Note that we save computational time by reducing timesteps from 50M to 10M, but our `c51_atari.py` scores the same or higher than (Mnih et al., 2015)[^1] in 10M steps.


Learning curves:

<div class="grid-container">
<img src="../c51/BeamRiderNoFrameskip-v4.png">

<img src="../c51/BreakoutNoFrameskip-v4.png">

<img src="../c51/PongNoFrameskip-v4.png">
</div>


Tracked experiments and game play videos:

<iframe src="https://wandb.ai/openrlbenchmark/openrlbenchmark/reports/Atari-CleanRL-s-C51--VmlldzoxNzI0NzQ0" style="width:100%; height:500px" title="CleanRL C51 Tracked Experiments"></iframe>



## `c51.py`

The [c51.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51.py) has the following features:

* Works with the `Box` observation space of low-level features
* Works with the `Discrete` action space
* Works with envs like `CartPole-v1`


### Usage

=== "poetry"

    ```bash
    poetry run python cleanrl/c51.py --env-id CartPole-v1
    ```

=== "pip"

    ```bash
    python cleanrl/c51.py --env-id CartPole-v1
    ```


### Explanation of the logged metrics

See [related docs](/rl-algorithms/c51/#explanation-of-the-logged-metrics) for `c51_atari.py`.

### Implementation details

The [c51.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51.py) shares the same implementation details as [`c51_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py) except the `c51.py` runs with different hyperparameters and neural network architecture. Specifically,

1. `c51.py` uses a simpler neural network as follows:
        ```python
        self.network = nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.single_action_space.n),
        )
        ```
2. `c51.py` runs with different hyperparameters:

    ```bash
    python c51.py --total-timesteps 500000 \
        --learning-rate 2.5e-4 \
        --buffer-size 10000 \
        --gamma 0.99 \
        --target-network-frequency 500 \
        --max-grad-norm 0.5 \
        --batch-size 128 \
        --start-e 1 \
        --end-e 0.05 \
        --exploration-fraction 0.5 \
        --learning-starts 10000 \
        --train-frequency 10
    ```


### Experiment results

To run benchmark experiments, see :material-github: [benchmark/c51.sh](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/c51.sh). Specifically, execute the following command:

<script src="https://emgithub.com/embed.js?target=https%3A%2F%2Fgithub.com%2Fvwxyzjn%2Fcleanrl%2Fblob%2Fmaster%2Fbenchmark%2Fc51.sh%23L2-L6&style=github&showBorder=on&showLineNumbers=on&showFileMeta=on&showCopy=on"></script>


Below are the average episodic returns for `c51.py`. 


| Environment      | `c51.py`  | 
| ----------- | ----------- | 
| CartPole-v1      | 481.20 ± 20.53      |
| Acrobot-v1  | -87.70 ± 5.52     | 
| MountainCar-v0   | -166.38 ± 27.94        | 


Note that the C51 has no official benchmark on classic control environments, so we did not include a comparison. That said, our `c51.py` was able to achieve near perfect scores in `CartPole-v1` and `Acrobot-v1`; further, it can obtain successful runs in the sparse environment `MountainCar-v0`.


Learning curves:

<div class="grid-container">
<img src="../c51/CartPole-v1.png">

<img src="../c51/Acrobot-v1.png">

<img src="../c51/MountainCar-v0.png">
</div>


Tracked experiments and game play videos:

<iframe src="https://wandb.ai/openrlbenchmark/openrlbenchmark/reports/Classic-Control-CleanRL-s-C51--VmlldzoxODIwMTE4" style="width:100%; height:500px" title="CleanRL C51 Tracked Experiments"></iframe>


## `c51_atari_jax.py`

The [c51_atari_jax.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari_jax.py) has the following features:

* Uses [Jax](https://github.com/google/jax), [Flax](https://github.com/google/flax), and [Optax](https://github.com/deepmind/optax) instead of `torch`.  [c51_atari_jax.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari_jax.py) is roughly 25% faster than  [c51_atari.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py)
* For playing Atari games. It uses convolutional layers and common atari-based pre-processing techniques.
* Works with the Atari's pixel `Box` observation space of shape `(210, 160, 3)`
* Works with the `Discrete` action space

### Usage


=== "poetry"

    ```bash
    poetry install -E "atari jax"
    poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    poetry run python cleanrl/c51_atari_jax.py --env-id BreakoutNoFrameskip-v4
    poetry run python cleanrl/c51_atari_jax.py --env-id PongNoFrameskip-v4
    ```

=== "pip"

    ```bash
    pip install -r requirements/requirements-atari.txt
    pip install -r requirements/requirements-jax.txt
    pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    python cleanrl/c51_atari_jax.py --env-id BreakoutNoFrameskip-v4
    python cleanrl/c51_atari_jax.py --env-id PongNoFrameskip-v4
    ```


### Explanation of the logged metrics
See [related docs](/rl-algorithms/c51/#explanation-of-the-logged-metrics) for `c51_atari.py`.


### Implementation details
See [related docs](/rl-algorithms/c51/#implementation-details) for `c51_atari.py`.


### Experiment results

To run benchmark experiments, see :material-github: [benchmark/c51.sh](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/c51.sh). Specifically, execute the following command:

<script src="https://emgithub.com/embed.js?target=https%3A%2F%2Fgithub.com%2Fvwxyzjn%2Fcleanrl%2Fblob%2Fmaster%2Fbenchmark%2Fc51.sh%23L23-L29&style=github&showBorder=on&showLineNumbers=on&showFileMeta=on&showCopy=on"></script>

Below are the average episodic returns for `c51_atari_jax.py`. 


| Environment             | `c51_atari_jax.py` 10M steps | `c51_atari.py` 10M steps | (Bellemare et al., 2017, Figure 14)[^1] 50M steps | (Hessel et al., 2017, Figure 5)[^3]   |
| ----------------------- | ---------------------------- | ------------------------ | ------------------------------------------------- | ------------------------------------- |
| BreakoutNoFrameskip-v4  | 448.56 ± 17.02               | 461.86 ± 69.65           | 748                                               | ~500 at 10M steps, ~600 at 50M steps  |
| PongNoFrameskip-v4      | 19.88 ± 0.31                 | 19.46 ± 0.70             | 20.9                                              | ~20 10M steps, ~20 at 50M steps       |
| BeamRiderNoFrameskip-v4 | 9504.91 ± 709.69             | 9592.90 ± 2270.15        | 14,074                                            | ~12000 10M steps, ~14000 at 50M steps |



Learning curves:
<div class="grid-container">
<img src="../c51/jax/BeamRiderNoFrameskip-v4.png">
<img src="../c51/jax/BeamRiderNoFrameskip-v4-time.png">

<img src="../c51/jax/BreakoutNoFrameskip-v4.png">
<img src="../c51/jax/BreakoutNoFrameskip-v4-time.png">

<img src="../c51/jax/PongNoFrameskip-v4.png">
<img src="../c51/jax/PongNoFrameskip-v4-time.png">
</div>

Tracked experiments and game play videos:

<iframe src="https://wandb.ai/openrlbenchmark/openrlbenchmark/reports/Atari-CleanRL-s-C51-JAX--VmlldzozMjM4MTIy" style="width:100%; height:500px" title="CleanRL C51 Tracked Experiments"></iframe>


## `c51_jax.py`

The [c51_jax.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_jax.py) has the following features:

* Uses [Jax](https://github.com/google/jax), [Flax](https://github.com/google/flax), and [Optax](https://github.com/deepmind/optax) instead of `torch`.  [c51_jax.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_jax.py) is roughly 55% faster than  [c51.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51.py)
* Works with the `Box` observation space of low-level features
* Works with the `Discrete` action space
* Works with envs like `CartPole-v1`


### Usage


=== "poetry"

    ```bash
    poetry install -E jax
    poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    poetry run python cleanrl/c51_jax.py --env-id CartPole-v1
    ```

=== "pip"

    ```bash
    pip install -r requirements/requirements-jax.txt
    pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    python cleanrl/c51_jax.py --env-id CartPole-v1
    ```


### Explanation of the logged metrics

See [related docs](/rl-algorithms/c51/#explanation-of-the-logged-metrics) for `c51_atari.py`.

### Implementation details

See [related docs](/rl-algorithms/c51/#implementation-details_1) for `c51.py`.


### Experiment results

To run benchmark experiments, see :material-github: [benchmark/c51.sh](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/c51.sh). Specifically, execute the following command:

<script src="https://emgithub.com/embed.js?target=https%3A%2F%2Fgithub.com%2Fvwxyzjn%2Fcleanrl%2Fblob%2Fmaster%2Fbenchmark%2Fc51.sh%23L15-L21&style=github&showBorder=on&showLineNumbers=on&showFileMeta=on&showCopy=on"></script>


Below are the average episodic returns for `c51_jax.py`. 


| Environment    | `c51_jax.py`    | `c51.py`        |
| -------------- | --------------- | --------------- |
| CartPole-v1    | 491.07 ± 9.70   | 481.20 ± 20.53  |
| Acrobot-v1     | -86.74 ± 2.19   | -87.70 ± 5.52   |
| MountainCar-v0 | -174.30 ± 36.35 | -166.38 ± 27.94 |


Learning curves:

<div class="grid-container">
<img src="../c51/jax/CartPole-v1.png">
<img src="../c51/jax/CartPole-v1-time.png">

<img src="../c51/jax/Acrobot-v1.png">
<img src="../c51/jax/Acrobot-v1-time.png">

<img src="../c51/jax/MountainCar-v0.png">
<img src="../c51/jax/MountainCar-v0-time.png">
</div>


Tracked experiments and game play videos:

<iframe src="https://wandb.ai/openrlbenchmark/openrlbenchmark/reports/Classic-Control-CleanRL-s-C51-JAX--VmlldzozMjM4MTM4" style="width:100%; height:500px" title="CleanRL C51 Tracked Experiments"></iframe>


[^1]:Bellemare, M.G., Dabney, W., & Munos, R. (2017). A Distributional Perspective on Reinforcement Learning. ICML.
[^2]:\[Proposal\] Formal API handling of truncation vs termination. https://github.com/openai/gym/issues/2510
[^3]: Hessel, M., Modayil, J., Hasselt, H.V., Schaul, T., Ostrovski, G., Dabney, W., Horgan, D., Piot, B., Azar, M.G., & Silver, D. (2018). Rainbow: Combining Improvements in Deep Reinforcement Learning. AAAI.