import os
import json
from pathlib import Path


def gen_prompt_doc(doc):
    events = doc['events']
    tokens = doc['text']
    event_num2id = {}
    event_idx = 0
    event_names = []
    last_sent_id = 0
    last_offset = 0
    tokens = [[toks] for toks in tokens]

    for event in events:
        if 'eiid' not in event.keys():
            continue
        event_num2id[event_idx] = event['eiid']
        sent_id = event['sent_id']
        offset = event['offset']
        assert sent_id >= last_sent_id
        if sent_id == last_sent_id:
            assert offset[0] >= last_offset
        event_name = f'event{event_idx} {tokens[sent_id][0][offset[0]:offset[1]]}'
        event_names.append(event_name)
        if sent_id > last_sent_id:
            tokens[last_sent_id].append(tokens[last_sent_id][0][last_offset:])
            tokens[sent_id].append(f'{tokens[sent_id][0][0:offset[0]]}{event_name}')
        else:
            tokens[sent_id].append(f'{tokens[sent_id][0][last_offset:offset[0]]}{event_name}')
        last_sent_id = sent_id
        last_offset = offset[1]
        event_idx += 1
    tokens[last_sent_id].append(tokens[last_sent_id][0][last_offset:])
    
    new_tokens = []
    for toks in tokens:
        if len(toks) == 1:
            new_tokens.append(toks[0])
        else:
            new_tokens.append(''.join(toks[1:]))
    new_tokens = ' '.join(new_tokens)
    
    relations = doc['relations']
    rel_dict = {}
    for event in events:
        if 'eiid' in event.keys():   # some train documents' event does not have eiid
            rel_dict[event['eiid']] = {}
        for event2 in events:
            if 'eiid' in event.keys() and 'eiid' in event2.keys():
                rel_dict[event['eiid']][event2['eiid']] = 'none'
    for rel_name in relations.keys():
        rel_list = relations[rel_name]
        for rel in rel_list:
            rel_dict[rel[0]][rel[1]] = rel_name.lower()
    return new_tokens, event_idx, rel_dict, event_names, event_num2id


def convert_format(split):
    with open(f'../../../data/MATRES/{split}.json') as file:
        raw_data = file.readlines()
    save_json = {
        "prompt": {
            "instructions": "Event Relation Extraction: ",
            "input_prefix": "",
            "input_suffix": "\n",
            "output_prefix": "",
            "output_suffix": "\n",
        },
        'request_states': []
    }

    for doc_idx, data in enumerate(raw_data):
        doc = json.loads(data.strip())
        tokens, num_event, rel_dict, event_names, event_num2id = gen_prompt_doc(doc)
        answer = ''
        for i in range(num_event):
            for j in range(i + 1, num_event):
                label = rel_dict[event_num2id[i]][event_num2id[j]]
                label2 = rel_dict[event_num2id[j]][event_num2id[i]]
                assert label2 == 'none'
                if label != 'none':
                    marker1 = event_names[i].split(' ')[0]
                    marker2 = event_names[j].split(' ')[0]
                    answer += f'({marker1}; {label}; {marker2}) | '
        
        if len(answer) > 0:
            answer = answer[:-3]
        tmp_json = {
            "input": {
                "text": tokens
            },
            "references": [
                {
                    "output": {
                        "text": answer
                    }
                }
            ],
            "split": f"{split}",
            "id": f"{doc_idx}"
        }
        save_json['request_states'].append({'instance': tmp_json, 'request': {}})   
    print(len(save_json['request_states'])) 
    output_dir = Path('../../../unified_data/ERE/matres')
    output_dir.mkdir(exist_ok=True, parents=True)
    with open(os.path.join(output_dir, f'{split}.json'), 'w') as f:
        json.dump(save_json, f, indent=4)


convert_format('train')
convert_format('dev')
convert_format('test')