#!/usr/bin/env bash

# CLI args with defaults
EXPDIR=${1:-./results/exp0}
DATAPATH=${2:-/Datasets/ShapeNet}
ROTFOLDER=${3:-/home/rotations}
ROTATION=${4:-aligned}  #aligned,so3,z
CONFIG=${5:-configs/eqat_onet.yaml}
BATCH_SIZE=${6:-4}
KMODE=${7:-15}
ILAYERS=${8:-10}
OLAYERS=${9:-2}
KNNQ=${10:-15}
WEIGHT_DECAY=${11:-0.001}
LOADCHK=${12:-trained_model.pth}
MINIBATCH=${13:-2} 
AMP=${14:-false}
NUM_EPOCHS=${15:-300}
LEARNING_RATE=${16:-0.0002}
NUMCHAN=${17:-32}
NUMHEADS=${18:-8}
VISDIR=${19:-visualization_eval}
LOGDIR=${20:-tensorboard_eval}

TASK=recon

python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu \
  --module se3_transformer.runtime.recon_inference \
  --amp "$AMP" \
  --batch_size "$BATCH_SIZE" \
  --epochs "$NUM_EPOCHS" \
  --lr "$LEARNING_RATE" \
  --weight_decay "$WEIGHT_DECAY" \
  --use_layer_norm True \
  --norm True\
  --seed 0 \
  --task "$TASK"\
  --mini_batch "$MINIBATCH"\
  --eval_interval 3000\
  --num_layers 4\
  --base_dir "$EXPDIR"\
  --kmode "$KMODE"\
  --load_ckpt_path "$LOADCHK"\
  --visualize_interval 100\
  --i_num_layers $ILAYERS\
  --o_num_layers $OLAYERS\
  --vis_dir "$VISDIR"\
  --log_dir "$LOGDIR"\
  --iou_number_points 100000\
  --val_mini_batch 10\
  --knnq $KNNQ\
  --num_channels $NUMCHAN\
  --num_heads $NUMHEADS\
  --data_rotation "$ROTATION"\
  --data_rotation_folder "$ROTFOLDER"\
  --config "$CONFIG"\
  --data_path "$DATAPATH"\
  --accumulate_grad_batches 2\
  #--precompute_bases