#!/bin/bash
#SBATCH --job-name=bash          # avoid lightning auto-debug configuration
#SBATCH --nodes=1                  
#SBATCH --ntasks-per-node=1      # shall >#GPU to avoid overtime thread distribution 
#SBATCH --cpus-per-task=1        # number of OpenMP threads per MPI process
#SBATCH --mem=8GB               
#SBATCH --time 5:59:59          # time limit (D-HH:MM:ss)
#SBATCH --gres=gpu:a100:1             # number of GPUs

#########################
####### Configs #########
#########################
CONDA_ENV_NAME=molmdgm
CONDA_HOME=$HOME/miniconda3
WORKDIR=$(pwd)

#########################
####### Env loader ######
#########################
source $CONDA_HOME/etc/profile.d/conda.sh
conda activate ${CONDA_ENV_NAME}
module load cuda/12.4.0


dt=$(date '+%d/%m/%Y-%H:%M:%S')
echo "[$0] >>> Starttime => ${dt}"

#########################
####### Routine #########
#########################
export OMP_NUM_THREADS="${SLURM_CPUS_PER_TASK:-1}"
export HYDRA_FULL_ERROR=1

ROOT_DIR="/home/mila/s/stephen.lu/scratch/mols/syncogen"
DATA_DIR="$ROOT_DIR/data/10mols"

python main.py \
    sampling.steps=1000 \
    model.length=5 \
    trainer.precision="16-mixed" \
    loader.batch_size=128 \
    loader.eval_batch_size=2 \
    trainer.val_check_interval=null \
    trainer.check_val_every_n_epoch=null \
    data.cache_dir=$DATA_DIR \
    spatial.overfit=True \
    spatial.random_conformer=True \
    spatial.sample_conformer=False \
    model.use_global_features=True \
    paths.root_dir=$ROOT_DIR
