from text_clustering import TextClustering
from text_clustering_tm import TextClusteringTM
import numpy as np

def get(model_base_path, embedding_model):
    print(f"Embedding model: {embedding_model}")
    # Load the embedding model
    if embedding_model == "word2vec":
        from gensim.models import Word2Vec
        model_path=f"{model_base_path}/{embedding_model}.model"
        model = Word2Vec.load(model_path) 
        word_vectors = model.wv
        clustering = TextClustering(word_vectors)
        
    elif embedding_model == "fasttext":
        from gensim.models import FastText
        model_path=f"{model_base_path}/{embedding_model}.model"
        model = FastText.load(model_path) 
        word_vectors = model.wv
        clustering = TextClustering(word_vectors)
        
    elif embedding_model == "glove":
        model_path=f"{model_base_path}/{embedding_model}.model"
        def load_glove_embeddings(file_path):
            embeddings = {}
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    values = line.split()
                    word = values[0]
                    vector = np.asarray(values[1:], dtype='float32')
                    embeddings[word] = vector
            return embeddings
        word_vectors = load_glove_embeddings(model_path)
        clustering = TextClustering(word_vectors)

    elif embedding_model == "omnitm":
        model_path=f"{model_base_path}/{embedding_model}.model"
        clustering = TextClusteringTM(model_path)
    return clustering