import click
import logging
import ujson as json
from collections import namedtuple
from lexicalrichness import LexicalRichness
from tqdm import tqdm
import transformers

logger = logging.getLogger(__file__)

'''
def extract_first_conv(conv):
    if conv[0]["from"] == "gpt":
        # if first round begins by gpt
        input_conv_str = f"{conv[0]['value']}\n{conv[1]['value']}"
        output_conv_str = conv[2]["value"]
        conv_str = f"{conv[0]['value']}\n{conv[1]['value']}\n{conv[2]['value']}"
    elif conv[0]["from"] == "human":
        # if first round begins by human
        input_conv_str = conv[0]["value"]
        output_conv_str = conv[1]["value"]
        conv_str = f"{conv[0]['value']}\n{conv[1]['value']}"
    else:
        pass

    return input_conv_str, output_conv_str, conv_str
'''
from extract_length import extract_first_conv

def run(args):
    with open(args.input) as fd:
        raw_data = json.load(fd)

    validated_data = []
    for one_data in raw_data:
        if not one_data["conversations"]:
            continue
        same_role_flag = False
        wrong_from_flag = False
        last_from = None
        for sent in one_data["conversations"]:
            if sent["from"] == last_from:
                same_role_flag = True
            if sent["from"] not in ["human", "gpt"]:
                wrong_from_flag = True
            last_from = sent["from"]
        if same_role_flag or wrong_from_flag:
            continue
        validated_data.append(one_data)

    save_data = []
    for one_data in tqdm(validated_data):
        input_str, output_str= extract_first_conv(
            one_data["conversations"]
        )
        first_round_str=input_str + "\n" + output_str
        input_richness, output_richness, fsr_richness = (
            #LexicalRichness(input_str),
            #LexicalRichness(output_str),
            None, None, LexicalRichness(first_round_str),
        )
        #one_data["input_mtld"] = input_richness.mtld()
        #one_data["output_mtld"] = output_richness.mtld()
        one_data["first_round_mtld"] = fsr_richness.mtld()
        save_data.append(one_data)

    with open(args.output, "w") as fd:
        json.dump(save_data, fd, indent=1)


@click.command()
@click.option(
    "-i",
    "--input",
    default="",
)
@click.option(
    "-o", "--output", default="n"
)
def main(**kwargs):
    Arg = namedtuple("Arg", kwargs.keys())
    args = Arg(**kwargs)
    logger.info(kwargs)
    run(args)


if __name__ == "__main__":
    logging.basicConfig(
        filename=__file__ + ".log",
        filemode="a",
        format="[%(levelname)-5.5s][%(asctime)s][%(filename)s %(lineno)d]: %(message)s",
        datefmt="%d-%m-%Y %H:%M:%S",
        level=logging.DEBUG,
    )
    logging.getLogger().addHandler(logging.StreamHandler())
    main()
