'''
- word2vec.py
- This file handles the loading and distance measuring of word2vec embeddings
'''

# External imports
import numpy as np
from gensim import models
from gensim.models import Word2Vec
import nltk
from nltk.tokenize import word_tokenize

# Internal imports


'''
----------load_embedding----------
- This function loads in the current feature vector set to work with
-----Inputs-----
- file_location - the location of the fasttext embedding file to use
-----Output-----
- embedding - the dictionary of embeddings
'''
def load_embedding(file_location):
    nltk.download('punkt')
    return {"embedding":models.KeyedVectors.load_word2vec_format(file_location,binary=True)}


'''
----------get_avg_text_vector----------
- This function loads in the current feature vector set to work with
-----Inputs-----
- text - the text whose average embedding is to be retrieved
- embedding - the name of the embedding to use
-----Output-----
- distance - the averaged
'''
def get_avg_text_vector(text, embedding):
    text_array = word_tokenize(text)
    feature_vec = np.zeros((embedding.vector_size,), dtype="float32")
    num_words = 0
    word_vectors = embedding
    for word in text_array:
        if word in word_vectors.key_to_index:
            num_words += 1
            feature_vec = np.add(feature_vec, embedding[word])
    if num_words > 0:
        feature_vec = np.divide(feature_vec, num_words)
    return feature_vec


'''
----------load_finetuned_embedding----------
- This function loads in the current feature vector set to work with
-----Inputs-----
- file_location - the location of the fasttext embedding file to use
-----Output-----
- embedding - the dictionary of embeddings
'''


'''
----------get_finetuned_avg_text_vector----------
- This function loads in the current feature vector set to work with
-----Inputs-----
- embedding - the name of the embedding to use
- schema - the currently-active schema
-----Output-----
- embedding - the dictionary of embeddings
'''