import click
import logging
import ujson as json
from collections import namedtuple
from tqdm import tqdm
import torch
import transformers

import sys
import os
sys.path.append(os.path.dirname(__file__) + "./UniEval/")
from metric.evaluator import get_evaluator as get_unieval_evaluator
from utils import convert_to_json as unieval_convert_to_json


logger = logging.getLogger(__file__)

'''
def extract_first_conv(conv):
    if conv[0]["from"] == "gpt":
        # if first round begins by gpt
        input_conv_str = f"{conv[0]['value']}\n{conv[1]['value']}"
        output_conv_str = conv[2]["value"]
    elif conv[0]["from"] == "human":
        # if first round begins by human
        input_conv_str = conv[0]["value"]
        output_conv_str = conv[1]["value"]
    else:
        pass

    return input_conv_str, output_conv_str
'''
from extract_length import extract_first_conv

def run(args):
    with open(args.input) as fd:
        raw_data = json.load(fd)

    evaluator = get_unieval_evaluator("dialogue")

    unieval_input = []
    unieval_output = []
    for item in raw_data:
        input_conv_str, output_conv_str = extract_first_conv(item["conversations"])
        unieval_input.append(input_conv_str)
        unieval_output.append(output_conv_str)
    unieval_data = unieval_convert_to_json(
        output_list=unieval_output,
        src_list=unieval_input,
        context_list=["" for _ in unieval_input],
    )
    for dim in [
        "understandability",
        "naturalness",
        "coherence",
    ]:
        eval_scores = evaluator.evaluate(
            unieval_data,
            dims=[dim],
        )
        for i, item in enumerate(raw_data):
            item[dim] = eval_scores[i][dim]
        with open(args.output + "-" + dim, "w") as fd:
            json.dump(raw_data, fd, indent=1)

    with open(args.output, "w") as fd:
        json.dump(raw_data, fd, indent=1)


@click.command()
@click.option(
    "-i",
    "--input",
    default="",
)
@click.option(
    "-o", "--output", default="n"
)
def main(**kwargs):
    Arg = namedtuple("Arg", kwargs.keys())
    args = Arg(**kwargs)
    logger.info(kwargs)
    run(args)


if __name__ == "__main__":
    logging.basicConfig(
        filename=__file__ + ".log",
        filemode="a",
        format="[%(levelname)-5.5s][%(asctime)s][%(filename)s %(lineno)d]: %(message)s",
        datefmt="%d-%m-%Y %H:%M:%S",
        level=logging.DEBUG,
    )
    logging.getLogger().addHandler(logging.StreamHandler())
    main()
