#!/bin/bash
# shellcheck disable=SC2206
#SBATCH --account=
#SBATCH --output=slurm-%A_%a.out
#SBATCH --error=slurm-%A_%a.err
#SBATCH --mail-user=
#SBATCH --mail-type=ALL
#SBATCH --partition=
#SBATCH --nodes=
#SBATCH --tasks-per-node=1
#SBATCH --cpus-per-task= d
#SBATCH --gres=gpu:4
#SBATCH --time=20:00:00
#SBATCH --exclusive

declare -a combinations
index=0
##for dataset in 'cifar-10' 'cifar-100' 'imagenet-16'
for dataset in 'imagenet-16'
do
    for sched in 'RAND' 'BOHB' 'ASHA' 'HB'
    ##for sched in 'RAND'
    do
        for workers in 1 2 4 8
        do
            combinations[$index]="$dataset $sched $workers"
            index=$((index + 1))
        done
    done
done

parameters=(${combinations[${SLURM_ARRAY_TASK_ID}]})

DATASET=${parameters[0]}
SCHED=${parameters[1]}
WORKERS=${parameters[2]}

ml --force purge

ml Stages/2022  GCC/11.2.0  OpenMPI/4.1.2 PyTorch/1.11-CUDA-11.5 torchvision/0.12.0-CUDA-11.5

source ray_tune_env/bin/activate

## SINGLE GPU 
COMMAND="single_gpu_tune.py --scheduler ${SCHED} --num-samples 64 --par-workers ${WORKERS} --dataset ${DATASET} --seed 333 --scale-bs 1"

echo $COMMAND


sleep 1
# make sure CUDA devices are visible
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export SRUN_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK}

num_gpus=1

## Limit number of max pending trials
export TUNE_MAX_PENDING_TRIALS_PG=$(($SLURM_NNODES * 4))

## Disable Ray Usage Stats
export RAY_USAGE_STATS_DISABLE=1

####### this part is taken from the ray example slurm script #####
set -x

# __doc_head_address_start__

# Getting the node names
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
nodes_array=($nodes)

head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

# if we detect a space character in the head node IP, we'll
# convert it to an ipv4 address. This step is optional.
if [[ "$head_node_ip" == *" "* ]]; then
IFS=' ' read -ra ADDR <<<"$head_node_ip"
if [[ ${#ADDR[0]} -gt 16 ]]; then
  head_node_ip=${ADDR[1]}
else
  head_node_ip=${ADDR[0]}
fi
echo "IPV6 address detected. We split the IPV4 address as $head_node_ip"
fi
# __doc_head_address_end__

# __doc_head_ray_start__
port=7563
ip_head=$head_node_ip:$port
export ip_head
echo "IP Head: $ip_head"

echo "Starting HEAD at $head_node"
srun --nodes=1 --ntasks=1 -w "$head_node" \
    ray start --head --node-ip-address="$head_node"i --port=$port \
    --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus $num_gpus  --block &
# __doc_head_ray_end__

# __doc_worker_ray_start__

# optional, though may be useful in certain versions of Ray < 1.0.
sleep 10

# number of nodes other than the head node
worker_num=$((SLURM_JOB_NUM_NODES - 1))

for ((i = 1; i <= worker_num; i++)); do
    node_i=${nodes_array[$i]}
    echo "Starting WORKER $i at $node_i"
    srun --nodes=1 --ntasks=1 -w "$node_i" \
        ray start --address "$head_node"i:"$port" --redis-password='5241590000000000' \
        --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus $num_gpus --block &
    sleep 5
done

echo "Ready"

python3 -u $COMMAND
