#!/bin/bash
#
# This script is for training with updated ann driver
#
# The design for this ann driver is to have 2 separate processes for training: one for passage/query 
# inference using trained checkpoint to generate ann data and calcuate ndcg, another for training the model 
# using the ann data generated. Data between processes is shared on common directory, model_dir for checkpoints
# and model_ann_data_dir for ann data.
#
# This script initialize the training and start the model training process
# It first preprocess the msmarco data into indexable cache, then generate a single initial ann data
# version to train on, after which it start training on the generated ann data, continously looking for
# newest ann data generated in model_ann_data_dir
#
# To start training, you'll need to run this script first
# after intial ann data is created (you can tell by either finding "successfully created 
# initial ann training data" in console output or if you start seeing new model on tensorboard),
# start run_ann_data_gen.sh in another dlts job (or same dlts job using split GPU)
#
# Note if preprocess directory or ann data directory already exist, those steps will be skipped
# and training will start immediately

##################################### Data Preprocessing ################################
# Passage ANCE(FirstP) 
preprocess_cmd="\
python ../data/msmarco_data.py --data_dir $base_data_dir --out_data_dir $preprocessed_data_dir --model_type rdot_nll \
--model_name_or_path roberta-base --max_seq_length 512 --data_type 1\
"

# # Document ANCE(FirstP) 
# preprocess_cmd="\
# python ../data/msmarco_data.py --data_dir $base_data_dir --out_data_dir $preprocessed_data_dir --model_type rdot_nll \
# --model_name_or_path roberta-base --max_seq_length 512 --data_type 0\
# "

# # Document ANCE(MaxP) 
# preprocess_cmd="\
# python ../data/msmarco_data.py --data_dir $base_data_dir --out_data_dir $preprocessed_data_dir --model_type rdot_nll_multi_chunk \
# --model_name_or_path roberta-base --max_seq_length 2048 --data_type 0\
# "

echo $preprocess_cmd
eval $preprocess_cmd

if [[ $? = 0 ]]; then
    echo "successfully created preprocessed data"
else
	echo "preprocessing failed"
    echo "failure: $?"
    exit 1
fi

##################################### Inital ANN Data generation ################################
# Passage ANCE(FirstP) 
initial_data_gen_cmd="\
python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen.py --training_dir $model_dir \
--init_model_dir $pretrained_checkpoint_dir --model_type rdot_nll --output_dir $model_ann_data_dir \
--cache_dir "${model_ann_data_dir}cache/" --data_dir $preprocessed_data_dir --max_seq_length 512 \
--per_gpu_eval_batch_size 16 --topk_training 200 --negative_sample 20 --end_output_num 0 \
"

# # Document ANCE(FirstP) 
# initial_data_gen_cmd="\
# python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen.py --training_dir $model_dir \
# --init_model_dir $pretrained_checkpoint_dir --model_type rdot_nll --output_dir $model_ann_data_dir \
# --cache_dir "${model_ann_data_dir}cache/" --data_dir $preprocessed_data_dir --max_seq_length 512 \
# --per_gpu_eval_batch_size 16 --topk_training 200 --negative_sample 20 --end_output_num 0 \
# "

# # Document ANCE(MaxP) 
# initial_data_gen_cmd="\
# python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann_data_gen.py --training_dir $model_dir \
# --init_model_dir $pretrained_checkpoint_dir --model_type rdot_nll_multi_chunk --output_dir $model_ann_data_dir \
# --cache_dir "${model_ann_data_dir}cache/" --data_dir $preprocessed_data_dir --max_seq_length 2048 \
# --per_gpu_eval_batch_size 16 --topk_training 200 --negative_sample 20 --end_output_num 0 \
# "

echo $initial_data_gen_cmd
eval $initial_data_gen_cmd

if [[ $? = 0 ]]; then
    echo "successfully created initial ann training data"
else
	echo "initial data generation failed"
    echo "failure: $?"
    exit 1
fi

############################################# Training ########################################
# Passage ANCE(FirstP) 
train_cmd="\
python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann.py --model_type rdot_nll \
--model_name_or_path $pretrained_checkpoint_dir --task_name MSMarco --triplet --data_dir $preprocessed_data_dir \
--ann_dir $model_ann_data_dir --max_seq_length 512 --per_gpu_train_batch_size=8 \
--gradient_accumulation_steps 2 --learning_rate 1e-6 --output_dir $model_dir \
--warmup_steps 5000 --logging_steps 100 --save_steps 10000 --optimizer lamb --single_warmup --fp16 --log_dir "~/tensorboard/${DLWS_JOB_ID}/logs/${job_name}"\
"

# # Document ANCE(FirstP) 
# train_cmd="\
# python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann.py --model_type rdot_nll \
# --model_name_or_path $pretrained_checkpoint_dir --task_name MSMarco --triplet --data_dir $preprocessed_data_dir \
# --ann_dir $model_ann_data_dir --max_seq_length 512 --per_gpu_train_batch_size=8 \
# --gradient_accumulation_steps 2 --learning_rate 5e-6 --output_dir $model_dir \
# --warmup_steps 3000 --logging_steps 100 --save_steps 10000 --optimizer lamb --single_warmup --fp16 --log_dir "~/tensorboard/${DLWS_JOB_ID}/logs/${job_name}"\
# "

# # Document ANCE(FirstP) 
# train_cmd="\
# python -m torch.distributed.launch --nproc_per_node=$gpu_no ../drivers/run_ann.py --model_type rdot_nll_multi_chunk \
# --model_name_or_path $pretrained_checkpoint_dir --task_name MSMarco --triplet --data_dir $preprocessed_data_dir \
# --ann_dir $model_ann_data_dir --max_seq_length 2048 --per_gpu_train_batch_size=2 \
# --gradient_accumulation_steps 8 --learning_rate 1e-5 --output_dir $model_dir \
# --warmup_steps 500 --logging_steps 100 --save_steps 10000 --optimizer lamb --single_warmup --log_dir "~/tensorboard/${DLWS_JOB_ID}/logs/${job_name}"\
# "

echo $train_cmd
eval $train_cmd