from tqdm import tqdm
import random
import numpy as np
from numpy import dot
from numpy.linalg import norm

def cosine_similarity(vec1, vec2):
    return dot(vec1, vec2) / (norm(vec1) * norm(vec2))

def most_similar_glove(word, glove_vectors, topn=5):
    if word not in glove_vectors:
        return []
    word_vector = glove_vectors[word]
    similarities = {}
    for other_word, other_vector in glove_vectors.items():
        if other_word != word:
            similarities[other_word] = cosine_similarity(word_vector, other_vector)
    # Sort by similarity score
    sorted_similarities = sorted(similarities.items(), key=lambda item: item[1], reverse=True)
    # Most similar words (highest cosine similarity)
    most_similar_words = sorted_similarities[:topn]
    # Least similar words (lowest cosine similarity)
    most_dissimilar_words = sorted_similarities[-topn:]
    return most_similar_words, most_dissimilar_words

def perturb(X_train_text, y_train, model_path, percent_to_change=5):
    glove_vectors = {}
    with open(model_path, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            vector = np.asarray(values[1:], dtype='float32')
            glove_vectors[word] = vector

    X_train_perturbed = []
    
    for r, doc in enumerate(tqdm(X_train_text)):
        label = y_train[r]
        tokens = doc.split()
        num_words_to_change = len(tokens) * (percent_to_change / 100)
        words_changed = 0
        indices_to_change = set(random.sample(range(len(tokens)), int(num_words_to_change)))
    
        new_tokens = []
        for i, word in enumerate(tokens):
            if i in indices_to_change and word in glove_vectors:
                if label == 1:
                    similar_words, _ = most_similar_glove(word, glove_vectors, topn=5)
                    if similar_words:
                        chosen_word = random.choice(similar_words)[0]  # Choose one similar word randomly
                    else:
                        chosen_word = word
                else:
                    _, dissimilar_words = most_similar_glove(word, glove_vectors, topn=5)
                    if dissimilar_words:
                        chosen_word = random.choice(dissimilar_words)[0]
                    else:
                        chosen_word = word

                new_tokens.append(chosen_word)
                words_changed += 1
            else:
                new_tokens.append(word)
    
            if words_changed >= num_words_to_change:
                break
    
        X_train_perturbed.append(' '.join(new_tokens))
        
    return X_train_perturbed