#! /usr/bin/env bash

export PYTHONPATH=$PYTHONPATH:$(pwd)

DATASET=cifar10
DATADIR=/data1/TL/data

# ===================================

LEVEL=5

if [ "$#" -lt 2 ]; then
	CORRUPT=gaussian_noise

	# METHOD=ssl
	# METHOD=align
	METHOD=both
	NSAMPLE=100000
else
	CORRUPT=$1
	METHOD=$2
	NSAMPLE=$3
fi

# ===================================

SCALE_EXT=0.1
# SCALE_EXT=0.01
SCALE_SSH=0.2
LAMBDA_A=0.5 # align_loss
LAMBDA_S=1.2 # contrastive_loss
LAMBDA_D=1.0 # w-s distillation loss
LR=0.0005	
BS_SSL=256
BS_ALIGN=256
QS=1536
DIVERGENCE=all

echo 'DATASET: '${DATASET}
echo 'CORRUPT: '${CORRUPT}
echo 'METHOD:' ${METHOD}
echo 'DIVERGENCE:' ${DIVERGENCE}
echo 'LR:' ${LR}
echo 'SCALE_EXT:' ${SCALE_EXT}
echo 'SCALE_SSH:' ${SCALE_SSH}
echo 'BS_SSL:' ${BS_SSL}
echo 'NSAMPLE:' ${NSAMPLE}

# ===================================

printf '\n---------------------\n\n'

	python -W ignore slimttt.py \
	--dataroot ${DATADIR} \
	--resume results/${DATASET}_joint_slim_resnet50 \
	--outf results/${DATASET}_ttt_simclr_slim_joint_resnet50 \
	--corruption ${CORRUPT} \
	--level ${LEVEL} \
	--workers 8 \
	--batch_size ${BS_SSL} \
	--batch_size_align ${BS_ALIGN} \
	--queue_size ${QS} \
	--lr ${LR} \
	--scale_ext ${SCALE_EXT} \
	--scale_ssh ${SCALE_SSH} \
	--lambda_a ${LAMBDA_A} \
	--lambda_s ${LAMBDA_S} \
	--lambda_d ${LAMBDA_D} \
	--method ${METHOD} \
	--divergence ${DIVERGENCE} \
	--align_ssh \
	--align_ext \
	--num_sample ${NSAMPLE} \
	--app app:/data2/ll/slimTTT/cifar/slim_net/apps/s_resnet50_train_val.yml > outputlog/v3_cifar10_00005_05_12.log &
	# --tsne > outputlog/ttt++_ensemblenew_tsne.log