#!/bin/bash

set -euo pipefail

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"

# Defaults
per_device_train_batch_size=4
gradient_accumulation_steps=8

forget_split="forget"
retain_split="retain"

msa_forget_epochs=10
msa_retain_epochs=4

# model_name="Llama-2-7b-hf"
model_name="Llama-3.1-8B"
train_devices="0,1"

accelerate_config="configs/accelerate/default_config.yaml"

muse_model="saves/finetune/muse_books_llama3_full"

run_what=(target pretrained)

usage() {
  cat <<EOF
Usage: $(basename "$0") [options]

Options:
  --forget_split NAME          Forget split (default: ${forget_split})
  --retain_split NAME          Retain split (default: ${retain_split})
  --model_name NAME          Llama model family (default: ${model_name})
  --muse_model_path PATH            Path or HF id for the MUSE model (default derives from model_family)
  --accelerate_config PATH     Accelerate config file (default: ${accelerate_config})
  --run "modes"               Which to run: instruct,tofu,pretrained (default: ${run_what[*]})
  -h | --help                  Show this help
EOF
}

# Parse CLI
tofu_model_cli=""
while [[ $# -gt 0 ]]; do
  case "$1" in
    --forget_split) forget_split="$2"; shift 2;;
    --retain_split) retain_split="$2"; shift 2;;
    --model_name) model_name="$2"; shift 2;;
    --muse_model) tofu_model_cli="$2"; shift 2;;
    --accelerate_config) accelerate_config="$2"; shift 2;;
    --train_devices) train_devices="$2"; shift 2;;
    --run) s="$2"; s="${s//,/ }"; read -r -a run_what <<< "$s"; shift 2;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done

# Derived model identifiers
path_to_base_pretrained_model="meta-llama/${model_name}"
base_pretrained_model="${model_name}"


run_finetune_mode() {
  local mode="$1"  # target | pretrained

  local model_name=""
  local pretrained_path=""
  local method_tag=""

  case "$mode" in
    target)
      model_name="$base_pretrained_model"
      pretrained_path="$muse_model"
      method_tag="${model_name}_muse_books"
      ;;
    pretrained)
      model_name="$base_pretrained_model"
      pretrained_path="$path_to_base_pretrained_model"
      method_tag="${base_pretrained_model}"
      ;;
    *) echo "Unknown mode: $mode"; exit 1;;
  esac


  echo $pretrained_path
  echo $model_name
  
  echo "Finetuning MSA components for mode=$mode | model=$model_name | epochs=$msa_forget_epochs | batch_size=$per_device_train_batch_size"

  cmd=(accelerate launch --config_file "${accelerate_config}" --main_process_port "$MASTER_PORT" \
    src/train.py experiment=finetune/msa/forget_books.yaml \
    task_name=msa_muse_books_${method_tag}_${forget_split} \
    data_split="Books" \
    model=${model_name} \
    trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
    trainer.args.ddp_find_unused_parameters=true \
    trainer.args.gradient_checkpointing=true \
    trainer.args.num_train_epochs=${msa_forget_epochs} \
    trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps} \
    trainer.args.eval_strategy="no" \
    model.model_args.pretrained_model_name_or_path=${pretrained_path})

  CUDA_VISIBLE_DEVICES=${train_devices} "${cmd[@]}"


  cmd=(accelerate launch --config_file "${accelerate_config}" --main_process_port "$MASTER_PORT" \
    src/train.py experiment=finetune/msa/retain_books.yaml \
    task_name=msa_muse_books_${method_tag}_${retain_split} \
    model=${model_name} \
    trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
    trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps} \
    trainer.args.ddp_find_unused_parameters=true \
    trainer.args.gradient_checkpointing=true \
    trainer.args.num_train_epochs=${msa_retain_epochs} \
    trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps} \
    trainer.args.eval_strategy="no" \
    model.model_args.pretrained_model_name_or_path=${pretrained_path})
  
  CUDA_VISIBLE_DEVICES=${train_devices} "${cmd[@]}"
}

for mode in "${run_what[@]}"; do
  run_finetune_mode "$mode"
done
