import json
import os

from typing import List, Tuple


def read_data(data_folder: str) -> Tuple[List, List, List]:
    """Reads and returns the training and test data."""
    # Read the training data from the local path
    train_data = \
        json.load(open(os.path.join(data_folder, "ade_re_split_0_train.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_1_train.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_2_train.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_3_train.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_4_train.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_5_train.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_6_train.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_7_train.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_8_train.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_9_train.json"), "r"))

    # Read the test data from the local path
    test_data = \
        json.load(open(os.path.join(data_folder, "ade_re_split_0_test.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_1_test.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_2_test.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_3_test.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_4_test.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_5_test.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_6_test.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_7_test.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_8_test.json"), "r")) + \
        json.load(open(os.path.join(data_folder, "ade_re_split_9_test.json"), "r"))

    # Read the relation types from the local path
    rel_types = json.load(open(os.path.join(data_folder, "ade_types.json"), "r"))

    return train_data, test_data, rel_types


def convert_to_unified(data, rel_types, split):
    """Converts the data to the unified format."""
    # Define the prompt of the LLM
    instructions = \
        "Please extract the relational triplet in the form (subject; relationship; object) from the given text. " + \
        "If there is no triplet, please answer \"NA\". The set of relationships is as follows: " + \
        str(list(rel_types["relations"].keys())) + "."
    input_prefix = "Text: "
    output_prefix = "Answer: "

    # Construct the structure of the processed data
    unified_data = {
        "prompt": {
            "instructions": instructions,
            "input_prefix": input_prefix,
            "input_suffix": "\n",
            "output_prefix": output_prefix,
            "output_suffix": "\n"
        },
        "request_states": list()
    }

    # Convert the original dataset into the unified format
    for one_data in data:
        # Obtain the annotation of a piece of data
        tokens = one_data["tokens"]
        entities = one_data["entities"]
        relations = one_data["relations"]

        # Obtain the tokens of the annotated entities
        for entity in entities:
            entity["tokens"] = " ".join(tokens[entity["start"]:entity["end"]])

        # Obtain the triple for the relations
        triples = list()
        for relation in relations:
            triple = [entities[relation["head"]]["tokens"], relation["type"],  entities[relation["tail"]]["tokens"]]
            triples.append("(" + "; ".join(triple) + ")")

        # Add the unified annotation to the processed data
        unified_data["request_states"].append({
            "instance": {
                "input": {
                    "text": " ".join(tokens)
                },
                "references": [{
                    "output": {
                        "text": "\n".join(triples)
                    }
                }],
                "split": split,
                "id": one_data["orig_id"]
            },
            "request": dict()
        })

    return unified_data


if __name__ == "__main__":
    # Read the ADE dataset from the local path
    train_data, test_data, rel_types = read_data("../../deepstruct/src/data/ade_re")

    # Convert the ADE dataset into the unified format
    train_unified = convert_to_unified(train_data, rel_types, "train")
    test_unified = convert_to_unified(test_data, rel_types, "test")

    # Write the processed data back to the local path
    json.dump(train_unified, open("ADE/ade.train.json", "w"), indent=4)
    json.dump(test_unified, open("ADE/ade.test.json", "w"), indent=4)
