import json

import mlxu

from EasyLM.serving import LMClient

FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
    input_file="",
    output_file="",
    prefix_field="prefix",
    text_field="text",
    until_field="until",
    eval_type="loglikelihood",
    lm_client=LMClient.get_default_config(),
)


def main(argv):
    lm_client = LMClient(FLAGS.lm_client)
    with mlxu.open_file(FLAGS.input_file, "r") as fin:
        input_data = json.load(fin)

    if FLAGS.eval_type == "loglikelihood":
        prefix = input_data[FLAGS.prefix_field]
        text = input_data[FLAGS.text_field]
        loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
        output_data = {
            "loglikelihood": loglikelihoods,
            "is_greedy": is_greedys,
        }
    elif FLAGS.eval_type == "loglikelihood_rolling":
        text = input_data[FLAGS.text_field]
        loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
        output_data = {
            "loglikelihood": loglikelihoods,
            "is_greedy": is_greedys,
        }
    elif FLAGS.eval_type == "greedy_until":
        prefix = input_data[FLAGS.prefix_field]
        until = input_data[FLAGS.until_field]
        output_data = {"output_text": lm_client.greedy_until(prefix, until)}
    elif FLAGS.eval_type == "generate":
        prefix = input_data[FLAGS.prefix_field]
        output_data = {"output_text": lm_client.generate(prefix)}
    else:
        raise ValueError(f"Unknown eval_type: {FLAGS.eval_type}")

    with mlxu.open_file(FLAGS.output_file, "w") as fout:
        json.dump(output_data, fout)


if __name__ == "__main__":
    mlxu.run(main)
