# Entrop-Regularized Actor-Critic

This repository contains a **tabular entropy-regularized actor–critic** implementation written in **JAX**.
It supports two experiment modes:

- **Standard actor–critic**: learn a critic via TD-style updates for `H` inner steps, then do one actor update.
- **Exact-critic oracle**: replace the learned critic with the **exact** regularized Q-function.

The code is intended for **small tabular MDPs** (GridWorld and a synthetic random MDP) and reproduces experiments by
running a small hyperparameter grid-search and saving learning curves to disk.

---

## Project structure

```
.
├── main.py                    # runs experiments + saves .pkl results
├── plot_experiments.py         # loads .pkl results + produces comparison PDFs
├── gridword_env/
│   ├── gridworld.py            # tabular GridWorld model builder (P, R, mu0)
│   └── gridworld_utils.py
└── synthetic_env/
    ├── Env.py                  # random tabular MDP generator (P, R, mu0)
    └── utils.py
```

## Method overview

### Policy parametrization

The policy is **tabular softmax**:
- parameters: `theta ∈ R^{S×A}`
- policy: `π_θ(a|s) = softmax(theta[s, :])`

### Entropy-regularized reward

For entropy coefficient `λ > 0`, the regularized per-step reward is
\$
\tilde r(s,a) = r(s,a) - \lambda \log \pi_\theta(a\mid s).
\$

### Discounted occupancy used for sampling

The actor and critic minibatches sample states from the **discounted state distribution**
\$
d_\pi = (1-\gamma) (I - \gamma P_\pi)^{-\top} \mu_0,
\$
where `P_π(s'|s) = Σ_a π(a|s) P(s'|s,a)`.

This is computed **exactly** each iteration by solving a linear system (works well for small `S`).

### Critic update (standard mode)

Given current critic table `q`, we run `H` inner updates. Each inner update:
1. sample `(s,a,s')` using `d_π` and the model `P`
2. compute TD error using the entropy-regularized bootstrap target
\$
\delta = \tilde r(s,a) + \gamma \sum_{a'} \pi(a'|s')\big(q(s',a') - \lambda \log \pi(a'|s')\big) - q(s,a).
\$
3. update a single table entry (in-place via `jax.numpy` indexed updates):
\$
q(s,a) \leftarrow q(s,a) + \eta_c \; \mathbb{E}[\delta].
\$

### Actor update

We use an advantage-style update with baseline
\$
v(s) = \sum_a \pi(a|s)\big(q(s,a) - \lambda \log \pi(a|s)\big),
\$
and
\$
\text{adv}(s,a) = q(s,a) - \lambda\log \pi(a|s) - v(s).
\$

In code, one actor step performs:
\$
\theta(s,a) \leftarrow \theta(s,a) + \eta_a \; \mathbb{E}\left[\frac{\text{adv}(s,a)}{1-\gamma}\right].
\$

### Exact critic oracle (ablation)

When `use_exact=True`, the critic is set to the **exact** entropy-regularized Q-function by solving the Bellman system
(via `compute_exact_q`). The actor step is the same, but uses this oracle `q`.

### Logged metric

We log the exact entropy-regularized objective
\$
\tilde J_\lambda(\theta) = \mathbb{E}_{s_0\sim \mu_0}[V^{\pi_\theta}_\lambda(s_0)],
\$
computed by solving `(I - γP_π) v = \tilde r_π` in `compute_exact_v`.

---

## Installation

Create a Python environment (3.10+ recommended) and install dependencies.

### Minimal dependencies

- `jax` / `jaxlib`
- `flax` (used by the environments for `struct.dataclass`)
- `numpy`
- `matplotlib`

Example (CPU-only JAX):

```bash
pip install -U "jax[cpu]" flax numpy matplotlib
```

If you want GPU/TPU JAX, follow the official JAX installation instructions for your platform.

---

## Running experiments

From the project root (the folder containing `main.py`):

```bash
python main.py
```

By default, `main.py` runs **four** jobs:

1. GridWorld, standard actor–critic (`use_exact=False`)
2. GridWorld, exact-critic oracle (`use_exact=True`)
3. Synthetic MDP, standard actor–critic (`use_exact=False`)
4. Synthetic MDP, exact-critic oracle (`use_exact=True`)

### Outputs

Each job writes a pickle file to the current directory, e.g.

- `results_gridworld_3x4_gridsearch.pkl`
- `results_gridworld_3x4_EXACT.pkl`
- `results_synthetic_S16_gridsearch.pkl`
- `results_synthetic_S16_EXACT.pkl`

The pickle format is:

```python
{
  "results": { label: histories, ... },
  "iters": ITERS + 1,
  "env": "gridworld" or "synthetic",
  "size": "...",
  "Hs": [...]
}
```

where:
- `label` is a string like `"H=32_lra=0.01_lrc=0.03"`
- `histories` is an array/list of shape `(NUM_RUNS, ITERS+1)` (learning curves of `\tilde J_\lambda`)

---

## Plotting

After you have result `.pkl` files, run:

```bash
python plot_experiments.py
```

This script:
- selects the **best hyperparameter configuration per `H`** (based on the last `tail_avg` iterations),
- plots **mean ± std** across runs,
- overlays the **exact-critic** curve as a dashed black line,
- saves figures as PDFs.

### Outputs

- `comparison_gridworld_3x4.pdf`
- `comparison_synthetic_S16.pdf`
- `legend_unified.pdf` (a standalone legend you can include in LaTeX figures)

You can edit the `datasets` list at the bottom of `plot_experiments.py` to add/remove result files.

---

## Reproducing / modifying experiments

All knobs are in `run_experiment(...)` inside `main.py`, including:

- environment:
  - GridWorld size (`rows`, `cols`)
  - synthetic MDP size (`num_states`, `num_actions`)
- discount and entropy: `GAMMA`, `LAMBDA`
- inner critic steps: `Hs`
- batches: `ACTOR_BATCH`, `CRITIC_BATCH`
- learning-rate grids: `LR_ACTOR_VALS`, `LR_CRITIC_VALS`
- number of runs / iterations: `NUM_RUNS`, `ITERS`

---