import os
import json
import random
from pathlib import Path


def random_sample_data(data, k=8):
    rel_data = [instance for instance in data if instance["relation"] != "no_relation"]
    rel_instances = random.sample(rel_data, k=k-2)
    na_data = [instance for instance in data if instance["relation"] == "no_relation"]
    na_instances = random.sample(na_data, k=2)
    instances = rel_instances + na_instances
    random.shuffle(instances)
    return instances


def sample_uniform_relaton():
    pass


def mark_entity(tokens, h_pos, t_pos, h_type, t_type):
    text = ""
    for i, token in enumerate(tokens):
        if i == h_pos[0]:
            text += "<head>" + " " + h_type.lower() + " "
        if i == t_pos[0]:
            text += "<tail>" + " " + t_type.lower() + " "
        if i not in range(h_pos[0], h_pos[1]) and i not in range(t_pos[0], t_pos[1]):
            if token == "-LRB-":
                token = "("
            if token == "-RRB-":
                token = ")"
            text += token + " "
        if i == h_pos[1] - 1:
            text += "</head>" + " "
        if i == t_pos[1] - 1:
            text += "</tail>" + " "
    text = text.strip()
    return text


def convert_format(input_path, rel2id_path, output_folder, split, random_sample=True):
    data = []
    with open(input_path) as f:
        for line in f.readlines():
            instance = json.loads(line.strip())
            data.append(instance)
    
    # train
    if random_sample and split == "train":
        data = random_sample_data(data)

    # rel2id
    rel2id = json.load(open(rel2id_path))
    relations = [rel for rel in list(rel2id.keys())]
    assert len(set(relations)) == len(rel2id)

    # prompt
    instructions = "Please classify relationships between the two entities (marked with <entity> and </entity>). " + \
        "If the two entities have no relationships, please answer NA. " + \
        "The set of relationships is as follows: [{}]. ".format(", ".join(relations))
    input_prefix = "Text: "
    output_prefix = ""

    # convert data
    unified_data = {
        "prompt": {
            "instructions": "Relation Classification: ",
            "input_prefix": input_prefix, 
            "input_suffix": "\n",
            "output_prefix": output_prefix,
            "output_suffix": "\n",
        },
        "request_states": [
        ]
    }

    for i, instance in enumerate(data):
        text = mark_entity(instance["token"], instance["h"]["pos"], instance["t"]["pos"], instance["h"]["type"], instance["t"]["type"])
        relation = instance["relation"]
        if relation != "no_relation":
            output = "({}; {}; {})".format("<head>", relation, "<tail>")
        else:
            output = ""

        # output
        unified_instance = {
            "instance": {
                "input": {
                    "text": text
                },
                "references": [
                {
                    "output": {
                        "text": output
                    },
                }
                ],
                "split": split,
                "id": i
            },
        
            "request":{
                "result": {
                "completions": [
                    {
                        "text": "",
                    }
                    ],
                },
                "request_time": 1.622053623199463,
                "request_datetime": 1669584580
            }
        } 
        
        unified_data["request_states"].append(unified_instance)
    
    # random sample 1k samples
    if random_sample and split in ["dev", "test"]:
        unified_data["request_states"] = random.sample(unified_data["request_states"], k=1000)

    # dump json
    json.dump(unified_data, open(os.path.join(output_folder, f"{split}.json"), "w"), indent=4)



if __name__ == "__main__":
    output_folder = Path("../../../unified_data/RC/tacred")
    output_folder.mkdir(exist_ok=True, parents=True)
    convert_format(os.path.join("../../../data/tacred/train.txt"), "../../../data/tacred/rel2id.json", output_folder, "train", False)
    convert_format(os.path.join("../../../data/tacred/dev.txt"), "../../../data/tacred/rel2id.json", output_folder, "dev", False)
    convert_format(os.path.join("../../../data/tacred/test.txt"), "../../../data/tacred/rel2id.json", output_folder, "test", False)

