#!/usr/bin/env bash
# 串行执行所有 rank；即使某轮失败也继续下一轮
set -u -o pipefail   # 不用 `-e`，避免失败时全局退出

# ===== GPU & 调试 =====
export CUDA_VISIBLE_DEVICES=4,5,6,7
export CUDA_LAUNCH_BLOCKING=1

# ===== 你的专属基目录 =====
BASE="/data2/xinluo"

# 1) 临时目录 → /data2/xinluo/tmp
export TMPDIR="$BASE/tmp"
export TMP="$TMPDIR"
export TEMP="$TMPDIR"
mkdir -p "$TMPDIR" && chmod 700 "$TMPDIR"

# 2) 缓存目录统统挪到 /data2/xinluo/cache
export HF_HOME="$BASE/cache/hf"
export TRANSFORMERS_CACHE="$BASE/cache/hf/transformers"
export HF_DATASETS_CACHE="$BASE/cache/hf/datasets"
export TORCH_HOME="$BASE/cache/torch"
export XDG_CACHE_HOME="$BASE/cache/xdg"
mkdir -p "$HF_HOME" "$TRANSFORMERS_CACHE" "$HF_DATASETS_CACHE" "$TORCH_HOME" "$XDG_CACHE_HOME"

# （可选）固定端口，避免冲突
export MASTER_PORT=29501

# ===== 数据/权重路径 =====
SA1B_TRAIN="/data1/xinluo/datasets/sa_000000"
SA1B_VAL="/data1/xinluo/datasets/sa_000001"
SAMED2D="/data/xinluo/Datasets/SAMED/"
MED2D_CKPT="/data1/xinluo/projects/SAM-Med2D/pretrain_model/sam-med2d_b_qkv.pth"
SAM_CKPT="/data1/xinluo/projects/SlimSAM/checkpoints/sam_vit_b_qkv.pth"

# ===== 循环跑不同的 rank（串行执行）=====
for R in 4 8 16 32; do
  TS="$(date +%Y%m%d_%H%M%S)"
  EXP_NAME="assp_rank${R}_${TS}"
  LOG_DIR="${BASE}/logs/${EXP_NAME}"
  RUN_DIR="${BASE}/runs/${EXP_NAME}"
  mkdir -p "${LOG_DIR}" "${RUN_DIR}"

  echo "=============================="
  echo "[`date '+%F %T'`] Start experiment: rank=${R}"
  echo "Exp name: ${EXP_NAME}"
  echo "Logs: ${LOG_DIR}"
  echo "=============================="

  # 注意：前台阻塞执行；这一行完成（成功/失败）后才会进入下一轮
  # 使用 pipefail + $? 捕获 accelerate 的真实退出码（即使通过 tee 了）
  set +e
  accelerate launch \
    --num_processes=4 \
    --mixed_precision=bf16 \
    --main_process_port="$MASTER_PORT" \
    main.py \
      --exp_root "${BASE}/ASSPv4/runs" \
      --exp_name "${EXP_NAME}" \
      --sa1b_train_root "${SA1B_TRAIN}" \
      --sa1b_val_root   "${SA1B_VAL}" \
      --samed2d_root    "${SAMED2D}" \
      --sammed2d_checkpoint "${MED2D_CKPT}" \
      --sam_checkpoint       "${SAM_CKPT}" \
      --epochs_per_round_mode1 40 \
      --epochs_per_round_mode3 40 \
      --pruning_ratio       0.75 \
      --pruning_rounds      1 \
      --rank                "${R}" \
    2>&1 | tee "${LOG_DIR}/train.log"
  status=${PIPESTATUS[0]}
  set -e

  echo "[`date '+%F %T'`] Finished: rank=${R} (exit_code=${status})" | tee -a "${LOG_DIR}/train.log"

  # 可选：如果你希望某轮失败时就停止整个批量，把下面两行取消注释
  # if [[ $status -ne 0 ]]; then
  #   echo "Abort: rank=${R} failed with exit_code=${status}"; exit $status
  # fi

  # 可选：在两轮之间稍作间隔
  # sleep 10
done

echo "All experiments done."
