#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Extracts random constraints from reference files."""

import argparse
import random
import sys
from sacrebleu import extract_ngrams


def get_phrase(words, index, length):
    assert(index < len(words) - length + 1)
    phr = ' '.join(words[index:index+length])
    for i in range(index, index + length):
        words.pop(index)
    return phr


def main(args):

    if args.seed:
        random.seed(args.seed)

    for line in sys.stdin:
        constraints = []

        def add_constraint(constraint):
            constraints.append(constraint)

        source = line.rstrip()
        if '\t' in line:
            source, target = line.split('\t')
            if args.add_sos:
                target = f"<s> {target}"
            if args.add_eos:
                target = f"{target} </s>"

            if len(target.split()) >= args.len:
                words = [target]

                num = args.number

                choices = {}
                for i in range(num):
                    if len(words) == 0:
                        break
                    segmentno = random.choice(range(len(words)))
                    segment = words.pop(segmentno)
                    tokens = segment.split()
                    phrase_index = random.choice(range(len(tokens)))
                    choice = " ".join(tokens[phrase_index:min(len(tokens), phrase_index + args.len)])
                    for j in range(phrase_index, min(len(tokens), phrase_index + args.len)):
                        tokens.pop(phrase_index)
                    if phrase_index > 0:
                        words.append(" ".join(tokens[0:phrase_index]))
                    if phrase_index + 1 < len(tokens):
                        words.append(" ".join(tokens[phrase_index:]))
                    choices[target.find(choice)] = choice

                    # mask out with spaces
                    target = target.replace(choice, " " * len(choice), 1)

                for key in sorted(choices.keys()):
                    add_constraint(choices[key])

        print(source, *constraints, sep="\t")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--number', '-n', type=int, default=1, help="number of phrases")
    parser.add_argument('--len', '-l', type=int, default=1, help="phrase length")
    parser.add_argument('--add-sos', default=False, action='store_true', help='add <s> token')
    parser.add_argument('--add-eos', default=False, action='store_true', help='add </s> token')
    parser.add_argument('--seed', "-s", default=0, type=int)
    args = parser.parse_args()

    main(args)
