#!/usr/bin/env python3 -u

import argparse
import fileinput
import logging
import os
import sys

from fairseq.models.transformer import TransformerModel


logging.getLogger().setLevel(logging.INFO)


def main():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--en2fr', required=True,
                        help='path to en2fr model')
    parser.add_argument('--fr2en', required=True,
                        help='path to fr2en mixture of experts model')
    parser.add_argument('--user-dir',
                        help='path to fairseq examples/translation_moe/src directory')
    parser.add_argument('--num-experts', type=int, default=10,
                        help='(keep at 10 unless using a different model)')
    parser.add_argument('files', nargs='*', default=['-'],
                        help='input files to paraphrase; "-" for stdin')
    args = parser.parse_args()

    if args.user_dir is None:
        args.user_dir = os.path.join(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__))),  # examples/
            'translation_moe',
            'src',
        )
        if os.path.exists(args.user_dir):
            logging.info('found user_dir:' + args.user_dir)
        else:
            raise RuntimeError(
                'cannot find fairseq examples/translation_moe/src '
                '(tried looking here: {})'.format(args.user_dir)
            )

    logging.info('loading en2fr model from:' + args.en2fr)
    en2fr = TransformerModel.from_pretrained(
        model_name_or_path=args.en2fr,
        tokenizer='moses',
        bpe='sentencepiece',
    ).eval()

    logging.info('loading fr2en model from:' + args.fr2en)
    fr2en = TransformerModel.from_pretrained(
        model_name_or_path=args.fr2en,
        tokenizer='moses',
        bpe='sentencepiece',
        user_dir=args.user_dir,
        task='translation_moe',
    ).eval()

    def gen_paraphrases(en):
        fr = en2fr.translate(en)
        return [
            fr2en.translate(fr, inference_step_args={'expert': i})
            for i in range(args.num_experts)
        ]

    logging.info('Type the input sentence and press return:')
    for line in fileinput.input(args.files):
        line = line.strip()
        if len(line) == 0:
            continue
        for paraphrase in gen_paraphrases(line):
            print(paraphrase)


if __name__ == '__main__':
    main()
