# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
from contextlib import redirect_stdout

from fairseq import options
from fairseq_cli import generate

from examples.noisychannel import rerank_options, rerank_utils


def score_bw(args):
    if args.backwards1:
        scorer1_src = args.target_lang
        scorer1_tgt = args.source_lang
    else:
        scorer1_src = args.source_lang
        scorer1_tgt = args.target_lang

    if args.score_model2 is not None:
        if args.backwards2:
            scorer2_src = args.target_lang
            scorer2_tgt = args.source_lang
        else:
            scorer2_src = args.source_lang
            scorer2_tgt = args.target_lang

    rerank1_is_gen = (
        args.gen_model == args.score_model1 and args.source_prefix_frac is None
    )
    rerank2_is_gen = (
        args.gen_model == args.score_model2 and args.source_prefix_frac is None
    )

    (
        pre_gen,
        left_to_right_preprocessed_dir,
        right_to_left_preprocessed_dir,
        backwards_preprocessed_dir,
        lm_preprocessed_dir,
    ) = rerank_utils.get_directories(
        args.data_dir_name,
        args.num_rescore,
        args.gen_subset,
        args.gen_model_name,
        args.shard_id,
        args.num_shards,
        args.sampling,
        args.prefix_len,
        args.target_prefix_frac,
        args.source_prefix_frac,
    )

    score1_file = rerank_utils.rescore_file_name(
        pre_gen,
        args.prefix_len,
        args.model1_name,
        target_prefix_frac=args.target_prefix_frac,
        source_prefix_frac=args.source_prefix_frac,
        backwards=args.backwards1,
    )

    if args.score_model2 is not None:
        score2_file = rerank_utils.rescore_file_name(
            pre_gen,
            args.prefix_len,
            args.model2_name,
            target_prefix_frac=args.target_prefix_frac,
            source_prefix_frac=args.source_prefix_frac,
            backwards=args.backwards2,
        )

    if args.right_to_left1:
        rerank_data1 = right_to_left_preprocessed_dir
    elif args.backwards1:
        rerank_data1 = backwards_preprocessed_dir
    else:
        rerank_data1 = left_to_right_preprocessed_dir

    gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
    if not rerank1_is_gen and not os.path.isfile(score1_file):
        print("STEP 4: score the translations for model 1")

        model_param1 = [
            "--path",
            args.score_model1,
            "--source-lang",
            scorer1_src,
            "--target-lang",
            scorer1_tgt,
        ]
        gen_model1_param = [rerank_data1] + gen_param + model_param1

        gen_parser = options.get_generation_parser()
        input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)

        with open(score1_file, "w") as f:
            with redirect_stdout(f):
                generate.main(input_args)

    if (
        args.score_model2 is not None
        and not os.path.isfile(score2_file)
        and not rerank2_is_gen
    ):
        print("STEP 4: score the translations for model 2")

        if args.right_to_left2:
            rerank_data2 = right_to_left_preprocessed_dir
        elif args.backwards2:
            rerank_data2 = backwards_preprocessed_dir
        else:
            rerank_data2 = left_to_right_preprocessed_dir

        model_param2 = [
            "--path",
            args.score_model2,
            "--source-lang",
            scorer2_src,
            "--target-lang",
            scorer2_tgt,
        ]
        gen_model2_param = [rerank_data2] + gen_param + model_param2

        gen_parser = options.get_generation_parser()
        input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)

        with open(score2_file, "w") as f:
            with redirect_stdout(f):
                generate.main(input_args)


def cli_main():
    parser = rerank_options.get_reranking_parser()
    args = options.parse_args_and_arch(parser)
    score_bw(args)


if __name__ == "__main__":
    cli_main()
