{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67ef2618-2761-4722-8215-db6c6eff7635",
   "metadata": {},
   "outputs": [],
   "source": [
    "from summaryCentroids import kLLMmeans, kNLPmeans, get_embeddings, summarize_cluster\n",
    "from experiment_utils import load_dataset, cluster_metrics, avg_closest_distance\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn_extra.cluster import KMedoids\n",
    "\n",
    "import numpy as np\n",
    "import json, pickle\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d82f2a4-d421-4cf1-8610-90710937b03a",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_iter = 120\n",
    "\n",
    "data_list = ['clinic','bank77','massive_I','massive_D','goemo']\n",
    "\n",
    "for data in data_list:\n",
    "    results_dict = {}\n",
    "    \n",
    "    with open(\"processed_data/data_\" + data + \".pkl\", \"rb\") as f:\n",
    "        data_dict = pickle.load(f)\n",
    "        \n",
    "    labels = data_dict['labels']\n",
    "    num_clusters = data_dict['num_clusters']   \n",
    "    documents = data_dict['documents']\n",
    "    text_features = data_dict['embeddings']\n",
    "    #oracle_summaries = data_dict['summaries']\n",
    "    prompt = data_dict['prompt']\n",
    "    text_type = data_dict['text_type']\n",
    "\n",
    "    with open(\"processed_data/data_sentences_\" + data + \".pkl\", \"rb\") as f:\n",
    "        sentences_dict = pickle.load(f)\n",
    "    \n",
    "    oracle_cluster_assignments = labels\n",
    "    \n",
    "    for emb_type in ['distilbert', 'openai', 'e5-large', 'sbert']:\n",
    "        \n",
    "        results_dict[emb_type] = {}\n",
    "        emb_data = text_features[emb_type]\n",
    "        text_sentences = sentences_dict[emb_type]\n",
    "        \n",
    "        #calculate oracle embeddings\n",
    "        oracle_clustered_embeddings = {i: [] for i in range(num_clusters)}\n",
    "        for embedding, cluster in zip(emb_data, oracle_cluster_assignments):\n",
    "            oracle_clustered_embeddings[cluster].append(embedding)\n",
    "        oracle_centroids = [np.mean(oracle_clustered_embeddings[i], axis=0) if oracle_clustered_embeddings[i] else None for i in range(num_clusters)]\n",
    "        #oracle_summary_embeddings = get_embeddings(oracle_summaries, emb_type = emb_type)\n",
    "        oracle_summary_embeddings = oracle_centroids\n",
    "        \n",
    "        for seed in range(10):\n",
    "            results_dict[emb_type][seed] = {}\n",
    "    \n",
    "            #kmeans\n",
    "            kmeans = KMeans(n_clusters=num_clusters, max_iter=max_iter, random_state=seed)\n",
    "            kmeans_assignments = kmeans.fit_predict(text_features[emb_type])\n",
    "            kmeans_centroids = kmeans.cluster_centers_\n",
    "            results = cluster_metrics(np.array(labels), kmeans_assignments, oracle_centroids, kmeans_centroids, oracle_summary_embeddings)\n",
    "            data_results ={'assignments':kmeans_assignments,\n",
    "                           'final_centroids':kmeans_centroids,\n",
    "                           'results':results}\n",
    "            \n",
    "            results_dict[emb_type][seed]['kmeans'] = data_results\n",
    "            \n",
    "            print([data, emb_type, seed, 'kmeans', results])\n",
    "    \n",
    "            #kmedoids\n",
    "            kmedoids = KMedoids(n_clusters=num_clusters, max_iter = max_iter, random_state=seed)\n",
    "            kmedoids.fit(text_features[emb_type])\n",
    "            kmedoids_assignments = kmedoids.labels_\n",
    "            kmedoids_indices = kmedoids.medoid_indices_\n",
    "            kmedoids_centroids = text_features[emb_type][kmedoids_indices]\n",
    "            results = cluster_metrics(np.array(labels), kmedoids_assignments, oracle_centroids, kmedoids_centroids, oracle_summary_embeddings)\n",
    "            data_results ={'assignments':kmedoids_assignments,\n",
    "                           'final_centroids':kmedoids_centroids,\n",
    "                           'results':results}\n",
    "            \n",
    "            results_dict[emb_type][seed]['kmedoids'] = data_results\n",
    "            print([data, emb_type, seed, 'kmedoids', results])\n",
    "\n",
    "            #gmm\n",
    "            # Reduce to 50 principal components\n",
    "            pca = PCA(n_components=50, random_state=seed)\n",
    "            X_pca = pca.fit_transform(X_scaled)\n",
    "            gmm = GaussianMixture(n_components=num_clusters, random_state=seed)\n",
    "            gmm_assignments = gmm.fit_predict(X_pca)\n",
    "\n",
    "            gmm_clustered_embeddings = {i: [] for i in range(num_clusters)}\n",
    "            for embedding, cluster in zip(emb_data, gmm_assignments):\n",
    "                gmm_clustered_embeddings[cluster].append(embedding)\n",
    "            gmm_centroids = [np.mean(gmm_clustered_embeddings[i], axis=0) if gmm_clustered_embeddings[i] else None for i in range(num_clusters)]\n",
    "\n",
    "            results = cluster_metrics(np.array(labels), gmm_assignments, gmm_centroids, oracle_centroids, oracle_centroids)\n",
    "            data_results ={'assignments':gmm_assignments,\n",
    "                           'results':results}\n",
    "            results_dict[emb_type][seed]['gmm'] = data_results\n",
    "            print([data, emb_type, seed, 'gmm', results])\n",
    "\n",
    "            #spectral\n",
    "            spectral = SpectralClustering(n_clusters=num_clusters, random_state=seed, affinity='nearest_neighbors')\n",
    "            spectral_assignments = spectral.fit_predict(X_scaled)\n",
    "            \n",
    "            spectral_clustered_embeddings = {i: [] for i in range(num_clusters)}\n",
    "            for embedding, cluster in zip(emb_data, spectral_assignments):\n",
    "                spectral_clustered_embeddings[cluster].append(embedding)\n",
    "            spectral_centroids = [np.mean(spectral_clustered_embeddings[i], axis=0) if spectral_clustered_embeddings[i] else None for i in range(num_clusters)]\n",
    "\n",
    "            results = cluster_metrics(np.array(labels), spectral_assignments, spectral_centroids, oracle_centroids, oracle_centroids)\n",
    "            data_results ={'assignments':spectral_assignments,\n",
    "                           'results':results}\n",
    "            results_dict[emb_type][seed]['spectral'] = data_results\n",
    "            print([data, emb_type, seed, 'spectral', results])\n",
    "\n",
    "            #Agglo\n",
    "            agglo = AgglomerativeClustering(n_clusters=num_clusters)\n",
    "            agglo_assignments = agglo.fit_predict(X_scaled)\n",
    "\n",
    "            agglo_clustered_embeddings = {i: [] for i in range(num_clusters)}\n",
    "            for embedding, cluster in zip(emb_data, agglo_assignments):\n",
    "                agglo_clustered_embeddings[cluster].append(embedding)\n",
    "            agglo_centroids = [np.mean(agglo_clustered_embeddings[i], axis=0) if agglo_clustered_embeddings[i] else None for i in range(num_clusters)]\n",
    "\n",
    "            results = cluster_metrics(np.array(labels), agglo_assignments, agglo_centroids, oracle_centroids, oracle_centroids)\n",
    "            data_results ={'assignments':agglo_assignments,\n",
    "                           'results':results}\n",
    "            results_dict[emb_type][seed]['agglomerative'] = data_results\n",
    "            print([data, emb_type, seed, 'agglomerative', results])\n",
    "\n",
    "\n",
    "            #k-LLMmeans\n",
    "            for llm_type in ['gpt-3.5-turbo','gpt-4o','llama3.3-70b','deepseek-chat','claude-3-7-sonnet-20250219']:\n",
    "                \n",
    "                results_dict[emb_type][seed][llm_type] = {}\n",
    "\n",
    "                #Uncomment to replicate only gpt-4o\n",
    "                #if llm_type!='gpt-4o':\n",
    "                #    continue\n",
    "                \n",
    "                for force_context_length in [0, 10]:\n",
    "                    results_dict[emb_type][seed][llm_type][force_context_length] = {}\n",
    "                                \n",
    "                    for max_llm_iter in [1, 5]:\n",
    "                        \n",
    "                        assignments, final_summaries, final_summary_embeddings, final_centroids, summaries_evolution, centroids_evolution = kLLMmeans(documents,\n",
    "                                                                 prompt = prompt, text_type = text_type,\n",
    "                                                                 num_clusters = num_clusters, \n",
    "                                                                 force_context_length = force_context_length, max_llm_iter = max_llm_iter, \n",
    "                                                                 max_iter = max_iter, tol=1e-4, random_state = seed, \n",
    "                                                                 emb_type = emb_type,\n",
    "                                                                 text_features = text_features[emb_type])\n",
    "                        \n",
    "                        results = cluster_metrics(np.array(labels), assignments,\n",
    "                                                  oracle_centroids, final_centroids, \n",
    "                                                  oracle_summary_embeddings, final_summary_embeddings)\n",
    "        \n",
    "                        data_results ={'assignments':assignments,\n",
    "                                       'final_summaries':final_summaries,\n",
    "                                       'final_summary_embeddings':final_summary_embeddings,\n",
    "                                       'final_centroids':final_centroids,\n",
    "                                       'summaries_evolution':summaries_evolution,\n",
    "                                       'centroids_evolution':centroids_evolution,\n",
    "                                       'results':results}\n",
    "                        \n",
    "                        results_dict[emb_type][seed][llm_type][force_context_length][max_llm_iter] = data_results\n",
    "                        \n",
    "                        print([data, emb_type, llm_type, seed, force_context_length, max_llm_iter, results])\n",
    "                        \n",
    "                        # Save as pkl file\n",
    "                        with open(\"results/sims_offline_results_\" + emb_type + '_' + data + \".pkl\", \"wb\") as f:\n",
    "                            pickle.dump(results_dict, f)\n",
    "                            \n",
    "            #k-NLPmeans\n",
    "            for nlp_type in ['lsa','centroid','textrank']:\n",
    "                \n",
    "                results_dict[emb_type][seed][nlp_type] = {}\n",
    "               \n",
    "                for top_k in [3, 5, 10, 15]:\n",
    "                    results_dict[emb_type][seed][nlp_type][top_k] = {}\n",
    "                                \n",
    "                    for max_llm_iter in [1, 5]:\n",
    "\n",
    "                        force_context_length = 0\n",
    "                        assignments, final_summaries, final_summary_embeddings, final_centroids, summaries_evolution, centroids_evolution = kNLPmeans(documents,\n",
    "                                                                 num_clusters = num_clusters, \n",
    "                                                                 force_context_length = force_context_length, max_llm_iter = max_llm_iter, \n",
    "                                                                 max_iter = max_iter, tol=1e-4, random_state = seed, \n",
    "                                                                 emb_type = emb_type,\n",
    "                                                                 text_features = text_features[emb_type],\n",
    "                                                                 top_k = top_k,\n",
    "                                                                 text_sentences = text_sentences,\n",
    "                                                                 nlp = nlp_type)\n",
    "                        \n",
    "                        results = cluster_metrics(np.array(labels), assignments,\n",
    "                                                  oracle_centroids, final_centroids, \n",
    "                                                  oracle_summary_embeddings, final_summary_embeddings)\n",
    "        \n",
    "                        data_results ={'assignments':assignments,\n",
    "                                       'final_summaries':final_summaries,\n",
    "                                       'final_summary_embeddings':final_summary_embeddings,\n",
    "                                       'final_centroids':final_centroids,\n",
    "                                       'summaries_evolution':summaries_evolution,\n",
    "                                       'centroids_evolution':centroids_evolution,\n",
    "                                       'results':results}\n",
    "                        \n",
    "                        results_dict[emb_type][seed][nlp_type][top_k][max_llm_iter] = data_results\n",
    "                        \n",
    "                        print([data, emb_type, nlp_type, seed, top_k, max_llm_iter, results])\n",
    "                        \n",
    "                        # Save as pkl file\n",
    "                        with open(\"results/sims_offline_results_\" + emb_type + '_' + data + \".pkl\", \"wb\") as f:\n",
    "                            pickle.dump(results_dict, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3128ff8b-2713-4337-a917-1262785a112b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:base] *",
   "language": "python",
   "name": "conda-base-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
