#!/bin/bash

set -euo pipefail

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"

############################################
# Defaults (OLMo stage-1 models)
############################################
per_device_train_batch_size=4
forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"
msa_epochs=5
train_devices="0,1"
accelerate_config="configs/accelerate/default_config.yaml"

model="Olmo-2-7B"
model_name="Olmo-2-7B-stage1-final"
path_to_tofu_model="saves/finetune/tofu_Olmo-2-7B-stage1-final_full"
use_constant_lr=0
tag="${model_name}_tofu"

usage() {
  cat <<EOF
Usage: $(basename "$0") [options]

Options:
  --forget_split NAME          Forget split (default: ${forget_split})
  --retain_split NAME          Retain split (default: ${retain_split})
  --holdout_split NAME         Holdout split (default: ${holdout_split})
  --model_names "NAMES"        Space/comma-separated OLMo model names (default: ${model_names[*]})
  --epochs N                   Number of finetune epochs (default: ${msa_epochs})
  --batch_size N               Per-device batch size (default: ${per_device_train_batch_size})
  --accelerate_config PATH     Accelerate config file (default: ${accelerate_config})
  --train_devices DEVICES      CUDA devices (default: ${train_devices})
  --use_constant_lr            Use constant LR scheduler and suffix task_name
  -h | --help                  Show this help

Examples:
  # Run for a subset of OLMo stage-1 checkpoints
  ./scripts/experiments/Olmo/msa_finetune.sh \
    --model_names "Olmo-2-1B-stage1-210B,Olmo-2-1B-stage1-final" \
    --epochs 5 --batch_size 4
EOF
}

# Parse CLI
while [[ $# -gt 0 ]]; do
  case "$1" in
    --forget_split) forget_split="$2"; shift 2;;
    --retain_split) retain_split="$2"; shift 2;;
    --holdout_split) holdout_split="$2"; shift 2;;
    --model_names)
      s="$2"; s="${s//,/ }"; read -r -a model_names <<< "$s"; shift 2;;
    --epochs) msa_epochs="$2"; shift 2;;
    --batch_size) per_device_train_batch_size="$2"; shift 2;;
    --accelerate_config) accelerate_config="$2"; shift 2;;
    --train_devices) train_devices="$2"; shift 2;;
    --use_constant_lr) use_constant_lr=1; shift 1;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done

echo "Finetuning MSA components for OLMo models | epochs=$msa_epochs | batch_size=$per_device_train_batch_size"

# Forget finetune
task_suffix=""
if [[ "${use_constant_lr}" -eq 1 ]]; then
  task_suffix="_constant_lr"
fi



echo "[OLMo] Finetune MSA for ${model_name}"

# Compute task suffix once
task_suffix=""
if [[ "${use_constant_lr}" -eq 1 ]]; then
  task_suffix="_constant_lr"
fi

# Forget MSA
# cmd_f=(accelerate launch --config_file "${accelerate_config}" --main_process_port "$MASTER_PORT" \
#   src/train.py experiment=finetune/msa/forget.yaml \
#   task_name=msa_${tag}_${forget_split}${task_suffix} \
#   model=${model} \
#   data/datasets@data.train=TOFU_QA_forget \
#   data.train.TOFU_QA_forget.args.hf_args.name=${forget_split} \
#   trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
#   trainer.args.ddp_find_unused_parameters=true \
#   trainer.args.gradient_checkpointing=true \
#   trainer.args.num_train_epochs=${msa_epochs} \
#   model.model_args.pretrained_model_name_or_path=${path_to_tofu_model} \
#   trainer.args.eval_strategy="no")
# if [[ "${use_constant_lr}" -eq 1 ]]; then
#   cmd_f+=(trainer.args.lr_scheduler_type=constant)
# fi
# CUDA_VISIBLE_DEVICES=${train_devices} "${cmd_f[@]}"

# Retain MSA
cmd_r=(accelerate launch --config_file "${accelerate_config}" --main_process_port "$MASTER_PORT" \
  src/train.py experiment=finetune/msa/retain.yaml \
  task_name=msa_${tag}_${retain_split}${task_suffix} \
  model=${model} \
  data/datasets@data.train=TOFU_QA_retain \
  data.train.TOFU_QA_retain.args.hf_args.name=${retain_split} \
  trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
  trainer.args.ddp_find_unused_parameters=true \
  trainer.args.gradient_checkpointing=true \
  trainer.args.num_train_epochs=${msa_epochs} \
  model.model_args.pretrained_model_name_or_path=${path_to_tofu_model} \
  trainer.args.eval_strategy="no")
if [[ "${use_constant_lr}" -eq 1 ]]; then
  cmd_r+=(trainer.args.lr_scheduler_type=constant)
fi
CUDA_VISIBLE_DEVICES=${train_devices} "${cmd_r[@]}"

# How to run (cheatsheet):
#   ./scripts/experiments/Olmo/msa_finetune.sh
#   ./scripts/experiments/Olmo/msa_finetune.sh --model_names "Olmo-2-1B-stage1-210B,Olmo-2-1B-stage1-final" \
#       --epochs 3 --batch_size 2 --train_devices 0,1
