#!/bin/bash

set -euo pipefail

SAVE_BASE="/cmlscratch//open-unlearning/saves/finetune"

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"

usage() {
    cat <<'USAGE'
Usage: restor_unlearn.sh --model <name> --trainer <name> --lr <float> [--alpha <float>] [--gamma <float>]

Required arguments:
  --model      Model name to unlearn
  --trainer    Trainer method to use
  --lr         Learning rate for training

Optional arguments:
  --alpha      Alpha coefficient for the trainer method
  --gamma      Gamma coefficient for the trainer method
  -h, --help   Show this help message and exit

Example:
  ./restor_scripts/restor_unlearn.sh \
      --model Llama-3.2-1B-Instruct \
      --trainer NPO \
      --lr 1e-5 \
      --alpha 4 \
      --gamma 4
USAGE
}

model=""
trainer=""
learning_rate=""
alpha=""
gamma=""

while [[ $# -gt 0 ]]; do
    case "$1" in
        --model)
            [[ -n "${2:-}" && "${2:-}" != --* ]] || { echo "Error: --model requires a value"; usage; exit 1; }
            model="$2"
            shift 2
            ;;
        --trainer)
            [[ -n "${2:-}" && "${2:-}" != --* ]] || { echo "Error: --trainer requires a value"; usage; exit 1; }
            trainer="$2"
            shift 2
            ;;
        --lr)
            [[ -n "${2:-}" && "${2:-}" != --* ]] || { echo "Error: --lr requires a value"; usage; exit 1; }
            learning_rate="$2"
            shift 2
            ;;
        --alpha)
            [[ -n "${2:-}" && "${2:-}" != --* ]] || { echo "Error: --alpha requires a value"; usage; exit 1; }
            alpha="$2"
            shift 2
            ;;
        --gamma)
            [[ -n "${2:-}" && "${2:-}" != --* ]] || { echo "Error: --gamma requires a value"; usage; exit 1; }
            gamma="$2"
            shift 2
            ;;
        -h|--help)
            usage
            exit 0
            ;;
        *)
            echo "Error: Unknown option $1"
            usage
            exit 1
            ;;
    esac
done

if [[ -z "$model" || -z "$trainer" || -z "$learning_rate" ]]; then
    echo "Error: --model, --trainer, and --lr are required"
    usage
    exit 1
fi

echo "[setup] Parsed arguments: model=${model}, trainer=${trainer}, lr=${learning_rate}, alpha=${alpha:-N/A}, gamma=${gamma:-N/A}"

experiment="unlearn/restor/default.yaml"
path_to_target_model="saves/finetune/RESTOR_${model}"

per_device_train_batch_size=4
gradient_accumulation_steps=8

trainer_config="${trainer}"
if [[ -n "$alpha" ]]; then
    trainer_config+="_a${alpha}"
fi
if [[ -n "$gamma" ]]; then
    trainer_config+="_g${gamma}"
fi

trainer_config+="_lr${learning_rate}"
task_name="RESTOR_${model}_${trainer_config}"

echo "[setup] Task name resolved to ${task_name}"

echo "[1/3] Launching unlearning job"

export CUDA_VISIBLE_DEVICES=0,1

cmd=(
    accelerate
    launch
    --config_file
    configs/accelerate/default_config.yaml
    --main_process_port
    "$MASTER_PORT"
    src/train.py
    --config-name=unlearn.yaml
    experiment=${experiment}
    trainer=${trainer}
    task_name=${task_name}
    model=${model}
    model.model_args.pretrained_model_name_or_path=${path_to_target_model}
    trainer.args.per_device_train_batch_size=${per_device_train_batch_size}
    trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps}
    trainer.args.ddp_find_unused_parameters=true
    trainer.args.gradient_checkpointing=true
    trainer.args.learning_rate=${learning_rate}
    trainer.args.eval_strategy="no"
)

if [[ -n "$alpha" ]]; then
    cmd+=("trainer.method_args.alpha=${alpha}")
fi
if [[ -n "$gamma" ]]; then
    cmd+=("trainer.method_args.gamma=${gamma}")
fi


"${cmd[@]}"

echo "[2/3] Unlearning run completed"

checkpoint_root="saves/unlearn/${task_name}"
if [[ ! -d "${checkpoint_root}" ]]; then
    echo "Warning: No checkpoint directory found at ${checkpoint_root}"
    exit 0
fi

shopt -s nullglob
checkpoints=("${checkpoint_root}"/checkpoint-*)
shopt -u nullglob

if [[ ${#checkpoints[@]} -eq 0 ]]; then
    echo "Warning: No checkpoints found under ${checkpoint_root}"
    exit 0
fi

echo "[3/3] Found ${#checkpoints[@]} checkpoint(s) under ${checkpoint_root}"

for checkpoint_path in "${checkpoints[@]}"; do
    checkpoint_dir=$(basename "${checkpoint_path}")
    ckpt=${checkpoint_dir#checkpoint-}

    echo "--> Evaluating checkpoint ${ckpt}"
    ./restor_scripts/restor_evaluate.sh \
        --model "${model}" \
        --task_name "${task_name}_ckpt_${ckpt}" \
        --path_to_model "${checkpoint_path}" \
        --batch_size 64 \
        --eval_devices 0

    echo "--> Cleaning safetensors for checkpoint ${ckpt}"
    find "${checkpoint_path}" -maxdepth 1 -type f -name "*.safetensors" -exec rm -f {} \;

done

find "${checkpoint_root}" -maxdepth 1 -type f -name "*.safetensors" -exec rm -f {} \;

echo "All checkpoints evaluated and cleaned"
