from tmu.models.autoencoder.autoencoder import TMAutoEncoder
from time import time
import numpy as np
import os
import tqdm
        
def generate(X_train, vectorizer_X, parameters,dataset_name):
    vocabulary = vectorizer_X.vocabulary_
    number_of_features = len(vocabulary)
    
    all_embeddings = {}
    for target_word, word_id in tqdm.tqdm(vocabulary.items(), desc="Generating embeddings", unit="word"):
        single_output_active = np.empty(1, dtype=np.uint32)
        single_output_active[0] = word_id

        tm = TMAutoEncoder(
            number_of_clauses=parameters['clauses'],
            T=parameters['T'],
            s=parameters['s'],
            output_active=single_output_active,
            max_included_literals=3,
            accumulation=parameters['accumulation'],
            feature_negation=True,
            platform='CPU',  
            output_balancing=0.5
        )

        start_training = time()
        for e in range(parameters['epochs']):
            tm.fit(X_train, number_of_examples=parameters['number_of_examples'])
        stop_training = time()

        clauses_weights= tm.get_weights(0)
        training_time= stop_training - start_training
        
        literal_sums = np.zeros(number_of_features)
        literal_counts = np.zeros(number_of_features) 
        for j in range(parameters['clauses']):
            clause_weight = clauses_weights[j]
            if clause_weight > 0:  
                for i in range(tm.clause_bank.number_of_literals):
                    if i < number_of_features:
                        literal_sums[i] += tm.get_ta_state(j, i, the_class=1, polarity=1)
                        literal_counts[i] += 1
                    else:
                        literal_sums[i - number_of_features] -= tm.get_ta_state(j, i, the_class=1, polarity=1)
                        literal_counts[i - number_of_features] += 1

        non_zero_counts = literal_counts > 0
        embedding = np.zeros(number_of_features)
        embedding[non_zero_counts] = (literal_sums[non_zero_counts] / literal_counts[non_zero_counts]).astype(int)
        all_embeddings[word_id] = embedding
        
    corpus_name = "omnitm"
    model_path = os.path.join("data",dataset_name, f"{corpus_name}.model")
    np.save(model_path, all_embeddings)
    print(f"Model saved to {model_path}")