# Augmented Intermediate Representations (AIR)

All experiments were done on a node with 8 x A100 (80GB) GPUs

## 0. Setup
a. uv venv
b. source .venv/bin/activate
c. uv pip install -r requirements.txt
d. Edit .venv/lib/python*/site-packages/alpaca_eval/evaluators_configs/alpaca_eval_vllm_llama3_70b_fn/configs.yaml and set:
    model_name: "meta-llama/Meta-Llama-3-70B-Instruct" 

## 1. Download, create training/eval datasets
python ./src/create_data.py
python ./src/get_sep.py

## 2. Non-Adversarial Instruction Tuning
(You may have to run "hf auth login" and enter your HF token to access this model)
accelerate launch --config_file=./configs/deepspeed.yaml ./src/instr_tune.py --model=llama3.2_3b --lr=2e-5 --defense=air

## 3. Adversarial Robustness Training
For DPO (PEFT):
accelerate launch --config_file=./configs/fsdp.yaml ./src/align.py --model=llama3.2_3b --lr=2e-4 --defense=air --trainer=dpo --peft

For SFT (FFT):
accelerate launch --config_file=./configs/deepspeed.yaml ./src/align.py --model=llama3.2_3b --lr=2e-4 --defense=air --trainer=sft

## 4. Utility Evaluation
accelerate launch --config_file=./configs/ddp.yaml ./src/eval.py --model=llama3.2_3b --trainer=dpo --defense=air
./src/alpaca_eval.sh llama3.2_3b dpo air

Results can be found in ./exp/alpaca/llama3.2_3b/dpo/air/eval/alpaca_eval_vllm_llama3_70b_fn/leaderboard.csv

## 5. Static Attack Evaluation
accelerate launch --config_file=./configs/ddp.yaml ./src/static_attack.py --model=llama3.2_3b --trainer=dpo --defense=air --attack=ignore

## 6. Compute sensitivity (needed for Astra attack)
python src/compute_sensitivity.py --model=llama3.2_3b --trainer=dpo --defense=air  

## 7. GCG Attack
python ./src/grad_attack.py --model=llama3.2_3b --trainer=dpo --defense=air --n_adv_tok=100 --n_steps=200 --mode=gcg

## 8. Astra Attack 
python ./src/grad_attack.py --model=llama3.2_3b --trainer=dpo --defense=air --n_adv_tok=100 --n_steps=200 --mode=astra

## 9. SEP Evaluation
accelerate launch --config_file=./configs/ddp.yaml ./src/eval_sep.py --model=llama3.2_3b --trainer=dpo --defense=air
