#! /usr/bin/env bash

export PYTHONPATH=$PYTHONPATH:$(pwd)

DATASET=cifar100
DATADIR=/data1/TL/data

# ===================================

LEVEL=5

if [ "$#" -lt 2 ]; then
	CORRUPT=gaussian_noise

	# METHOD=ssl
	# METHOD=align
	METHOD=both
	BS_ALIGN=256
	QS=2048
else
	CORRUPT=$1
	METHOD=$2
	BS_ALIGN=$3
	QS=$4
fi

# ===================================

SCALE_EXT=5.0
SCALE_SSH=20.0
LR=0.0008
BS_SSL=256
LAMBDA_A=0.5 # align_loss
LAMBDA_S=1.3 # contrastive_loss
LAMBDA_D=1.0 # w-s distillation loss
DIVERGENCE=all

COEF=1.0
NSAMPLE=1000000

SCALE_EXT=$( bc <<<"$SCALE_EXT * $COEF" )
SCALE_SSH=$( bc <<<"$SCALE_SSH * $COEF" )

echo 'DATASET: '${DATASET}
echo 'CORRUPT: '${CORRUPT}
echo 'METHOD:' ${METHOD}
echo 'DIVERGENCE:' ${DIVERGENCE}
echo 'LR:' ${LR}
echo 'SCALE_EXT:' ${SCALE_EXT}
echo 'SCALE_SSH:' ${SCALE_SSH}
echo 'BS_SSL:' ${BS_SSL}
echo 'BS_ALIGN:' ${BS_ALIGN}
echo 'QS:' ${QS}
echo 'NSAMPLE:' ${NSAMPLE}
echo 'COEF:' ${COEF}

# ===================================

printf '\n---------------------\n\n'

	python -W ignore slimttt.py \
    --dataset ${DATASET} \
	--dataroot ${DATADIR} \
	--resume results/${DATASET}_joint_slim_resnet50 \
	--outf results/${DATASET}_ttt_simclr_joint_slim_resnet50 \
	--corruption ${CORRUPT} \
	--level ${LEVEL} \
	--workers 8 \
	--batch_size ${BS_SSL} \
	--batch_size_align ${BS_ALIGN} \
	--lr ${LR} \
	--scale_ext ${SCALE_EXT} \
	--scale_ssh ${SCALE_SSH} \
    --lambda_a ${LAMBDA_A} \
	--lambda_s ${LAMBDA_S} \
	--lambda_d ${LAMBDA_D} \
	--method ${METHOD} \
	--divergence ${DIVERGENCE} \
	--align_ssh \
	--align_ext \
	--num_sample ${NSAMPLE} \
	--queue_size ${QS} \
    --app app:/data2/ll/slimTTT/cifar/slim_net/apps/s_resnet50_train_val.yml > c100outputlog/v3_00008_05_13.log &
	# --tsne
