

FLOPS=1e18
MODEL_NAME=pythia-160m
TRAIN_DATA_FILE=/data/cancer_liver_low_risk/train.pkl
VALID_DATA_FILE=/data/cancer_liver_low_risk/valid.pkl
TEST_DATA_FILE=/data/cancer_liver_low_risk/test.pkl


OUTPUT_DIR=/data/finetuned_ckpt/flop_${FLOPS}/${MODEL_NAME}/cancer_liver_low_risk
PRETRAINED_MODEL=/data/pretrained_ckpt/flop_${FLOPS}/${MODEL_NAME}/step-4800-ckpt-converted/lit_model.pth


TRAIN_BATCH_SIZE=16
EVAL_BATCH_SIZE=16


LEARNING_RATE=1e-5
WARMUP_RATIO=0.1
DECAY_RATIO=0.1
WEIGHT_DECAY=0.01
ACCUM_STEP=1
EPOCH=5
NUM_WORKER=8
GRAD_NORM=1.0


LOG_STEP=1
NUM_EVAL=2
DEVICES=8


for i in  {28,35,42,49,56}
do
echo "start training on seed $i"

export WANDB_NAME="Foundation Model Finetuning ${MODEL_NAME} flop=${FLOPS} cancer_liver on seed $i low-risk "

srun python3 finetune/full_code.py \
    --model_name $MODEL_NAME \
    --train_data_file $TRAIN_DATA_FILE \
    --eval_data_file $VALID_DATA_FILE \
    --pretrained_model_path $PRETRAINED_MODEL \
    --output_dir $OUTPUT_DIR \
    --per_device_train_batch_size $TRAIN_BATCH_SIZE \
    --per_device_eval_batch_size $EVAL_BATCH_SIZE \
    --learning_rate $LEARNING_RATE  \
    --warmup_ratio $WARMUP_RATIO \
    --decay_ratio $DECAY_RATIO \
    --weight_decay $WEIGHT_DECAY \
    --max_gradient_norm $GRAD_NORM \
    --gradient_accumulation_steps $ACCUM_STEP \
    --num_train_epochs $EPOCH  \
    --dataloader_num_workers $NUM_WORKER \
    --logging_steps $LOG_STEP \
    --num_eval_per_epoch $NUM_EVAL \
    --devices $DEVICES \
    --is_test False \
    --seed $i &&

echo "start testing on seed $i"

python3 finetune/full_code.py \
    --model_name $MODEL_NAME \
    --train_data_file $TRAIN_DATA_FILE \
    --eval_data_file $TEST_DATA_FILE \
    --output_dir $OUTPUT_DIR \
    --per_device_train_batch_size $TRAIN_BATCH_SIZE \
    --per_device_eval_batch_size $EVAL_BATCH_SIZE \
    --learning_rate $LEARNING_RATE  \
    --warmup_ratio $WARMUP_RATIO \
    --decay_ratio $DECAY_RATIO \
    --weight_decay $WEIGHT_DECAY \
    --max_gradient_norm $GRAD_NORM \
    --gradient_accumulation_steps $ACCUM_STEP \
    --num_train_epochs $EPOCH  \
    --dataloader_num_workers $NUM_WORKER \
    --logging_steps $LOG_STEP \
    --num_eval_per_epoch $NUM_EVAL \
    --devices 1 \
    --is_test True \
    --seed $i 
done