#!/bin/bash

set -euo pipefail

if [ $# -lt 1 ]; then
    echo "Usage: $0 <model>"
    echo "  model in {Llama-3.1-8B-Instruct, Olmo-2-7B-stage1-final}"
    exit 1
fi

model="$1"
case "${model}" in
    "Llama-3.1-8B-Instruct"|"Olmo-2-7B-stage1-final")
        ;;
    *)
        echo "Unsupported model: ${model}"
        echo "Supported models: Llama-3.1-8B-Instruct, Olmo-2-7B-stage1-final"
        exit 1
        ;;
esac

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"


per_device_train_batch_size=4 # Effective batch size 32 on two GPUs with gradent_accumulation_steps=8
gradient_accumulation_steps=8
finetune_epoch=5

CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file configs/accelerate/default_config.yaml --main_process_port $MASTER_PORT \
    src/train.py experiment=finetune/restor/default.yaml \
    task_name="RESTOR_${model}" \
    model=${model} \
    data/datasets@data.train=RESTOR \
    trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
    trainer.args.ddp_find_unused_parameters=true \
    trainer.args.gradient_checkpointing=true \
    trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps} \
    trainer.args.num_train_epochs=${finetune_epoch} \
    trainer.args.eval_strategy="no"
