import os
import json
import random
from pathlib import Path


def random_sample(data, k=8):
    rel_data = [instance for instance in data]
    rel_instances = random.sample(rel_data, k)
    instances = rel_instances
    random.shuffle(instances)
    return instances


def sample_uniform_relaton():
    pass


def mark_entity(tokens, h_pos, t_pos):
    text = ""
    id=(int)(random.random()>=0.5)
    for i, token in enumerate(tokens):
        for v in h_pos:
            if i == v[0]:
                text += "<head> "
                break
        for v in t_pos:
            if i == v[0]:
                text += "<tail> "
                break
        text += token + " "
        for v in h_pos:
            l=len(v)-1
            if i == v[l]:
                text += "</head> "
                break
        for v in t_pos:
            l = len(v) - 1
            if i == v[l]:
                text += "</tail> "
                break
    text = text.strip()
    return text,id

def convert_usual_format(data,rel2id,flag):
    d=[]
    for key,value in data.items():
        for instance in value:
            if flag==0:
                r=rel2id[key][0]
            else:
                r=key
            instance["relation"]=r
            d.append(instance)
    return d


def convert_format(input_path, rel2id_path, output_folder, split, sample_data=True):
    with open(input_path, "r", encoding='utf-8') as reader:
        d = json.load(reader)
        reader.close()

    # rel2id
    rel2id = json.load(open(rel2id_path))
    relations=[]
    if split == "test_wiki":
        for key,value in d.items():
            r=rel2id[key][0]
            relations.append(r)
        print(relations)
    if split == "test_pubmed":
        relations = [dd for dd in list(d.keys())]
        print(relations)


    data= convert_usual_format(d,rel2id,split=="test_pubmed" or split=="val_pubmed")

    # train
    if sample_data and split == "train":
        data = random_sample(data)


    # prompt

    instructions = "Please classify the relationship between the given two entities. \n" + \
                "In the text, we use \"<entity_ID> NAME </entity_ID>\" to label entities, where entity_ID is the unique id of the entity, and those with the same entity_ID represent the same entity.\n"+\
                "Note the relation need to be in the predefined set of relations.\n"+\
                "The output format required to is the same as the demonstration(entity_ID, relation, entity_ID).\n"+\
                "The predefined set of relations is as follows: [{}]. ".format(", ".join(relations))
    input_prefix = "Text: "
    output_prefix = ""

    # convert data
    unified_data = {
        "prompt": {
            "instructions": "Relation Classification: ",
            "input_prefix": input_prefix,
            "input_suffix": "\n",
            "output_prefix": output_prefix,
            "output_suffix": "\n",
        },
        "request_states": [
        ]
    }

    for i, instance in enumerate(data):
        text,id = mark_entity(instance["tokens"], instance["h"][2], instance["t"][2])
        relation = instance["relation"]
        if relation != "no_relation": #############################
            output = "({}; {}; {})".format("<head>", relation, "<tail>")
        else:
            output = ""

        # output
        unified_instance = {
            "instance": {
                "input": {
                    "text": text
                },
                "references": [
                    {
                        "output": {
                            "text": output
                        },
                    }
                ],
                "split": split,
                "id": i
            },

            "request": {
                "result": {
                    "completions": [
                        {
                            "text": "",
                        }
                    ],
                },
                "request_time": 1.622053623199463,
                "request_datetime": 1669584580
            }
        }

        unified_data["request_states"].append(unified_instance)

    # random sample 1k samples
    if sample_data and split in ["test_pubmed", "test_wiki"]:
        unified_data["request_states"] = random.sample(unified_data["request_states"], k=1000)

    # dump json
    print(len(unified_data["request_states"]))
    json.dump(unified_data, open(os.path.join(output_folder, f"{split}.json"), "w"), indent=4)


if __name__ == "__main__":
    output_folder = Path("../../../unified_data/RC/fewrel")
    output_folder.mkdir(exist_ok=True, parents=True)
    convert_format(os.path.join("../../../data/fewrel/train_wiki.json"), "../../../data/fewrel/pid2name.json", output_folder, "train", False)
    #fewrel1.0
    convert_format(os.path.join("../../../data/fewrel/val_wiki.json"), "../../../data/fewrel/pid2name.json", output_folder, "val_wiki", False)
    convert_format(os.path.join("../../../data/fewrel/test_wiki.json"), "../../../data/fewrel/pid2name.json", output_folder, "test_wiki", False)
    #fewrel2.0
    convert_format(os.path.join("../../../data/fewrel/val_pubmed.json"), "../../../data/fewrel/pid2name.json", output_folder, "val_pubmed", False)
    convert_format(os.path.join("../../../data/fewrel/test_pubmed.json"), "../../../data/fewrel/pid2name.json", output_folder, "test_pubmed", False)
