import os
import json
from pathlib import Path
from collections import defaultdict


def merge_tokens(tokens):
    text = ''
    blank = False
    for toks in tokens:
        for tok in toks:
            if tok in [',', '.', ';', ':', '!', '?', '>', ')', ']', '}', "''", "'s", '%']:
                text += tok
                blank = True
            elif tok == '-':
                text += tok
                blank = False
            elif tok in ['<', '(', '[', '{', '``', '`', '$']:
                if blank == True:
                    text += ' '
                text += tok
                blank = False
            else:
                if blank == True:
                    text += ' '
                text += tok
                blank = True
    return text


def gen_prompt_doc(doc, relation_type):
    if relation_type == "temporal":
        all_rel = doc['temporal_relations']
    elif relation_type == "causal":
        all_rel = doc['causal_relations']
    elif relation_type == "subevent":
        all_rel = {'subevent': doc['subevent_relations']}
    else:
        all_rel = {}
    events = doc['events']
    tokens = doc['tokens']
    timexs = doc['TIMEX']
    event_num2id = {}
    event_idx = 0
    event_names = []

    for event in events:
        if len(event['mention']) > 1:
            for mention in event['mention']:
                event_num2id[event_idx] = event['id']
                sent_id = mention['sent_id']
                offset = mention['offset']
                tokens[sent_id][offset[0]] = f'event{event_idx} {tokens[sent_id][offset[0]]}'
                # tokens[sent_id][offset[1] - 1] = f'{tokens[sent_id][offset[1] - 1]} </event_{event_idx}>'
                event_names.append(' '.join(tokens[sent_id][offset[0]:offset[1]]))
                event_idx += 1
        else:
            event_num2id[event_idx] = event['id']
            mention = event['mention'][0]
            sent_id = mention['sent_id']
            offset = mention['offset']
            tokens[sent_id][offset[0]] = f'event{event_idx} {tokens[sent_id][offset[0]]}'
            # tokens[sent_id][offset[1] - 1] = f'{tokens[sent_id][offset[1] - 1]} </event_{event_idx}>'
            event_names.append(' '.join(tokens[sent_id][offset[0]:offset[1]]))
            event_idx += 1

    start = event_idx
    if relation_type == "temporal":
        for timex in timexs:
            event_num2id[event_idx] = timex['id']
            sent_id = timex['sent_id']
            offset = timex['offset']
            tokens[sent_id][offset[0]] = f'time{event_idx-start} {tokens[sent_id][offset[0]]}'
            # tokens[sent_id][offset[1] - 1] = f'{tokens[sent_id][offset[1] - 1]} </timex_{event_idx-start}>'
            event_names.append(' '.join(tokens[sent_id][offset[0]:offset[1]]))
            event_idx += 1

    tokens = merge_tokens(tokens)
    
    rel_dict = {}
    for i in range(event_idx):
        rel_dict[event_num2id[i]] = {}
        for j in range(event_idx):
            rel_dict[event_num2id[i]][event_num2id[j]] = []
    for rel_name in all_rel.keys():
        rel_list = all_rel[rel_name]
        for rel in rel_list:
            rel_dict[rel[0]][rel[1]].append(rel_name.lower())
    return tokens, event_idx, rel_dict, event_names, event_num2id


def convert_format(split, relation_type="temporal"):
    load_split_name = split
    # if split == 'dev':
    #     load_split_name = 'valid'
    with open(f'../../../data/MAVEN-ERE/{load_split_name}.json') as file:
        raw_data = file.readlines()
    save_json = {
        "prompt": {
            "instructions": "Event Relation Extraction: ",
            "input_prefix": "",
            "input_suffix": "\n",
            "output_prefix": "",
            "output_suffix": "\n",
        },
        'request_states': []
    }

    for doc_idx, data in enumerate(raw_data):
        doc = json.loads(data.strip())
        # if split == 'test':
        new_doc = {}
        new_doc['tokens'] = doc['doc']['tokens']
        new_doc['events'] = doc['events']
        for events in new_doc['events']:
            events['mention'] = events['mentions']
        new_doc['TIMEX'] = doc['TIMEX3']
        new_doc['temporal_relations'] = doc['temporal_event_relation']
        new_doc['causal_relations'] = doc['causal_relation']
        new_doc['subevent_relations'] = doc['subevent_relation']
        doc = new_doc
    
        tokens, num_event, rel_dict, event_names, event_num2id = gen_prompt_doc(doc, relation_type)
        answer = []
        for i in range(num_event):
            for j in range(num_event):
                if i == j:
                    continue
                if event_num2id[i] == event_num2id[j] and relation_type == "coref":
                    label = ['coreference']
                    if j < i:
                        continue
                else:
                    label = rel_dict[event_num2id[i]][event_num2id[j]]
                marker1 = event_names[i].split(' ')[0]
                marker2 = event_names[j].split(' ')[0]
                for la in label:
                    answer.append([marker1, la, marker2])
        def extract_number(item):
            if item[0] == "e":
                return int(item[5:])
            elif item[0] == "t":
                return int(item[4:])
            else:
                raise ValueError
        answer = sorted(answer, key=lambda item: (item[0][0], extract_number(item[0]), item[1], item[2][0], extract_number(item[2])))
        hr_t = defaultdict(list)
        for item in answer:
            hr_t[(item[0], item[1])].append(item[2])
        answer = []
        for key in hr_t:
            t = ", ".join(hr_t[key])
            answer.append(f"({key[0]}; {key[1]}; {t})")
        # import pdb; pdb.set_trace()
        answer = " | ".join(answer)
        tmp_json = {
            "input": {
                "text": tokens
            },
            "references": [
                {
                    "output": {
                        "text": answer
                    }
                }
            ],
            "split": f"{split}",
            "id": f"{doc_idx}"
        }
        save_json['request_states'].append({'instance': tmp_json, 'request': {}})
    
    print(len(save_json['request_states']))
    output_dir = Path(f'../../../unified_data/ERE/maven-ere-{relation_type}')
    output_dir.mkdir(exist_ok=True, parents=True)
    with open(os.path.join(output_dir, f'{split}.json'), 'w') as f:
        json.dump(save_json, f, indent=4)

TYPE = "temporal"
convert_format('train', TYPE)
convert_format('dev', TYPE)
convert_format('test', TYPE)