#!/bin/bash
#SBATCH -p gpu20
#SBATCH -t 24:00:00
#SBATCH -o posttrain_procy_qa-%j.out
#SBATCH --gres gpu:2

export HF_DATASETS_CACHE='/sdb/zke4/dataset_cache'
#export TRANSFORMERS_CACHE='./model_cache'
export TRANSFORMERS_OFFLINE=1

max_samples=640000

seed=(2021 111 222 333 444 555 666 777 888 999)

for round in 4 ;
do
  for idrandom in 0;
  do
  for pt_task in 0 1 2 3 4 5
    do
      for ft_task in $(seq 0 ${pt_task});
        do
          CUDA_VISIBLE_DEVICES=0 python finetune.py \
          --max_seq_length 164 \
          --pt_task ${pt_task} \
          --ft_task ${ft_task} \
          --idrandom ${idrandom} \
          --ntasks 6 \
          --max_samples ${max_samples} \
          --seed ${seed[$round]} \
          --baseline 'softmask_pipeline_standard_norm_dgi_pre_as_general_first_proxy_all_layer'
      done
    done
  done
done
#  --how_to_block ${how_to_block} \
# $(seq 0 ${pt_task})
#  --finetune_type 'teacher_include'
#  --n_tokens 164 \
#  --finetune_type 'distill'
#  --finetune_type 'teacher_include'

# TODO: consider use different bsz for different domain (10 bsz decrease)
# python -m torch.distributed.launch --nproc_per_node 2
#CUDA_VISIBLE_DEVICES=0 python test_posttrain.py
#  --route_type 'id_gate' \
#gradient_accumulation_steps

# post-train depends on the new AE
# -m torch.distributed.launch --nproc_per_node 3