import json
import pytreebank
import os
from pathlib import Path

label2id = {
    "0": "very negative",
    "1": "negative",
    "2": "neutral",
    "3": "positive",
    "4": "very positive"
}

split2id = {
    "train": "1",
    "test": "2",
    "dev": "3"
}


def convert_to_table(input_folder, output_folder, source_folder):
    """Converts the SST-5 dataset into table format."""
    # Load the SST-5 dataset from the local path
    dataset = pytreebank.load_sst(input_folder)
    # Convert and save the SST-5 dataset
    output_path = os.path.join(output_folder, "sst_{}.txt")
    for split in ["train", "dev", "test"]:
        with open(output_path.format(split), "w") as file:
            for item in dataset[split]:
                file.write("{}\t{}\n".format(
                    label2id[str(item.to_labeled_lines()[0][0])],
                    item.to_labeled_lines()[0][1]
                ))

    # Read the training, validation, and testing data
    train_data, valid_data, test_data = \
        open(os.path.join(output_folder, "sst_train.txt"), "r").readlines(), \
        open(os.path.join(output_folder, "sst_dev.txt"), "r").readlines(), \
        open(os.path.join(output_folder, "sst_test.txt"), "r").readlines()
    # Read the data split of the dataset
    data_split = open(os.path.join(source_folder, "datasetSplit.txt"), "r").readlines()

    return train_data, valid_data, test_data, data_split


def convert_to_unified(data, data_split, split):
    """Converts the SST-5 dataset into unified format."""
    # Define the prompt of the LLM
    instructions = str()
    input_prefix = "Text: "
    output_prefix = "Answer: "

    # Construct the structure of the processed data
    unified_data = {
        "prompt": {
            "instructions": "Sentiment Classification: ",
            "input_prefix": input_prefix,
            "input_suffix": "\n",
            "output_prefix": output_prefix,
            "output_suffix": "\n"
        },
        "request_states": list()
    }

    # Convert the SST-5 dataset into the unified format
    for line in data:
        # Obtain the sentiment and text of the data
        sentiment, text = line.strip("\n").split("\t")
        # Add the piece of data into the unified data
        unified_data["request_states"].append({
            "instance": {
                "input": {
                    "text": text
                },
                "references": [{
                    "output": {
                        "text": f"(emotion; is; {sentiment})"
                    }
                }],
                "split": split,
                "id": "NA"
            },
            "request": dict()
        })

    # Obtain the sentence ids of the data split
    sent_ids = [line.split(",")[0] for line in data_split if line.strip("\n").split(",")[1] == split2id[split]]
    assert len(sent_ids) == len(unified_data["request_states"])
    # Assign sentence ids for the unified data
    for i, one_data in enumerate(unified_data["request_states"]):
        one_data["instance"]["id"] = sent_ids[i]

    print(len(unified_data["request_states"]))
    return unified_data


if __name__ == "__main__":
    # Read and convert SST-5 into table format
    input_folder = "../../data/sst-5"
    train_data, valid_data, test_data, data_split = convert_to_table(
                                os.path.join(input_folder, "trees"), 
                                os.path.join(input_folder, "sst-5"), 
                                os.path.join(input_folder, "stanfordSentimentTreebank")
                            )

    # Convert the SST-5 dataset into unified format
    train_unified = convert_to_unified(train_data, data_split, "train")
    valid_unified = convert_to_unified(valid_data, data_split, "dev")
    test_unified = convert_to_unified(test_data, data_split, "test")

    # Save the processed data to the local path
    output_folder = Path("../../unified_data/SC/sst-5")
    output_folder.mkdir(parents=True, exist_ok=True)
    json.dump(train_unified, open(os.path.join(output_folder, "train.json"), "w"), indent=4)
    json.dump(valid_unified, open(os.path.join(output_folder, "dev.json"), "w"), indent=4)
    json.dump(test_unified, open(os.path.join(output_folder, "test.json"), "w"), indent=4)
