python -m torch.distributed.run  --nproc_per_node=1 --master_port=30425 align.py \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct   \
--data_path data/training/alpaca_data_cleaned.json \
--cache_dir /checkpoint_gcg/CacheDir  \
--evaluation_strategy "no"  \
--save_strategy "steps"             \
--save_steps 1   \
--logging_steps 1      \
--per_device_train_batch_size 8  \
--learning_rate 1.6e-4           \
--fsdp_transformer_layer_cls_to_wrap "LlamaDecoderLayer"   \
--fsdp "full_shard auto_wrap"    \
--lr_scheduler_type "cosine"         \
--gradient_accumulation_steps 8         \
--output_dir meta-llama/Meta-Llama-3-8B-Instruct_dpo__NaiveCompletion_2025   \
--num_train_epochs 3     \
--attack NaiveCompletion          \
--alignment dpo  \
--bf16 False  \
--fp16 True  \
--tf32 False  \
--gradient_checkpointing True  \
--model_max_length 512 \
--save_only_model True