#!/bin/bash

set -euo pipefail

if [ $# -lt 1 ] || [ $# -gt 2 ]; then
    echo "Usage: $0 <model> [path_to_msa_base_model]"
    echo "  model in {Llama-3.1-8B-Instruct, Olmo-2-7B-stage1-{X}}"
    exit 1
fi

model="$1"
path_to_msa_base_model="${2:-}"

case "${model}" in
    "Llama-3.1-8B-Instruct"|\
    "Llama-3.1-8B"|\
    "Olmo-2-7B-stage1-final"|\
    "Olmo-2-7B-stage1-3859B"|\
    "Olmo-2-7B-stage1-3691B"|\
    "Olmo-2-7B-stage1-2207B"|\
    "Olmo-2-7B-stage1-500B")
        ;;
    *)
        echo "Unsupported model: ${model}"
        echo "Supported models: Llama-3.1-8B-Instruct, Olmo-2-7B-stage1-final, Olmo-2-7B-stage1-3859B, Olmo-2-7B-stage1-3691B, Olmo-2-7B-stage1-2207B, Olmo-2-7B-stage1-500B"
        exit 1
        ;;
esac

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"

per_device_train_batch_size=4 # Effective batch size 32 on two GPUs with gradient_accumulation_steps=8
gradient_accumulation_steps=8
finetune_epoch=3

task_name="msa_RESTOR_${model}"

if [[ -n "$path_to_msa_base_model" ]]; then
    task_name="msa_RESTOR_${model}_RESTOR"
fi

echo $task_name

export CUDA_VISIBLE_DEVICES=0,1

cmd=(
    accelerate
    launch
    --config_file configs/accelerate/default_config.yaml
    --main_process_port "$MASTER_PORT"
    src/train.py
    experiment=finetune/restor/default.yaml
    task_name=${task_name}
    model=${model}
    data/datasets@data.train=RESTOR
    trainer.args.per_device_train_batch_size=${per_device_train_batch_size}
    trainer.args.ddp_find_unused_parameters=true
    trainer.args.gradient_checkpointing=true
    trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps}
    trainer.args.num_train_epochs=${finetune_epoch}
    trainer.args.eval_strategy="no"
)

if [[ -n "$path_to_msa_base_model" ]]; then
    cmd+=("model.model_args.pretrained_model_name_or_path=${path_to_msa_base_model}")
fi

"${cmd[@]}"
