import argparse
import os


def run(cmd):
    print(cmd)
    os.system(cmd)


def preprocess_cmd():
    TEXT = f"wmt16.tokenized.{args.source}-{args.target}"
    cmd = f"python preprocess.py --source-lang {args.source} --target-lang {args.target} \
    --trainpref {TEXT}/train --validpref {TEXT}/valid --testpref {TEXT}/test \
    --destdir data-bin/wmt16.tokenized.{args.source}-{args.target} \
    --workers 20"
    run(cmd)


def train_cmd():
    cmd = f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7  python train.py \
    data-bin/wmt16.tokenized.{args.source}-{args.target}  \
    --save-dir checkpoints/wmt16-{args.source}-{args.target} \
    --criterion label_smoothed_cross_entropy \
    --arch transformer_repro_wmt --max-epoch 100 \
    --share-decoder-input-output-embed  \
    --optimizer adam --adam-betas '(0.9,0.98)' \
    --lr 5e-4 --lr-scheduler inverse_sqrt \
    --stop-min-lr '1e-09' --warmup-updates 4000 \
    --warmup-init-lr '1e-07' --label-smoothing 0.1 \
    --dropout 0.3 --weight-decay 0.0001 \
    --log-interval 100 \
    --max-tokens 8000 " + " --eval-bleu \
    --eval-bleu-args '{\"beam\": 5, \"max_len_a\": 1, \"max_len_b\": 50}' \
    --eval-bleu-detok moses \
    --eval-bleu-remove-bpe \
    --eval-bleu-print-samples \
    --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \
    --keep-last-epochs 10 --keep-best-checkpoints 10  --fp16"
    run(cmd)


def gen_cmd():
    batch_size = round((32 / args.beam) * 64)
    if args.avg_ckpt == "True":
        run("python scripts/average_checkpoints.py --inputs checkpoints_wmt16/ "
            "--num-epoch-checkpoints  5 --output averaged_model.pt")
    # data-bin/wmt16.tokenized.{args.source}-{args.target}
    shared_cmd = f"CUDA_VISIBLE_DEVICES={args.gpu} python generate.py \
        data-bin/wmt16.tokenized.{args.source}-{args.target} \
        --path {args.ckpt_dir}/{args.ckpt_name}.pt \
        --batch-size {batch_size} --lenpen 0.1 \
        --max-len-a 1.0 --max-len-b 50 --remove-bpe  \
        --beam {args.beam} --diverse-beam-groups {args.diverse_beam_groups} \
        --diverse-beam-strength {args.diverse_beam_strength} \
        --diversity-rate {args.diversity_rate} "
    if args.nbest > 1:
        cmd_all = shared_cmd + f" --retain-dropout --retain-dropout-modules '[\"TransformerDecoder\"]' --nbest {args.nbest} > gen_results_multi/wmt16_{args.ckpt_name}_{args.decoding_method}_{args.nbest}.out"
        run(cmd_all)
    else:
        run(shared_cmd + f" > gen_results/wmt16_{args.ckpt_name}.out")
        # show the results
        run(f"tail -1 gen_results/wmt16_{args.ckpt_name}.out")


def gen_cl_cmd():
    batch_size = round((32 / args.beam) * 64)
    if args.avg_ckpt == "True":
        run("python scripts/average_checkpoints.py --inputs checkpoints_wmt16/ "
            "--num-epoch-checkpoints  5 --output averaged_model.pt")
    # data-bin/wmt16.tokenized.{args.source}-{args.target}
    shared_cmd = f"CUDA_VISIBLE_DEVICES={args.gpu} python my_generate.py  --task translation_cl \
        data-bin/wmt16.tokenized.{args.source}-{args.target} \
        --path {args.ckpt_dir}/{args.ckpt_name}.pt \
        --batch-size {batch_size} --lenpen 0.1 \
        --max-len-a 1.0 --max-len-b 50 --remove-bpe  \
        --diversity-rate {args.diversity_rate}  "
    if args.nbest > 1:
        cmd_all = shared_cmd + f" --retain-dropout --retain-dropout-modules '[\"TransformerDecoder\"]' --nbest {args.nbest} > gen_results_multi/wmt16_{args.ckpt_name}_{args.decoding_method}_{args.nbest}.out"
        run(cmd_all)
    else:
        run(shared_cmd + f" > gen_results/wmt16_cl_{args.ckpt_name}.out")
        # show the results
        run(f"tail -1 gen_results/wmt16_cl_{args.ckpt_name}.out")


def score_cmd():
    file_name = "wmt16_checkpoint_best"
    base_dir = "gen_results"
    run(
        f"grep ^T {base_dir}/{file_name}.out" + " | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' " + f"> {base_dir}/{file_name}.ref")
    run(
        f"grep ^H {base_dir}/{file_name}.out" + " | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' " + f"> {base_dir}/{file_name}.sys")

    run(f"python score_multi.py --beam 1 --sys {base_dir}/{file_name}.sys --ref {base_dir}/{file_name}.ref")


def score_multi_cmd():
    file_name = "wmt16_checkpoint_best_div_beam_search_16"
    base_dir = "gen_results_multi"
    run(
        f"grep ^T {base_dir}/{file_name}.out" + " | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' " + f"> {base_dir}/{file_name}.ref")
    run(
        f"grep ^H {base_dir}/{file_name}.out" + " | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' " + f"> {base_dir}/{file_name}.sys")
    run(f"wc -l {base_dir}/{file_name}.sys")
    run(f"wc -l {base_dir}/{file_name}.ref")
    run(
        f"python score_multi.py --beam {file_name.split('_')[-1]} --sys {base_dir}/{file_name}.sys --ref {base_dir}/{file_name}.ref")


def train_cl_cmd():
    model_dict = "skip_warmup_ckpts/wmt16-en-de/model-warmed-up.pt"
    cmd = f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py --task translation_cl \
    data-bin/wmt16.tokenized.{args.source}-{args.target}  \
    --save-dir checkpoints/wmt16_{args.cl_loss}{args.n_gram}-{args.source}-{args.target} \
    --criterion label_smoothed_cross_entropy \
    --max-epoch 150 \
    --arch transformer_clg_wmt \
    --share-decoder-input-output-embed \
    --optimizer adam --adam-betas '(0.9,0.98)' \
    --lr 5e-4 --lr-scheduler inverse_sqrt  --seed 2022  \
    --stop-min-lr '1e-09' --warmup-updates 4000 \
    --warmup-init-lr '1e-07' --label-smoothing 0.1 \
    --dropout 0.3 --weight-decay 0.0001 \
    --diverse_bias 3.0  --find-unused-parameters \
    --log-interval 10 --cl_loss {args.cl_loss} --n_gram {args.n_gram}  --skip_warmup_ckpt {model_dict} \
    --max-tokens {10000}  --update-freq 1  " + " --eval-bleu  \
    --max_len_a 1 --max_len_b 50 --lenpen 0.1 --beam_size 14 \
    --eval-bleu-detok moses --keep-best-checkpoints 5  --validate-interval-updates 60 --save-interval-updates 60 --keep-interval-updates 2 \
    --eval-bleu-remove-bpe \
    --eval-bleu-print-samples   \
    --best-checkpoint-metric bleu --maximize-best-checkpoint-metric  \
    --fp16  "
    run(cmd)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='run cmd python wrapper'
    )
    # training config
    parser.add_argument('--mode', choices=["preprocess", "train", "train_cl", "gen", "gen_cl", "score"])
    parser.add_argument('--source', default="en")
    parser.add_argument('--target', default="de")
    parser.add_argument('--cl_loss', default="ranking")
    parser.add_argument('--n_gram', default=2, type=int)
    # gen parameter
    parser.add_argument('--avg_ckpt', default="False")
    parser.add_argument('--ckpt_dir', default="checkpoints/wmt16_ranking3-en-de")
    parser.add_argument('--ckpt_name', default="checkpoint_best")
    parser.add_argument('--nbest', default=1, type=int)
    parser.add_argument('--gpu', default="0")
    parser.add_argument('--beam', default=5, type=int)
    parser.add_argument('--decoding_method', default="beam_search",
                        choices=["beam_search", "div_sibling_search", "div_beam_search"])

    # no need to set
    parser.add_argument('--diversity_rate', default=-1.0, type=float)
    parser.add_argument("--diverse_beam_strength", default=3.5, type=float)
    parser.add_argument('--diverse_beam_groups', default=-1)
    args = parser.parse_args()
    if args.decoding_method == "div_sibling_search":
        args.diversity_rate = 1.0
    elif args.decoding_method == "div_beam_search":
        args.diverse_beam_groups = args.beam

    eval(f"{args.mode}_cmd()")
