"""
Cluster texts from save_df_amazon.csv using embeddings from McAuley-Lab/Amazon-Reviews-2023
and add cluster_id column to the CSV.
"""

import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
import warnings
from datasets import load_dataset
import itertools
import pdb
import re
from tqdm import tqdm
from sklearn.feature_selection import mutual_info_classif
from sklearn.model_selection import train_test_split
warnings.filterwarnings('ignore')
import json


def cluster_texts(json_path="task_amazon.json", output_path="save_df_amazon.csv", model_name="sentence-transformers/all-MiniLM-L6-v2", n_clusters=5):
    """
    Cluster texts using sentence transformers and add cluster_id column.
    
    Args:
        csv_path: Path to input CSV file
        output_path: Path to save output CSV file
        model_name: Sentence transformer model to use for embeddings
        n_clusters: Number of clusters for KMeans
    """
    # dataset = pd.read_csv(dataset_path)
    # dataset = pd.read_csv(dataset_path)
    # texts = np.array([text for text in dataset["text"] if len(text.split(" ")) > 10])
    # labels_beer = np.array([dataset.loc[j, "label"] for j in range(len(dataset)) if len(dataset["text"][j].split(" ")) > 10])
    # print(f"Filtered dataset to {len(texts)} rows")
    # pdb.set_trace()



    # Use sentence-transformers for embeddings
    model = SentenceTransformer(model_name)
    # embeddings = model.encode(texts, show_progress_bar=True)


    task_json = json.loads(open(json_path).read())
    texts = [j["X"] for j in task_json]
    label = [j["Y"] for j in task_json]
    pred = [j["pred"] for j in task_json]
    conf = [j["conf"] for j in task_json]
    conf_bin = [int(c * 10) / 10 for c in conf]
    medium_conf = np.quantile(conf, 0.5)

    def extract_explanation(explanation_text, pred, conf, method="top1"):
        if method == "top2":
            pattern = re.compile(
                r"<span class=class[01]>(.*?)</span>",
                re.DOTALL
            )
        elif method == "expert":
            pattern = re.compile(
                r"<span class='class[01]'>(.*?)</span>",
                re.DOTALL
            )
        elif method == "top1":
            pattern = re.compile(
                f"<span class=class{pred}>(.*?)</span>",
                re.DOTALL
            )
        elif method == "adaptive":
            if conf > medium_conf:
                return extract_explanation(explanation_text, pred, conf, "top1")
            else:
                return extract_explanation(explanation_text, pred, conf, "top2")
        else:
            raise ValueError(f"Invalid method: {method}")

        return [m.group(1) for m in pattern.finditer(explanation_text)]
    
    explanation_expert = [" ".join(extract_explanation(j["expert"], j["pred"], j["conf"], "expert")) for j in task_json]
    explanation_top1 = [" ".join(extract_explanation(j["system"], j["pred"], j["conf"], "top1")) for j in task_json]
    explanation_top2 = [" ".join(extract_explanation(j["system"], j["pred"], j["conf"], "top2")) for j in task_json]
    explanation_adaptive = [" ".join(extract_explanation(j["system"], j["pred"], j["conf"], "adaptive")) for j in task_json]

    embeddings = model.encode(texts, show_progress_bar=True)
    explanation_expert_embeddings = model.encode(explanation_expert, show_progress_bar=True)
    explanation_top1_embeddings = model.encode(explanation_top1, show_progress_bar=True)
    explanation_top2_embeddings = model.encode(explanation_top2, show_progress_bar=True)
    explanation_adaptive_embeddings = model.encode(explanation_adaptive, show_progress_bar=True)

    def scoring_rule(a, s):
        # return 1 - (a - s)**2
        return (a > 0.5) == s

    def test_overfitting(cluster_label_df, tolerance=5e-2):
        train_df, test_df = train_test_split(cluster_label_df, test_size=0.3, random_state=42)
        train_df = train_df[["cluster_id", "label"]]
        test_df = test_df[["cluster_id", "label"]]
        prior = train_df["label"].mean()
        posterior = train_df.groupby("cluster_id").agg({"label": "mean"}).to_dict(orient="index")

        train_performance = train_df.apply(lambda x: scoring_rule(posterior[x[0]]["label"], x[1]), axis=1).mean()
        test_performance = test_df.apply(lambda x: scoring_rule((posterior[x[0]]["label"] if x[0] in posterior else prior), x[1]), axis=1).mean()
        return train_performance > test_performance + tolerance, train_performance, test_performance

    def validation_find_optimal_n_clusters(embeddings, label, n_clusters_range=(2, 30, 2)):
        best_n_clusters = n_clusters_range[0]
        best_train_performance = 0
        for n_clusters in tqdm(range(n_clusters_range[0], n_clusters_range[1], n_clusters_range[2]), desc="Searching for optimal n clusters"):
            kmeans = KMeans(n_clusters=n_clusters, random_state=42)
            cluster_ids = kmeans.fit_predict(embeddings)
            overfitting, train_performance, test_performance = test_overfitting(pd.DataFrame({"cluster_id": cluster_ids, "label": label}))
            print(f"N clusters: {n_clusters}, Overfitting: {overfitting}, Train performance: {train_performance}, Test performance: {test_performance}")
            # df["cluster_id"] = cluster_ids
            # mutual_info_score = mutual_info_classif(cluster_ids.reshape(-1, 1), labels_beer)[0]
            if not overfitting and train_performance > best_train_performance:
                best_n_clusters = n_clusters
                best_train_performance = train_performance
        return best_n_clusters, best_train_performance

    # best_n_clusters, best_train_performance = validation_find_optimal_n_clusters()
    # print(f"Best n clusters: {best_n_clusters}, Best train performance: {best_train_performance}")

    # kmeans = KMeans(n_clusters=best_n_clusters, random_state=42)
    # cluster_ids = kmeans.fit_predict(embeddings)

    # Add cluster_id column to the DataFrame and save
    # df["cluster_id"] = kmeans.predict(embeddings)
    print("Finding optimal n clusters for explanation expert...")
    best_n_clusters_expert, best_train_performance_expert = validation_find_optimal_n_clusters(explanation_expert_embeddings, label)
    kmeans = KMeans(n_clusters=best_n_clusters_expert, random_state=42)
    explanation_expert_cluster_id = kmeans.fit_predict(explanation_expert_embeddings)
    print(f"Best n clusters for explanation expert: {best_n_clusters_expert}, Best train performance for explanation expert: {best_train_performance_expert}")
    print("Finding optimal n clusters for explanation top1...")
    best_n_clusters_top1, best_train_performance_top1 = validation_find_optimal_n_clusters(explanation_top1_embeddings, label)
    kmeans = KMeans(n_clusters=best_n_clusters_top1, random_state=42)
    explanation_top1_cluster_id = kmeans.fit_predict(explanation_top1_embeddings)
    print(f"Best n clusters for explanation top1: {best_n_clusters_top1}, Best train performance for explanation top1: {best_train_performance_top1}")
    print("Finding optimal n clusters for explanation top2...")
    best_n_clusters_top2, best_train_performance_top2 = validation_find_optimal_n_clusters(explanation_top2_embeddings, label)
    kmeans = KMeans(n_clusters=best_n_clusters_top2, random_state=42)
    explanation_top2_cluster_id = kmeans.fit_predict(explanation_top2_embeddings)
    print(f"Best n clusters for explanation top2: {best_n_clusters_top2}, Best train performance for explanation top2: {best_train_performance_top2}")
    print("Finding optimal n clusters for explanation adaptive...")
    best_n_clusters_adaptive, best_train_performance_adaptive = validation_find_optimal_n_clusters(explanation_adaptive_embeddings, label)
    kmeans = KMeans(n_clusters=best_n_clusters_adaptive, random_state=42)
    explanation_adaptive_cluster_id = kmeans.fit_predict(explanation_adaptive_embeddings)
    print(f"Best n clusters for explanation adaptive: {best_n_clusters_adaptive}, Best train performance for explanation adaptive: {best_train_performance_adaptive}")

    cluster_n_clusters = 2
    kmeans = KMeans(n_clusters=cluster_n_clusters, random_state=42)
    pred_ids = np.unique(pred)
    conf_bin_ids = np.unique(conf_bin)
    explanation_expert_cluster_ids_unique = np.unique(explanation_expert_cluster_id)
    explanation_top1_cluster_ids_unique = np.unique(explanation_top1_cluster_id)
    explanation_top2_cluster_ids_unique = np.unique(explanation_top2_cluster_id)
    explanation_adaptive_cluster_ids_unique = np.unique(explanation_adaptive_cluster_id)

    x_cluster_ids = np.zeros(len(texts))

    count_group_with_enough_data = 0
    for i, d in tqdm(enumerate(itertools.product(pred_ids, conf_bin_ids, explanation_expert_cluster_ids_unique, explanation_top1_cluster_ids_unique, explanation_top2_cluster_ids_unique, explanation_adaptive_cluster_ids_unique)), total=len(pred_ids) * len(conf_bin_ids) * len(explanation_expert_cluster_ids_unique) * len(explanation_top1_cluster_ids_unique) * len(explanation_top2_cluster_ids_unique) * len(explanation_adaptive_cluster_ids_unique), desc="Fitting KMeans for each cluster"):
        pred_id_i, conf_bin_id_i, explanation_expert_cluster_id_i, explanation_top1_cluster_id_i, explanation_top2_cluster_id_i, explanation_adaptive_cluster_id_i = d
        texts_i =np.array(texts)[np.where((pred == pred_id_i) & (conf_bin == conf_bin_id_i) & (explanation_expert_cluster_id == explanation_expert_cluster_id_i) & (explanation_top1_cluster_id == explanation_top1_cluster_id_i) & (explanation_top2_cluster_id == explanation_top2_cluster_id_i) & (explanation_adaptive_cluster_id == explanation_adaptive_cluster_id_i))]
        if len(texts_i) > 2 * cluster_n_clusters:
            embeddings_i = model.encode(texts_i, show_progress_bar=False)
            cluster_ids_i = kmeans.fit_predict(embeddings_i)
            x_cluster_ids[np.where((pred == pred_id_i) & (conf_bin == conf_bin_id_i) & (explanation_expert_cluster_id == explanation_expert_cluster_id_i) & (explanation_top1_cluster_id == explanation_top1_cluster_id_i) & (explanation_top2_cluster_id == explanation_top2_cluster_id_i) & (explanation_adaptive_cluster_id == explanation_adaptive_cluster_id_i))] = cluster_ids_i + i * cluster_n_clusters
            count_group_with_enough_data += 1
        else:
            x_cluster_ids[np.where((pred == pred_id_i) & (conf_bin == conf_bin_id_i) & (explanation_expert_cluster_id == explanation_expert_cluster_id_i) & (explanation_top1_cluster_id == explanation_top1_cluster_id_i) & (explanation_top2_cluster_id == explanation_top2_cluster_id_i) & (explanation_adaptive_cluster_id == explanation_adaptive_cluster_id_i))] = i * cluster_n_clusters
    print(f"Count of groups with enough data: {count_group_with_enough_data}")

    df = pd.DataFrame({"text": texts, "x_cluster_id": x_cluster_ids, "pred": pred, "conf": conf, "explanation_expert": explanation_expert, "explanation_top1": explanation_top1, "explanation_top2": explanation_top2, "explanation_adaptive": explanation_adaptive, "explanation_expert_cluster_id": explanation_expert_cluster_id, "explanation_top1_cluster_id": explanation_top1_cluster_id, "explanation_top2_cluster_id": explanation_top2_cluster_id, "explanation_adaptive_cluster_id": explanation_adaptive_cluster_id})
    df.to_csv(output_path, index=False)
    
    return df


if __name__ == "__main__":
    # Cluster the texts
    df = cluster_texts(
        json_path="task_amazon.json",
        output_path="save_df_amazon.csv",
        model_name="sentence-transformers/all-MiniLM-L6-v2"
    )
