import click
import logging
import ujson as json
from collections import namedtuple
from tqdm import tqdm
import torch
import transformers

logger = logging.getLogger(__file__)

def extract_first_conv(conv):
    input_conv_str = ""
    output_conv_str = ""
    if conv[0]["from"] == "gpt":
        # if first round begins by gpt
        if len(conv) >=3:
            input_conv_str = f"{conv[0]['value']}\n{conv[1]['value']}"
            output_conv_str = conv[2]['value']
        else:
            output_conv_str = conv[0]['value']
    elif conv[0]["from"] == "human":
        # if first round begins by human
        input_conv_str = conv[0]['value']
        output_conv_str = conv[1]['value']

    return input_conv_str, output_conv_str

def run(args):
    with open(args.input) as fd:
        raw_data = json.load(fd)
    model_name = "meta-llama/Llama-2-7b-hf"
    # load tokenizer model
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_name,
        max_length=4096,
    )

    validated_data = []
    for one_data in raw_data:
        if not one_data["conversations"]:
            continue
        same_role_flag = False
        wrong_from_flag = False
        last_from = None
        for sent in one_data["conversations"]:
            if sent["from"] == last_from:
                same_role_flag = True
            if sent["from"] not in ["human", "gpt"]:
                wrong_from_flag = True
            last_from = sent["from"]
        if same_role_flag or wrong_from_flag:
            continue
        validated_data.append(one_data)


    save_data = []
    for one_data in tqdm(validated_data):
        inputs, outputs = extract_first_conv(one_data['conversations'])
        input_len = len(tokenizer.tokenize(inputs))
        output_len = len(tokenizer.tokenize(outputs))
        one_data["input_length"] = float(input_len)
        one_data["output_length"] = float(output_len)
        save_data.append(one_data)
    
    with open(args.output, "w") as fd:
        json.dump(save_data, fd, indent=1)


@click.command()
@click.option('-i', '--input', default="")
@click.option('-o', '--output', default="")
def main(**kwargs):
    Arg = namedtuple('Arg', kwargs.keys())
    args = Arg(**kwargs)
    logger.info(kwargs)
    run(args)


if __name__ == '__main__':
    logging.basicConfig(filename=__file__ + '.log',
                        filemode='a',
                        format='[%(levelname)-5.5s][%(asctime)s][%(filename)s %(lineno)d]: %(message)s',
                        datefmt='%d-%m-%Y %H:%M:%S',
                        level=logging.DEBUG)
    logging.getLogger().addHandler(logging.StreamHandler())
    main()
