#!/bin/bash

# 用法说明
if [ $# -lt 1 ] || [ $# -gt 3 ]; then
  echo "Usage: $0 [audiocaps|clotho] [optional: run_name] [optional: --single]"
  exit 1
fi

# ----------------- 通用配置 -----------------
SAVE_FREQ=1
REPORT_TO="tensorboard"
BATCH_SIZE=85
LR_SCHEDULER="const_and_nowarm"
LR=5e-5
WD=0.1
EPOCHS=20
WORKERS=8
MODEL="ResNet38"
OT="--ot"
FLOAT_LOSS="--float-loss"
TRANSFER_WEIGHT=0.05


# ----------------- 解析参数 -----------------
DATASET=""
CUSTOM_NAME=""
SINGLE_CARD="false"

for arg in "$@"; do
  case $arg in
    audiocaps|clotho)
      DATASET=$arg
      ;;
    --single)
      SINGLE_CARD="true"
      ;;
    *)
      CUSTOM_NAME=$arg
      ;;
  esac
done

# 检查数据集参数
if [ -z "$DATASET" ]; then
  echo "Error: dataset not specified. Use 'audiocaps' or 'clotho'."
  exit 1
fi

# 数据路径与默认名称
if [ "$DATASET" = "audiocaps" ]; then
  DATA_PATH="float_distributed/data/AudioCaps"
  DEFAULT_NAME="audiocaps_distributed"
  AUDIO_DATASET_NAME="AudioCaps"
elif [ "$DATASET" = "clotho" ]; then
  DATA_PATH="float_distributed/data/Clotho"
  DEFAULT_NAME="clotho_distributed"
  AUDIO_DATASET_NAME="Clotho"
fi

# 设置名称
if [ -z "$CUSTOM_NAME" ]; then
  NAME=$DEFAULT_NAME
else
  NAME=$CUSTOM_NAME
fi

# ----------------- 启动方式 -----------------
if [ "$SINGLE_CARD" = "true" ]; then
  DEVICES="0"
  LAUNCHER="python -m"
  echo "[INFO] Running in single GPU mode (CUDA_VISIBLE_DEVICES=$DEVICES)"
else
  DEVICES="0,1,2"
  # 自动分配一个空闲端口（15000~16000 之间）
  PORT=$(shuf -i 15000-16000 -n 1)
  export MASTER_PORT=$PORT
  LAUNCHER="torchrun --nproc_per_node=3 --master_port=$MASTER_PORT -m"
  echo "[INFO] Running in multi-GPU mode (CUDA_VISIBLE_DEVICES=$DEVICES, MASTER_PORT=$MASTER_PORT)"
fi

# ----------------- 启动训练 -----------------
CUDA_VISIBLE_DEVICES=$DEVICES $LAUNCHER float_train.main \
    --save-frequency $SAVE_FREQ \
    --report-to $REPORT_TO \
    --train-data "$DATA_PATH" \
    --val-data "$DATA_PATH" \
    --test-data "$DATA_PATH" \
    --audio-dataset "$AUDIO_DATASET_NAME" \
    --dataset-type audio \
    --batch-size $BATCH_SIZE \
    --lr-scheduler "$LR_SCHEDULER" \
    --lr $LR \
    --wd $WD \
    --epochs $EPOCHS \
    --workers $WORKERS \
    --model $MODEL \
    --name $NAME \
    --transfer-weight $TRANSFER_WEIGHT \
    $FLOAT_LOSS \
    $OT
