# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
export DATA_DIR="data/"
export OUTPUT_DIR="test_causal/"
export MODEL="gpt2"

# make predictions on test set for a model trained with Pytorch DDP
# python task.py --append_descr 1 --append_triples --append_retrieval 1 --data_version csqa_ret_3datasets --append_answer_text 1 --model_type electra --batch_size 1 --max_seq_length 512 --vary_segment_id --bert_model_dir test/ --mission output --predict_dir $OUTPUT_DIR/prediction/ --pred_file_name pred_test.csv --bert_vocab_dir google/electra-large-discriminator


# make predictions on test set for a model trained with DeepSpeed
# deepspeed --include="localhost:0" task.py --append_descr 1 --append_triples --append_retrieval 1 --data_version csqa_ret_3datasets --append_answer_text 1 --model_type debertav2 --batch_size 1 --max_seq_length 512 --vary_segment_id --ddp --deepspeed --bert_model_dir test/ --predict_dir $OUTPUT_DIR/prediction/ --pred_file_name pred_test.csv --mission output --deepspeed_config debertav3-test --predict_dev 
# deepspeed --include="localhost:0" task.py --append_descr 1 --append_triples --append_retrieval 1 --data_version csqa_ret_3datasets --append_answer_text 1 --preset_model_type debertav3  --batch_size 1 --max_seq_length 512 --vary_segment_id --ddp --deepspeed --bert_model_dir test/ --predict_dir $OUTPUT_DIR/prediction/ --pred_file_name pred_test.csv --mission output --deepspeed_config debertav3-test --predict_dev
CUDA_VISIBLE_DEVICES=2 python task_retrive.py --append_descr 0 --append_triples --append_retrieval 0 --data_version csqa_counterfact_causal_new --data_version_KGT csqa_counterfact_causal_new --append_answer_text 0 --preset_model_type $MODEL --retrival_model_type $MODEL  --batch_size 1 --max_seq_length 512 --vary_segment_id --bert_model_dir test_causal/ --predict_dir $OUTPUT_DIR/prediction/ --pred_file_name pred_test.csv --mission KGT --predict_dev