#!/bin/bash

# Default parameters
NUM_NODES=64
NUM_EPOCHS=1000
BATCH_SIZE=1000
DATASET_TYPE="erdos_renyi"
RESTRICT_DIAM=None
SEED=42
DATETIME=$(date +"%Y%m%d_%H%M")

echo "=== Training Disentangled Transformer Model ==="
echo "Model Type: Disentangled Transformer (Post Layer Norm ReLU)"
echo "Dataset: $DATASET_TYPE"
echo "Nodes: $NUM_NODES"
echo "Epochs: $NUM_EPOCHS"
echo "Batch Size: $BATCH_SIZE"
echo "=========================="

# two chain max length k is 3^L + 1 
python train.py \
    --model_type disentangled_transformer \
    --dataset $DATASET_TYPE \
    --eval_dataset all \
    --k 28 \
    --num_nodes $NUM_NODES \
    --num_epochs $NUM_EPOCHS \
    --batch_size $BATCH_SIZE \
    --heads 1 1 1 \
    --init_type randn \
    --readout_type sum \
    --disentangled_final_activation phi_logits \
    --learning_rate 1e-3 \
    --optimizer AdamW \
    --weight_decay 1e-6 \
    --criterion_type bce \
    --num_samples 1_000_000 \
    --eval_before_train \
    --on_the_fly \
    --save_every 100_000 \
    --seed $SEED \
    --restrict_diam $RESTRICT_DIAM \
    --wandb_run_name "disentangled_transformer_$(echo $DATASET_TYPE | tr ' ' '_')_restrict_diam=${RESTRICT_DIAM}_eval_$(echo $EVAL_DATASET | tr ' ' '_')_n${NUM_NODES}_seed${SEED}_datetime=${DATETIME}" \
    "$@"  # Pass any additional arguments
