# Mitigating Suboptimality of Deterministic Policy Gradients in Complex Q-functions

This is the official PyTorch implementation of the paper "**Mitigating Suboptimality of Deterministic Policy Gradients in Complex Q-functions**" (ICLR 2025).

## Problem

In value-based actor-critic reinforcement learning, the actor is trained to maximize the critic (Q-function) via gradient ascent. However, in complex tasks like dexterous manipulation, the Q-function landscape has several locally optimal actions. This makes the actor susceptible to getting stuck at local optima, leading to sample-inefficient training and a suboptimal policy on convergence.

<img src="./data/assets/teaser.png" style="zoom: 45%;" />

Plot of Q-value versus action $a$ (projected to 2D) at one state. In control of restricted Walker (left), various motions are locally optimal, like avoiding falling and moving forward slowly. Similarly, in recommenders (right), finding the global optimum among the representations of many items (black dots) is challenging.

## Approach: Successive Actors for Value Optimization (SAVO)

<img src="./data/assets/method.png" style="zoom:45%;" />

An actor \(\mu\) trained with gradient ascent on a challenging Q-landscape gets stuck in local optima. Our approach learns a sequence of surrogates \(\Psi_i\) of the Q-function that successively prune out the Q-landscape below the current best Q-values, resulting in fewer local optima. Thus, the actors \(\nu_i\) trained to ascend on these surrogates produce actions with a more optimal Q-value.



-----

# Code

## Directories

The structure of the repository:

- `data`: Scripts used for experiments.
- `large_rl`: Implementation of all components.
- `main.py`: Entry point for running all the methods.

## Python Environment

- Python: 3.6 or more is required (Recommended Python 3.6.9)

## Dependencies

- All the python package requirements are in `requirements.txt`. Install them in a new virtual environment (e.g. pyenv,
  conda) via:
    - `pip install -r requirements.txt`

# Experiments

## Mine world

- **General Notes**
    - Turn 'mw_test_save_video' to be True to generate videos during evaluation.
    - Videos are generated in ./videos.

- **Baseline Agents**
    - SAVO

        ```shell
        python main.py --env_name=mine --method_name savo_refined --WOLP_cascade_list_len 3 --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True --TD3_target_policy_smoothing=False --seed=5
        ```

    - TD3

        ```shell
        python main.py --env_name=mine --method_name savo_refined --WOLP_cascade_list_len 1 --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True --seed=5
        ```

    - Ensemble

        ```shell
        python main.py --env_name=mine --method_name ensemble --WOLP_cascade_list_len 3 --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True --seed=5
        ```

    - TD3 + Sampling

        ```shell
        python main.py --env_name=mine --method_name wolp_dual --WOLP_topK 3 --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True --seed=5
        ```

    - Joint

        ```shell
        python main.py --env_name=mine --method_name flair_joint --WOLP_cascade_list_len 3 --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True --seed=5
        ```

    - Wolpertinger

        ```shell
        python main.py --env_name=mine --method_name wolp --WOLP_topK 3 --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True --seed=5
        ```

    - CEM

        ```shell
        python main.py --env_name=mine --method_name cem --CEM_rescale_actions=True --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True --seed=5
        ```

- **Ablation Agents**
    - DeepSet / Transformer / LSTM: Change `WOLP_ar_type_list_encoder=deepset / transformer / lstm`
        ```shell
        # Same as SAVO above
        python main.py --env_name=mine --method_name savo_refined --WOLP_cascade_list_len 3 --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True --TD3_target_policy_smoothing=False --WOLP_ar_type_list_encoder=deepset --seed=5
        ```

## RecSim

- **Baseline Agents**
    - SAVO

        ```shell
        python main.py --env_name=recsim-10k --method_name=savo_refined --WOLP_cascade_list_len=3 --env_dim_extra=5 --seed=5
        ```

    - TD3

        ```shell
        python main.py --seed=3 --env_name=recsim-10k --method_name=savo_refined --WOLP_cascade_list_len=1 --env_dim_extra=5
        ```

    - Ensemble

        ```shell
        python main.py --seed=1 --env_name=recsim-10k --method_name ensemble --WOLP_cascade_list_len=3
        ```

    - TD3 + Sampling

        ```shell
        python main.py --seed=1 --env_name=recsim-10k --method_name=wolp_dual --WOLP_topK=3 --env_dim_extra=5
        ```

    - Joint

        ```shell
        python main.py --seed=5 --env_name=recsim-10k --method_name flair_joint --WOLP_cascade_list_len 3
        ```

    - Wolpertinger

        ```shell
        python main.py --seed=5 --env_name=recsim-10k --method_name=wolp --WOLP_topK=3 --env_dim_extra=5
        ```

    - CEM

        ```shell
        python main.py --seed=5 --env_name=recsim-10k --method_name cem --CEM_rescale_actions=True
        ```

- **Ablation Agents**
    - DeepSet / Transformer / LSTM: Change `WOLP_ar_type_list_encoder=deepset / transformer / lstm`
        ```shell
        # Same as SAVO above
        python main.py --env_name=recsim-10k --method_name=savo_refined --WOLP_cascade_list_len=3 --env_dim_extra=5 --WOLP_ar_type_list_encoder=deepset --seed=5
        ```

- **Some important parameters**
    - *Env* (Experiments done on default values. The following is for reference only.)
        + `--num_all_actions=5000` -> Size of the action set
        + `--num_envs=16` -> number of environments to run in parallel
        + `--recsim_dim_embed=30` -> dimension of action embedding
        + `--recsim_num_categories=30` -> number of item-categories in the action set

## Reacher

- **Baseline Agents**
    - SAVO

        ```shell
        python main.py --env_name=mujoco-walker2d --method_name savo --reacher_validity_type=box --seed 33
        ```

    - TD3

        ```shell
        python main.py --env_name=mujoco-walker2d --method_name savo --WOLP_cascade_list_len 1 --reacher_validity_type=box --seed 33
        ```

    - Joint

        ```shell
        python main.py --seed=2 --env_name=mujoco-hopper --method_name flair_joint --reacher_validity_type box --TD3_target_policy_smoothing=False
        ```

    - Greedy-AC

        ```shell
        python main.py --env_name=mujoco-walker2d --method_name greedy_ac --reacher_validity_type box --greedy_td3_exploration=False --seed 4
        ```

    - Greedy-TD3

        ```shell
        python main.py --seed=1 --env_name=mujoco-walker2d --method_name greedy_ac --reacher_validity_type box --greedy_td3_exploration=True --TD3_target_policy_smoothing=False
        ```

    - TD3 + Sampling

        ```shell
        python main.py --env_name=mujoco-walker2d --method_name wolp_dual --WOLP_topK 3 --do_naive_eval=True --reacher_validity_type=box --seed 42
        ```

    - CEM

        ```shell
        python main.py --seed=4 --env_name=mujoco-inverted_double_pendulum --method_name cem --reacher_validity_type box
        ```
    - Ensemble

        ```shell
        python main.py --seed=5 --env_name=mujoco-hopper --method_name ensemble --reacher_validity_type box --TD3_target_policy_smoothing=False
        ```

# Acknowledgement

- The Grid world environment is adapted from https://github.com/maximecb/gym-minigrid
- RecSim Simulator: https://github.com/google-research/recsim

