# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
from multiprocessing import Pool

import numpy as np
from fairseq import options
from fairseq.data import dictionary
from fairseq.scoring import bleu

from examples.noisychannel import (
    rerank_generate,
    rerank_options,
    rerank_score_bw,
    rerank_score_lm,
    rerank_utils,
)


def score_target_hypo(
    args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize
):

    print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
    gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
    dict = dictionary.Dictionary()
    scorer = scorer = bleu.Scorer(
        bleu.BleuConfig(
            pad=dict.pad(),
            eos=dict.eos(),
            unk=dict.unk(),
        )
    )

    ordered_hypos = {}
    ordered_targets = {}

    for shard_id in range(len(bitext1_lst)):
        bitext1 = bitext1_lst[shard_id]
        bitext2 = bitext2_lst[shard_id]
        gen_output = gen_output_lst[shard_id]
        lm_res = lm_res_lst[shard_id]

        total = len(bitext1.rescore_source.keys())
        source_lst = []
        hypo_lst = []
        score_lst = []
        reference_lst = []
        j = 1
        best_score = -math.inf

        for i in range(total):
            # length is measured in terms of words, not bpe tokens, since models may not share the same bpe
            target_len = len(bitext1.rescore_hypo[i].split())

            if lm_res is not None:
                lm_score = lm_res.score[i]
            else:
                lm_score = 0

            if bitext2 is not None:
                bitext2_score = bitext2.rescore_score[i]
                bitext2_backwards = bitext2.backwards
            else:
                bitext2_score = None
                bitext2_backwards = None

            score = rerank_utils.get_score(
                a,
                b,
                c,
                target_len,
                bitext1.rescore_score[i],
                bitext2_score,
                lm_score=lm_score,
                lenpen=lenpen,
                src_len=bitext1.source_lengths[i],
                tgt_len=bitext1.target_lengths[i],
                bitext1_backwards=bitext1.backwards,
                bitext2_backwards=bitext2_backwards,
                normalize=normalize,
            )

            if score > best_score:
                best_score = score
                best_hypo = bitext1.rescore_hypo[i]

            if j == gen_output.num_hypos[i] or j == args.num_rescore:
                j = 1
                hypo_lst.append(best_hypo)
                score_lst.append(best_score)
                source_lst.append(bitext1.rescore_source[i])
                reference_lst.append(bitext1.rescore_target[i])

                best_score = -math.inf
                best_hypo = ""
            else:
                j += 1

        gen_keys = list(sorted(gen_output.no_bpe_target.keys()))

        for key in range(len(gen_keys)):
            if args.prefix_len is None:
                assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
                    "pred and rescore hypo mismatch: i: "
                    + str(key)
                    + ", "
                    + str(hypo_lst[key])
                    + str(gen_keys[key])
                    + str(gen_output.no_bpe_hypo[key])
                )
                sys_tok = dict.encode_line(hypo_lst[key])
                ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
                scorer.add(ref_tok, sys_tok)

            else:
                full_hypo = rerank_utils.get_full_from_prefix(
                    hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
                )
                sys_tok = dict.encode_line(full_hypo)
                ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
                scorer.add(ref_tok, sys_tok)

        # if only one set of hyper parameters is provided, write the predictions to a file
        if write_hypos:
            # recover the orinal ids from n best list generation
            for key in range(len(gen_output.no_bpe_target)):
                if args.prefix_len is None:
                    assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
                        "pred and rescore hypo mismatch:"
                        + "i:"
                        + str(key)
                        + str(hypo_lst[key])
                        + str(gen_output.no_bpe_hypo[key])
                    )
                    ordered_hypos[gen_keys[key]] = hypo_lst[key]
                    ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
                        gen_keys[key]
                    ]

                else:
                    full_hypo = rerank_utils.get_full_from_prefix(
                        hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
                    )
                    ordered_hypos[gen_keys[key]] = full_hypo
                    ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
                        gen_keys[key]
                    ]

    # write the hypos in the original order from nbest list generation
    if args.num_shards == (len(bitext1_lst)):
        with open(target_outfile, "w") as t:
            with open(hypo_outfile, "w") as h:
                for key in range(len(ordered_hypos)):
                    t.write(ordered_targets[key])
                    h.write(ordered_hypos[key])

    res = scorer.result_string(4)
    if write_hypos:
        print(res)
    score = rerank_utils.parse_bleu_scoring(res)
    return score


def match_target_hypo(args, target_outfile, hypo_outfile):
    """combine scores from the LM and bitext models, and write the top scoring hypothesis to a file"""
    if len(args.weight1) == 1:
        res = score_target_hypo(
            args,
            args.weight1[0],
            args.weight2[0],
            args.weight3[0],
            args.lenpen[0],
            target_outfile,
            hypo_outfile,
            True,
            args.normalize,
        )
        rerank_scores = [res]
    else:
        print("launching pool")
        with Pool(32) as p:
            rerank_scores = p.starmap(
                score_target_hypo,
                [
                    (
                        args,
                        args.weight1[i],
                        args.weight2[i],
                        args.weight3[i],
                        args.lenpen[i],
                        target_outfile,
                        hypo_outfile,
                        False,
                        args.normalize,
                    )
                    for i in range(len(args.weight1))
                ],
            )

    if len(rerank_scores) > 1:
        best_index = np.argmax(rerank_scores)
        best_score = rerank_scores[best_index]
        print("best score", best_score)
        print("best lenpen", args.lenpen[best_index])
        print("best weight1", args.weight1[best_index])
        print("best weight2", args.weight2[best_index])
        print("best weight3", args.weight3[best_index])
        return (
            args.lenpen[best_index],
            args.weight1[best_index],
            args.weight2[best_index],
            args.weight3[best_index],
            best_score,
        )

    else:
        return (
            args.lenpen[0],
            args.weight1[0],
            args.weight2[0],
            args.weight3[0],
            rerank_scores[0],
        )


def load_score_files(args):
    if args.all_shards:
        shard_ids = list(range(args.num_shards))
    else:
        shard_ids = [args.shard_id]

    gen_output_lst = []
    bitext1_lst = []
    bitext2_lst = []
    lm_res1_lst = []

    for shard_id in shard_ids:
        using_nbest = args.nbest_list is not None
        (
            pre_gen,
            left_to_right_preprocessed_dir,
            right_to_left_preprocessed_dir,
            backwards_preprocessed_dir,
            lm_preprocessed_dir,
        ) = rerank_utils.get_directories(
            args.data_dir_name,
            args.num_rescore,
            args.gen_subset,
            args.gen_model_name,
            shard_id,
            args.num_shards,
            args.sampling,
            args.prefix_len,
            args.target_prefix_frac,
            args.source_prefix_frac,
        )

        rerank1_is_gen = (
            args.gen_model == args.score_model1 and args.source_prefix_frac is None
        )
        rerank2_is_gen = (
            args.gen_model == args.score_model2 and args.source_prefix_frac is None
        )

        score1_file = rerank_utils.rescore_file_name(
            pre_gen,
            args.prefix_len,
            args.model1_name,
            target_prefix_frac=args.target_prefix_frac,
            source_prefix_frac=args.source_prefix_frac,
            backwards=args.backwards1,
        )
        if args.score_model2 is not None:
            score2_file = rerank_utils.rescore_file_name(
                pre_gen,
                args.prefix_len,
                args.model2_name,
                target_prefix_frac=args.target_prefix_frac,
                source_prefix_frac=args.source_prefix_frac,
                backwards=args.backwards2,
            )
        if args.language_model is not None:
            lm_score_file = rerank_utils.rescore_file_name(
                pre_gen, args.prefix_len, args.lm_name, lm_file=True
            )

        # get gen output
        predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
        if using_nbest:
            print("Using predefined n-best list from interactive.py")
            predictions_bpe_file = args.nbest_list
        gen_output = rerank_utils.BitextOutputFromGen(
            predictions_bpe_file,
            bpe_symbol=args.post_process,
            nbest=using_nbest,
            prefix_len=args.prefix_len,
            target_prefix_frac=args.target_prefix_frac,
        )

        if rerank1_is_gen:
            bitext1 = gen_output
        else:
            bitext1 = rerank_utils.BitextOutput(
                score1_file,
                args.backwards1,
                args.right_to_left1,
                args.post_process,
                args.prefix_len,
                args.target_prefix_frac,
                args.source_prefix_frac,
            )

        if args.score_model2 is not None or args.nbest_list is not None:
            if rerank2_is_gen:
                bitext2 = gen_output
            else:
                bitext2 = rerank_utils.BitextOutput(
                    score2_file,
                    args.backwards2,
                    args.right_to_left2,
                    args.post_process,
                    args.prefix_len,
                    args.target_prefix_frac,
                    args.source_prefix_frac,
                )

                assert (
                    bitext2.source_lengths == bitext1.source_lengths
                ), "source lengths for rescoring models do not match"
                assert (
                    bitext2.target_lengths == bitext1.target_lengths
                ), "target lengths for rescoring models do not match"
        else:
            if args.diff_bpe:
                assert args.score_model2 is None
                bitext2 = gen_output
            else:
                bitext2 = None

        if args.language_model is not None:
            lm_res1 = rerank_utils.LMOutput(
                lm_score_file,
                args.lm_dict,
                args.prefix_len,
                args.post_process,
                args.target_prefix_frac,
            )
        else:
            lm_res1 = None

        gen_output_lst.append(gen_output)
        bitext1_lst.append(bitext1)
        bitext2_lst.append(bitext2)
        lm_res1_lst.append(lm_res1)
    return gen_output_lst, bitext1_lst, bitext2_lst, lm_res1_lst


def rerank(args):
    if type(args.lenpen) is not list:
        args.lenpen = [args.lenpen]
    if type(args.weight1) is not list:
        args.weight1 = [args.weight1]
    if type(args.weight2) is not list:
        args.weight2 = [args.weight2]
    if type(args.weight3) is not list:
        args.weight3 = [args.weight3]
    if args.all_shards:
        shard_ids = list(range(args.num_shards))
    else:
        shard_ids = [args.shard_id]

    for shard_id in shard_ids:
        (
            pre_gen,
            left_to_right_preprocessed_dir,
            right_to_left_preprocessed_dir,
            backwards_preprocessed_dir,
            lm_preprocessed_dir,
        ) = rerank_utils.get_directories(
            args.data_dir_name,
            args.num_rescore,
            args.gen_subset,
            args.gen_model_name,
            shard_id,
            args.num_shards,
            args.sampling,
            args.prefix_len,
            args.target_prefix_frac,
            args.source_prefix_frac,
        )
        rerank_generate.gen_and_reprocess_nbest(args)
        rerank_score_bw.score_bw(args)
        rerank_score_lm.score_lm(args)

        if args.write_hypos is None:
            write_targets = pre_gen + "/matched_targets"
            write_hypos = pre_gen + "/matched_hypos"
        else:
            write_targets = args.write_hypos + "_targets" + args.gen_subset
            write_hypos = args.write_hypos + "_hypos" + args.gen_subset

    if args.all_shards:
        write_targets += "_all_shards"
        write_hypos += "_all_shards"

    (
        best_lenpen,
        best_weight1,
        best_weight2,
        best_weight3,
        best_score,
    ) = match_target_hypo(args, write_targets, write_hypos)

    return best_lenpen, best_weight1, best_weight2, best_weight3, best_score


def cli_main():
    parser = rerank_options.get_reranking_parser()
    args = options.parse_args_and_arch(parser)
    rerank(args)


if __name__ == "__main__":
    cli_main()
