#!/usr/bin/env bash

# CLI args with defaults
EXPDIR=${1:-./results/exp0} 
DATAPATH=${2:-/Datasets/ShapeNet} 
ROTFOLDER=${3:-/home/rotations} 
ROTATION=${4:-aligned}  #aligned,so3,z
CONFIG=${5:-configs/eqat_onet.yaml} 
BATCH_SIZE=${6:-4} 
KMODE=${7:-15} 
ILAYERS=${8:-10}
OLAYERS=${9:-2}
KNNQ=${10:-15} 
WEIGHT_DECAY=${11:-0.001} 
LOADCHK=${12:-trained_model.pth} 
MINIBATCH=${13:-2} 
AMP=${14:-false}
NUM_EPOCHS=${15:-300}
LEARNING_RATE=${16:-0.0002} 
NUM_CHAN=${17:-32}
NUM_HEAD=${18:-8}

TASK=recon 

python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu \
  --module se3_transformer.runtime.recon_training \
  --amp "$AMP" \
  --batch_size "$BATCH_SIZE" \
  --epochs "$NUM_EPOCHS" \
  --lr "$LEARNING_RATE" \
  --weight_decay "$WEIGHT_DECAY" \
  --use_layer_norm True \
  --norm True\
  --seed 0 \
  --task "$TASK"\
  --mini_batch "$MINIBATCH"\
  --eval_interval 3000\
  --num_layers 4\
  --base_dir "$EXPDIR"\
  --kmode "$KMODE"\
  --knnq $KNNQ\
  --load_ckpt_path "$LOADCHK"\
  --visualize_interval 100\
  --i_num_layers $ILAYERS\
  --o_num_layers $OLAYERS\
  --data_path '/Datasets/ShapeNet'\
  --data_rotation "$ROTATION"\
  --data_rotation_folder "$ROTFOLDER"\
  --iou_number_points 1024\
  --config "$CONFIG"\
  --num_channels $NUM_CHAN\
  --num_heads $NUM_HEAD\
  --val_mini_batch 8\
  --accumulate_grad_batches 2\
  #--fixed_percentage True
  #--precompute_bases

