#!/bin/bash
#SBATCH --nodes=1
#SBATCH --gpus=4
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=6
#SBATCH --job-name=swav_200ep_bs256_pretrain
#SBATCH --time=72:00:00
#SBATCH --mem=150G

TMPDIR='/home/ubuntu/data_balanced_user'
mkdir ${TMPDIR}
tar xf /home/ubuntu/tars/user_stratified.tar -C ${TMPDIR}
mv ${TMPDIR}/scratch/shared/beegfs/yuki/fast/yfcc ${TMPDIR}/yfcc
mv ${TMPDIR}/yfcc ${TMPDIR}/train
DATASET_PATH=${TMPDIR}/train
EXPERIMENT_PATH="/home/ubuntu/experiments/swav/balanced_user-BS256-default"
mkdir -p $EXPERIMENT_PATH



#srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --label \
python -m torch.distributed.launch --nproc_per_node=4 main_swav.py \
--data_path $DATASET_PATH \
--workers 6 \
--nmb_crops 2 6 \
--size_crops 224 96 \
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--crops_for_assign 0 1 \
--temperature 0.1 \
--epsilon 0.05 \
--sinkhorn_iterations 3 \
--feat_dim 128 \
--nmb_prototypes 3000 \
--queue_length 3840 \
--epoch_queue_starts 15 \
--epochs 200 \
--batch_size 64 \
--base_lr 0.6 \
--final_lr 0.0006 \
--freeze_prototypes_niters 5005 \
--wd 0.000001 \
--warmup_epochs 0 \
--arch resnet50 \
--use_fp16 true \
--sync_bn pytorch \
--dump_path $EXPERIMENT_PATH
