#!/bin/bash

set -euo pipefail

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"

# Defaults
per_device_train_batch_size=4
gradient_accumulation_steps=8

# model_name="Llama-2-7b-hf"
model_name="Llama-3.1-8B"
path_to_base_model="meta-llama/${model_name}"

train_devices="0,1"
eval_devices="0"

accelerate_config="configs/accelerate/default_config.yaml"

echo "Finetuning on MUSE Books"

data_split="Books"


# CUDA_VISIBLE_DEVICES=${eval_devices} python src/eval.py \
#           experiment=eval/muse/default.yaml \
#           data_split=${data_split} \
#           task_name=base_model \
#           model=${model_name} \
#           model.model_args.pretrained_model_name_or_path=${path_to_base_model} \
#           paths.output_dir="saves/muse_llama3_base_model/evals" \
#           retain_logs_path=saves/eval/muse_${model_name}_${data_split}_retrain/MUSE_EVAL.json

data_sub_set="retain"

cmd=(accelerate launch --config_file "${accelerate_config}" --main_process_port "$MASTER_PORT" \
  src/train.py experiment=finetune/muse/default.yaml \
  task_name=muse_books_${data_sub_set} \
  data_split=${data_split} \
  data_sub_set=${data_sub_set}
  model=${model_name} \
  trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
  trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps} \
  trainer.args.ddp_find_unused_parameters=true \
  trainer.args.gradient_checkpointing=true \
  trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps} \
  trainer.args.eval_strategy="no")

# CUDA_VISIBLE_DEVICES=${train_devices} "${cmd[@]}"

path_to_retain_model="saves/finetune/muse_books_llama3_${data_sub_set}"

# CUDA_VISIBLE_DEVICES=${eval_devices} python src/eval.py \
#           experiment=eval/muse/default.yaml \
#           data_split=${data_split} \
#           task_name=${path_to_retain_model} \
#           model=${model_name} \
#           model.model_args.pretrained_model_name_or_path=${path_to_retain_model} \
#           paths.output_dir=${path_to_retain_model}/evals \
#           retain_logs_path=saves/eval/muse_${model_name}_${data_split}_retrain/MUSE_EVAL.json



data_sub_set="full"

cmd=(accelerate launch --config_file "${accelerate_config}" --main_process_port "$MASTER_PORT" \
  src/train.py experiment=finetune/muse/default.yaml \
  task_name=muse_books_llama3_${data_sub_set} \
  data_split=${data_split} \
  data_sub_set=${data_sub_set} \
  model=${model_name} \
  trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
  trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps} \
  trainer.args.ddp_find_unused_parameters=true \
  trainer.args.gradient_checkpointing=true \
  trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps} \
  trainer.args.eval_strategy="no")


CUDA_VISIBLE_DEVICES=${train_devices} "${cmd[@]}"

path_to_full_model="saves/finetune/muse_books_llama3_${data_sub_set}"

CUDA_VISIBLE_DEVICES=${eval_devices} python src/eval.py \
          experiment=eval/muse/default.yaml \
          data_split=${data_split} \
          task_name=${path_to_full_model} \
          model=${model_name} \
          model.model_args.pretrained_model_name_or_path=${path_to_full_model} \
          paths.output_dir=${path_to_full_model}/evals \
          retain_logs_path=${path_to_retain_model}/evals/MUSE_EVAL.json
