import os
import re
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from nltk.corpus import stopwords
import nltk
import tqdm
import pickle
import sys
from cuml.cluster import KMeans as cuKMeans
import cupy as cp

env_path = os.getcwd()
BASE_PATH = os.path.abspath(os.path.join(env_path, '../'))
print(f"Base path: {BASE_PATH}")
sys.path.append(BASE_PATH)
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

class TextClusteringTM:
    def __init__(self, model_path):
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file not found: {model_path}")
        self.token_embeddings = pickle.load(open(model_path, 'rb'))
        
        parent_dir = os.path.dirname(model_path)
        vectorize_path = os.path.join(parent_dir, "vectorizer_X.pickle")
        if not os.path.exists(vectorize_path):
            print(f"Vectorizer file not found: {vectorize_path}")
            raise FileNotFoundError(f"Vectorizer file not found: {vectorize_path}")
        self.vectorizer_X = pickle.load(open(vectorize_path, 'rb'))
        self.vocabulary_size = len(self.vectorizer_X.vocabulary_)  
    
    def preprocess_text(self, text):
        text = re.sub(r'\W+', ' ', text)  # Remove non-word characters
        tokens = text.lower().split()  # Tokenize and lowercase
        tokens = [word for word in tokens if word not in stop_words and word in self.vectorizer_X.vocabulary_]
        return tokens

    def get_word_embedding(self, word):
        """Retrieve the precomputed embedding for a given word."""
        id = self.vectorizer_X.vocabulary_.get(word, None)
        if id is None or id not in self.token_embeddings:
            return np.zeros(self.vocabulary_size)
        return self.token_embeddings[id]

    def get_document_embedding_with_cache(self, doc):
        tokens = self.preprocess_text(doc)
        if not tokens:
            return np.zeros(self.vocabulary_size)
        
        embeddings = []
        for token in tokens:
            word_embedding = self.get_word_embedding(token)
            embeddings.append(word_embedding)

        return np.mean(embeddings, axis=0)

    def compute_embeddings(self, documents):
        print("Loading precomputed token embeddings from cache...")
        
        embeddings = []
        for doc in tqdm.tqdm(documents, desc="Computing document embeddings"):
            doc_embedding = self.get_document_embedding_with_cache(doc)
            embeddings.append(doc_embedding)
        
        document_embeddings = np.array(embeddings)
        return document_embeddings

    def cluster_documents(self, embeddings, true_labels):
        num_clusters = len(set(true_labels))  # Number of true categories
        print(f"Number of clusters: {num_clusters}")
        kmeans = KMeans(n_clusters=num_clusters, random_state=42)
        predicted_labels = kmeans.fit_predict(embeddings)
        print("Clustering complete.")
        nmi = normalized_mutual_info_score(true_labels, predicted_labels)
        print(f"NMI Complete.")
        ari = adjusted_rand_score(true_labels, predicted_labels)
        print(f"ARI Complete.")
        return nmi, ari

    def cluster_documents_gpu(self, embeddings, true_labels):
        num_clusters = len(set(true_labels))
        print(f"Number of clusters: {num_clusters}")
        
        # Make sure embeddings are on GPU (cupy array)
        if not isinstance(embeddings, cp.ndarray):
            embeddings = cp.asarray(embeddings)
        
        kmeans = cuKMeans(n_clusters=num_clusters, random_state=42)
        predicted_labels = kmeans.fit_predict(embeddings)
        
        # Move labels back to CPU for sklearn metrics
        predicted_labels_cpu = predicted_labels.get() if hasattr(predicted_labels, "get") else predicted_labels
        
        print("Clustering complete.")
        nmi = normalized_mutual_info_score(true_labels, predicted_labels_cpu)
        print(f"NMI Complete.")
        ari = adjusted_rand_score(true_labels, predicted_labels_cpu)
        print(f"ARI Complete.")
        
        return nmi, ari