#!/bin/bash

export WANDB_MODE=offline
export MASTER_PORT=$((29000 + RANDOM % 1000))
export CUBLAS_WORKSPACE_CONFIG=:16:8

### Mistral Config ###
# model_name_or_path=mistralai/Mistral-7B-Instruct-v0.2
# user_tag="[INST]"
# assistant_tag="[/INST]"
# lorra_alpha=5

### Llama Config ###
model_name_or_path=/data/public_models/huggingface/meta-llama/Meta-Llama-3-8B-Instruct
user_tag="<|start_header_id|>user<|end_header_id|>"
assistant_tag="<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
lorra_alpha=10


layers="10,20"
transform_layers="-1"

output_dir="./out/llama3_center_constant_rmu_50"
valsets="harmful_behaviors#harmless_behaviors#arc-c"

echo "model_name_or_path=$model_name_or_path"
echo "user_tag=$user_tag"
echo "assistant_tag=$assistant_tag"
echo "output_dir=$output_dir"

cd ..

accelerate launch --config_file configs/accelerate_zero1.yaml \
    --num_processes 1 --main_process_port $MASTER_PORT --deepspeed_hostfile ds_hostfile \
    src_short_circuiting/lorra_short_circuiting.py \
    --model_name_or_path $model_name_or_path \
    --user_tag $user_tag \
    --assistant_tag $assistant_tag \
    --target_layers $layers \
    --transform_layers $transform_layers \
    --lorra_alpha $lorra_alpha \
    --lora_r 16 \
    --lora_alpha 16 \
    --lora_dropout 0.05 \
    --output_dir  $output_dir \
    --overwrite_output_dir \
    --max_steps 150 \
    --bf16 True \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 32 \
    --gradient_accumulation_steps 1 \
    --do_eval \
    --evaluation_strategy "steps" \
    --eval_steps 1000  \
    --save_total_limit 0 \
    --learning_rate 1e-4 \
    --weight_decay 0. \
    --lr_scheduler_type "constant" \
    --logging_strategy "steps" \
    --logging_steps 10 \
    --tf32 True \
    --model_max_length 8192 \
    --q_lora False \
    --gradient_checkpointing True \
    --report_to none \
    --log_every 1 \
    --sc_loss_type "center_constant_rmu_50"