import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import random
from tqdm import tqdm

elmo_model = hub.load("https://tfhub.dev/google/elmo/3")

# Function to compute ELMo embeddings for a sentence
def get_elmo_embeddings(sentences):
    embeddings = elmo_model.signatures['default'](tf.constant(sentences))["elmo"]
    return embeddings
    
def get_word_embeddings(doc):
    tokens = doc.split()
    embeddings = get_elmo_embeddings(tokens).numpy()
    return dict(zip(tokens, embeddings))

# Function to compute cosine similarity
def cosine_similarity(vec1, vec2):
    vec1 = np.squeeze(vec1)  # This will convert (1, 1024) to (1024,)
    vec2 = np.squeeze(vec2)  # This will convert (1, 1024) to (1024,)
    dot_product = np.dot(vec1, vec2)
    norm_vec1 = np.linalg.norm(vec1)
    norm_vec2 = np.linalg.norm(vec2)
    return dot_product / (norm_vec1 * norm_vec2)

def perturb(X_train_text, y_train, model_path, percent_to_change=5):
    X_train_perturbed = []
    
    for r, doc in enumerate(tqdm(X_train_text)):
        label = y_train[r]
        tokens = doc.split()
        num_words_to_change = len(tokens) * (percent_to_change / 100)
        words_changed = 0
        indices_to_change = set(random.sample(range(len(tokens)), int(num_words_to_change)))
        word_embeddings = get_word_embeddings(doc)
        
        new_tokens = []
        for i, word in enumerate(tokens):
            if i in indices_to_change:
                word_embedding = word_embeddings[word]
                
                # Compute cosine similarities for the word with all other words in the document
                similarities = []
                for j, other_word in enumerate(tokens):
                    if j != i:  # Do not compare with itself
                        other_word_embedding = word_embeddings[other_word]
                        similarity = cosine_similarity(word_embedding, other_word_embedding)
                        similarities.append((similarity, other_word))

                # Sort by similarity and choose based on label
                if label == 1:  # For positive labels, get the most similar
                    most_similar = sorted(similarities, key=lambda x: x[0], reverse=True)[:5]
                    chosen_word = random.choice([w for _, w in most_similar])
                else:  # For negative labels, get the most dissimilar
                    least_similar = sorted(similarities, key=lambda x: x[0])[:5] 
                    chosen_word = random.choice([w for _, w in least_similar])

                new_tokens.append(chosen_word)
                words_changed += 1
            else:
                new_tokens.append(word)
    
            if words_changed >= num_words_to_change:
                break
    
        X_train_perturbed.append(' '.join(new_tokens))
        
    return X_train_perturbed