#!/bin/bash

# The SBATCH directives must appear before any executable line in this script.

#SBATCH --qos high2         # QOS (priority).
#SBATCH -N 1               # Number of nodes requested.
#SBATCH -n 1               # Number of tasks (i.e. processes).
#SBATCH --cpus-per-task=1  # Number of cores per task.
#SBATCH --gres=gpu:4       # Number of GPUs.
#SBATCH -t 60-00:00          # Time requested (D-HH:MM).
#SBATCH --nodelist=em5    # Uncomment if you need a specific machine.

# Uncomment this to have Slurm cd to a directory before running the script.
# You can also just run the script from the directory you want to be in.
#SBATCH -D /home/test_time_training/ttt_mae_v1

# Uncomment to control the output files. By default stdout and stderr go to
# the same place, but if you use both commands below they'll be split up.
# %N is the hostname (if used, will create output(s) per node).
# %j is jobid.
##SBATCH -o slurm.%N.%j.out    # STDOUT
##SBATCH -e slurm.%N.%j.err    # STDERR

# Print some info for context.
source ~/.bashrc
conda activate taming
cd /home/test_time_training/ttt_mae_v1

nvidia-smi
# Python will buffer output of your script unless you set this.
# If you're not using python, figure out how to turn off output
# buffering when stdout is a file, or else when watching your output
# script you'll only get updated every several lines printed.
export PYTHONUNBUFFERED=1
DATA_PATH='/scratch/data/yearbook_faces/'

PRETRAIN_CHKPT='/home/test_time_training/ttt_mae_v1/models/imagenet/vit_head_2_layers/mae_pretrain_vit_large_full.pth'
# PRETRAIN_CHKPT='/home/test_time_training/ttt_mae_v1/models/imagenet/visualize/mae_visualize_vit_large.pth'
OUTPUT_DIR='/home/test_time_training/ttt_mae_v1/models/faces/vit_head/'

# Do all the research.
# -m torch.distributed.launch --nproc_per_node=4 
for train_label in 3
do 
        TIME=$(date +%s%3N)
        CUDA_VISIBLE_DEVICES=4 python main_prob.py \
                --batch_size 256 \
                --accum_iter 1 \
                --xp_class_index ${train_label} \
                --model mae_vit_large_patch16 \
                --dataset_name 'faces' \
                --finetune ${PRETRAIN_CHKPT} \
                --epochs 20 \
                --input_size 224 \
                --norm_pix_loss \
                --head_type vit_head \
                --weight_decay 0.2 \
                --blr 1e-3 \
                --dist_eval --data_path ${DATA_PATH} \
                --output_dir "${OUTPUT_DIR}/split_${train_label}" \
                --dist_url "file://$OUTPUT_DIR/$TIME"
                
        echo "Done train_${train_label}!"
done
# Print completion time.
date


# OUTPUT_DIR='/home/test_time_training/ttt_mae_v1/models/chexpert/vit_head_vis_with_aug/'

# Do all the research.
# -m torch.distributed.launch --nproc_per_node=4 
# for train_label in 5 
#         do 
#         for split_label in 6
#                 do 
#                 for value in 0
#                         do
#                         TIME=$(date +%s%3N)
#                         CUDA_VISIBLE_DEVICES=7 python main_prob.py \
#                                 --batch_size 256 \
#                                 --accum_iter 1 \
#                                 --xp_class_index ${train_label} \
#                                 --xp_split_by_class_index ${split_label} \
#                                 --xp_split_by_class_value ${value} \
#                                 --model mae_vit_large_patch16 \
#                                 --dataset_name 'chexpert' \
#                                 --finetune ${PRETRAIN_CHKPT} \
#                                 --epochs 20 \
#                                 --input_size 224 \
#                                 --head_type vit_head \
#                                 --blr 1e-3 \
#                                 --dist_eval --data_path ${DATA_PATH} \
#                                 --output_dir "${OUTPUT_DIR}/train_${train_label}_split_${split_label}_value_${value}" \
#                                 --dist_url "file://$OUTPUT_DIR/$TIME" \
#                                 --use_augmentations \
#                                 --weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25
                                
#                         echo "Done train_${train_label}_split_${split_label}_value_${value}!"
#                         done
#                 done
#         done