import os
import jsonlines
import random
import json
from pathlib import Path


def convert_format(input_path, output_folder, split, random_sample=True):
    data = []
    with open(input_path, 'r') as f:
        for line in jsonlines.Reader(f):
            data.append(line)

    # train
    if random_sample and split == "train":
        data = random.sample(data, 8)

    # label2id
    events = set()
    for sample in data:
        for event in sample['events']:
            events.add(event['type'])

    # prompt
    instructions = 'Please identify the events in the text ' + \
    'and classify them into appropriate categories; ' + \
    f'The collection of categories is [{", ".join(events)}]'
    input_prefix = "Text: "
    output_prefix = "Answer: "

    # convert data
    unified_data = {
        "prompt": {
            "instructions": "Event Detection: ",
            "input_prefix": input_prefix, 
            "input_suffix": "\n",
            "output_prefix": output_prefix,
            "output_suffix": "\n",
        },
        "request_states": [
        ]
    }

    count = 0
    for instance in data:
        num_sent = len(instance['content'])
        events = [[] for _ in range(num_sent)]
        for event in instance['events']:
            for mention in event['mention']:
                events[mention['sent_id']].append([mention['trigger_word'], event['type'], mention["offset"]])
        for i in range(len(events)):
            events[i] = sorted(events[i], key=lambda item: item[2][0])
        # output
        for i, sent in enumerate(instance['content']):
            unified_instance = {
                "instance": {
                    "input": {
                        "text": sent['sentence']
                    },
                    "references": [
                    {
                        "output": {
                            # "text": ';'.join([':'.join(e) for e in events[i]])
                            "text": ' | '.join([f"({e[0]}; is; {e[1]})" for e in events[i]])
                        },
                    }
                    ],
                    "split": split,
                    "id": i+count
                },
            
                "request":{
                    "result": {
                    "completions": [
                        {
                            "text": "",
                        }
                        ],
                    },
                    "request_time": 1.622053623199463,
                    "request_datetime": 1669584580
                }
            } 
            
            unified_data["request_states"].append(unified_instance)
        count += num_sent

    if random_sample:
        # random sample 1k samples
        if split == "test" or "dev":
            num_sample = 1000
        if split == "train":
            num_sample = 8
        unified_data["request_states"] = random.sample(unified_data["request_states"], k=min(num_sample, len(unified_data["request_states"])))

    # dump json
    print(len(unified_data["request_states"]))
    json.dump(unified_data, open(output_folder.joinpath(f"{split}.json"), "w"), indent=4)



if __name__ == "__main__":
    output_folder = Path("../../unified_data/ED/maven")
    output_folder.mkdir(exist_ok=True, parents=True)
    convert_format("../../data/maven/train.jsonl", output_folder, "train", random_sample=False)
    convert_format("../../data/maven/test_gold.jsonl", output_folder, "test", random_sample=False)
    convert_format("../../data/maven/valid.jsonl", output_folder, "dev", random_sample=False)

