import os
import json
import random
from pathlib import Path

def find_schema(data):
    schema = {}
    for sample in data:
        for event in sample['events']:
            if event['type'] not in schema:
                schema[event['type']] = set()
            for trigger in event['triggers']:
                for arg in trigger['arguments']:
                    schema[event['type']].add(arg['role'])
    return schema
                
def mark_event(input_text, h_pos, t_pos, markers=["<event>", "</event>"]):
    text = ""
    for i, ch in enumerate(input_text):
        if i == h_pos:
            text += f"{markers[0]} "
        if i == t_pos:
            text += f" {markers[1]}"
        text += ch
    if t_pos == len(input_text):
        text += f" {markers[1]}"
    return text 

def format_sample(sample, schema):
    inputs, outputs = [], []
    for event in sample['events']:
        for trigger in event['triggers']:
            text = mark_event(sample['text'], trigger['position'][0], trigger['position'][1])
            text += f"\nRoleset: [{', '.join(schema[event['type']])}]"
            inputs.append(text)
            roles = []
            for arg in trigger['arguments']:
                for mention in arg['mentions']:
                    roles.append((mention['mention'], arg['role'], mention["position"]))
            # outputs.append(';'.join([':'.join(role) for role in roles]))
            roles = sorted(roles, key=lambda item: item[2][0])
            outputs.append(' | '.join([f"(<event>; {role[1]}; {role[0]})" for role in roles]))
    return inputs, outputs

def convert_format(data, schema, output_folder, split, random_sample=True):

    # prompt
    instructions = 'Please extract event arguments and their roles for the events ' + \
    'marked with <event> and </event> in the text, the possible roles must be chosen ' + \
    'from the Roleset. If there is no roles for the event, place output "NA".'
    input_prefix = "Text: "
    output_prefix = "Answer: "

    # convert data
    unified_data = {
        "prompt": {
            "instructions": "Event Argument Extraction: ",
            "input_prefix": input_prefix, 
            "input_suffix": "\n",
            "output_prefix": output_prefix,
            "output_suffix": "\n",
        },
        "request_states": [
        ]
    }

    count = 0
    for instance in data:
        inputs, outputs = format_sample(instance, schema)
        # output
        for i, (ipt, opt) in enumerate(zip(inputs, outputs)):
            unified_instance = {
                "instance": {
                    "input": {
                        "text": ipt
                    },
                    "references": [
                    {
                        "output": {
                            "text": opt
                        },
                    }
                    ],
                    "split": split,
                    "id": i+count
                },
            
                "request":{
                    "result": {
                    "completions": [
                        {
                            "text": "",
                        }
                        ],
                    },
                    "request_time": 1.622053623199463,
                    "request_datetime": 1669584580
                }
            } 
            unified_data["request_states"].append(unified_instance)
        count += len(inputs)
    
    # random sample 1k samples
    if random_sample:
        if split == "test" or "dev":
            num_sample = 1000
        if split == "train":
            num_sample = 8
        unified_data["request_states"] = random.sample(unified_data["request_states"], k=min(num_sample, len(unified_data["request_states"])))

    # dump json
    print(len(unified_data["request_states"]))
    json.dump(unified_data, open(output_folder.joinpath(f"{split}.json"), "w"), indent=4)

if __name__ == '__main__':
    input_folder = Path('../../data/RichERE')
    output_folder = Path('../../unified_data/EAE/RichERE')
    output_folder.mkdir(exist_ok=True, parents=True)
    train = input_folder.joinpath('train.json')
    test = input_folder.joinpath('test.json')
    dev = input_folder.joinpath('valid.json')
    with open(train, 'r') as f:
        train_data = json.load(f)
        schema = find_schema(train_data)
    convert_format(train_data, schema, output_folder, 'train', False)
    convert_format(json.load(open(test)), schema, output_folder, 'test', False)
    convert_format(json.load(open(dev)), schema, output_folder, 'dev', False)