# Rethinking Actor-Critic: Successive Actors for Critic Maximization

This is the official PyTorch implementation of the paper "**RETHINKING ACTOR-CRITIC: SUCCESSIVE ACTORS FOR CRITIC MAXIMIZATION**" (ICLR 2024).

-----

# Code

## Directories

The structure of the repository:

- `data`: Scripts used for experiments.
- `large_rl`: Implementation of all components.
- `main.py`: Entry point for running all the methods.

## Python Environment

- Python: 3.6 or more is required (Recommended Python 3.6.9)

## Dependencies

- All the python package requirements are in `requirements.txt`. Install them in a new virtual environment (e.g. pyenv,
  conda) via:
    - `pip install -r requirements.txt`

# Experiments

- **Some important parameters**
    - *Agent*
        + `--agent_type=dqn/wolp/wolp-sac` -> change the learning agent
        + `--WOLP_ar_type_list_encoder=deepset/lstm/transformer` -> select a method for List-encoder
        + `--WOLP_if_ar_actor_share_weight=True/False` -> Whether to share the network in the cascading architecture

    - *Env* (Experiments done on default values. The following is for reference only.)
        + `--num_all_actions=5000` -> Size of the action set
        + `--num_envs=16` -> number of environments to run in parallel
        + `--recsim_dim_embed=30` -> dimension of action embedding
        + `--recsim_num_categories=30` -> number of item-categories in the action set

## RecSim

- **Baseline Agents**
    - Wolp

        ```
         python main.py --seed=1 --env_name=recsim-10k --method_name=wolp --prefix=recsim-extra-wolp --WOLP_topK=3 --env_dim_extra=5 --run_setup=exp
        ```
    
    - Wolp-Dual

        ```
         python main.py --seed=1 --env_name=recsim-10k --method_name=wolp_dual --prefix=recsim-extra-wolp_dual --WOLP_topK=3 --env_dim_extra=5 --run_setup=exp
        ```
    
    - Joint

        ```
         python main.py --seed=1 --env_name=recsim-10k --method_name flair_joint  --prefix recsim-extra-flair_joint --WOLP_cascade_list_len 3 --run_setup=exp-no-video
        ```

    - FLAIR

        ```
         python main.py --seed=1 --env_name=recsim-10k --method_name=flair_inside --prefix=recsim-extra-flair_inside --WOLP_cascade_list_len=3 --env_dim_extra=5 --run_setup=exp
        ```
    
    - FLAIR - Len 1

        ```
         python main.py --seed=1 --env_name=recsim-10k --method_name flair_inside --prefix recsim-extra-new-no_num-taken-len1 --WOLP_cascade_list_len 1 --run_setup=exp-no-video --WOLP_if_0th_ref_critic True --WOLP_if_ar_noise_before_cascade=False --WOLP_if_noise_postQ False --WOLP_ar_actor_no_conditioning=True --WOLP_list_concat_state=True --WOLP_ar_list_encoder_deepset_maxpool=True --WOLP_ar_value_loss_if_sum False --WOLP_ar_critic_scaled_num_updates=False --WOLP_ar_critic_taken_action_update=True --do_naive_eval=False --WOLP_if_0th_ref_critic=False --WOLP_if_ar_noise_before_cascade=True
        ```
    
    - FLAIR - No Linkage

        ```
         python main.py --seed=1 --env_name=recsim-10k --method_name flair_no_linkage --prefix recsim-extra-no_linkage --WOLP_cascade_list_len 3 --run_setup=exp-no-video --WOLP_if_0th_ref_critic True --WOLP_if_ar_noise_before_cascade=False --WOLP_if_noise_postQ False --WOLP_ar_actor_no_conditioning=True --WOLP_list_concat_state=True --WOLP_ar_list_encoder_deepset_maxpool=True --WOLP_ar_value_loss_if_sum False --WOLP_ar_critic_scaled_num_updates=False --WOLP_ar_critic_taken_action_update=True --do_naive_eval=False --WOLP_if_0th_ref_critic=False --WOLP_if_ar_noise_before_cascade=True
        ```

## RecSim-data

- **Baseline Agents**
    - Wolp

        ```
         python main.py --seed=1 --env_name=recsim-data-rating5 --method_name=wolp --prefix=recsim-extra-wolp --WOLP_topK=3 --env_dim_extra=5 --run_setup=exp
        ```
    
    - Wolp-Dual

        ```
         python main.py --seed=1 --env_name=recsim-data-rating5 --method_name=wolp_dual --prefix=recsim-extra-wolp_dual --WOLP_topK=3 --env_dim_extra=5 --run_setup=exp
        ```
    
    - Joint

        ```
         python main.py --seed=1 --env_name=recsim-data-rating5 --method_name flair_joint  --prefix recsim-extra-flair_joint --WOLP_cascade_list_len 3 --run_setup=exp-no-video
        ```

    - FLAIR

        ```
         python main.py --seed=1 --env_name=recsim-data-rating5 --method_name=flair_inside --prefix=recsim-extra-flair_inside --WOLP_cascade_list_len=3 --env_dim_extra=5 --run_setup=exp
        ```
    
    - FLAIR - Len 1

        ```
         python main.py --seed=1 --env_name=recsim-data-rating5 --method_name flair_inside --prefix recsim-extra-new-no_num-taken-len1 --WOLP_cascade_list_len 1 --run_setup=exp-no-video --WOLP_if_0th_ref_critic True --WOLP_if_ar_noise_before_cascade=False --WOLP_if_noise_postQ False --WOLP_ar_actor_no_conditioning=True --WOLP_list_concat_state=True --WOLP_ar_list_encoder_deepset_maxpool=True --WOLP_ar_value_loss_if_sum False --WOLP_ar_critic_scaled_num_updates=False --WOLP_ar_critic_taken_action_update=True --do_naive_eval=False --WOLP_if_0th_ref_critic=False --WOLP_if_ar_noise_before_cascade=True
        ```
    
    - FLAIR - No Linkage

        ```
         python main.py --seed=1 --env_name=recsim-data-rating5 --method_name flair_no_linkage --prefix recsim-extra-no_linkage --WOLP_cascade_list_len 3 --run_setup=exp-no-video --WOLP_if_0th_ref_critic True --WOLP_if_ar_noise_before_cascade=False --WOLP_if_noise_postQ False --WOLP_ar_actor_no_conditioning=True --WOLP_list_concat_state=True --WOLP_ar_list_encoder_deepset_maxpool=True --WOLP_ar_value_loss_if_sum False --WOLP_ar_critic_scaled_num_updates=False --WOLP_ar_critic_taken_action_update=True --do_naive_eval=False --WOLP_if_0th_ref_critic=False --WOLP_if_ar_noise_before_cascade=True
        ```

## Mine World

- **General Notes**
    - Turn 'mw_test_save_video' to be True to generate videos during evaluation.
    - Videos are generated in ./videos.

- **Baseline Agents**
    - Wolp

        ```
         python main.py --seed=1 --env_name=mine --method_name wolp --prefix mine-wolp --WOLP_topK 3 --run_setup=exp-no-video --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True
        ```
    
    - Wolp-Dual

        ```
         python main.py --seed=1 --env_name=mine --method_name wolp_dual --prefix mine-wolp_dual --WOLP_topK 3 --run_setup=exp-no-video --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True
        ```
    
    - Joint

        ```
         python main.py --seed=1 --env_name=mine --method_name flair_joint --prefix mine-flair_joint-len3 --WOLP_cascade_list_len 3 --run_setup=exp-no-video --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True
        ```

    - FLAIR

        ```
         python main.py --seed=1 --env_name=mine --method_name flair_inside --prefix mine-new-no_num-taken-noise_before_True --WOLP_if_ar_noise_before_cascade=True --WOLP_cascade_list_len 3 --WOLP_ar_actor_no_conditioning=True --WOLP_ar_critic_taken_action_update=True --do_naive_eval=False --run_setup=exp-no-video --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True
        ```
    
    - FLAIR - Len 1

        ```
         python main.py --seed=1 --env_name=mine --method_name flair_inside --prefix mine-new-no_num-taken-noise_before-len1 --WOLP_cascade_list_len 1 --WOLP_ar_actor_no_conditioning=True --WOLP_ar_critic_taken_action_update=True --do_naive_eval=False --run_setup=exp-no-video --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True
        ```
    
    - FLAIR - No Linkage

        ```
         python main.py --seed=1 --env_name=mine --method_name flair_no_linkage --prefix mine-no_linkage --WOLP_cascade_list_len 3 --WOLP_ar_actor_no_conditioning=True --WOLP_ar_critic_taken_action_update=True --do_naive_eval=False --run_setup=exp-no-video --mw_tool_size=100 --mw_mine_size=20 --mw_minRoomSize=17 --mw_maxRoomSize=17 --mw_grid_size=25 --mw_randomise_grid=True
        ```

## Continuous envs -- Box
Here we will take Hopper as an example. You can change the envs by changing the env_name.

- **Baseline Agents**
    
    - Wolp-Dual

        ```
         python main.py --seed=1 --env_name=mujoco-hopper --method_name wolp_dual --prefix hopper-wolp_dual --WOLP_topK 3 --run_setup=exp-no-video --do_naive_eval=True --reacher_validity_type=box
        ```
    
    - Joint

        ```
         python main.py --seed=1 --env_name=mujoco-hopper --method_name flair_joint --prefix hopper-BOX-joint --WOLP_cascade_list_len 3 --run_setup=exp-no-video --WOLP_if_0th_ref_critic True --WOLP_if_ar_noise_before_cascade=False --WOLP_if_noise_postQ False --WOLP_ar_actor_no_conditioning=True --WOLP_list_concat_state=True --WOLP_ar_list_encoder_deepset_maxpool=True --WOLP_ar_value_loss_if_sum False --WOLP_ar_critic_scaled_num_updates=False --WOLP_ar_critic_taken_action_update=True --do_naive_eval=True --reacher_validity_type=box
        ```

    - FLAIR

        ```
         python main.py --seed=1 --env_name=mujoco-hopper --method_name flair_inside --prefix hopper-BOX-new-no_num-taken --WOLP_cascade_list_len 3 --run_setup=exp-no-video --WOLP_if_0th_ref_critic True --WOLP_if_ar_noise_before_cascade=False --WOLP_if_noise_postQ False --WOLP_ar_actor_no_conditioning=True --WOLP_list_concat_state=True --WOLP_ar_list_encoder_deepset_maxpool=True --WOLP_ar_value_loss_if_sum False --WOLP_ar_critic_scaled_num_updates=False --WOLP_ar_critic_taken_action_update=True --do_naive_eval=True --reacher_validity_type=box
        ```
    
    - FLAIR - Len 1

        ```
         python main.py --seed=1 --env_name=mujoco-hopper --method_name flair_inside --prefix hopper-BOX-new-no_num-taken-len1 --WOLP_cascade_list_len 1 --run_setup=exp-no-video --WOLP_if_0th_ref_critic True --WOLP_if_ar_noise_before_cascade=False --WOLP_if_noise_postQ False --WOLP_ar_actor_no_conditioning=True --WOLP_list_concat_state=True --WOLP_ar_list_encoder_deepset_maxpool=True --WOLP_ar_value_loss_if_sum False --WOLP_ar_critic_scaled_num_updates=False --WOLP_ar_critic_taken_action_update=True --do_naive_eval=True --reacher_validity_type=box
        ```
    
    - FLAIR - No Linkage

        ```
         python main.py --seed=1 --env_name=mujoco-hopper --method_name flair_no_linkage --prefix hopper-BOX-no_linkage --WOLP_cascade_list_len 3 --run_setup=exp-no-video --WOLP_if_0th_ref_critic True --WOLP_if_ar_noise_before_cascade=False --WOLP_if_noise_postQ False --WOLP_ar_actor_no_conditioning=True --WOLP_list_concat_state=True --WOLP_ar_list_encoder_deepset_maxpool=True --WOLP_ar_value_loss_if_sum False --WOLP_ar_critic_scaled_num_updates=False --WOLP_ar_critic_taken_action_update=True --do_naive_eval=True --reacher_validity_type=box
        ```



# Acknowledgement

- The Grid world environment is adapted from https://github.com/maximecb/gym-minigrid
- RecSim Simulator: https://github.com/google-research/recsim
