# 1. Setup and Imports
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.preprocessing import normalize
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split
import numpy as np
import itertools
import pdb
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
# Download NLTK data (if you haven't already)
nltk.download('stopwords')
nltk.download('wordnet')

# ---

import argparse

parser = argparse.ArgumentParser()
args = parser.parse_args()


# 2. Data Loading and Preprocessing
def preprocess_text(text):
    """Cleans and prepares a single text document."""
    # Remove special characters and digits
    text = re.sub(r'[^a-zA-Z\s]', '', text, re.I|re.A)
    # Convert to lowercase
    text = text.lower()
    # Tokenize (split into words)
    tokens = text.split()
    # Remove stopwords and lemmatize
    stop_words = set(stopwords.words('english'))
    lemmatizer = WordNetLemmatizer()
    lemmatized_tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words]
    return " ".join(lemmatized_tokens)

df = pd.read_csv('review_model_pred_and_explanation.csv')


# Apply the preprocessing function to your text column
df['processed_text'] = df['review_text'].apply(preprocess_text)
print("--- Preprocessed Text ---")
print(df[['review_text', 'processed_text']].head())

# ---
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

texts = np.array(df['processed_text'])
embeddings = model.encode(texts)
# ---

# 4. Clustering


def scoring_rule(a, s):
    # return 1 - (a - s)**2
    return (a > 0.5) == s

def test_overfitting(cluster_label_df, tolerance=5e-2):
    train_df, test_df = train_test_split(cluster_label_df, test_size=0.3, random_state=42)
    train_df = train_df[["cluster_id", "label"]]
    test_df = test_df[["cluster_id", "label"]]
    prior = train_df["label"].mean()
    posterior = train_df.groupby("cluster_id").agg({"label": "mean"}).to_dict(orient="index")

    train_performance = train_df.apply(lambda x: scoring_rule(posterior[x[0]]["label"], x[1]), axis=1).mean()
    test_performance = test_df.apply(lambda x: scoring_rule((posterior[x[0]]["label"] if x[0] in posterior else prior), x[1]), axis=1).mean()
    return train_performance > test_performance + tolerance, train_performance, test_performance

def validation_find_optimal_n_clusters(embeddings, label, n_clusters_range=(2, 30, 2)):
    best_n_clusters = n_clusters_range[0]
    best_train_performance = 0
    for n_clusters in tqdm(range(n_clusters_range[0], n_clusters_range[1], n_clusters_range[2]), desc="Searching for optimal n clusters"):
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        cluster_ids = kmeans.fit_predict(embeddings)
        overfitting, train_performance, test_performance = test_overfitting(pd.DataFrame({"cluster_id": cluster_ids, "label": label}))
        print(f"N clusters: {n_clusters}, Overfitting: {overfitting}, Train performance: {train_performance}, Test performance: {test_performance}")
        # df["cluster_id"] = cluster_ids
        # mutual_info_score = mutual_info_classif(cluster_ids.reshape(-1, 1), labels_beer)[0]
        if not overfitting and train_performance > best_train_performance:
            best_n_clusters = n_clusters
            best_train_performance = train_performance
    return best_n_clusters, best_train_performance



# Initialize and run the KMeans algorithm
# km = KMeans(n_clusters=num_clusters, random_state=42, n_init=10)
# km.fit(embeddings)


df["word_in_one_sentence"] = df['word'].apply(lambda x: " ".join(eval(x)))

embeddings_explanation = model.encode(df["word_in_one_sentence"])
best_n_clusters, best_train_performance = validation_find_optimal_n_clusters(embeddings_explanation,  df["actual_label"])
print(f"Best n clusters for explanation: {best_n_clusters}, Best train performance: {best_train_performance}")
km = KMeans(n_clusters=best_n_clusters, random_state=42)
km.fit(embeddings_explanation)
df['cluster_heatmap'] = km.predict(embeddings_explanation)

embeddings_random = model.encode(df["random_10_words"].apply(lambda x: " ".join(eval(x))))
best_n_clusters, best_train_performance = validation_find_optimal_n_clusters(embeddings_random,  df["actual_label"])
print(f"Best n clusters for random: {best_n_clusters}, Best train performance: {best_train_performance}")
km = KMeans(n_clusters=best_n_clusters, random_state=42)
km.fit(embeddings_random)
df["cluster_random"] = km.predict(embeddings_random)

best_n_clusters, best_train_performance = validation_find_optimal_n_clusters(embeddings[df["nn_pos_idx"]],  df["actual_label"])
print(f"Best n clusters for nn pos: {best_n_clusters}, Best train performance: {best_train_performance}")
km = KMeans(n_clusters=best_n_clusters, random_state=42)
km.fit(embeddings[df["nn_pos_idx"]])
df["nn_pos_cluster"] = km.predict(embeddings[df["nn_pos_idx"]])

best_n_clusters, best_train_performance = validation_find_optimal_n_clusters(embeddings[df["nn_neg_idx"]],  df["actual_label"])
print(f"Best n clusters for nn neg: {best_n_clusters}, Best train performance: {best_train_performance}")
km = KMeans(n_clusters=best_n_clusters, random_state=42)
km.fit(embeddings[df["nn_neg_idx"]])
df["nn_neg_cluster"] = km.predict(embeddings[df["nn_neg_idx"]])

unique_heatmap_clusters = np.unique(df["cluster_heatmap"])
unique_random_clusters = np.unique(df["cluster_random"])
unique_nn_pos_clusters = np.unique(df["nn_pos_cluster"])
unique_nn_neg_clusters = np.unique(df["nn_neg_cluster"])
unique_yai = np.unique(df["predicted_label"])

n_cluster_each_group = 2
km_group = KMeans(n_clusters=n_cluster_each_group, random_state=42)
count_group_w_enough_data = 0

for i, d in tqdm(enumerate(itertools.product(unique_yai, unique_heatmap_clusters, unique_random_clusters, unique_nn_pos_clusters, unique_nn_neg_clusters)), total=len(unique_yai) * len(unique_heatmap_clusters) * len(unique_random_clusters) * len(unique_nn_pos_clusters) * len(unique_nn_neg_clusters), desc="Clustering groups"):
    id_yai, id_heatmap, id_random, id_nn_pos, id_nn_neg = d
    df_group = df[(df["predicted_label"] == id_yai) & (df["cluster_heatmap"] == id_heatmap) & (df["cluster_random"] == id_random) & (df["nn_pos_cluster"] == id_nn_pos) & (df["nn_neg_cluster"] == id_nn_neg)]
    if len(df_group) > 2 * n_cluster_each_group:
        texts_d = np.array(df_group["processed_text"])
        embeddings_d = model.encode(texts_d, show_progress_bar=False)
        cluster_ids_d = km_group.fit_predict(embeddings_d)
        df.loc[df_group.index, "x_cluster_id"] = cluster_ids_d + i * n_cluster_each_group
        count_group_w_enough_data += 1
    else:
        df.loc[df_group.index, "x_cluster_id"] = i * n_cluster_each_group
print(f"Number of groups with enough data: {count_group_w_enough_data}")

df.to_csv(f'review_model_pred_and_explanation_cluster.csv', index=False)
