import os
import json
import random
from pathlib import Path
import copy





def get_rel(rel):
    res="Please follow the above demonstration, extract relations from the [Question text].\nNote the relation need to be in the predefined set of relations.\nThe output format required to is the same as the demonstration, format:(<entity_ID>, relation, <entity_ID>).\n"
    res+="The predefined set of relations:\n{"
    for key,value in rel.items():
        res+=value+","
    res+="}\n"
    return res


def get_context_all(ALL_data,num):
    data1= copy.deepcopy(ALL_data[num])
    for i in range(0,len(data1['vertexSet'])):
        for node in data1['vertexSet'][i]:
            sent_id=node['sent_id']
            start=node['pos'][0]
            end=node['pos'][1]
            data_list=list(data1['sents'][sent_id])
            data_list.insert(start,'entity'+str(i))
            # data_list.insert(end+1,'</entity_'+str(i)+'>')
            data1['sents'][sent_id]=data_list
            node['pos'][0]+=1
            node['pos'][1]+=1
            for j in range(i,len(data1['vertexSet'])):
                for n in data1['vertexSet'][j]:
                    if n['sent_id']==sent_id:
                        if n['pos'][0]>start:
                            n['pos'][0]+=1
                            n['pos'][1]+=1
    content_new=""
    for sen in data1['sents']:
        for word in sen:
            content_new=" ".join([content_new,word])
        # content_new=" ".join([content_new,'\n'])
    # res="Title:"+data1['title']+"\n"
    # res+="Text:"+content_new+"\n"
    # res+="Relations in the predefined set of relations in the above text:?\n"
    return content_new

#Demonstration Relation Generate
def get_Demonstration(ALL_data,num,rel):
    data=copy.deepcopy(ALL_data[num])
    res=[]
    labels = sorted(data["labels"], key=lambda item: (item["h"], item["t"]))
    for label in labels:
        r_id=label['r']
        h_id=label['h']
        t_id=label['t']
        res.append("(entity"+str(h_id) + "; " + rel[r_id] + "; " + "entity"+str(t_id)+")")
    res = " | ".join(res)
    return res


def convert_format(input_path, rel2id_path, output_folder, split, sample_data=True):
    with open(input_path, "r", encoding='utf-8') as reader:
        data = json.load(reader)
        reader.close()

    with open(rel2id_path, "r", encoding='utf-8') as reader:
        rel = json.load(reader)
        reader.close()

    # convert data
    unified_data = {
        "prompt": {
            # "instructions": get_rel(rel),
            "instructions": "Document-Level Relation Extraction: ",
            "input_prefix": "",
            "input_suffix": "\n",
            "output_prefix": "",
            "output_suffix": "\n",
        },
        "request_states": [
        ]
    }

    for instance in data:
        instance['relation_num'] = len(instance['labels'])
        s = set()
        for node in instance['labels']:
            s.add(node['r'])
        instance['relation_type_num'] = len(s)
        s.clear()


    for i, instance in enumerate(data):
        # text=""
        # output=""
        # if split=="train":
        #     if data[i]['relation_num'] > 10 and data[i]['relation_num'] < 15:
        #         if data[i]['relation_type_num'] > 5 and data[i]['relation_type_num'] < 7:
        #             text=get_context_all(data,i)
        #             output=get_Demonstration(data, i,rel)
        #         else:
        #             continue
        #     else:
        #         continue
        # else:
        text = get_context_all(data, i)
        output = get_Demonstration(data, i, rel)

        # output
        unified_instance = {
            "instance": {
                "input": {
                    "text": text
                },
                "references": [
                    {
                        "output": {
                            "text": output
                        },
                    }
                ],
                "split": split,
                "id": i
            },

            "request": {
                "result": {
                    "completions": [
                        {
                            "text": "",
                        }
                    ],
                },
                "request_time": 1.622053623199463,
                "request_datetime": 1669584580
            }
        }

        unified_data["request_states"].append(unified_instance)

    if sample_data and split == "train":
        unified_data["request_states"] = random.sample(unified_data["request_states"], k=4)

    # random sample 1k samples
    if sample_data and split in ["dev", "test"]:
        unified_data["request_states"] = random.sample(unified_data["request_states"], k=200)

    # dump json
    print(len(unified_data["request_states"]))
    json.dump(unified_data, open(os.path.join(output_folder, f"{split}.json"), "w"), indent=4)


if __name__ == "__main__":
    output_folder = Path("../../../unified_data/RC/docred")
    output_folder.mkdir(exist_ok=True, parents=True)
    convert_format(os.path.join("../../../data/docred/train_annotated.json"), "../../../data/docred/rel_info.json", output_folder, "train", False)
    convert_format(os.path.join("../../../data/docred/dev.json"), "../../../data/docred/rel_info.json", output_folder, "dev", False)
    convert_format(os.path.join("../../../data/docred/dev_test.json"), "../../../data/docred/rel_info.json", output_folder, "test", False)

