#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import fileinput

from tqdm import tqdm


def main():
    parser = argparse.ArgumentParser(description=(
        'Extract back-translations from the stdout of fairseq-generate. '
        'If there are multiply hypotheses for a source, we only keep the first one. '
    ))
    parser.add_argument('--output', required=True, help='output prefix')
    parser.add_argument('--srclang', required=True, help='source language (extracted from H-* lines)')
    parser.add_argument('--tgtlang', required=True, help='target language (extracted from S-* lines)')
    parser.add_argument('--minlen', type=int, help='min length filter')
    parser.add_argument('--maxlen', type=int, help='max length filter')
    parser.add_argument('--ratio', type=float, help='ratio filter')
    parser.add_argument('files', nargs='*', help='input files')
    args = parser.parse_args()

    def validate(src, tgt):
        srclen = len(src.split(' ')) if src != '' else 0
        tgtlen = len(tgt.split(' ')) if tgt != '' else 0
        if (
            (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
            or (args.maxlen is not None and (srclen > args.maxlen or tgtlen > args.maxlen))
            or (args.ratio is not None and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio))
        ):
            return False
        return True

    def safe_index(toks, index, default):
        try:
            return toks[index]
        except IndexError:
            return default

    with open(args.output + '.' + args.srclang, 'w') as src_h, \
            open(args.output + '.' + args.tgtlang, 'w') as tgt_h:
        for line in tqdm(fileinput.input(args.files)):
            if line.startswith('S-'):
                tgt = safe_index(line.rstrip().split('\t'), 1, '')
            elif line.startswith('H-'):
                if tgt is not None:
                    src = safe_index(line.rstrip().split('\t'), 2, '')
                    if validate(src, tgt):
                        print(src, file=src_h)
                        print(tgt, file=tgt_h)
                    tgt = None


if __name__ == '__main__':
    main()
