#!/bin/bash

# Default parameters
NUM_NODES=20
NUM_LAYERS=2
NUM_EPOCHS=1000
BATCH_SIZE=1000
DATASET_TYPE="erdos_renyi"
RESTRICT_DIAM=9
DATETIME=$(date +"%Y%m%d_%H%M")

echo "=== Training RoBERTa Model ==="
echo "Model Type: RoBERTa (Post Layer Norm ReLU)"
echo "Dataset: $DATASET_TYPE"
echo "Nodes: $NUM_NODES"
echo "Layers: $NUM_LAYERS"
echo "Epochs: $NUM_EPOCHS"
echo "Batch Size: $BATCH_SIZE"
echo "=========================="

# two chain max length k is 3^L + 1
python train.py --multi_gpu \
    --model_type roberta \
    --dataset $DATASET_TYPE --fixed_p 0.16 \
    --restrict_diam $RESTRICT_DIAM \
    --num_layers $NUM_LAYERS \
    --num_attention_heads 1 \
    --num_nodes $NUM_NODES \
    --num_layers $NUM_LAYERS \
    --num_epochs $NUM_EPOCHS \
    --batch_size $BATCH_SIZE \
    --eval_dataset all \
    --k 10 \
    --roberta_type relu \
    --layer_norm_type post \
    --hidden_size 512 \
    --learning_rate 1e-4 \
    --optimizer AdamW \
    --weight_decay 1e-4 \
    --num_samples 1_000_000 \
    --on_the_fly \
    --save_every 100_000 \
    --seed 42 \
    --visualize_samples \
    --wandb_run_name "roberta_$(echo $DATASET_TYPE | tr ' ' '_')_p=${FIXED_P}_restrict_diam=${RESTRICT_DIAM}_eval_$(echo $EVAL_DATASET | tr ' ' '_')_n${NUM_NODES}_seed${SEED}_datetime=${DATETIME}" \
    "$@"  # Pass any additional arguments
