import os
import json
import random
from pathlib import Path


def convert_format(input_path, label2id_path, output_folder, split, random_sample=True):
    data = []
    with open(input_path) as f:
        for line in f.readlines():
            data.append(json.loads(line.strip()))

    # train
    if random_sample and split == "train":
        data = random.sample(data, 8)

    # label2id
    label2id = json.load(open(label2id_path))
    events = [e for e in list(label2id.keys())]
    assert len(set(events)) == len(events)
    events.remove('NA')

    # prompt
    instructions = 'Please identify the events in the text ' + \
    'and classify them into appropriate categories; ' + \
    f'The collection of categories is [{", ".join(events)}]'
    input_prefix = "Text: "
    output_prefix = "Answer: "

    # convert data
    unified_data = {
        "prompt": {
            "instructions": "Event Detection: ",
            "input_prefix": input_prefix, 
            "input_suffix": "\n",
            "output_prefix": output_prefix,
            "output_suffix": "\n",
        },
        "request_states": [
        ]
    }

    for i, instance in enumerate(data):
        events = []
        for event in instance['events']:
            for trigger in event['triggers']:
                events.append([trigger['trigger_word'], event['type'], trigger["position"]])
        events = sorted(events, key=lambda item: item[2][0])
        # output
        unified_instance = {
            "instance": {
                "input": {
                    "text": instance['text']
                },
                "references": [
                {
                    "output": {
                        # "text": ';'.join([':'.join(e) for e in events])
                        "text": ' | '.join([f"({e[0]}; is; {e[1]})" for e in events])
                    },
                }
                ],
                "split": split,
                "id": i
            },
        
            "request":{
                "result": {
                "completions": [
                    {
                        "text": "",
                    }
                    ],
                },
                "request_time": 1.622053623199463,
                "request_datetime": 1669584580
            }
        } 
        
        unified_data["request_states"].append(unified_instance)
    
    # random sample 1k samples
    if random_sample and split in ["dev", "test"]:
        unified_data["request_states"] = random.sample(unified_data["request_states"], k=min(1000, len(data)))

    # dump json
    print(len(unified_data["request_states"]))
    json.dump(unified_data, open(output_folder.joinpath(f"{split}.json"), "w"), indent=4)



if __name__ == "__main__":
    output_folder = Path("../..//unified_data/ED/ace2005-dygie")
    output_folder.mkdir(exist_ok=True, parents=True)
    convert_format("../../data/ace2005-dygie/train.unified.jsonl", "../../data/ace2005-dygie/label2id.json", output_folder, "train", False)
    convert_format("../../data/ace2005-dygie/valid.unified.jsonl", "../../data/ace2005-dygie/label2id.json", output_folder, "dev", False)
    convert_format("../../data/ace2005-dygie/test.unified.jsonl", "../../data/ace2005-dygie/label2id.json", output_folder, "test", False)

