# CoPDT

This is the repository for the paper :  `Adaptable Safe Policy Learning from Multi-task Data with Constraint Prioritized Decision Transformer`.

## Algorithm projects

### Installation Instructions

1. Create Environment
    
    Please first create the conda environment by using:
    
    ```bash
    conda create -n CoPDT python=3.8.19
    ```
    
    Then, install Pytorch and packages in requirements by using:
    
    ```bash
    pip install -r requirements.txt
    ```
    
    After that, please enter directions `DSRL` , `FSRL` and `OSRL` and install them respectively by
    
    ```bash
    pip install -e .
    ```
    
2. Download Datasets
    
    First, download the datasets of original OSRL tasks from [http://data.offline-saferl.org/download](http://data.offline-saferl.org/download).
    
    Then, download the additional datasets we used from [https://drive.google.com/file/d/1RgEQttWxxaOFRqxlry1SO7INsawrhYuT/view?usp=sharing](https://drive.google.com/file/d/1RgEQttWxxaOFRqxlry1SO7INsawrhYuT/view?usp=sharing).
    
    Finally, put the `finetune_data` folder into the `OSRL` folder and the other datas into the `datasets` folder.
    

### CoPDT Training and Deployment

First, enter the OSRL folder:

```bash
cd OSRL
```

In single-constraint settings, please first train the constrained prioritzed RTG generator by

```bash
python examples/train/train_rtg_model.py --task $task --use_prompt False
```

where, $task refers to the target task.

Then, train the simple DT policy by

```bash
python examples/train/train_cdt.py --task $task --use_prompt False
```

Finally, put the path of learned RTG generator into the rtg_model_path value in the class EvalConfig in the file `examples/eval/eval_cdt.py` and the path of learned policy into the path value in the class EvalConfig in the file `examples/eval/eval_cdt.py`, and run the following command for evaluation:

```bash
python examples/eval/eval_cdt.py --conservative True
```

In multi-constraint settings, one should train the environment-specific state action encoders first:

```bash
python examples/train/train_sa_encoder_share.py --task_share $env_name
```

where the env_name is the name of the environment of the task, please refer to `examples/train/train_sa_encoder_share.py` for all valid environment names. 

Then, put the learned state encoder paths and action encoder paths into the state_encoder_paths term and action_encoder_paths term in the ContextEncoderTrainConfig class in `examples/configs/context_encoder_configs.py`. After this, we can train the Constrained Prioritized Prompt Encoder by

```bash
python examples/train/train_context_encoder.py
```

Next, please put the path of the learned prompt encoder into the context_encoder_path term in CDTTrainConfig class in `examples/configs/cdt_configs.py` and the context_encoder_path term in RTGTrainConfig class in `examples/configs/rtg_configs.py` respectively.

Finally, use the following commands for training

```bash
python examples/train/train_mtrtg.py
python examples/train/train_mtcdt.py
```

and similarly fill the paths in the file `examples/eval/eval_mtcdt.py` , and after that run the following command for evaluation:

```bash
python examples/eval/eval_mtcdt.py --conservative True
```

### CoPDT Fine-tuning

We also provide the code for fine-tuning CoPDT by full fine-tuning or LoRA.

First, assign the path of the learned policy to the path term in function train in both `examples/train/train_mtcdt_fft.py` and  `examples/train/train_mtcdt_lora.py`.

Then, one can use the following commands for fine-tuning:

```bash
python examples/train/train_mtcdt_fft.py --target_task $task_id
python examples/train/train_mtcdt_lora.py --target_task $task_id --lora_rank $rank
```

where $task_id is the ID for the 5 new tasks set by us, and $rank is the LoRA rank hyperparameter. For more information about task IDs, please refer to `examples/train/train_mtcdt_fft.py` and  `examples/train/train_mtcdt_lora.py`.

## Note

The implementation is based on [[OSRL](https://www.offline-saferl.org/)] which is open-sourced.