# This is an example slurm launcher config that should be added to the main config.yaml file under the hydra section. This cannot be run directly.
hydra:
  launcher:
    name: ${get_slurm_name:}
    # See https://hydra.cc/docs/configure_hydra/workdir/
    submitit_folder: ${hydra.sweep.dir}/%j
    nodes: ${nodes} # Number of nodes. This value is *per* node
    mem_gb: ${eval:'${mem_per_gpu} * ${trainer.devices}'} # 40GB per gpu. This value is *per* node
    gpus_per_node: ${trainer.devices}
    partition: ${partition}
    constraint: ${constraint}
    exclude: ${exclude_nodes:}

    timeout_min: ${timeout_min}
    max_num_timeout: 12 # Num requeue exlcuding pre-emptions
    comment: aswerdlo
    stderr_to_stdout: true

    # Be careful with changing anything below.
    # see: https://github.com/stas00/ml-engineering/tree/master/training/fault-tolerance#approach-b2-choosing-which-process-to-send-the-signal-to
    # see: https://github.com/huggingface/accelerate/issues/1918

    # The accelerate launcher w/1 initial process and then spawn 1 per GPU
    tasks_per_node: 1
    cpus_per_task: ${eval:'${cpus_per_gpu} * ${trainer.devices}'}
    python: |
            bash -c "torchrun --nnodes $SLURM_NNODES --nproc_per_node $SLURM_GPUS_PER_NODE --role \$(hostname -s|tr -dc '0-9'): --node_rank \$SLURM_PROCID --max-restarts=2 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \

    python_suffix: ' --dummy-arg $SLURM_JOB_ID" &'
    signal: 'B:USR2@360'
    post_srun_commands:
      - ''
      - wait

    srun_args:
      - '--jobid $SLURM_JOB_ID'

    setup:
      - |
        export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
        export MASTER_PORT=$(( ($SLURM_JOB_ID % 20001) + 30000 ))
        export NUM_PROCESSES=$((SLURM_NNODES * SLURM_GPUS_PER_NODE))
        export NCCL_DEBUG=INFO
        export NCCL_NSOCKS_PERTHREAD=4
        export NCCL_SOCKET_NTHREADS=2
        export OMP_NUM_THREADS=2
        export PYTHONUNBUFFERED=1
        export STDOUT_PATH=$(scontrol show job $SLURM_JOB_ID | grep -oP "StdOut=\K[^ ]+")
        export LOCAL_JOB_FOLDER=$(dirname $STDOUT_PATH)
        export NCCL_TOPO_DUMP_FILE="$LOCAL_JOB_FOLDER/nccl_topo.xml"
        if [ -n "$SLURM_RESTART_COUNT" ]; then
          export RESTART_COUNT=$SLURM_RESTART_COUNT
        else
          export RESTART_COUNT=0
        fi
        export MAIN_LOG_PATH="$LOCAL_JOB_FOLDER/log_$RESTART_COUNT.txt"

        mkdir -p $LOCAL_JOB_FOLDER
        printenv > "$LOCAL_JOB_FOLDER"/env_"$SLURM_LOCALID_$RESTART_COUNT.txt"

        echo "ibstatus: $(ibstatus)"
        echo "ibdev2netdev: $(ibdev2netdev)"
        echo "rdma device: $(rdma link)"
        echo "environment: $(env | grep NCCL)"
        echo "NUM_PROCESSES: $NUM_PROCESSES, SLURM_NNODES: $SLURM_NNODES SLURM_GPUS_PER_NODE: $SLURM_GPUS_PER_NODE"
        echo "NODE_ID: $SLURM_NODEID, SLURM_PROCID: $SLURM_PROCID, MASTER_ADDR: $MASTER_ADDR, MASTER_PORT: $MASTER_PORT"
        echo "PWD: $PWD, LOCAL_JOB_FOLDER: $LOCAL_JOB_FOLDER, MAIN_LOG_PATH: $MAIN_LOG_PATH"

        trap 'echo "SIGUSR2 received for $SLURM_JOB_ID"; \
        if [ -n "$SLURM_ARRAY_JOB_ID" ]; then echo "SLURM_ARRAY_JOB_ID: $SLURM_ARRAY_JOB_ID"; fi; \
        if [ -n "$SLURM_ARRAY_TASK_ID" ]; then echo "SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"; fi; \
        # ps auxww | grep $USER; \
        pid=$(pgrep -u $USER -f "python.*(accelerate|torchrun|deepspeed|distributed\.run).*dummy-arg $SLURM_JOB_ID"); \
        echo "Found parent PIDs: $pid"; \
        for p in $pid; do \
          echo "Parent PID has cmd: $(ps -p $p -o cmd=)"; \
          children=$(pgrep -P $p); \
          echo "Children: $children"; \
          if [ -n "$children" ]; then \
            for child in $children; do \
              ppid=$(ps -o ppid= -p $child | tr -d " ")
              if [ "$ppid" -eq "$p" ]; then
                echo "Killing direct child process: PID $child with cmd: $(ps -p $child -o cmd=)"
                kill -USR2 $child &
              else
                echo "Skipping non-direct child process: PID $child with PPID $ppid"
              fi
            done; \
            echo "Sent kill signals to children of $p"; \
          else \
            echo "No children found for $p"; \
          fi; \
        done; \
        wait;' SIGUSR2