#!/bin/bash

set -e

export OMP_NUM_THREADS=8       # OpenMP
export MKL_NUM_THREADS=8       # Intel MKL
export OPENBLAS_NUM_THREADS=8  # OpenBLAS
export NUMEXPR_NUM_THREADS=8   # NumExpr
export VECLIB_MAXIMUM_THREADS=8 # macOS
export PYTORCH_NUM_THREADS=8   # PyTorch (native threads)

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
GPU_IDS=$1
if [ -z "${GPU_IDS}" ]; then
  echo "Usage: $0 <gpu_ids> <config_file> [args...]"
  echo "Example: $0 0,1,2,3 config.yaml --arg1 value1"
  exit 1
fi
# CFG=$2
# if [ -z "${CFG}" ]; then
#   echo "Error: Configuration file is required."
#   exit 1
# fi
shift 1

export CUDA_VISIBLE_DEVICES=${GPU_IDS}
export NCCL_IB_DISABLE=1
export PYTHONHASHSEED=0

lock_gpu() {
  for id in ${GPU_IDS//,/ }; do
    echo "Locking GPU $id"
    sudo nvidia-smi -i "$id" -c EXCLUSIVE_PROCESS
  done
}
unlock_gpu() {
  for id in ${GPU_IDS//,/ }; do
    # GPU 5）
    for s in {1..5}; do
      if sudo nvidia-smi -i "$id" --query-compute-apps=pid \
           --format=csv,noheader | grep -q '[0-9]'; then
        sleep 1
      else
        break
      fi
    done
    sudo nvidia-smi -i "$id" -c DEFAULT
  done
}
trap unlock_gpu EXIT INT TERM

lock_gpu


# note batch_train=128 and batch_real=256, could be scaled up if more GPU memory is available
buffer_path="./store_local/buffer"
if [ ! -d "${buffer_path}" ]; then
  buffer_path="./store/buffer"
fi
python DM/main_DM.py --config-name imagenette128 "$@" buffer_path=${buffer_path}


# gpc=80
# bash DM/scripts/imagenette.sh 5 imagenette init_num_points=136 max_num_points=136 batch_size=800 gpc=80 batch_train=128
# gpc=100
# bash DM/scripts/imagenette.sh 4 imagenette init_num_points=109 max_num_points=109 batch_size=1000 gpc=100 batch_train=128
# gpc=40
# bash DM/scripts/imagenette.sh 6 imagenette init_num_points=273 max_num_points=273 batch_size=400 gpc=40 batch_train=128

