from tqdm import tqdm
import random
import pickle
from tqdm import tqdm
import os

def precompute_alternatives(vectorizer_X, token_embeddings):
    """Precompute and cache alternatives for all words based on their embeddings and labels."""
    import random
    import numpy as np

    alternatives_cache = {}

    for word, id in tqdm(vectorizer_X.vocabulary_.items(), desc="Precomputing alternatives"):
        embedding = token_embeddings.get(id, None)
        if embedding is None:
            continue

        # Convert embedding to a NumPy array for faster operations
        embedding = np.array(embedding)

        # Compute alternatives for label 1 (descending order)
        top_indices_label_1 = np.argpartition(-embedding, 400)[:400]  # Get indices of top 400 largest values
        top_indices_label_1 = top_indices_label_1[np.argsort(-embedding[top_indices_label_1])]  # Sort top 400
        top_features_label_1 = [vectorizer_X.get_feature_names_out()[idx] for idx in top_indices_label_1[:5]]
        selected_feature_label_1 = random.choice(top_features_label_1)  # Randomly select one from the top 5

        # Compute alternatives for label 0 (ascending order)
        top_indices_label_0 = np.argpartition(embedding, 400)[:400]  # Get indices of top 400 smallest values
        top_indices_label_0 = top_indices_label_0[np.argsort(embedding[top_indices_label_0])]  # Sort top 400
        top_features_label_0 = [vectorizer_X.get_feature_names_out()[idx] for idx in top_indices_label_0[:5]]
        selected_feature_label_0 = random.choice(top_features_label_0)  # Randomly select one from the top 5

        # Cache the alternatives for both labels
        alternatives_cache[word] = {
            1: selected_feature_label_1,
            0: selected_feature_label_0
        }

    return alternatives_cache
        
def perturb(X_train_text, y_train, model_path, percent_to_change=5):
    X_train_perturbed = []
    if not os.path.exists(model_path):
        print(f"Model file not found: {model_path}")
        return X_train_perturbed
    token_embeddings = pickle.load(open(model_path, 'rb'))
    
    parent_dir = os.path.dirname(model_path)
    vectorize_path = os.path.join(parent_dir, "vectorizer_X.pickle")
    if not os.path.exists(vectorize_path):
        print(f"Vectorizer file not found: {vectorize_path}")
        return X_train_perturbed
    vectorizer_X = pickle.load(open(vectorize_path, 'rb'))
    
    alternatives_cache_path = os.path.join(parent_dir, 'omnitm_alternatives_cache.pickle')
    if not os.path.exists(alternatives_cache_path):
        print("Precomputing alternatives for all words...")
        alternatives_cache = precompute_alternatives(vectorizer_X, token_embeddings)
        with open(alternatives_cache_path, 'wb') as f:
            pickle.dump(alternatives_cache, f)
    else:
        alternatives_cache = pickle.load(open(alternatives_cache_path, 'rb'))

    for r, doc in enumerate(tqdm(X_train_text)):
        label = y_train[r]
        tokens = doc.split()
        num_words_to_change = len(tokens) * (percent_to_change / 100)
        words_changed = 0
        indices_to_change = set(random.sample(range(len(tokens)), int(num_words_to_change)))
    
        for i in range(len(tokens)):
            if i in indices_to_change:
                original_word = tokens[i]
                if original_word in alternatives_cache and label in alternatives_cache[original_word]:
                    augmented_word = alternatives_cache[original_word][label]
                    if augmented_word != original_word:
                        tokens[i] = augmented_word
                        words_changed += 1
    
            if words_changed >= num_words_to_change:
                break
    
        X_train_perturbed.append(' '.join(tokens))
        
    return X_train_perturbed