"""
Cluster generated_text, selected_features_with_interpretations, and input_text columns
from toxic_training_evaluation.csv using embeddings and add cluster_id columns.
"""

import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
import warnings
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import itertools
warnings.filterwarnings('ignore')


def cluster_toxic_training(csv_path="toxic_training_evaluation.csv", output_path="toxic_training_evaluation_clustered.csv", model_name="sentence-transformers/all-MiniLM-L6-v2", n_clusters_range=(2, 10, 1)):
    """
    Cluster generated_text, selected_features_with_interpretations, and input_text columns
    from the toxic training evaluation CSV using a similar pipeline as cluster_texts_beer.py
    
    Args:
        csv_path: Path to input CSV file
        output_path: Path to save output CSV file
        model_name: Sentence transformer model to use for embeddings
        n_clusters_range: Range tuple (start, end, step) for finding optimal n_clusters
    """
    df = pd.read_csv(csv_path)
    model = SentenceTransformer(model_name)

    df.loc[:, "toxic_model"] = df["assigned_model"] == "transformer_toxic"
    # Determine target column (use is_toxic_training_data_with_sae if available)
    target_column = "toxic_model"
    
    if target_column is None:
        print("Warning: No target column found. Using default n_clusters.")
    else:
        # Convert boolean to numeric if needed
        if df[target_column].dtype == bool:
            label = df[target_column].astype(int).values
        else:
            label = df[target_column].values

    def scoring_rule(a, s):
        # return 1 - (a - s)**2
        return (a > 0.5) == s

    def test_overfitting(cluster_label_df, tolerance=5e-3):
        train_df, test_df = train_test_split(cluster_label_df, test_size=0.3, random_state=42)
        train_df = train_df[["cluster_id", "label"]]
        test_df = test_df[["cluster_id", "label"]]
        prior = train_df["label"].mean()
        posterior = train_df.groupby("cluster_id").agg({"label": "mean"}).to_dict(orient="index")

        train_performance = train_df.apply(lambda x: scoring_rule(posterior[x[0]]["label"], x[1]), axis=1).mean()
        test_performance = test_df.apply(lambda x: scoring_rule((posterior[x[0]]["label"] if x[0] in posterior else prior), x[1]), axis=1).mean()
        return train_performance > test_performance + tolerance, train_performance, test_performance

    def validation_find_optimal_n_clusters(embeddings, label, n_clusters_range=(2, 10, 1)):
        best_n_clusters = n_clusters_range[0]
        best_train_performance = 0
        for n_clusters in tqdm(range(n_clusters_range[0], n_clusters_range[1], n_clusters_range[2]), desc="Searching for optimal n clusters"):
            kmeans = KMeans(n_clusters=n_clusters, random_state=42)
            cluster_ids = kmeans.fit_predict(embeddings)
            overfitting, train_performance, test_performance = test_overfitting(pd.DataFrame({"cluster_id": cluster_ids, "label": label}))
            print(f"N clusters: {n_clusters}, Overfitting: {overfitting}, Train performance: {train_performance:.4f}, Test performance: {test_performance:.4f}")
            if not overfitting and train_performance > best_train_performance:
                best_n_clusters = n_clusters
                best_train_performance = train_performance
        return best_n_clusters, best_train_performance

    # Cluster generated_text
    print("Clustering generated_text...")
    generated_texts = np.array(df["generated_text"].fillna(""))
    generated_embeddings = model.encode(generated_texts, show_progress_bar=True)
    
    if target_column is not None:
        best_n_clusters_gen, best_train_performance_gen = validation_find_optimal_n_clusters(
            generated_embeddings, label, n_clusters_range
        )
        print(f"Best n clusters for generated_text: {best_n_clusters_gen}, Best train performance: {best_train_performance_gen:.4f}")
    else:
        best_n_clusters_gen = 5  # Default value if no target column
        print(f"Using default n_clusters={best_n_clusters_gen} for generated_text (no target column)")
    
    kmeans_gen = KMeans(n_clusters=best_n_clusters_gen, random_state=42)
    generated_cluster_ids = kmeans_gen.fit_predict(generated_embeddings)
    df["generated_text_cluster_id"] = generated_cluster_ids

    # Cluster selected_features_with_interpretations
    print("Clustering selected_features_with_interpretations...")
    
    if target_column is not None:
        for number_of_features in range(1, 6):
            df_d = df.loc[df["number_of_features"] == number_of_features, :]
            features_texts = np.array(df_d["selected_features_with_interpretations"].fillna(""))
            label_d = df_d[target_column].values
            features_embeddings = model.encode(features_texts, show_progress_bar=True)
            best_n_clusters_feat, best_train_performance_feat = validation_find_optimal_n_clusters(
                features_embeddings, label_d, n_clusters_range
            )
            print(f"Best n clusters for top-{number_of_features} features: {best_n_clusters_feat}, Best train performance: {best_train_performance_feat:.4f}")
            kmeans_feat = KMeans(n_clusters=best_n_clusters_feat, random_state=42)
            features_cluster_ids = kmeans_feat.fit_predict(features_embeddings)
            df.loc[df_d.index, "features_with_interpretations_cluster_id"] = features_cluster_ids
    else:
        best_n_clusters_feat = 5  # Default value if no target column
        print(f"Using default n_clusters={best_n_clusters_feat} for selected_features_with_interpretations (no target column)")
    

    # Cluster input_text
    print("Clustering input_text...")
    input_texts = np.array(df["input_text"].fillna(""))
    input_embeddings = model.encode(input_texts, show_progress_bar=True)

    unique_number_of_features = np.unique(df["number_of_features"])
    unique_features_cluster_ids = np.unique(df["features_with_interpretations_cluster_id"])
    unique_generated_cluster_ids = np.unique(generated_cluster_ids)

    unique_human_decision_labels = np.unique(df["is_toxic_training_data_without_sae"])

    n_clusters_group = 50
    count_group_with_enough_data = 0

    for i, j in tqdm(enumerate(itertools.product(unique_number_of_features, unique_features_cluster_ids, unique_generated_cluster_ids, unique_human_decision_labels)), total=len(unique_number_of_features) * len(unique_features_cluster_ids) * len(unique_generated_cluster_ids) * len(unique_human_decision_labels), desc="Clustering input_text"):
        number_of_features, features_cluster_id, generated_cluster_id, human_decision_label = j
        df_d = df.loc[(df["number_of_features"] == number_of_features) & (df["features_with_interpretations_cluster_id"] == features_cluster_id) & (df["generated_text_cluster_id"] == generated_cluster_id) & (df["is_toxic_training_data_without_sae"] == human_decision_label), :]
        if len(df_d) > n_clusters_group:
            texts_d = np.array(df_d["input_text"].fillna(""))
            embeddings_d = model.encode(texts_d, show_progress_bar=False)
            kmeans = KMeans(n_clusters=n_clusters_group, random_state=42)
            cluster_ids_d = kmeans.fit_predict(embeddings_d)
            df.loc[df_d.index, "input_text_cluster_id"] = cluster_ids_d + i * n_clusters_group
            count_group_with_enough_data += 1
        else:
            df.loc[df_d.index, "input_text_cluster_id"] = i * n_clusters_group
    print(f"Count of groups with enough data: {count_group_with_enough_data}")
    # Save the clustered dataframe
    df[["input_text", "generated_text_cluster_id", "features_with_interpretations_cluster_id", "input_text_cluster_id", "id",
    "toxic_model", "assigned_model", "is_toxic_training_data_with_sae", "is_toxic_training_data_without_sae", 
    "number_of_features"]].to_csv(output_path, index=False)
    print(f"Clustered data saved to {output_path}")
    
    return df


if __name__ == "__main__":
    # Cluster the texts
    df = cluster_toxic_training(
        csv_path="toxic_training_evaluation.csv",
        output_path="toxic_training_evaluation_clustered.csv",
        model_name="sentence-transformers/all-MiniLM-L6-v2",
        n_clusters_range=(2, 20, 2)
    )
