# Stable Alignment

[![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

This is the official repo for the Stable Alignment project. We aim to provide a RLHF alternative which is superior in alignment performance, highly-efficient in data learning, and easy to deploy in scaled-up settings. Instead of training an extra reward model that can be gamed during optimization, we directly train on the recorded interaction data in simulated social games. We find high-quality data + reliable algorithm is the secret recipe for stable alignment learning.

The repo contains:

- The code for [running social simulation in Sandbox](#sandbox-simulation).
- The [94K interaction data](#data-release) used for alignment training.
- The code for [training with stable alignment](#training-with-stable-alignment).

## Sandbox Simulation

### Installation

```bash
# install development environment
pip install -r requirements.txt
# install dependencies for package re-distribution
pip install -e .
```

### Data preparation and API key setup

- Initial data is stored at `assets/<dataset_name>/labeled_prior.jsonl` (with Git LFS). Download and place the `jsonl` file in the correponding folder (e.g. `assets/hh-rlhf/`). \
  After a round of simulation, the simulated interaction data and metrics will be saved at `data/cache/world_<world_id>/`.
- Place your OpenAI API key in `.env` inside the project root folder.

### Run simulation

Navigate to the project root folder and run simulation with customized settings. For example:

```bash
python stable_alignment/simulation.py -model_type 'text-davinci-002' -obs_model_type 'gpt-3.5-turbo' -world_id 1 -init_setting 'all_bad' -n_round '2' -size '4' -dataset_name 'hh-rlhf'
```

## Data Release

We release the 94K interaction data used for alignment training. The data is in the `assets` folder.

<details>
<summary> <strong> The Detailed Format of Alignment Data </strong> </summary>

Alignment Imitation

```
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{Question}

### Response:
{Draft Response / Revised Response}
```

Self-Critic

```
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{Question} Please comment on the response in the input in terms of how well it is aligned with shared social values.

### Input:
{Draft Answer}

### Response:
{Feedback}
```

Realignment

```
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{Question} {Draft Response} Please first common on this response, and then revise it to make it better aligned with shared social values.

### Response:
{Feedback} Here is the revised response: {Revised Response}
```

</details>

Here is the statistics of the alignment data we collect from the simulated interactions in Sandbox. For training, we use a ratio of 5:1:1 to sample a mixture of Alignment Imitation, Self-Critic, and Realignment data respectively. The sampled data has been already included in the path `assets/sandbox_v1.json`.

| Data / Social Agent Type | text-davinci-002 | text-davinci-003 | ChatGPT | Total |
| ------------------------ | ---------------- | ---------------- | ------- | ----- |
| Alignment Imitation      | 9.8k             | 10k              | 10k     | 29.8k |
| Self-Critic              | 17k              | 20k              | 20k     | 57k   |
| Realignment              | 3.3k             | 3k               | 0.7k    | 7k    |
| Total                    | 30.1k            | 33k              | 30.7k   | 93.8k |

## Training with Stable Alignment

```bash
torchrun --nproc_per_node=4 --master_port=36646 train_alignment.py \
      --model_name_or_path "/workspace/hhh_sft" \  # path to your SFT model
      --data_path "/assets/sandbox_v1.json" \ # path to the alignment data
      --bf16 True \
      --output_dir "/workspace/<your_output_lm_name>" \
      --num_train_epochs 7 \
      --per_device_train_batch_size 1 \  # batch size has to be 1 for alignment training
      --per_device_eval_batch_size 1 \
      --gradient_accumulation_steps 8 \
      --evaluation_strategy "no" \
      --save_strategy "steps" \
      --save_steps 200 \
      --save_total_limit 1 \
      --learning_rate 2e-5 \
      --weight_decay 0. \
      --warmup_ratio 0.03 \
      --lr_scheduler_type "cosine" \
      --logging_steps 1 \
      --fsdp "shard_grad_op auto_wrap" \  # change to "full_shard auto_wrap" if OOM
      --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
      --tf32 True \
      --model_max_length 360 \  # change to shorter length if OOM
      --rating_scale 7 \  # the scale of the ratings. 7 for 1-7, 10 for 1-10, etc.
      --margin 10 \  # constant, see the paper
      --max_flow False \  # mean or max for the penalty
      --ratio 0.2 \  # control the ratio of the penalty
      --num_comp 3
```
