import json
import os
import numpy as np
import random
import re
from pathlib import Path
from collections import defaultdict


def gen_prompt_doc(doc):
    temp_rel = doc['temporal_relations']
    caus_rel = doc['causal_relations']
    sub_rel = {'subevent': doc['subevent_relations']}
    temp_rel.update(caus_rel)
    temp_rel.update(sub_rel)
    all_rel = temp_rel
    events = doc['events']
    tokens = doc['tokens']
    timexs = doc['TIMEX']
    if 'sentences' in doc.keys():
        sentences = doc['sentences']
    else:
        sentences = [merge_tokens([toks]) for toks in doc['tokens']]
    event_num2id = {}
    event_idx = 0
    event_names = []
    event_list = []

    for event in events:
        if len(event['mention']) > 1:
            for mention in event['mention']:
                event_num2id[event_idx] = event['id']
                event_list.append(mention)
                sent_id = mention['sent_id']
                offset = mention['offset']
                new_event_name = '<event> ' + ' '.join(tokens[sent_id][offset[0]:offset[1]]) + ' </event>'
                event_names.append(new_event_name)
                event_idx += 1
        else:
            event_num2id[event_idx] = event['id']
            mention = event['mention'][0]
            event_list.append(mention)
            sent_id = mention['sent_id']
            offset = mention['offset']
            new_event_name = '<event> ' + ' '.join(tokens[sent_id][offset[0]:offset[1]]) + ' </event>'
            event_names.append(new_event_name)
            event_idx += 1

    for timex in timexs:
        event_num2id[event_idx] = timex['id']
        event_list.append(timex)
        sent_id = timex['sent_id']
        offset = timex['offset']
        new_event_name = '<timex> ' + ' '.join(tokens[sent_id][offset[0]:offset[1]]) + ' </timex>'
        event_names.append(new_event_name)
        event_idx += 1
    
    rel_dict = {}
    for i in range(event_idx):
        rel_dict[event_num2id[i]] = {}
        for j in range(event_idx):
            rel_dict[event_num2id[i]][event_num2id[j]] = []
    for rel_name in all_rel.keys():
        rel_list = all_rel[rel_name]
        for rel in rel_list:
            rel_dict[rel[0]][rel[1]].append(rel_name.lower())
    return tokens, event_idx, rel_dict, event_names, event_num2id, event_list, sentences


def merge_tokens(tokens):
    text = ''
    blank = False
    for toks in tokens:
        for tok in toks:
            if tok in [',', '.', ';', ':', '!', '?', '>', ')', ']', '}', "''", "'s", '%']:
                text += tok
                blank = True
            elif tok == '-':
                text += tok
                blank = False
            elif tok in ['<', '(', '[', '{', '``', '`', '$']:
                if blank == True:
                    text += ' '
                text += tok
                blank = False
            else:
                if blank == True:
                    text += ' '
                text += tok
                blank = True
    return text


def gen_save_json():
    with open('instruction.txt') as inst_file:
        instructions = ''.join(inst_file.readlines())
    return {
        "prompt": {
            "instructions": "",
            "input_prefix": "", 
            "input_suffix": "\n",
            "output_prefix": "",
            "output_suffix": "\n",
        },
        "request_states": []
    }


def make_eval_data(split, number, selection='average'):
    load_split_name = split
    if split == 'dev':
        load_split_name = 'valid'
    with open(f'../../data/maven-ere/{load_split_name}.jsonl') as load_file:
        load_file = load_file.readlines()
    save_json = gen_save_json()
    select_list = []

    assert selection == 'random'
    sel_idx = list(range(len(load_file)))
    random.shuffle(sel_idx)
    for i in sel_idx[0:number]:
        doc = json.loads(load_file[i].strip())
        if split == 'test':
            new_doc = {}
            new_doc['tokens'] = doc['doc']['tokens']
            new_doc['events'] = doc['events']
            for events in new_doc['events']:
                events['mention'] = events['mentions']
            new_doc['TIMEX'] = doc['TIMEX3']
            new_doc['temporal_relations'] = doc['temporal_event_relation']
            new_doc['causal_relations'] = doc['causal_relation']
            new_doc['subevent_relations'] = doc['subevent_relation']
            select_list.append(new_doc)
        else:
            select_list.append(doc)
    
    for doc_idx, doc in enumerate(select_list):
        tokens, num_event, rel_dict, event_names, event_num2id, event_list, sentences = gen_prompt_doc(doc)
        cnt = 0
        for i in range(num_event):
            for j in range(num_event):
                if i == j:
                    continue
                sent_id1 = event_list[i]['sent_id']
                offset1 = event_list[i]['offset']
                sent_id2 = event_list[j]['sent_id']
                offset2 = event_list[j]['offset']
                work_tokens = []

                for sent_id, toks in enumerate(tokens):
                    if sent_id1 != sent_id and sent_id2 != sent_id:
                        work_tokens.append(sentences[sent_id])
                    else:
                        tmp_tokens = []
                        if sent_id1 == sent_id and sent_id2 != sent_id:
                            for off, tok in enumerate(toks):
                                if offset1[0] == off:
                                    tmp_tokens.append(event_names[i])
                                elif off < offset1[0] or off >= offset1[1]:
                                    tmp_tokens.append(tok)
                        elif sent_id2 == sent_id and sent_id1 != sent_id:
                            for off, tok in enumerate(toks):
                                if offset2[0] == off:
                                    tmp_tokens.append(event_names[j])
                                elif off < offset2[0] or off >= offset2[1]:
                                    tmp_tokens.append(tok)
                        else:
                            for off, tok in enumerate(toks):
                                if offset1[0] == off:
                                    tmp_tokens.append(event_names[i])
                                elif offset2[0] == off:
                                    tmp_tokens.append(event_names[j])
                                elif (off < offset1[0] or off >= offset1[1]) and (off < offset2[0] or off >= offset2[1]):
                                    tmp_tokens.append(tok)
                        work_tokens.append(merge_tokens([tmp_tokens]))
                work_tokens = ' '.join(work_tokens)

                query = 'Document: ' + work_tokens + '\n\nThe first event/timex: ' + event_names[i] \
                    + '\nThe second event/timex: ' + event_names[j] + '\n'
                if event_num2id[i] == event_num2id[j]:
                    label = ['coreference']
                else:
                    label = rel_dict[event_num2id[i]][event_num2id[j]]
                tmp_json = {
                    "input": {
                        "text": query
                    },
                    "references": [
                        {
                            "output": {
                                "text": f"({event_num2id[i]}; {label}; {event_num2id[j]})"
                            }
                        }
                    ],
                    "split": f"{split}",
                    "id": f"{doc_idx}_{cnt}"
                }
                save_json['request_states'].append({'instance': tmp_json, 'request': {}})
                cnt += 1
    
    print(len(save_json['request_states']))
    with open(f'../../unified_data/maven-ere/{split}.json', 'w') as f:
        json.dump(save_json, f)


def valid(doc, selection='small'):
    if selection == 'small':
        tokens = doc["tokens"]
        num_tokens = np.sum([len(toks) for toks in tokens])
        if num_tokens > 100:
            return False
    else:
        raise ValueError
    print(num_tokens)
    return True


def gen_demo_single(tokens, num_event, rel_dict, event_names, event_num2id, event_list, sentences):
    while True:
        idx1 = random.randint(0, num_event - 1)
        idx2 = random.randint(0, num_event - 1)
        if idx1 == idx2:
            continue
        if len(rel_dict[event_num2id[idx1]][event_num2id[idx2]]) == 0:
            if random.random() > 0.5:  # if no relation, refuse it with 50% probability
                continue
    
        sent_id1 = event_list[idx1]['sent_id']
        offset1 = event_list[idx1]['offset']
        sent_id2 = event_list[idx2]['sent_id']
        offset2 = event_list[idx2]['offset']
        work_tokens = []

        for sent_id, toks in enumerate(tokens):
            if sent_id1 != sent_id and sent_id2 != sent_id:
                work_tokens.append(sentences[sent_id])
            else:
                tmp_tokens = []
                if sent_id1 == sent_id and sent_id2 != sent_id:
                    for off, tok in enumerate(toks):
                        if offset1[0] == off:
                            tmp_tokens.append(event_names[idx1])
                        elif off < offset1[0] or off >= offset1[1]:
                            tmp_tokens.append(tok)
                elif sent_id2 == sent_id and sent_id1 != sent_id:
                    for off, tok in enumerate(toks):
                        if offset2[0] == off:
                            tmp_tokens.append(event_names[idx2])
                        elif off < offset2[0] or off >= offset2[1]:
                            tmp_tokens.append(tok)
                else:
                    assert sent_id == sent_id1 and sent_id1 == sent_id2
                    for off, tok in enumerate(toks):
                        if offset1[0] == off:
                            tmp_tokens.append(event_names[idx1])
                        elif offset2[0] == off:
                            tmp_tokens.append(event_names[idx2])
                        elif (off < offset1[0] or off >= offset1[1]) and (off < offset2[0] or off >= offset2[1]):
                            tmp_tokens.append(tok)
                work_tokens.append(merge_tokens([tmp_tokens]))
        
        work_tokens = ' '.join(work_tokens)
        if event_num2id[idx1] == event_num2id[idx2]:
            label = ['coreference']
            label2 = ['coreference']
        else:
            label = rel_dict[event_num2id[idx1]][event_num2id[idx2]]
            label2 = rel_dict[event_num2id[idx2]][event_num2id[idx1]]
        label = '[' + ', '.join(label) + ']'
        demo_doc = 'Document:\n' + work_tokens + '\n\nThe first event/"timex": ' + event_names[idx1] + \
            '\nThe second event/"timex": ' + event_names[idx2] + '\n'
        demo_doc_show = demo_doc + 'label: ' + label + str(label2) + '\n'
        print(demo_doc_show)
        choose = int(input('choose this or not:\n'))
        print()
        if choose:
            break

    demo_ans = input('input the explanation\n')
    demo_ans = 'Explanation: ' + demo_ans + f'\nAnswer: {label}\n'
    print()
    print(demo_ans)
    return demo_doc, demo_ans


def make_train_data():
    with open(f'../../data/maven-ere/train.jsonl') as load_file:
        load_file = load_file.readlines()
    try:
        with open(f'../../unified_data/maven-ere/train.json') as file:
            save_json = json.load(file)
    except:
        save_json = gen_save_json()
    num_done = len(save_json['request_states'])

    sel_list_raw = []
    for doc_raw in load_file:
        doc = json.loads(doc_raw.strip())
        if not valid(doc, 'small'):
            continue
        sel_list_raw.append(doc)
    sel_idx = list(range(len(sel_list_raw)))
    random.shuffle(sel_idx)

    for idx, doc_idx in enumerate(sel_idx[num_done:8]):
        doc = json.loads(load_file[doc_idx])
        tokens, num_event, rel_dict, event_names, event_num2id, event_list, sentences = gen_prompt_doc(doc)
        query, label = gen_demo_single(tokens, num_event, rel_dict, event_names, event_num2id, event_list, sentences)
        tmp_json = {
            "input": {
                "text": query
            },
            "references": [
                {
                    "output": {
                        "text": label
                    },
                    "tags": [
                        "correct"
                    ]
                }
            ],
            "split": f"train",
            "id": f"{idx}"
        }
        save_json['request_states'].append({'instance': tmp_json})
        with open(f'../../unified_data/maven-ere/train.json', 'w') as f:
            json.dump(save_json, f)


def make_train_data_for_instrcut_tuning(split, number, selection='average'):
    load_split_name = split
    if split == 'dev':
        load_split_name = 'valid'
    with open(f'../../../data/MAVEN-ERE/{load_split_name}.json') as load_file:
        load_file = load_file.readlines()
    save_json = gen_save_json()
    select_list = []

    assert selection == 'random'
    sel_idx = list(range(len(load_file)))
    random.shuffle(sel_idx)
    for i in sel_idx[0:number]:
        doc = json.loads(load_file[i].strip())
        if split == 'test':
            new_doc = {}
            new_doc['tokens'] = doc['doc']['tokens']
            new_doc['events'] = doc['events']
            for events in new_doc['events']:
                events['mention'] = events['mentions']
            new_doc['TIMEX'] = doc['TIMEX3']
            new_doc['temporal_relations'] = doc['temporal_event_relation']
            new_doc['causal_relations'] = doc['causal_relation']
            new_doc['subevent_relations'] = doc['subevent_relation']
            select_list.append(new_doc)
        else:
            select_list.append(doc)
    
    for doc_idx, doc in enumerate(select_list):
        tokens, num_event, rel_dict, event_names, event_num2id, event_list, sentences = gen_prompt_doc(doc)
        cnt = 0
        for i in range(num_event):
            for j in range(num_event):
                if i == j:
                    continue
                sent_id1 = event_list[i]['sent_id']
                offset1 = event_list[i]['offset']
                sent_id2 = event_list[j]['sent_id']
                offset2 = event_list[j]['offset']
                work_tokens = []

                for sent_id, toks in enumerate(tokens):
                    if sent_id1 != sent_id and sent_id2 != sent_id:
                        work_tokens.append(sentences[sent_id])
                    else:
                        tmp_tokens = []
                        if sent_id1 == sent_id and sent_id2 != sent_id:
                            for off, tok in enumerate(toks):
                                if offset1[0] == off:
                                    tmp_tokens.append(event_names[i])
                                elif off < offset1[0] or off >= offset1[1]:
                                    tmp_tokens.append(tok)
                        elif sent_id2 == sent_id and sent_id1 != sent_id:
                            for off, tok in enumerate(toks):
                                if offset2[0] == off:
                                    tmp_tokens.append(event_names[j])
                                elif off < offset2[0] or off >= offset2[1]:
                                    tmp_tokens.append(tok)
                        else:
                            for off, tok in enumerate(toks):
                                if offset1[0] == off:
                                    tmp_tokens.append(event_names[i])
                                elif offset2[0] == off:
                                    tmp_tokens.append(event_names[j])
                                elif (off < offset1[0] or off >= offset1[1]) and (off < offset2[0] or off >= offset2[1]):
                                    tmp_tokens.append(tok)
                        work_tokens.append(merge_tokens([tmp_tokens]))
                work_tokens = ' '.join(work_tokens)

                query = 'Document: ' + work_tokens + '\n\nThe first event/timex: ' + event_names[i] \
                    + '\nThe second event/timex: ' + event_names[j] + '\n'
                if event_num2id[i] == event_num2id[j]:
                    label = ['coreference']
                else:
                    label = rel_dict[event_num2id[i]][event_num2id[j]]
                tmp_json = {
                    "input": {
                        "text": query
                    },
                    "references": [
                        {
                            "output": {
                                "text": " | ".join([f"({event_num2id[i]}; {_label}; {event_num2id[j]})" for _label in label])
                            }
                        }
                    ],
                    "split": f"{split}",
                    "id": f"{doc_idx}_{cnt}"
                }
                save_json['request_states'].append({'instance': tmp_json, 'request': {}})
                cnt += 1
    
    print(len(save_json['request_states']))

    # sample 50% NA, 50% non-NA
    save_json["request_states"] = sample_instance(save_json["request_states"])
    print(len(save_json['request_states']))

    output_dir = Path("../../../unified_data/maven-ere")
    output_dir.mkdir(exist_ok=True)
    with open(os.path.join(output_dir, f'{split}.json'), 'w') as f:
        json.dump(save_json, f, indent=4)


def sample_instance(data, num_instances=30000):
    pattern = re.compile("\((.*);(.*);(.*)\)")
    sampled_data = defaultdict(list)
    for instance in data["request_states"]:
        if instance["references"]["output"]["text"] == "":
            sampled_data["na"].append(instance)
        else:
            rel = re.findall(pattern, instance["references"]["output"]["text"])[0][1]
            if rel in ["subevent"]:
                sampled_data["subevent"].append(instance)
            elif rel in ["coref"]:
                sampled_data["coref"].append(instance)
            elif rel in ["precondition", "cause"]:
                sampled_data["causal"].append(instance)
            elif rel in ["before", "overlap", "contains", "simultaneous", "ends-on", "begins-on"]:
                sampled_data["temporal"].append(instance)
            else:
                print(rel)
    final_data = []
    def sample_per_rel(data, k):
        if len(data) < k:
            return data
        else:
            return random.sample(data, k=k)
    final_data.extend(sample_per_rel(sampled_data["na"], k=int(num_instances*0.5)))
    print(len(final_data))
    final_data.extend(sample_per_rel(sampled_data["subevent"], k=int(num_instances*0.5*0.25)))
    print(len(final_data))
    final_data.extend(sample_per_rel(sampled_data["causal"], k=int(num_instances*0.5*0.25)))
    print(len(final_data))
    final_data.extend(sample_per_rel(sampled_data["coref"], k=int(num_instances*0.5*0.25)))
    print(len(final_data))
    k = num_instances - len(final_data)
    final_data.extend(sample_per_rel(sampled_data["subevent"], k=k))
    return final_data



if __name__ == "__main__":
    make_train_data_for_instrcut_tuning("train", 500, "random")