import os
import gzip
import json
from tqdm import tqdm
from pathlib import Path
from typing import List
from presidio_analyzer import AnalyzerEngine
from conversation import (
    ConversationSet,
    DataConversation,
    Message,
    PiiResult,
    PiiResults,
)

analyzer = AnalyzerEngine()
relevant_entities = ["EMAIL_ADDRESS", "PHONE_NUMBER", "LOCATION", "NRP"]


def get_threads_from_tree(tree):
    threads: List[List[Message]] = []

    # DFS through the children
    child_threads = []

    role = "human" if tree["role"] == "prompter" else "gpt"
    text = tree["text"]
    lang = tree["lang"]
    if "pii" in tree["labels"]:
        pii = tree["labels"]["pii"]["value"]
        if pii != 0.0:
            print("Found PII: ", pii)
    else:
        pii = 1.0

    curr_msg = Message(sender=role, text=text, pii_results=pii)

    if lang != "en":
        return threads

    # Invoke Presidio
    if pii != 0.0:
        an_res = analyzer.analyze(text=text, entities=relevant_entities, language="en")
        pii_results = []

        for r in an_res:
            r.orig = text[r.start : r.end]
            pii_results.append(
                PiiResult(
                    entity=r.entity_type,
                    start=r.start,
                    end=r.end,
                    orig=r.orig,
                    new_value="<ANONYMIZED>",
                )
            )

        curr_msg.set_pii_result(PiiResults(pii_results))
    else:
        curr_msg.set_pii_result(PiiResults([]))

    if len(tree["replies"]) == 0:
        # This is a leaf node, return a thread with just the msg
        threads.append([curr_msg])
        return threads

    for child in tree["replies"]:
        child_threads.extend(get_threads_from_tree(child))
    # Prepend our msg to the child threads
    for child_thread in child_threads:
        threads.append([curr_msg] + child_thread)

    return threads


if __name__ == "__main__":
    input_file_path = Path("data/open_assistant/2023-04-12_oasst_ready.trees.jsonl.gz")
    if input_file_path.suffix == ".gz":
        file_in = gzip.open(str(input_file_path), mode="tr", encoding="UTF-8")
    else:
        file_in = input_file_path.open("r", encoding="UTF-8")

    conversation_list = []

    with file_in:
        # read one object per line
        for line in tqdm(file_in):
            dict_tree = json.loads(line)
            # manual parsing of data now goes here ...

            threads = get_threads_from_tree(dict_tree["prompt"])

            # Build the Conversations
            for thread in threads:
                if len(thread) == 0:
                    continue
                conv = DataConversation(
                    id=dict_tree["message_tree_id"],
                    origin="openassistant",
                    language="en",
                    messages=thread,
                )
                conversation_list.append(conv)

    conv_set = ConversationSet(conversation_list)
    conv_json_export = conv_set.toJSON()
    if not os.path.exists("data/open_assistant"):
        os.makedirs("data/open_assistant")
    with open("data/open_assistant/oa_en.json", "w") as f:
        f.write(conv_json_export)
