#!/bin/bash
#SBATCH --job-name=USW-GAMMA-alpha0.1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=10
#SBATCH --mem-per-cpu=5000
#SBATCH --gres=gpu:1
#SBATCH --time=2-00:00:00
#SBATCH --mail-type=BEGIN,END,FAIL
#SBATCH --mail-user=
#SBATCH --output=logs/%x.out
#SBATCH --error=logs/%x.err
#SBATCH --partition=A100

output_dir='output/stage5_fn_2epoch_alpha0.1'
mkdir -p $output_dir
cp "$0" ${output_dir}/$(date +"%Y-%m-%d-%H-%M-%S").sh

conda init bash
source ~/.bashrc

python finetune.py \
    --base_model 'ckpt/stage4_epoch2/pytorch_model.bin' \
    --data_path 'CompA-R.json' \
    --output_dir $output_dir \
    --batch_size 164 \
    --micro_batch_size 1 \
    --num_epochs 2 \
    --learning_rate 1e-4 \
    --cutoff_len 108 \
    --val_set_size 0 \
    --lora_r 8 \
    --lora_alpha 16 \
    --lora_dropout 0.05 \
    --lora_target_modules '[q_proj,v_proj]' \
    --train_on_inputs \
    --group_by_length \
    --wandb_run_name ${output_dir} \
    --save_steps 100 \
    --trainable_params qformer_all

# torchrun --nproc_per_node=2 --master_port=1234 finetune.py \
#     --base_model 'ckpt/stage4_epoch2/pytorch_model.bin' \
#     --data_path 'CompA-R.json' \
#     --output_dir $output_dir \
#     --batch_size 164 \
#     --micro_batch_size 2 \
#     --num_epochs 1 \
#     --learning_rate 1e-4 \
#     --cutoff_len 108 \
#     --val_set_size 0 \
#     --lora_r 8 \
#     --lora_alpha 16 \
#     --lora_dropout 0.05 \
#     --lora_target_modules '[q_proj,v_proj]' \
#     --train_on_inputs \
#     --group_by_length \
#     --wandb_run_name ${output_dir} \
#     --save_steps 100 \
#     --trainable_params qformer_all
