import json
import jsonlines
import os

from typing import List, Tuple


def read_data(data_folder: str) -> Tuple[List, List, List, List]:
    """Reads and returns the training, validation, and test data."""
    # Read the training, validation, and testing data from the local path
    train_data = list(jsonlines.open(os.path.join(data_folder, "train.json"), "r"))
    valid_data = list(jsonlines.open(os.path.join(data_folder, "dev.json"), "r"))
    test_data = list(jsonlines.open(os.path.join(data_folder, "test.json"), "r"))

    # Generate a list of the relation types
    rel_types = set()
    for one_data in list(train_data + valid_data + test_data):
        for relations in one_data["relations"]:
            for relation in relations:
                rel_types.add(relation[-1])

    return train_data, valid_data, test_data, list(rel_types)


def convert_to_unified(data, rel_types, split):
    """Converts the data to the unified format."""
    # Define the prompt of the LLM
    instructions = \
        "Please extract the relational triplet in the form (subject; relationship; object) from the given text. " + \
        "If there is no triplet, please answer \"NA\". The set of relationships is as follows: " + \
        str(rel_types) + "."
    input_prefix = "Text: "
    output_prefix = "Answer: "

    # Construct the structure of the processed data
    unified_data = {
        "prompt": {
            "instructions": instructions,
            "input_prefix": input_prefix,
            "input_suffix": "\n",
            "output_prefix": output_prefix,
            "output_suffix": "\n"
        },
        "request_states": list()
    }

    # Convert the original dataset into the unified format
    for one_data in data:
        # Obtain the tokens and sentences of one piece of data
        tokens = [token for sentence in one_data["sentences"] for token in sentence]
        sentences = [" ".join(sentence) for sentence in one_data["sentences"]]

        # Obtain the triple for the relations
        triples = list()
        for i, relations in enumerate(one_data["relations"]):
            for relation in relations:
                # Obtain the head and tail entities
                head_ent = " ".join(tokens[relation[0]:relation[1] + 1])
                tail_ent = " ".join(tokens[relation[2]:relation[3] + 1])
                triple = [head_ent, relation[-1], tail_ent]
                triples.append("(" + "; ".join(triple) + ")")

        # Add the unified annotation to the processed data
        unified_data["request_states"].append({
            "instance": {
                "input": {
                    "text": "\n".join(sentences)
                },
                "references": [{
                    "output": {
                        "text": "\n".join(triples)
                    }
                }],
                "split": split,
                "id": one_data["doc_key"]
            },
            "request": dict()
        })

    return unified_data


if __name__ == "__main__":
    # Read the ACE2005 dataset from the local path
    train_data, valid_data, test_data, rel_types = read_data("../../deepstruct/src/data/ace2005_joint_er_re")

    # Convert the ACE2005 dataset into the unified format
    train_unified = convert_to_unified(train_data, rel_types, "train")
    valid_unified = convert_to_unified(valid_data, rel_types, "dev")
    test_unified = convert_to_unified(test_data, rel_types, "test")

    # Write the processed data back to the local path
    json.dump(train_unified, open("ACE2005/ace2005.train.json", "w"), indent=4)
    json.dump(valid_unified, open("ACE2005/ace2005.valid.json", "w"), indent=4)
    json.dump(test_unified, open("ACE2005/ace2005.test.json", "w"), indent=4)
