#!/bin/bash
#SBATCH -J t2i_train
#SBATCH --account=bsc70
#SBATCH --qos=acc_debug
#SBATCH --output=slurm_output/out.txt
#SBATCH --error=slurm_output/err.txt
#SBATCH --nodes=2
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=80
#SBATCH --time=02:00:00
#SBATCH --exclusive

echo "START TIME: $(date)"

# ---------------------------
# Git commit info (optional)
# ---------------------------
if [ -f commit_hash.txt ]; then
    export GIT_COMMIT_SHORT=$(cat commit_hash.txt)
else
    export GIT_COMMIT_SHORT="unknown"
fi
echo "GIT_COMMIT_SHORT: ${GIT_COMMIT_SHORT}"

# ---------------------------
# Distributed setup
# ---------------------------
GPUS_PER_NODE=4
NNODES=$SLURM_NNODES
NUM_PROCESSES=$((NNODES * GPUS_PER_NODE))

# Pick first node as master
MASTER_HOST=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n1)
MASTER_ADDR=$(getent hosts $MASTER_HOST | awk '{ print $1 }')  # Convert to IP
export MASTER_ADDR
export MASTER_PORT=29500

# NCCL / PyTorch distributed
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
export NCCL_DEBUG_FILE=slurm_output/nccl_debug.txt
export NCCL_SOCKET_IFNAME=ib0    # or eth0 depending on your cluster
export NCCL_IB_DISABLE=0
export NCCL_P2P_DISABLE=1

export TORCH_DISTRIBUTED_DEBUG=INFO
export NCCL_ASYNC_ERROR_HANDLING=1

# WandB offline
export WANDB_MODE=offline

module purge
module load singularity

SINGULARITY_IMAGE_PATH="/gpfs/projects/bsc70/heka/singularity/synthetic-data/gm_18_08.sif"

echo "MASTER_ADDR=$MASTER_ADDR"
echo "MASTER_PORT=$MASTER_PORT"
echo "SLURM_NODEID=$SLURM_NODEID"
echo "SLURM_NNODES=$SLURM_NNODES"
echo "NUM_PROCESSES=$NUM_PROCESSES"
echo "================================================"

# ---------------------------
# Launch training
# ---------------------------
clear; srun $SRUN_ARGS singularity exec -B /gpfs/projects/bsc70 --nv $SINGULARITY_IMAGE_PATH \
    accelerate launch \
    --num_processes $NUM_PROCESSES \
    --num_machines $SLURM_NNODES \
    --machine_rank $SLURM_NODEID \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --rdzv_backend c10d \
    --max_restarts 0 \
    --mixed_precision fp16 \
    -m t2i.HASTE.train_t2i \
    --report-to="wandb" \
    --allow-tf32 \
    --seed=0 \
    --path-type="linear" \
    --prediction="v" \
    --weighting="uniform" \
    --enc-type="dinov2-vit-b" \
    --enc-path="/gpfs/projects/bsc70/bsc193242/Models/facebookresearch_dinov2_haste" \
    --proj-coeff=0.5 \
    --attn-coeff=0.5 \
    --encoder-depth=8 \
    --output-dir="MMDiT" \
    --exp-name="t2i_haste" \
    --data-dir="/gpfs/projects/bsc70/bsc193242/Data/coco256_features" \
    --early-stop-point=150000 \
    --checkpointing-steps=25000 \
    --max-train-steps=150000 \
    --gradient-accumulation-steps=2 \
