import tqdm
import numpy as np
from collections import defaultdict
import pickle

def build_embedding(X_train, vectorizer_X, parameters, words, use_pretrained, DATA_DIR):
    if use_pretrained:
        model_path = f"{DATA_DIR}/omnitm.model"
        with open(model_path, 'rb') as f:
            word_vectors = pickle.load(f)
        return word_vectors

    from tmu.models.autoencoder.autoencoder import TMAutoEncoder
    vocabulary = vectorizer_X.vocabulary_
    number_of_features = len(vocabulary)

    word_list = []
    for word in words:
        # Ensure word is a string, not a list
        if isinstance(word, list):
            print(f"Warning: word '{word}' is a list, skipping.")
            continue
        word_id = vocabulary.get(word)
        if word_id is not None:
            word_list.append((word, word_id))
        else:
            print(f"Word '{word}' not found in vocabulary, skipping.")

    all_embeddings = {}
    for target_word, word_id in tqdm.tqdm(word_list, desc="Generating embeddings", unit="word"):
        single_output_active = np.empty(1, dtype=np.uint32)
        single_output_active[0] = word_id

        tm = TMAutoEncoder(
            number_of_clauses=parameters['clauses'],
            T=parameters['T'],
            s=parameters['s'],
            output_active=single_output_active,
            max_included_literals=3,
            accumulation=parameters['accumulation'],
            feature_negation=True,
            platform='CPU',
            output_balancing=0.5
        )

        for e in range(parameters['epochs']):
            tm.fit(X_train, number_of_examples=parameters['number_of_examples'])
        clauses_weights = tm.get_weights(0)

        literal_sums = np.zeros(number_of_features)
        literal_counts = np.zeros(number_of_features)
        for j in range(parameters['clauses']):
            clause_weight = clauses_weights[j]
            if clause_weight > 0:
                for i in range(tm.clause_bank.number_of_literals):
                    if i < number_of_features:
                        literal_sums[i] += tm.get_ta_state(j, i, the_class=1, polarity=1)
                        literal_counts[i] += 1
                    else:
                        literal_sums[i - number_of_features] -= tm.get_ta_state(j, i, the_class=1, polarity=1)
                        literal_counts[i - number_of_features] += 1

        non_zero_counts = literal_counts > 0
        embedding = np.zeros(number_of_features)
        embedding[non_zero_counts] = (literal_sums[non_zero_counts] / literal_counts[non_zero_counts]).astype(int)
        all_embeddings[word_id] = embedding

    print(f"Total embeddings built: {len(all_embeddings)}")
    return all_embeddings

def get(word_vectors, vectorizer_X, target_words):
    target_similarity = defaultdict(list)

    # Prepare token_embeddings: word -> embedding (using word id)
    token_embeddings = {}
    vocabulary_size = len(vectorizer_X.vocabulary_)
    for word in tqdm.tqdm(target_words, desc="Loading token embeddings"):
        word_id = vectorizer_X.vocabulary_.get(word, None)
        if word_id is not None and word_id in word_vectors:
            token_embeddings[word] = word_vectors[word_id]
        else:
            print(f"Word '{word}' not found in vocabulary or word_vectors.")
            token_embeddings[word] = np.zeros(vocabulary_size)

    if not token_embeddings:
        print("No token embeddings found for the target words.")
        return target_similarity

    number_of_features = vocabulary_size
    profile = np.empty((len(target_words), number_of_features))
    for i, word in enumerate(target_words):
        profile[i, :] = token_embeddings[word]

    for i, word1 in enumerate(target_words):
        for j, word2 in enumerate(target_words):
            if i != j:
                word2_index = vectorizer_X.vocabulary_.get(word2, None)
                if word2_index is not None:
                    target_similarity[(word1, word2)] = profile[i, word2_index]
                else:
                    target_similarity[(word1, word2)] = 0.0
    return target_similarity