# need to deactivate the hcoll since it gives errors otherwise...but even wihtout that it gives erors?
#source wandbkey
#source master_def
export OMPI_ALLOW_RUN_AS_ROOT=1
export OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1
export OMPI_MCA_opal_cuda_support=true
#export CGX_COMPRESSION_BUCKET_SIZE=8192
export CGX_COMPRESSION_BUCKET_SIZE=1024
#export CUDA_LAUNCH_BLOCKING=1
DUMMY_COMPRESSION=0
UNIFORM=1
ADAPTIVE=2
# SELECT number of bits
BITS=5
NODES=4
ITERS=400000
SCORE_EVERY_EPOCHS=50
GRAD_PENALTY="--gradient-penalty 10"
GRAD_PENALTY=""
DIST_RUN="--mpi-flag"
#DIST_RUN=""
SEED=1234

#export ANOMALY="YES"  uncomment to debug with detect anomaly
#MPI_FLAGS="-v -np $NODES -H localhost:1,host2:1,host3:1,host4:1 -x PATH -x WANDB_API_KEY -x ANOMALY -mca coll ^hcoll --mca btl tcp,self --mca btl_tcp_if_include eth0 --mca pml ob1"

MPI_FLAGS="-v -np $NODES -x PATH -x WANDB_API_KEY -x ANOMALY -mca coll ^hcoll --mca btl tcp,self --mca pml ob1"
BASE_COMMAND="python train_extraadam.py --num-iter ${ITERS} --default --model resnet  --cuda
--fid-score --inception-score --quantization-bucket-size $CGX_COMPRESSION_BUCKET_SIZE --layernorm --batch-size 1024
--score-batch-size 4096 --num-threads=5 --score-every $SCORE_EVERY_EPOCHS $DIST_RUN $GRAD_PENALTY --seed $SEED --dataset cifar100 "
#BASE_COMMAND="python train_extraadam.py  --num-iter 500000 --default --model resnet --dist-backend nccl --cuda --fid-score --inception-score --quantization-bucket-size $CGX_COMPRESSION_BUCKET_SIZE --save-gen-samples --layernorm --batch-size 1024 --score-batch-size 4096 --num-threads=5 --score-every 10 --mpi-flag "
NUQ_FLAGS="--nuq --warmup-milestones 0 200 1000 5000 --nuq-method=alq --nuq-every=10000 --quantization-bits $BITS "
UNIFORM_FLAGS="--quantization-bits $BITS --log-path uniform_${BITS} --dist-backend nccl "
FULL_PREC_FLAGS="--quantization-bits 32 --log-path baseline_cifar100 --dist-backend nccl "
LGRECO_FLAGS=" --lgreco --quantization-bits $BITS --log-path lgreco_3_8_$BITS_cifar100 --dist-backend nccl"

BENCH_FLAGS="--quantization-bits $BITS --log-path bench_$BITS --dist-backend cgx --quantization-bucket-size $CGX_COMPRESSION_BUCKET_SIZE --master-addr 192.168.10.125 "
BENCH_BL_FLAGS="--quantization-bits 32 --log-path bench_$BITS --dist-backend nccl --master-addr 192.168.10.125 "

FULL_CMD="mpirun $MPI_FLAGS -- $BASE_COMMAND $FULL_PREC_FLAGS"
UNIFORM_CMD="mpirun $MPI_FLAGS -- $BASE_COMMAND $UNIFORM_FLAGS"
NUQ_CMD="mpirun $MPI_FLAGS -- $BASE_COMMAND $NUQ_FLAGS"
LGRECO_CMD="mpirun $MPI_FLAGS -- $BASE_COMMAND $LGRECO_FLAGS"
BENCH_CMD="mpirun $MPI_FLAGS -- $BASE_COMMAND $BENCH_FLAGS"
BENCH_BL_CMD="mpirun $MPI_FLAGS -- $BASE_COMMAND $BENCH_BL_FLAGS"
SINGLE_CMD="$BASE_COMMAND $BENCH_BL_FLAGS"

#SELECT the mode



#MODE=$DUMMY_COMPRESSION

MODE=$UNIFORM
#CMD=$UNIFORM_CMD
#CMD=$FULL_CMD

#CMD=$LGRECO_CMD
#
#MODE=$ADAPTIVE
#CMD=$NUQ_CMD

#export CGX_COMPRESSION_MODE=$MODE
#export CGX_QUANTIZATION_BITS=$BITS
#echo "Mode ${MODE} (0=dummy,1=uniform,adaptive=2)"
#echo "bits ${BITS}"
#echo "Command:\n"
#for CMD in [ ]
#echo $CMD

#$FULL_CMD
#$BENCH_CMD
#$BENCH_BL_CMD
#$SINGLE_CMD
$LGRECO_CMD
