# Reinforcement Learning Training Dynamics

> (not _dynamics of the environment_, but _dynamics of training and learning_, i.e. _how does the network train_ or _how do the weights change w.r.t. iterations?_)

## Essential requirements:

```python
ipykernel
matplotlib
jax==0.4.16
jaxlib==0.4.16
dm-haiku==0.0.10
gymnasium==0.29.0
neural_tangents==0.6.4
```

## Replicating results:

Make sure folders `data` and `plots` exist before running.

### Toy environment:

Training an ensemble of Q-networks:
```bash
python toy-nn.py --gamma 1.0 --lr 0.01 --num_layers 2 --layer_width 1024 --ensemble_size 50 --times 0 16 256 2048 65536
```

Computing Q-values from the linearization (mean & variance):
```bash
python toy-wide.py --gamma 1.0 --lr 0.01 --num_layers 2 --layer_width 1024 --times 0 16 256 2048 65536
```

Computing Q-values from the linearization (MC, via an ensemble):
```bash
python toy-wide-mc.py --gamma 1.0 --lr 0.01 --num_layers 2 --layer_width 1024 --ensemble_size 50 --times 0 16 256 2048 65536
```

Plotting the time progression:
```bash
python toy-plots.py --lr 0.01 --num_layers 2 --layer_width 256 --ensemble_size 50 --show_nn --show_wide --show_wide_mc --times 0 16 256 2048 65536
```


Showing individual runs from the ensemble of NNs:
```bash
python toy-plots.py --lr 0.01 --num_layers 2 --layer_width 256 --ensemble_size 50 --show_nn_individual --times 0 16 256 2048 65536
```

### Cartpole environment:

Generating the dataset:
```bash
python cartpole-generate-dataset.py --seed 10 --epsilon 0.0 --dataset_size 10240
```

Get NN parameters for a single run: (to replicate results in the paper, run seeds from 1 to 50)
```bash
python cartpole-single.py --seed 1
```

Get Q-values of 50 runs of the NNs (add `--test` to evaluate test set):
```bash
python cartpole-single-qvals.py
```

Get Q-values of 50 runs of the Wide MC estimation (add `--test` to evaluate test set):
```bash
python cartpole-wide-mc.py
```

Getting the true returns by rolling out the policy in the environment (add `--test` to get true returns from the test set):
```bash
python cartpole-get-returns.py
```

Evaluate RMSE and NLL for the evaluated Q-values and returns (add `--test` to evaluate test set):
```bash
python cartpole-get-metrics.py
```
