import json
import os

from typing import List, Tuple


def read_data(data_folder: str) -> Tuple[List, List, List, List]:
    """Reads and returns the training, validation, and test data."""
    with open(os.path.join(data_folder, "train.json"), "r") as f:
        train_data = json.load(f)
    with open(os.path.join(data_folder, "dev.json"), "r") as f:
        valid_data = json.load(f)
    with open(os.path.join(data_folder, "test.json"), "r") as f:
        test_data = json.load(f)
    with open(os.path.join(data_folder, "schemas.json"), "r") as f:
        schemas = json.load(f)

    return train_data, valid_data, test_data, schemas


def convert_to_unified(data, schemas, split):
    """Converts the data to the unified format."""
    # Define the prompt of the LLM
    instructions = \
        "Please extract the relational triplet in the form (subject; relationship; object) from the given text. " + \
        "If there is no triplet, please answer \"NA\". The set of relationships is as follows: " + \
        str(list(schemas[1].keys())) + "."
    input_prefix = "Text: "
    output_prefix = "Answer: "

    # Construct the structure of the processed data
    unified_data = {
        "prompt": {
            "instructions": instructions,
            "input_prefix": input_prefix,
            "input_suffix": "\n",
            "output_prefix": output_prefix,
            "output_suffix": "\n"
        },
        "request_states": list()
    }

    # Convert the original dataset into the unified format
    unified_id = -1
    for one_data in data:
        # Obtain the annotation of a piece of data
        tokens = one_data["tokens"]
        relations = one_data["spo_list"]
        # Obtain the triple for the relations
        triples = ["(" + "; ".join(relation) + ")" for relation in relations]

        # Add the unified annotation to the processed data
        unified_id += 1
        unified_data["request_states"].append({
            "instance": {
                "input": {
                    "text": " ".join(tokens)
                },
                "references": [{
                    "output": {
                        "text": "\n".join(triples)
                    }
                }],
                "split": split,
                "id": str(unified_id)
            },
            "request": dict()
        })

    return unified_data


if __name__ == "__main__":
    # Read the NYT dataset from the local path
    train_data, valid_data, test_data, schemas = read_data("../../deepstruct/src/data/nyt_re")

    # Convert the NYT dataset into the unified format
    train_unified = convert_to_unified(test_data, schemas, "train")
    valid_unified = convert_to_unified(valid_data, schemas, "dev")
    test_unified = convert_to_unified(test_data, schemas, "test")

    # Write the processed data back to the local path
    json.dump(train_unified, open("NYT/nyt.train.json", "w"), indent=4)
    json.dump(valid_unified, open("NYT/nyt.valid.json", "w"), indent=4)
    json.dump(test_unified, open("NYT/nyt.test.json", "w"), indent=4)
