import json
import os
from pathlib import Path

def merge_tokens(tokens):
    text = ''
    blank = False
    for toks in tokens:
        for tok in toks:
            if tok in [',', '.', ';', ':', '!', '?', '>', ')', ']', '}', "''", "'s", '%']:
                text += tok
                blank = True
            elif tok == '-':
                text += tok
                blank = False
            elif tok in ['<', '(', '[', '{', '``', '`', '$']:
                if blank == True:
                    text += ' '
                text += tok
                blank = False
            else:
                if blank == True:
                    text += ' '
                text += tok
                blank = True
    return text


def gen_prompt_doc(doc):
    events = doc['events']
    tokens = doc['text']
    event_num2id = {}
    event_idx = 0
    event_names = []

    for event in events:
        event_num2id[event_idx] = event['id']
        sent_id = event['sent_id']
        offset = event['offset']
        tokens[sent_id][offset[0]] = f'<Event_{event_idx}> {tokens[sent_id][offset[0]]}'
        tokens[sent_id][offset[1] - 1] = f'{tokens[sent_id][offset[1] - 1]} </Event_{event_idx}>'
        event_names.append(' '.join(tokens[sent_id][offset[0]:offset[1]]))
        event_idx += 1
    tokens = merge_tokens(tokens)
    assert event_idx == len(events)
    
    relations = doc['relations']
    rel_dict = {}
    for event in events:
        rel_dict[event['id']] = {}
        for event2 in events:
            rel_dict[event['id']][event2['id']] = 'No'
    for rel_name in relations.keys():
        assert rel_name == 'CAUSE'
        rel_list = relations[rel_name]
        for rel in rel_list:
            rel_dict[rel[0]][rel[1]] = 'Yes'
    return tokens, event_idx, rel_dict, event_names, event_num2id


def convert_format(split):
    with open(f'../../../data/CausalTimeBank/{split}.json') as file:
        raw_data = file.readlines()
    save_json = {
        "prompt": {
            "instructions": "Event Relation Extraction: ",
            "input_prefix": "",
            "input_suffix": "\n",
            "output_prefix": "",
            "output_suffix": "\n",
        },
        'request_states': []
    }

    for doc_idx, data in enumerate(raw_data):
        doc = json.loads(data.strip())
        tokens, num_event, rel_dict, event_names, event_num2id = gen_prompt_doc(doc)
        answer = ''
        for i in range(num_event):
            for j in range(num_event):
                if i == j:
                    continue
                label = rel_dict[event_num2id[i]][event_num2id[j]]
                if label == 'Yes':
                    marker1 = event_names[i].split(' ')[0]
                    marker2 = event_names[j].split(' ')[0]
                    answer += f'({marker1}; cause; {marker2}) | '
        
        if len(answer) > 0:
            answer = answer[:-3]
        tmp_json = {
            "input": {
                "text": tokens
            },
            "references": [
                {
                    "output": {
                        "text": answer
                    }
                }
            ],
            "split": f"{split}",
            "id": f"{doc_idx}"
        }
        save_json['request_states'].append({'instance': tmp_json, 'request': {}})   
    print(len(save_json['request_states'])) 
    output_dir = Path('../../../unified_data/ERE/causaltimebank')
    output_dir.mkdir(exist_ok=True, parents=True)
    with open(os.path.join(output_dir, f'{split}.json'), 'w') as f:
        json.dump(save_json, f, indent=4)



convert_format('train')
convert_format('dev')