# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse

from client import SimulSTEvaluationService, SimulSTLocalEvaluationService
from fairseq.registry import REGISTRIES
from agents import build_agent

DEFAULT_HOSTNAME = 'localhost'
DEFAULT_PORT = 12321


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME,
                        help='server hostname')
    parser.add_argument('--port', type=int, default=DEFAULT_PORT,
                        help='server port number')
    parser.add_argument('--agent-type', default='simul_trans_text',
                        help='Agent type')
    parser.add_argument('--scorer-type', default='text',
                        help='Scorer type')
    parser.add_argument('--start-idx', type=int, default=0,
                        help='Start index of the sentence to evaluate')
    parser.add_argument('--end-idx', type=int, default=float('inf'),
                        help='End index of the sentence to evaluate')
    parser.add_argument('--scores', action="store_true",
                        help='Request scores from server')
    parser.add_argument('--reset-server', action="store_true",
                        help='Reset the server')
    parser.add_argument('--num-threads', type=int, default=10,
                        help='Number of threads used by agent')
    parser.add_argument('--local', action="store_true", default=False,
                        help='Local evaluation')

    args, _ = parser.parse_known_args()

    for registry_name, REGISTRY in REGISTRIES.items():
        choice = getattr(args, registry_name, None)
        if choice is not None:
            cls = REGISTRY["registry"][choice]
            if hasattr(cls, "add_args"):
                cls.add_args(parser)
    args = parser.parse_args()

    return args


if __name__ == "__main__":
    args = get_args()

    if args.local:
        session = SimulSTLocalEvaluationService(args)
    else:
        session = SimulSTEvaluationService(args.hostname, args.port)

    if args.reset_server:
        session.new_session()

    if args.agent_type is not None:
        agent = build_agent(args)
        agent.decode(session, args.start_idx, args.end_idx, args.num_threads)

    if args.scores:
        session.get_scores()
    print(session.get_scores())
