import glob
import os
import argparse
import random

DATASET = "multi_news"
PEGA_PTM_PATH = "google/pegasus-multi_news"
T5_PTM_PATH = "pretrained_weights/multi_news_t5_small"


def run(inp_cmd):
    print(inp_cmd)
    os.system(inp_cmd)


# script to run test, val, train of multi_news dataset example: python run_sh.py  --mode train --baseline False
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', choices=["train", "test", "val"])
    parser.add_argument('--baseline', default="False", choices=["True", "False"])
    parser.add_argument('--gpu', default="0,1,2,3,4,5,6,7")
    parser.add_argument('--model_name', default="t5-small")
    parser.add_argument('--batch_size', default=8)
    # no need to set in training mode
    parser.add_argument('--save_path', default="")  # dir/contains/checkpoints
    args = parser.parse_args()
    if args.baseline == "False":
        if args.model_name == "t5-small":
            args.model_name = T5_PTM_PATH
        else:
            args.model_name = PEGA_PTM_PATH

    inference_param = f" --alpha 0.5  --max_length 400 --min_length 200 --length_pen 2.0  "
    if args.mode != "train":
        test_cmd = f"python inference.py --start_gpu {args.gpu} --split_num 2 --dataset {DATASET} " \
                   f" --baseline {args.baseline} --batch_size {args.batch_size} --model_name {args.model_name} " \
                   f" --PTM t5 --max_src_len 1024  --PTM t5 --save_path {args.save_path} " \
                   f" --diversity_pen 0.0 --beam_size 8 "
        # f" --from_pretrained sasaTrue --model_name {PEGA_PTM_PATH} " \
        test_cmd = test_cmd + inference_param
        run(test_cmd)

    else:
        num_process = len(args.gpu.split(','))
        # distributed
        train_cmd = f"CUDA_VISIBLE_DEVICES={args.gpu} " \
                    f" python -m torch.distributed.launch --master_port 2950{random.randint(0, 9)} --nproc_per_node={num_process} " \
                    f" train_distributed.py --save_path checkpoints/{DATASET}/{args.model_name} " \
                    f" --max_src_len 1024 --max_tgt_len 300   --baseline {args.baseline} " \
                    f" --mode train --accum_count 1  --batch_size {args.batch_size} --n_epochs 2 " \
                    f" --save_steps 2000 --dataset {DATASET} --PTM t5 --model_name {args.model_name} " \
                    f" --diversity_pen 2.8 --beam_size 14 "
        train_cmd = train_cmd + inference_param
        run(train_cmd)
