#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Extracts random constraints from reference files."""

import argparse
import random
import sys


def get_phrase(words, index, length):
    assert index < len(words) - length + 1
    phr = " ".join(words[index : index + length])
    for i in range(index, index + length):
        words.pop(index)
    return phr


def main(args):

    if args.seed:
        random.seed(args.seed)

    for line in sys.stdin:
        constraints = []

        def add_constraint(constraint):
            constraints.append(constraint)

        source = line.rstrip()
        if "\t" in line:
            source, target = line.split("\t")
            if args.add_sos:
                target = f"<s> {target}"
            if args.add_eos:
                target = f"{target} </s>"

            if len(target.split()) >= args.len:
                words = [target]

                num = args.number

                choices = {}
                for i in range(num):
                    if len(words) == 0:
                        break
                    segmentno = random.choice(range(len(words)))
                    segment = words.pop(segmentno)
                    tokens = segment.split()
                    phrase_index = random.choice(range(len(tokens)))
                    choice = " ".join(
                        tokens[phrase_index : min(len(tokens), phrase_index + args.len)]
                    )
                    for j in range(
                        phrase_index, min(len(tokens), phrase_index + args.len)
                    ):
                        tokens.pop(phrase_index)
                    if phrase_index > 0:
                        words.append(" ".join(tokens[0:phrase_index]))
                    if phrase_index + 1 < len(tokens):
                        words.append(" ".join(tokens[phrase_index:]))
                    choices[target.find(choice)] = choice

                    # mask out with spaces
                    target = target.replace(choice, " " * len(choice), 1)

                for key in sorted(choices.keys()):
                    add_constraint(choices[key])

        print(source, *constraints, sep="\t")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases")
    parser.add_argument("--len", "-l", type=int, default=1, help="phrase length")
    parser.add_argument(
        "--add-sos", default=False, action="store_true", help="add <s> token"
    )
    parser.add_argument(
        "--add-eos", default=False, action="store_true", help="add </s> token"
    )
    parser.add_argument("--seed", "-s", default=0, type=int)
    args = parser.parse_args()

    main(args)
