#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import fileinput

from tqdm import tqdm


def main():
    parser = argparse.ArgumentParser(
        description=(
            "Extract back-translations from the stdout of fairseq-generate. "
            "If there are multiply hypotheses for a source, we only keep the first one. "
        )
    )
    parser.add_argument("--output", required=True, help="output prefix")
    parser.add_argument(
        "--srclang", required=True, help="source language (extracted from H-* lines)"
    )
    parser.add_argument(
        "--tgtlang", required=True, help="target language (extracted from S-* lines)"
    )
    parser.add_argument("--minlen", type=int, help="min length filter")
    parser.add_argument("--maxlen", type=int, help="max length filter")
    parser.add_argument("--ratio", type=float, help="ratio filter")
    parser.add_argument("files", nargs="*", help="input files")
    args = parser.parse_args()

    def validate(src, tgt):
        srclen = len(src.split(" ")) if src != "" else 0
        tgtlen = len(tgt.split(" ")) if tgt != "" else 0
        if (
            (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
            or (
                args.maxlen is not None
                and (srclen > args.maxlen or tgtlen > args.maxlen)
            )
            or (
                args.ratio is not None
                and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
            )
        ):
            return False
        return True

    def safe_index(toks, index, default):
        try:
            return toks[index]
        except IndexError:
            return default

    with open(args.output + "." + args.srclang, "w") as src_h, open(
        args.output + "." + args.tgtlang, "w"
    ) as tgt_h:
        for line in tqdm(fileinput.input(args.files)):
            if line.startswith("S-"):
                tgt = safe_index(line.rstrip().split("\t"), 1, "")
            elif line.startswith("H-"):
                if tgt is not None:
                    src = safe_index(line.rstrip().split("\t"), 2, "")
                    if validate(src, tgt):
                        print(src, file=src_h)
                        print(tgt, file=tgt_h)
                    tgt = None


if __name__ == "__main__":
    main()
