import csv
import json

from copy import deepcopy
import random

import spacy
import en_core_web_md
from tqdm import tqdm
import pickle
from termcolor import cprint

import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, default='.')
args = parser.parse_args()

dataset = []
data_dir = args.data_dir
with open(f"{data_dir}/opendialkg.csv") as csvfile:
    reader = csv.reader(csvfile)

    unique_id = 1000000
    for i, rows in enumerate(reader):
        if i == 0: continue
        dataset.append((rows, unique_id))
        unique_id += 1

random.seed(42)
random.shuffle(dataset)

train_ratio = 0.7
valid_ratio = 0.15
test_ratio = 0.15

train_len = int(len(dataset) * train_ratio)
valid_len = train_len + int(len(dataset) * valid_ratio)

train_dataset = dataset[:train_len]
valid_dataset = dataset[train_len:valid_len]
test_dataset = dataset[valid_len:]

print(f"train Size: {len(train_dataset) }")
print(f"valid Size: {len(valid_dataset) }")
print(f"test Size: {len(test_dataset) }")

with open(f"{data_dir}/entity_codebook.pkl", 'rb') as f:
    entity_codebook = pickle.load(f)
reverse_entity_codebook = {v:k for k, v in entity_codebook.items()}
with open(f"{data_dir}/relation_codebook.pkl", 'rb') as f:
    relation_codebook = pickle.load(f)
reverse_relation_codebook = {v:k for k, v in relation_codebook.items()}
nlp = en_core_web_md.load()

# Map Entity to corresponding code
def map_entity(entity):
    try:
        code = entity_codebook[entity.lower()]
        return code
    except:
        return None

# Map Code to corresponding entity
def map_code(code):
    try:
        entity = reverse_entity_codebook[code]
        return entity
    except:
        return None

# Inverse Triplet
def inverse(triplet):
    head = triplet['h']
    relation = triplet['r']
    tail = triplet['t']

    relation_text = reverse_relation_codebook[relation]
    if '~' in relation_text:
        inverse_relation_text = relation_text[1:]
    else:
        inverse_relation_text = '~' + relation_text
    inverse_relation = relation_codebook[inverse_relation_text]
    return {'h':tail, 'r':inverse_relation, 't':tail}

# Map Tripelt to corresponding code (entity, relation)
def map_triplet(triplets):
    return_list = []
    for triplet in triplets:
        try:
            head_code = entity_codebook[triplet[0].lower()]
            relation_code = relation_codebook[triplet[1].lower()]
            tail_code = entity_codebook[triplet[2].lower()]
            return_list.append(
                {
                    'h': head_code,
                    'r': relation_code,
                    't': tail_code
                }
            )
        except:
            print()
            head_exist = triplet[0] in entity_codebook.keys()
            relation_exist = triplet[1] in relation_codebook.keys()
            tail_exist = triplet[2] in entity_codebook.keys()
            print(f"Head exist:{head_exist} {triplet[0]}\nRelation exist:{relation_exist} {triplet[1]}\nTail exist:{tail_exist} {triplet[2]}")
    return return_list

# We don't need start_char and end_char anymore? -> No
# 1. Find entity among the candidate entities
# 2. Run spaCy and get candidate of entities from text -> find aliases from the DB
def find_entity(message, entities):
    found_entities = set()
    for entity in entities:
        start_char = message.lower().find(entity.lower())
        if start_char >= 0:
            end_char = start_char + len(entity) - 1
            entity_code = map_entity(entity)
            if entity_code is None: continue
            found_entities.add(
                entity_code
            )

    doc = nlp(message)
    for ent in doc.ents:
        entity_code = map_entity(ent.text)
        if entity_code is None: continue
        start_char, end_char = ent.start_char, ent.end_char
        found_entities.add(
            entity_code
        )
    return found_entities

# Locate entity within history (used in model mention_positions)
def locate_entity(message, entities):
    located_entities = []

    for entity_code in list(entities):
        entity = map_code(entity_code)
        start_char = message.lower().find(entity.lower())
        if start_char >= 0:
            end_char = start_char + len(entity) - 1
            assert entity_code == map_entity(entity)

            located_entities.append(
                {
                    'start': start_char,
                    'end': end_char,
                    'text': message[start_char:end_char+1],
                    'id': entity_code
                }
            )
    return located_entities

with open(f"{data_dir}/opendialkg_triples.txt", 'r') as f:
    entire_triplets = f.readlines()
print(f"# Entire Triples: {len(entire_triplets)}")

def build_database(prop='head'):
    database = dict()
    for triplet in entire_triplets:
        _triplet = triplet.strip().split('\t')
        if len(_triplet) < 3:
            continue
        head, relation, tail = _triplet
        head_id = entity_codebook[head.lower()]
        relation_id = relation_codebook[relation.lower()]
        tail_id = entity_codebook[tail.lower()]

        if prop == "head":
            _id = head_id
        elif prop == "tail":
            _id = tail_id
        elif prop == "relation":
            _id = relation_id
        
        if _id not in database.keys():
            database[_id] = set()
        database[_id].add((head_id, relation_id, tail_id))
    return database

head_database = build_database("head")
tail_database = build_database("tail")
maximum_triplets = 1000

def preprocess_triplets(entire_entities):
    entire_entities = list(entire_entities)
    _triplets = set()
    for entity_id in entire_entities:
        try:
            head_set = head_database[entity_id]
        except KeyError:
            head_set = set()
        try:
            tail_set = tail_database[entity_id]
        except KeyError:
            tail_set = set()
        one_hop_facts = head_set.union(tail_set)
        if len(one_hop_facts) > maximum_triplets:
            one_hop_facts = random.sample(one_hop_facts, maximum_triplets)
        _triplets.update(one_hop_facts)
    _triplets = [{'h':t[0], 'r':t[1], 't':t[2]} for t in list(_triplets)]
    return _triplets
    
def _preprocess(dataset, fold):
    new_dataset = []
    episode_id = 0
    for i, rows in enumerate(tqdm(dataset, desc="Preprocessing...")):
        rows, unique_id = rows
        turn_id = 0
        ## Dialogue Histories
        history = []
        
        """
        candidate_entities:
        The set of entities that can be the root of sub-KG
        Condition for the sub-KG:
        1. It should be included in the last 2 utterances (assistant, user)
        """
        candidate_entities = set() # From metada:path
        history_entities = set() # From history

        pp_rows = json.loads(rows[0])

        # Gather possible entities along all possible history first
        for row in pp_rows:
            if 'metadata' in row.keys():
                if 'path' in row['metadata'].keys():
                    triplets = row['metadata']['path'][1]
                    for triplet in triplets:
                        candidate_entities.add(triplet[0])
                        candidate_entities.add(triplet[-1])

        ## Triplets and Entities
        gold_triplets = []
        has_gold = False
        for row in pp_rows:
            if 'metadata' in row.keys():
                if 'path' in row['metadata'].keys():
                    gold_triplets = row['metadata']['path'][1]
                    has_gold = True

            elif 'message' in row.keys():
                message = row['message']
                if row['sender'] == 'assistant':
                    # Wrap-up and make the data
                    dialog_text = '\n'.join(history)
                    triplets_codes = preprocess_triplets(history_entities)
                    located_entities = locate_entity(dialog_text, history_entities)
                    if has_gold:
                        gold_triplets = map_triplet(gold_triplets)
                        has_gold = False
                    else:
                        gold_triplets = []

                    triplets_label = []
                    for _triplet in triplets_codes:
                        if _triplet in gold_triplets or inverse(_triplet) in gold_triplets:
                            triplets_label.append(1)
                        else:
                            triplets_label.append(0)

                    if fold == "train":
                        for _triplet in gold_triplets:
                            if _triplet not in triplets_codes and inverse(_triplet) not in triplets_codes:
                                triplets_codes.append(_triplet)
                                triplets_label.append(1)
                    assert len(triplets_codes) == len(triplets_label)

                    data = {
                        "episode_id": episode_id,
                        "turn_id": turn_id,
                        "history": deepcopy(history),
                        "label": message,
                        "entities": located_entities,
                        "triplets": triplets_codes,
                        "unique_id": unique_id,
                        "triplets_label": triplets_label,
                    }
                    new_dataset.append(data)
                    turn_id += 1

                ## Common actions
                ## 1. Find entities
                ## 2. Find corresponding subgraphs
                found_entities = find_entity(message, list(candidate_entities))
                history_entities.update(found_entities)
                history.append(message)
        episode_id += 1
        # input()
    print(f"{fold} Size: {len(new_dataset) }")

    filename_out = f"{fold}.jsonl"
    with open(os.path.join(data_dir, filename_out), 'w') as outfile:
        for data in new_dataset:
            outfile.write(json.dumps(data) + '\n')

_preprocess(train_dataset, fold='train')
_preprocess(valid_dataset, fold='valid')
_preprocess(test_dataset, fold='test')