#!/bin/bash
#SBATCH --job-name=sinkhorn_l96
#SBATCH --output=logs/job%j.log
#SBATCH --error=logs/job%j.err
#SBATCH --time=22:00:00
#SBATCH --partition=A100
#SBATCH --gpus=1
#SBATCH --chdir=/home/ids/silva-21/ot4dynsys/neural_operators_for_chaos

export NCCL_DEBUG=info
export NCCL_P2P_DISABLE=1
export CUDA_LAUNCH_BLOCKING=1
export TORCH_DISTRIBUTED_DEBUG=DETAIL

# Clear PyKeOps cache to avoid GPU architecture incompatibility
rm -rf ~/.cache/keops*
export PYKEOPS_VERBOSE=1

EXP_NAME=${1:-l96_sinkhorn}

while
  port=$(shuf -n 1 -i 49152-65535)
  netstat -atun | grep -q "$port"
do
  continue
done

echo "$port"
echo "NODELIST="${SLURM_NODELIST}
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR

python -m torch.distributed.launch \
--nproc_per_node=1 --master_port=${port} scripts/main.py \
  --l96 \
  --batch_size 25 \
  --modes 28 \
  --width 64 \
  --x_len 100 \
  --with_geomloss_kd 0 \
  --with_geomloss 1 \
  --blur 0.02 \
  --lambda_geomloss 3 \
  --noisy_scale 0.3 \
  --prefix "state_sinkhorn_dim3_mlp_1gpu_A100" \
  --train_operator \
  --wandb \
  --loss_mode learnable_sinkhorn \
  --wgan_critic_steps 5 \
  --summary_clip 0.1 \
  --summary_dim 3\
  --summary_mode statewise \
  #--state_dim 60 \
