import pickle
import random
from os import fspath
from pathlib import Path
import numpy as np
from preprocess.preprocess_tools.process_utils import jsonl_save, load_glove
import math
import copy

SEED = 101
dev_keys = ["normal"]
test_keys = ["normal"]
np.random.seed(SEED)
random.seed(SEED)
MIN_FREQ = 2
WORDVECDIM = 300
MAX_VOCAB = 50000

train_path = Path('../data/ontonotes5_ner/onto.train.ner')
dev_path = {}
dev_path["normal"] = Path('../data/ontonotes5_ner/onto.development.ner')
test_path = {}
test_path["normal"] = Path('../data/ontonotes5_ner/onto.test.ner')
embedding_path = Path("../embeddings/glove/glove.840B.300d.txt")


Path('../processed_data/ontonotes5_ner').mkdir(parents=True, exist_ok=True)

train_save_path = Path('../processed_data/ontonotes5_ner/train.jsonl')
dev_save_path = {}
for key in dev_keys:
    dev_save_path[key] = Path('../processed_data/ontonotes5_ner/dev_{}.jsonl'.format(key))
test_save_path = {}
for key in test_keys:
    test_save_path[key] = Path('../processed_data/ontonotes5_ner/test_{}.jsonl'.format(key))
metadata_save_path = fspath(Path("../processed_data/ontonotes5_ner/metadata.pkl"))

labels2idx = {}
vocab2count = {}


def updateVocab(word):
    global vocab2count
    vocab2count[word] = vocab2count.get(word, 0) + 1


def process_data(filename, update_vocab=True):
    global labels2idx

    print("\n\nOpening directory: {}\n\n".format(filename))

    sequences = []
    labels = []
    count = 0

    with open(filename, encoding="utf8") as reader:
        lines = reader.readlines()
        sequence = []
        sequence_label = []
        for token_stuff in lines:
            if "DOCSTART" in token_stuff:
                continue
            token_stuff = token_stuff.strip()
            if token_stuff == "":
                if sequence:
                    assert len(sequence) == len(sequence_label)
                    sequences.append(sequence)
                    labels.append(sequence_label)
                    #print("sequence: ", sequence)
                    #print("labels2idx: ", labels2idx)
                    #print("sequence label: ", sequence_label)
                    sequence = []
                    sequence_label = []
                    count += 1
                    if count % 1000 == 0:
                        print("Processing Data # {}...".format(count))
            else:
                token_stuff = token_stuff.split("\t")
                token = token_stuff[0]
                label = token_stuff[-1]

                if label not in labels2idx:
                    labels2idx[label] = len(labels2idx)
                label_id = labels2idx[label]

                sequence.append(token)
                sequence_label.append(label_id)

                if update_vocab:
                    updateVocab(token)

    return sequences, labels


train_sequences, train_labels = process_data(train_path)
dev_sequences = {}
dev_labels = {}
for key in dev_keys:
    dev_sequences[key], dev_labels[key] = process_data(dev_path[key])
test_sequences = {}
test_labels = {}
for key in test_keys:
    test_sequences[key], test_labels[key] = process_data(test_path[key])

print("train len: ", len(train_sequences))
print("dev len: ", len(dev_sequences["normal"]))
print("test len: ", len(test_sequences["normal"]))

counts = []
vocab = []
for word, count in vocab2count.items():
    if count > MIN_FREQ:
        vocab.append(word)
        counts.append(count)

vocab2embed = load_glove(embedding_path, vocab=vocab2count, dim=WORDVECDIM)

sorted_idx = np.flip(np.argsort(counts), axis=0)
vocab = [vocab[id] for id in sorted_idx if vocab[id] in vocab2embed]
if len(vocab) > MAX_VOCAB:
    vocab = vocab[0:MAX_VOCAB]

vocab += ["<PAD>", "<UNK>"]

print(vocab)
print("vocab_size: ", len(vocab))

vocab2idx = {word: id for id, word in enumerate(vocab)}

vocab2embed["<PAD>"] = np.zeros((WORDVECDIM), np.float32)
b = math.sqrt(3 / WORDVECDIM)
vocab2embed["<UNK>"] = np.random.uniform(-b, +b, WORDVECDIM)

embeddings = []
vocab2idx = {}
for id, word in enumerate(vocab):
    vocab2idx[word] = id
    embeddings.append(vocab2embed[word])


def text_vectorize(text):
    return [vocab2idx.get(word, vocab2idx['<UNK>']) for word in text]


def vectorize_data(sequences, labels):
    data_dict = {}
    sequences_vec = [text_vectorize(sequence) for sequence in sequences]
    data_dict["sequence"] = sequences
    data_dict["sequence_vec"] = sequences_vec
    data_dict["label"] = labels
    for i in range(len(sequences)):
        assert len(sequences[i]) == len(sequences_vec[i])
        assert len(labels[i]) == len(sequences_vec[i])
    return data_dict


train_data = vectorize_data(train_sequences, train_labels)
dev_data = {}
for key in dev_keys:
    dev_data[key] = vectorize_data(dev_sequences[key], dev_labels[key])
test_data = {}
for key in test_keys:
    test_data[key] = vectorize_data(test_sequences[key], test_labels[key])

jsonl_save(filepath=train_save_path,
           data_dict=train_data)

for key in dev_keys:
    jsonl_save(filepath=dev_save_path[key],
               data_dict=dev_data[key])

for key in test_keys:
    jsonl_save(filepath=test_save_path[key],
               data_dict=test_data[key])

metadata = {"labels2idx": labels2idx,
            "vocab2idx": vocab2idx,
            "embeddings": np.asarray(embeddings, np.float32),
            "dev_keys": dev_keys,
            "test_keys": test_keys}

with open(metadata_save_path, 'wb') as outfile:
    pickle.dump(metadata, outfile)
