import json
import os

from typing import List, Tuple


def read_data(data_folder: str) -> Tuple[List, List, List, List]:
    """Reads and returns the training, validation, and test data."""
    with open(os.path.join(data_folder, "conll04_re_train.json"), "r") as f:
        train_data = json.load(f)
    with open(os.path.join(data_folder, "conll04_re_dev.json"), "r") as f:
        valid_data = json.load(f)
    with open(os.path.join(data_folder, "conll04_re_test.json"), "r") as f:
        test_data = json.load(f)
    with open(os.path.join(data_folder, "conll04_types.json"), "r") as f:
        rel_types = json.load(f)

    return train_data, valid_data, test_data, rel_types


def convert_to_unified(data, rel_types, split):
    """Converts the data to the unified format."""
    # Define the prompt of the LLM
    instructions = \
        "Please extract the relational triplet in the form (subject; relationship; object) from the given text. " + \
        "If there is no triplet, please answer \"NA\". The set of relationships is as follows: " + \
        str(list(rel_types["relations"].keys())) + "."
    input_prefix = "Text: "
    output_prefix = "Answer: "

    # Construct the structure of the processed data
    unified_data = {
        "prompt": {
            "instructions": instructions,
            "input_prefix": input_prefix,
            "input_suffix": "\n",
            "output_prefix": output_prefix,
            "output_suffix": "\n"
        },
        "request_states": list()
    }

    # Convert the original dataset into the unified format
    for one_data in data:
        # Obtain the annotation of a piece of data
        tokens = one_data["tokens"]
        entities = one_data["entities"]
        relations = one_data["relations"]

        # Obtain the tokens of the annotated entities
        for entity in entities:
            entity["tokens"] = " ".join(tokens[entity["start"]:entity["end"]])

        # Obtain the triple for the relations
        triples = list()
        for relation in relations:
            triple = [entities[relation["head"]]["tokens"], relation["type"],  entities[relation["tail"]]["tokens"]]
            triples.append("(" + "; ".join(triple) + ")")

        # Add the unified annotation to the processed data
        unified_data["request_states"].append({
            "instance": {
                "input": {
                    "text": " ".join(tokens)
                },
                "references": [{
                    "output": {
                        "text": "\n".join(triples)
                    }
                }],
                "split": split,
                "id": str(one_data["orig_id"])
            },
            "request": dict()
        })

    return unified_data


if __name__ == "__main__":
    # Read the CoNLL04 dataset from the local path
    train_data, valid_data, test_data, rel_types = read_data("../../deepstruct/src/data/conll04_re")

    # Convert the CoNLL04 dataset into the unified format
    train_unified = convert_to_unified(train_data, rel_types, "train")
    valid_unified = convert_to_unified(valid_data, rel_types, "dev")
    test_unified = convert_to_unified(test_data, rel_types, "test")

    # Write the processed data back to the local path
    json.dump(train_unified, open("CoNLL04/conll04.train.json", "w"), indent=4)
    json.dump(valid_unified, open("CoNLL04/conll04.valid.json", "w"), indent=4)
    json.dump(test_unified, open("CoNLL04/conll04.test.json", "w"), indent=4)
