import re
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from nltk.corpus import stopwords
import nltk

from cuml.cluster import KMeans as cuKMeans
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
import cupy as cp

# Ensure NLTK resources are downloaded
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

class TextClustering:
    def __init__(self, word_vectors):
        self.word_vectors = word_vectors
        # Determine vector size dynamically
        if isinstance(word_vectors, dict):
            # If word_vectors is a dictionary (e.g., GloVe), get size from the first vector
            self.vector_size = len(next(iter(word_vectors.values())))
        else:
            # If word_vectors is a Gensim model, use the vector_size attribute
            self.vector_size = word_vectors.vector_size

    def preprocess_text(self, text):
        text = re.sub(r'\W+', ' ', text)  # Remove non-word characters
        tokens = text.lower().split()  # Tokenize and lowercase
        tokens = [word for word in tokens if word not in stop_words and word in self.word_vectors]
        return tokens

    def get_document_embedding(self, doc):
        tokens = self.preprocess_text(doc)
        if not tokens:
            return np.zeros(self.vector_size)
        
        embeddings = [self.word_vectors[word] for word in tokens if word in self.word_vectors]
        return np.mean(embeddings, axis=0)

    def compute_embeddings(self, documents):
        return np.array([self.get_document_embedding(doc) for doc in documents])

    def cluster_documents(self, embeddings, true_labels):
        num_clusters = len(set(true_labels))  # Number of true categories
        kmeans = KMeans(n_clusters=num_clusters, random_state=42)
        predicted_labels = kmeans.fit_predict(embeddings)

        nmi = normalized_mutual_info_score(true_labels, predicted_labels)
        ari = adjusted_rand_score(true_labels, predicted_labels)

        return nmi, ari
    
    def cluster_documents_gpu(self, embeddings, true_labels):
        num_clusters = len(set(true_labels))
        print(f"Number of clusters: {num_clusters}")
        
        # Make sure embeddings are on GPU (cupy array)
        if not isinstance(embeddings, cp.ndarray):
            embeddings = cp.asarray(embeddings)
        
        kmeans = cuKMeans(n_clusters=num_clusters, random_state=42)
        predicted_labels = kmeans.fit_predict(embeddings)
        
        # Move labels back to CPU for sklearn metrics
        predicted_labels_cpu = predicted_labels.get() if hasattr(predicted_labels, "get") else predicted_labels
        
        print("Clustering complete.")
        nmi = normalized_mutual_info_score(true_labels, predicted_labels_cpu)
        print(f"NMI Complete.")
        ari = adjusted_rand_score(true_labels, predicted_labels_cpu)
        print(f"ARI Complete.")
        
        return nmi, ari