#!/bin/bash
#SBATCH -J t2i_Vanilla
#SBATCH --exclusive
#SBATCH --account=bsc70
#SBATCH --qos=acc_debug
#SBATCH --output=slurm_output/out.txt
#SBATCH --error=slurm_output/err.txt
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=4
#SBATCH --time=00:30:00
#SBATCH --gres=gpu:4

echo "START TIME: $(date)"

# export SINGULARITY_TMPDIR=/dev/shm/
module purge
module load singularity

# Set environment variables for PyTorch distributed training
export MASTER_ADDR=$(hostname)
export MASTER_PORT=29500  # Ensure this port is free
export WORLD_SIZE=$SLURM_NTASKS  # Total number of processes (4 GPUs)
export NCCL_DEBUG=INFO  # Enable debugging logs for NCCL
export NCCL_P2P_DISABLE=1  # Prevents certain deadlocks
export NCCL_IB_DISABLE=1

export WANDB_MODE=offline

singularity exec --nv /gpfs/projects/bsc70/heka/singularity/synthetic-data/gm_18_08.sif \
    accelerate launch \
    --multi-gpu \
    --mixed_precision fp16 \
    -m t2i.train_t2i_cc3m \
    --report-to="wandb" \
    --allow-tf32 \
    --mixed-precision="fp16" \
    --seed=0 \
    --path-type="linear" \
    --prediction="v" \
    --weighting="uniform" \
    --enc-type="dinov2-vit-b" \
    --enc-path="/gpfs/projects/bsc70/bsc193242/Models/facebookresearch_dinov2_haste" \
    --proj-coeff=0.0 \
    --attn-coeff=0.0 \
    --encoder-depth=8 \
    --output-dir="MMDiT" \
    --exp-name="t2i_Vanilla" \
    --data-dir="/gpfs/projects/bsc70/bsc131047/data/cc3m" \
    --batch-size=256 \
    --early-stop-point=400000 \
    --checkpointing-steps=25000 \
    --max-train-steps=400000 \
    --sampling-steps=10000 \
    --num-workers=8 \
    --resume-step=250000 \
    --ckpt-dir="/gpfs/projects/bsc70/bsc193242/t2i_models/MMDiT-Vanilla-CFG-CC3M"