from gensim.models import Word2Vec
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict

def build_embedding(tokenized_sentences, parameters, use_pretrained, DATA_DIR):
    if use_pretrained:
        model_path=f"{DATA_DIR}/word2vec.model"
        model = Word2Vec.load(model_path) 
        return model.wv
    
     # Train the Word2Vec model
    word2vec_model = Word2Vec(
        sentences=tokenized_sentences,
        vector_size=parameters['vector_size'],
        window=parameters['window'],
        min_count=parameters['min_count'],
        workers=parameters['workers'],
        sg=parameters['sg'],
        epochs=parameters['epochs']
    )
    return word2vec_model.wv

def get(word_vectors, parameters, target_words):
    vector_size = parameters['vector_size']
    target_similarity=defaultdict(list)
    profile = np.empty((len(target_words), vector_size))
    for i, word in enumerate(target_words):
        if word in word_vectors:
            profile[i, :] = word_vectors[word]
        else:
            profile[i, :] = np.zeros(vector_size)

    similarity = cosine_similarity(profile)
    for i in range(len(target_words)):
        sorted_index = np.argsort(-1*similarity[i,:])
        for j in range(1, len(target_words)):
            target_similarity[(target_words[i], target_words[sorted_index[j]])]  = similarity[i,sorted_index[j]]
    
    return target_similarity