#!/usr/bin/env bash
set -euo pipefail
gpu=${1:-0}
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "${ROOT_DIR}"

MASK_CONFIG="configs/mask_training_with_llm_persona_freeze_8b.yaml"
LLM_CONFIG="configs/mask_training_with_llm_persona_unfreeze_8b.yaml"
TARGET_SIZE=10752
MASK_EPOCH=3
MASK_EPOCH_PAD="$(printf "%02d" "${MASK_EPOCH}")"

DATA_BASE="${DATA_BASE:-data/rolebench_subset}"
PERSONA_PREFIX="${PERSONA_PREFIX:-persona_rolebench_user0_subset}"
PERSONA_SUFFIX="${PERSONA_SUFFIX:-.json}"
PERSONA_IDS=(0)

MASK_LR="1e-2"
LLM_LR="5e-5"
WARMUP="0.03"
SCHED_SCALE="1.0"
RUN_TAG="${RUN_TAG:-$(date +%s)}"
RUN_TAG="${RUN_TAG}_${RANDOM}"

declare -A EXPORT_DIRS=()

for pid in "${PERSONA_IDS[@]}"; do
  persona_name="${PERSONA_PREFIX}_${pid}"
  chatml_path="${DATA_BASE}/${persona_name}${PERSONA_SUFFIX}"

  stage1_run="mask_then_llm/maskonly_${persona_name}_${TARGET_SIZE}_m${MASK_LR}_l${LLM_LR}_w${WARMUP}_ms${SCHED_SCALE}_${RUN_TAG}"
  stage1_out="outputs/${stage1_run}"
  CUDA_VISIBLE_DEVICES=${gpu} python -m training.train_mask_directly \
    --config "${MASK_CONFIG}" \
    --output-dir "${stage1_out}" \
    --target-intermediate-size "${TARGET_SIZE}" \
    --chatml-file "${chatml_path}:train:${persona_name}" \
    --test-chatml-file "${chatml_path}:test:${persona_name}" \
    --wandb-run-name "${stage1_run}" \
    --prompt-text-mode system_only \
    --llm-prompt-mode system_user \
    --learning-rate-mask "${MASK_LR}" \
    --learning-rate-llm "${LLM_LR}" \
    --warmup-ratio "${WARMUP}" \
    --mask-scheduler-scale "${SCHED_SCALE}" \
    --export-only-final-epoch

  export_root="exports"
  export_base="${export_root}/${stage1_run}"
  if [[ ! -d "${export_base}" ]]; then
    export_base="$(ls -td "${export_root}/${stage1_run}"_* 2>/dev/null | head -n 1 || true)"
  fi
  if [[ -z "${export_base}" || ! -d "${export_base}" ]]; then
    echo "[ERROR] export base directory not found for run: ${stage1_run}" >&2
    exit 1
  fi
  export_dir="${export_base}/epoch_${MASK_EPOCH_PAD}"
  if [[ ! -d "${export_dir}" ]]; then
    echo "[ERROR] exported model directory not found: ${export_dir}" >&2
    exit 1
  fi
  EXPORT_DIRS["${pid}"]="${export_dir}"
done

for pid in "${PERSONA_IDS[@]}"; do
  persona_name="${PERSONA_PREFIX}_${pid}"
  chatml_path="${DATA_BASE}/${persona_name}${PERSONA_SUFFIX}"
  export_dir="${EXPORT_DIRS[${pid}]-}"
  if [[ -z "${export_dir}" || ! -d "${export_dir}" ]]; then
    echo "[ERROR] exported model directory not found for user ${pid}: ${export_dir}" >&2
    exit 1
  fi

  stage2_run="mask_then_llm/llm_from_export_${persona_name}_${TARGET_SIZE}_l${LLM_LR}_w${WARMUP}_${RUN_TAG}"
  stage2_out="outputs/${stage2_run}"
  CUDA_VISIBLE_DEVICES=${gpu} python -m training.train_exported_llm \
    --config "${LLM_CONFIG}" \
    --base-model "${export_dir}" \
    --output-dir "${stage2_out}" \
    --chatml-file "${chatml_path}:train:${persona_name}" \
    --test-chatml-file "${chatml_path}:test:${persona_name}" \
    --wandb-run-name "${stage2_run}" \
    --prompt-text-mode system_only \
    --llm-prompt-mode system_user \
    --learning-rate-llm "${LLM_LR}" \
    --warmup-ratio "${WARMUP}" \
    --keep-last-epoch-only \
    --skip-final-save
done
