import click
import logging
import ujson as json
from collections import namedtuple
import sys
import os
from tqdm import tqdm
import torch
import transformers
from pynndescent import NNDescent
from sentence_transformers import SentenceTransformer

logger = logging.getLogger(__file__)

'''
def extract_first_conv(conv):
    if conv[0]["from"] == "gpt":
        # if first round begins by gpt
        input_conv_str = f"{conv[0]['value']}\n{conv[1]['value']}"
        output_conv_str = conv[2]['value']
    elif conv[0]["from"] == "human":
        # if first round begins by human
        input_conv_str = conv[0]['value']
        output_conv_str = conv[1]['value']
    else:
        pass

    return input_conv_str, output_conv_str
'''
from extract_length import extract_first_conv


def run(args):
    sent_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    print(sent_model)
    sent_model.to(torch.device("cuda:0"))
    print(sent_model.device)

    with open(args.input) as fd:
        raw_data = json.load(fd)

    db_dataset = []
    for one_data in tqdm(raw_data, desc="preprocess"):
        input_conv_str, output_conv_str = extract_first_conv(one_data["conversations"])
        db_dataset.append(input_conv_str)
    #db_dataset = db_dataset[:10]
    print(len(db_dataset))

    db_embeddings = sent_model.encode(db_dataset, show_progress_bar=True)
    print("start nndesent", len(db_embeddings))

    anns_index = NNDescent(db_embeddings)
    ann_distance = anns_index.neighbor_graph[1]
    print("NNDesent done")

    #ann_index, ann_distance = anns_index.query(db_embeddings, k=6)
    #print("query done")

    save_data = []
    for i, one_data in enumerate(tqdm(raw_data, desc="save data")):
        one_data["knn_6"] = float(ann_distance[i, 6 - 1])
        save_data.append(one_data)

    with open(args.output, "w") as fd:
        json.dump(save_data, fd, indent=1)


@click.command()
@click.option(
    "-i", "--input", default=""
)
@click.option(
    "-o",
    "--output",
    default="",
)
def main(**kwargs):
    Arg = namedtuple("Arg", kwargs.keys())
    args = Arg(**kwargs)
    logger.info(kwargs)
    run(args)


if __name__ == "__main__":
    logging.basicConfig(
        filename=__file__ + ".log",
        filemode="a",
        format="[%(levelname)-5.5s][%(asctime)s][%(filename)s %(lineno)d]: %(message)s",
        datefmt="%d-%m-%Y %H:%M:%S",
        level=logging.INFO,
    )
    logging.getLogger().addHandler(logging.StreamHandler())
    main()
