# Efficient Multi-Task Learning via Selective Behavior Sharing
This is a PyTorch implementation of Efficient Multi-Task Learning via Selective Behavior Sharing.
We propose a sample efficient multi-task learning method, Q-switch Mixture of Policies (QMP), that selectively shares exploratory behaviors between tasks through a Q-function-based criteria, the Q-switch.

The ability to leverage shared behaviors between tasks is critical for sample efficient multi-task reinforcement learning (MTRL).  Prior approaches based on parameter sharing or policy distillation share behaviors uniformly across tasks and states or focus on learning one optimal policy.  Therefore, they are fundamentally limited when tasks have conflicting behaviors because no one optimal policy exists.  Our key insight is that, we can instead share exploratory behavior which can be helpful even when the optimal behaviors differ.  Furthermore, as we learn each task, we can guide the exploration by sharing behaviors in a task and state dependent way.   To this end, we propose a novel MTRL method, Q-switch Mixture of policies (QMP), that learns to selectively shares exploratory behavior between tasks by using a mixture of policies based on estimated discounted returns to gather training data.  Experimental results in manipulation and locomotion tasks demonstrate that our method outperforms prior behavior sharing methods, highlighting the importance of task and state dependent sharing. 

## Directories
* `run.py` take arguments and initializes experiments
* `garage_experiments.py` defines experiments and starts training
* `learning/`: contains all learning code, baseline implementations, and our method
* `environments/`: registers environments

## Dependencies
* Ubuntu 18.04 or above
* Python 3.7 or above
* Mujoco 2.1 [https://github.com/deepmind/mujoco/releases]
## Installation

To install python dependencies.
   ```bash
   pip install -r requirements.txt
   ```

## Example Commands

### Multistage Reacher
* Our Method
  ```bash
  python run.py hmop_dnc --env=JacoReachMT5-v1 --n_policies 5 --n_epochs=2000 --vis_num=5 --num_evaluation_episodes=50
  ```
* Separated
  ```bash
  python run.py dnc_sac --env=JacoReachMT5-v1 --n_policies 5 --n_epochs=2000 --vis_num=5 --num_evaluation_episodes=50 --kl_coeff 0
  ```
* Shared
  ```bash
  python run.py sac --env=JacoReachMT5-v1 --n_policies 5 --n_epochs=2000 --lr 0.0003 --qf_lr 0.0003 --alpha_lr 0.0003 --hidden_sizes 600 600
  ```
* DnC
  ```bash
  python run.py dnc_sac --env=JacoReachMT5-v1 --n_policies 5 --n_epochs=2000 --vis_num=5 --num_evaluation_episodes=50 --kl_coeff 0.001 --distillation_period 200 --distillation_n_epochs 500
  ```
* DnC (Reg)
  ```bash
  python run.py dnc_sac --env=JacoReachMT5-v1 --n_policies 5 --n_epochs=2000 --vis_num=5 --num_evaluation_episodes=50 --kl_coeff 0.001
  ```
  
* UDS
  ```bash
  python run.py cds_dnc --env=JacoReachMT5-v1 --n_policies 5 --n_epochs=2000 --vis_num=5 --num_evaluation_episodes=50 --kl_coeff 0 --unsupervised True --sharing_quantile=80
  ```

### Maze Navigation
* Our Method
  ```bash
  python run.py hmop_dnc --env=MazeLarge-10-v0 --n_policies 10 --batch_size 6000 --num_evaluation_episodes 100 --vis_freq 50 --snapshot_gap 50 --steps_per_epoch 1 --min_buffer_size 3000 --buffer_batch_size 256 --gradient_steps_per_itr 1000 --n_epochs 2000 --Qfilter argmax --resample 1 --mixture_warmup 0
  ```
* Separated
  ```bash
  python run.py dnc_sac --env MazeLarge-10-v0 --n_policies 10 --kl_coeff 0 --batch_size 6000 --num_evaluation_episodes 100 --vis_freq 50 --snapshot_gap 50 --steps_per_epoch 1 --min_buffer_size 3000 --buffer_batch_size 256 --gradient_steps_per_itr 1000 --n_epochs 2000
  ```
* Shared
  ```bash
  python run.py sac --env=MazeLarge-10-v0 --policy_architecture shared --Q_architecture shared --n_policies 10 --batch_size 6000 --num_evaluation_episodes 100 --vis_freq 50 --snapshot_gap 50 --steps_per_epoch 1 --min_buffer_size 3000 --buffer_batch_size 256 --lr 0.0003 --qf_lr 0.0003 --alpha_lr 0.0003 --gradient_steps_per_itr 1000 --n_epochs 2000 --hidden_sizes 832 832
  ```
* DnC
  ```bash
  python run.py dnc_sac --env=MazeLarge-10-v0 --n_policies 10 --batch_size 6000 --num_evaluation_episodes 100 --vis_freq 50 --snapshot_gap 50 --steps_per_epoch 1 --min_buffer_size 3000 --buffer_batch_size 256 --gradient_steps_per_itr 1000 --n_epochs 2000 --kl_coeff 0.001 --distillation_period 200 --distillation_n_epochs 500
  ```
* DnC (Reg)
  ```bash
  python run.py dnc_sac --env=MazeLarge-10-v0 --n_policies 10 --batch_size 6000 --num_evaluation_episodes 100 --vis_freq 50 --snapshot_gap 50 --steps_per_epoch 1 --min_buffer_size 3000 --buffer_batch_size 256 --gradient_steps_per_itr 1000 --n_epochs 2000 --kl_coeff 0.0001
  ```
* UDS
  ```bash
  python run.py cds_dnc --env=MazeLarge-10-v0 --n_policies 10 --batch_size 6000 --num_evaluation_episodes 100 --vis_freq 50 --snapshot_gap 50 --steps_per_epoch 1 --min_buffer_size 3000 --buffer_batch_size 256 --gradient_steps_per_itr 1000 --n_epochs 2000 --kl_coeff 0 --unsupervised True --sharing_quantile=0
  ```

### Meta-World Manipulation
* Our Method
  ```bash
  python run.py hmop_dnc --env=MetaWorldCDS-v1 --n_policies 4 --batch_size 2000 --num_evaluation_episodes 40 --vis_freq 50 --snapshot_gap 50 --steps_per_epoch 10 --min_buffer_size 5000 --buffer_batch_size 256 --lr 0.0015 --alpha_lr 0.0015 --qf_lr 0.0015 --gradient_steps_per_itr 200 --n_epochs 500 --Qfilter argmax
  ```
* Separated
  ```bash
  python run.py dnc_sac --env=MetaWorldCDS-v1 --n_policies 4 --kl_coeff 0.0 --batch_size 2000 --num_evaluation_episodes 40 --vis_freq 50 --snapshot_gap 50 --steps_per_epoch 10 --min_buffer_size 5000 --buffer_batch_size 256 --lr 0.0015 --alpha_lr 0.0015 --qf_lr 0.0015 --gradient_steps_per_itr 200 --n_epochs 500 
  ```
* Shared
  ```bash
  python run.py sac --env=MetaWorldCDS-v1 --policy_architecture=shared --Q_architecture=shared --n_policies 4 --batch_size 2000 --num_evaluation_episodes 40 --vis_freq 50 --snapshot_gap 50 --steps_per_epoch 10 --min_buffer_size 5000 --buffer_batch_size 256 --lr 0.0015 --alpha_lr 0.0015 --qf_lr 0.0015 --gradient_steps_per_itr 200 --n_epochs 500 --lr 0.0003 --qf_lr 0.0003 --alpha_lr 0.0003 --hidden_sizes 540 540
  ```
* DnC
  ```bash
  python run.py dnc_sac --env=MetaWorldCDS-v1 --n_policies 4 --batch_size 2000 --num_evaluation_episodes 40 --vis_freq 50 --snapshot_gap 50 --steps_per_epoch 10 --min_buffer_size 5000 --buffer_batch_size 256 --lr 0.0015 --alpha_lr 0.0015 --qf_lr 0.0015 --gradient_steps_per_itr 200 --n_epochs 500 --kl_coeff 0.1 --distillation_period 20 --distillation_n_epochs 500 
  ```
* DnC (Reg)
  ```bash
  python run.py dnc_sac --env=MetaWorldCDS-v1 --kl_coeff 0.1 --n_policies 4 --batch_size 2000 --num_evaluation_episodes 40 --vis_freq 50 --snapshot_gap 50 --steps_per_epoch 10 --min_buffer_size 5000 --buffer_batch_size 256 --lr 0.0015 --alpha_lr 0.0015 --qf_lr 0.0015 --gradient_steps_per_itr 200 --n_epochs 500 
  ```
  
* UDS
  ```bash
  python run.py cds_dnc --env=MetaWorldCDS-v1 --kl_coeff 0 --unsupervised True --sharing_quantile=80 --n_policies 4 --batch_size 2000 --num_evaluation_episodes 40 --vis_freq 50 --snapshot_gap 50 --steps_per_epoch 10 --min_buffer_size 5000 --buffer_batch_size 256 --lr 0.0015 --alpha_lr 0.0015 --qf_lr 0.0015 --gradient_steps_per_itr 200 --n_epochs 500 
  ```
  
## Quantitative Results
<p align="center">
    <img src="main_results.png">
</p>