#!/bin/bash

set -euo pipefail

usage() {
    cat <<'USAGE'
Usage: restor_msa_evaluate.sh --base_model <name> --forget_task_name <name> [--base_model_hf_path <hf_id>] [--alphas "a b"] [--betas "x y"] [--unlearn_devices <ids>] [--eval_devices <ids>]

Required arguments:
  --base_model          Base model family (e.g. Llama-3.1-8B-Instruct, Olmo-2-7B-stage1-final)
  --forget_task_name    Identifier of the finetuned forget task (used for save paths)

Optional arguments:
  --base_model_hf_path  Override HF identifier to load the base model (defaults inferred from base_model)
  --alphas              Space-separated list of alpha values (defaults to "0.75 1.0 1.25")
  --betas               Space-separated list of beta values (defaults to "0.0 0.5 1.0")
  --unlearn_devices     CUDA devices for msa_unlearn.py (default "0,1")
  --eval_devices        CUDA devices for evaluation script (default "0")
  -h, --help            Show this help message
USAGE
}

base_model=""
forget_task_name=""
base_model_hf_path=""
alpha_values="1.0 1.5 2.0"
beta_values="0.0"
unlearn_devices="${UNLEARN_DEVICES:-0,1}"
eval_devices="${EVAL_DEVICES:-0}"

echo $eval_devices

while [[ $# -gt 0 ]]; do
    case "$1" in
        --base_model)
            base_model="$2"
            shift 2
            ;;
        --forget_task_name)
            forget_task_name="$2"
            shift 2
            ;;
        --base_model_hf_path)
            base_model_hf_path="$2"
            shift 2
            ;;
        --alphas)
            alpha_values="$2"
            shift 2
            ;;
        --betas)
            beta_values="$2"
            shift 2
            ;;
        --unlearn_devices)
            unlearn_devices="$2"
            shift 2
            ;;
        --eval_devices)
            eval_devices="$2"
            shift 2
            ;;
        -h|--help)
            usage
            exit 0
            ;;
        *)
            echo "Unknown option: $1" >&2
            usage
            exit 1
            ;;
    esac
done

if [[ -z "$base_model" || -z "$forget_task_name" ]]; then
    echo "Error: --base_model and --forget_task_name are required." >&2
    usage
    exit 1
fi

IFS=' ' read -r -a alphas <<< "$alpha_values"
IFS=' ' read -r -a betas <<< "$beta_values"

if [[ ${#alphas[@]} -eq 0 || ${#betas[@]} -eq 0 ]]; then
    echo "Error: alpha and beta lists must be non-empty." >&2
    exit 1
fi

declare -A OLMO_REVISIONS=(
    ["Olmo-2-7B-stage1-final"]="stage1-step928000-tokens3893B"
    ["Olmo-2-7B-stage1-3859B"]="stage1-step920000-tokens3859B"
    ["Olmo-2-7B-stage1-3691B"]="stage1-step880000-tokens3691B"
    ["Olmo-2-7B-stage1-2207B"]="stage1-step526000-tokens2207B"
    ["Olmo-2-7B-stage1-500B"]="stage1-step119000-tokens500B"
)

base_revision=""
restor_eval_model="$base_model"

if [[ -z "$base_model_hf_path" ]]; then
    if [[ "$base_model" == Llama* ]]; then
        base_model_hf_path="meta-llama/${base_model}"
    elif [[ -n "${OLMO_REVISIONS[$base_model]:-}" ]]; then
        base_model_hf_path="allenai/OLMo-2-1124-7B"
        base_revision="${OLMO_REVISIONS[$base_model]}"
    else
        echo "Unsupported base_model: ${base_model}" >&2
        exit 1
    fi
else
    base_revision=""
fi

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"


forget_model="saves/finetune/${forget_task_name}"
retain_model="saves/finetune/${forget_task_name}"

echo "Using base_model_hf_path=${base_model_hf_path}"
echo "forget_model=${forget_model}"
echo "retain_model=${retain_model}"
if [[ "$base_model" == Llama* ]]; then
    path_to_target_model="saves/finetune/RESTOR_Llama-3.1-8B-Instruct"
    restor_eval_model="Llama-3.1-8B-Instruct"
else
    path_to_target_model="saves/finetune/RESTOR_Olmo-2-7B-stage1-final"
    restor_eval_model="Olmo-2-7B-stage1-final"
fi

if [[ ! -d "$forget_model" ]]; then
    echo "Warning: forget model directory ${forget_model} does not exist." >&2
fi

if [[ ! -d "$path_to_target_model" ]]; then
    echo "Warning: target model directory ${path_to_target_model} does not exist." >&2
fi

for alpha in "${alphas[@]}"; do
    for beta in "${betas[@]}"; do
        path_to_unlearned_model="saves/unlearn_msa/${forget_task_name}/msa_alpha_${alpha}_beta_${beta}"
        echo "path_to_target_model=${path_to_target_model}"
        echo "path_to_unlearned_model=${path_to_unlearned_model}"
        echo "alpha=${alpha}"
        echo "beta=${beta}"
        mkdir -p "${path_to_unlearned_model}"

        msa_cmd=(
            python src/utils/msa_unlearn.py
            --base_model "${base_model_hf_path}"
            --forget_model "${forget_model}"
            --retain_model "${retain_model}"
            --target_model "${path_to_target_model}"
            --alpha "-${alpha}"
            --beta "${beta}"
            --save_path "${path_to_unlearned_model}"
        )

        if [[ -n "$base_revision" ]]; then
            msa_cmd+=(--base_revision "${base_revision}")
        fi

        echo "Running MSA unlearn with alpha=${alpha}, beta=${beta}"
        CUDA_VISIBLE_DEVICES="${unlearn_devices}" "${msa_cmd[@]}"

        echo "Evaluating RESTOR metrics for ${path_to_unlearned_model}"
        ./restor_scripts/restor_evaluate.sh \
            --model "${restor_eval_model}" \
            --task_name "${forget_task_name}_alpha_${alpha}_beta_${beta}" \
            --path_to_model "${path_to_unlearned_model}" \
            --batch_size 64 \
            --eval_devices ${eval_devices}

        find "${path_to_unlearned_model}" -maxdepth 1 -type f -name "*.safetensors" -exec rm -f {} \;

    done
done
