#!/bin/sh
# Description: Script to train a TNP classifier on the synthetic data.
export CUDA_VISIBLE_DEVICES=0

python3 train_causal_classify.py \
    --learning_rate=2e-4 \
    --batch_size=32 \
    --max_epochs=2 \
    --run_name="test" \
    --data_file="gplvm_20var_er20" \
    --num_workers=12 \
    --num_layers_encoder=4 \
    --num_layers_decoder=4 \
    --dim_model=128 \
    --dim_feedforward=256 \
    --decoder="probabilistic" \
    --seed=0 \
    --lr_warmup_ratio=0.1 \
    --num_nodes=20 \
    --nhead=8 \
    --n_perm_samples=25 \
    --sinkhorn_iter=1000 \
