{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d43bd564-ff37-4e89-a2c5-88efc27d36fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from kLLMmeans import kLLMmeans, get_embeddings, summarize_cluster, sequentialMiniBatchKmeans, miniBatchKLLMeans, miniBatchNLPeans\n",
    "from experiment_utils import load_dataset, cluster_metrics, avg_closest_distance\n",
    "from sklearn.cluster import KMeans, MiniBatchKMeans\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "import numpy as np\n",
    "import json, pickle\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73289721-39f2-4eb3-bb40-a5ba9acd6f5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"The following is a cluster of questions from the same community. Write a summary that represents the cluster:\"\n",
    "text_type = \"Summary:\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebb6dd27-eacd-4727-b361-8df2121fc976",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_iter = 120\n",
    "emb_type = 'openai'\n",
    "with open(\"processed_data/data_stackexchange_openai_2023.pkl\", \"rb\") as f:\n",
    "    data_dict = pickle.load(f) \n",
    "\n",
    "groups_df = data_dict['data'].groupby('Label').count().reset_index()\n",
    "selected = list(groups_df[groups_df.Year>500].Label)\n",
    "data_dict['data'] = data_dict['data'][data_dict['data']['Label'].isin(selected)]\n",
    "encoder = LabelEncoder()\n",
    "numeric_labels = encoder.fit_transform(list(data_dict['data'].Label))\n",
    "max_batch_size = 10000\n",
    "\n",
    "for year in [2020,2021,2022,2023]:\n",
    "\n",
    "    with open(\"processed_data/data_stackexchange_openai_\" + str(year) + \".pkl\", \"rb\") as f:\n",
    "        data_dict = pickle.load(f) \n",
    "    \n",
    "    try:\n",
    "        with open(\"results/sims_stackexchange_results_\" + str(year) + \".pkl\", \"rb\") as f:\n",
    "            results_dict = pickle.load(f)\n",
    "        print('Old results_dict loaded')\n",
    "    except:\n",
    "        print('No previous results')\n",
    "        results_dict = {}\n",
    "\n",
    "    data_dict['data']['embeddings'] = list(data_dict['embeddings'])\n",
    "    data_dict['data'] = data_dict['data'][data_dict['data']['Label'].isin(selected)]\n",
    "    #total = total + data_dict['data'].shape[0]\n",
    "\n",
    "    data_dict['data'] = data_dict['data'].sort_values('CreationDate')\n",
    "    \n",
    "    text_data = list(data_dict['data']['Text'])\n",
    "    labels = list(encoder.transform(list(data_dict['data'].Label)))\n",
    "    num_clusters = len(np.unique(labels))\n",
    "    text_features = list(data_dict['data']['embeddings'])\n",
    "\n",
    "    del data_dict\n",
    "\n",
    "    with open(\"processed_data/data_sentences_stackexchange_openai_\" + str(year) + \".pkl\", \"rb\") as f:\n",
    "        sentences_dict = pickle.load(f) \n",
    "    text_sentences = sentences_dict[emb_type]\n",
    "    \n",
    "    del sentences_dict\n",
    "\n",
    "    oracle_clustered_embeddings = {i: [] for i in range(num_clusters)}\n",
    "    for embedding, cluster in zip(text_features, labels):\n",
    "        oracle_clustered_embeddings[cluster].append(embedding)\n",
    "    oracle_centroids = [np.mean(oracle_clustered_embeddings[i], axis=0) if oracle_clustered_embeddings[i] else None for i in range(num_clusters)]\n",
    "\n",
    "    for seed in range(5):\n",
    "        if results_dict.get(seed) is None:\n",
    "            results_dict[seed] = {}\n",
    "\n",
    "        #minibatch\n",
    "        if results_dict[seed].get('minibatchkmeans') is None:\n",
    "            minibatchkmeans = MiniBatchKMeans(n_clusters=num_clusters,\n",
    "                                     random_state=seed,\n",
    "                                     batch_size=max_batch_size,\n",
    "                                     init=\"k-means++\")\n",
    "            minibatch_assignments = minibatchkmeans.fit_predict(text_features)\n",
    "            minibatch_centroids = minibatchkmeans.cluster_centers_\n",
    "            results = cluster_metrics(np.array(labels), minibatch_assignments, oracle_centroids, minibatch_centroids, minibatch_centroids)\n",
    "            \n",
    "            data_results ={'assignments':minibatch_assignments,\n",
    "                           'final_centroids':minibatch_centroids,\n",
    "                           'results':results}\n",
    "    \n",
    "            \n",
    "            results_dict[seed]['minibatchkmeans'] = data_results\n",
    "            print([year, seed, 'minibatchkmeans', results])\n",
    "        \n",
    "        #seq minibatch\n",
    "        if results_dict[seed].get('seqminibatchkmeans') is None:\n",
    "            seqminibatchKMeans = sequentialMiniBatchKmeans(text_features, \n",
    "                                                   num_clusters, \n",
    "                                                   random_state=seed, \n",
    "                                                   max_batch_size=max_batch_size)\n",
    "            seqminibatch_assignments = seqminibatchKMeans.predict(text_features)\n",
    "            seqminibatch_centroids = seqminibatchKMeans.cluster_centers_\n",
    "            results = cluster_metrics(np.array(labels), seqminibatch_assignments, oracle_centroids, seqminibatch_centroids, seqminibatch_centroids)\n",
    "     \n",
    "            data_results ={'assignments':seqminibatch_assignments,\n",
    "                           'final_centroids':seqminibatch_centroids,\n",
    "                           'results':results}\n",
    "            \n",
    "            results_dict[seed]['seqminibatchkmeans'] = data_results\n",
    "            print([year, seed, 'seqminibatchkmeans', results])\n",
    "        \n",
    "        #kmeans\n",
    "        if results_dict[seed].get('kmeans') is None:\n",
    "            kmeans = KMeans(n_clusters=num_clusters, random_state=seed)\n",
    "            kmeans_assignments = kmeans.fit_predict(text_features)\n",
    "            kmeans_centroids = kmeans.cluster_centers_\n",
    "            results = cluster_metrics(np.array(labels), kmeans_assignments, oracle_centroids, kmeans_centroids, kmeans_centroids)\n",
    "            \n",
    "            data_results ={'assignments':seqminibatch_assignments,\n",
    "                           'final_centroids':seqminibatch_centroids,\n",
    "                           'results':results}\n",
    "            \n",
    "            results_dict[seed]['kmeans'] = data_results\n",
    "            print([year, seed, 'kmeans', results])\n",
    "\n",
    "        #kLLMmeans\n",
    "        for force_context_length in [10, 50]:\n",
    "            if results_dict[seed].get(force_context_length) is None:\n",
    "                results_dict[seed][force_context_length] = {}\n",
    "                \n",
    "            for max_llm_iter in [1, 5]:\n",
    "\n",
    "                if results_dict[seed][force_context_length].get(max_llm_iter) is None:\n",
    "                    summaries, centroids = miniBatchKLLMmeans(text_data, \n",
    "                                                            num_clusters,\n",
    "                                                            max_batch_size = max_batch_size, \n",
    "                                                            init = 'k-means++',\n",
    "                                                            prompt = prompt, text_type = text_type,\n",
    "                                                            force_context_length = force_context_length, max_llm_iter = max_llm_iter, \n",
    "                                                            max_iter = 120, tol=1e-4, random_state = seed, \n",
    "                                                            emb_type = 'openai', text_features = text_features)\n",
    "\n",
    "                    kmeans2 = KMeans(n_clusters=num_clusters, init=centroids, max_iter=1)\n",
    "                    cluster_assignments = kmeans2.fit_predict(text_features)\n",
    "                    results = cluster_metrics(np.array(labels), cluster_assignments, oracle_centroids, centroids, centroids)\n",
    "                    \n",
    "                    data_results ={'assignments':cluster_assignments,\n",
    "                                   'summaries':summaries,\n",
    "                                   'final_centroids':centroids,\n",
    "                                   'results':results}\n",
    "                    \n",
    "                    results_dict[seed][force_context_length][max_llm_iter] = data_results\n",
    "                    print([year, seed, force_context_length, max_llm_iter, results])\n",
    "\n",
    "                    # Save as pkl file\n",
    "                    with open(\"results/sims_stackexchange_results_\" + str(year) + \".pkl\", \"wb\") as f:\n",
    "                        pickle.dump(results_dict, f)\n",
    "                        \n",
    "                else:\n",
    "                    results = results_dict[seed][force_context_length][max_llm_iter]['results']\n",
    "                    print([year, seed, force_context_length, max_llm_iter, results])\n",
    "        #kNLPmeans\n",
    "        for nlp_type in ['lsa', 'centroid', 'textrank']:\n",
    "            if results_dict[seed].get(nlp_type) is None:\n",
    "                results_dict[seed][nlp_type] = {}\n",
    "                \n",
    "            for max_llm_iter in [1, 5]:\n",
    "\n",
    "                if results_dict[seed][nlp_type].get(max_llm_iter) is None:\n",
    "                    summaries, centroids = miniBatchKNLPmeans(text_data, \n",
    "                                                            num_clusters,\n",
    "                                                            max_batch_size = max_batch_size, \n",
    "                                                            init = 'k-means++',\n",
    "                                                            force_context_length = 0, max_llm_iter = max_llm_iter, \n",
    "                                                            max_iter = 120, tol=1e-4, random_state = seed, \n",
    "                                                            emb_type = 'openai', text_features = text_features,\n",
    "                                                            top_k = top_k, text_sentences = text_sentences,\n",
    "                                                            nlp = nlp_type)\n",
    "                    \n",
    "                    kmeans2 = KMeans(n_clusters=num_clusters, init=centroids, max_iter=1)\n",
    "                    cluster_assignments = kmeans2.fit_predict(text_features)\n",
    "                    results = cluster_metrics(np.array(labels), cluster_assignments, oracle_centroids, centroids, centroids)\n",
    "                    \n",
    "                    data_results ={'assignments':cluster_assignments,\n",
    "                                   'summaries':summaries,\n",
    "                                   'final_centroids':centroids,\n",
    "                                   'results':results}\n",
    "                    \n",
    "                    results_dict[seed][nlp_type][max_llm_iter] = data_results\n",
    "                    print([year, seed, nlp_type, max_llm_iter, results])\n",
    "\n",
    "                    # Save as pkl file\n",
    "                    with open(\"results/sims_stackexchange_results_\" + str(year) + \".pkl\", \"wb\") as f:\n",
    "                        pickle.dump(results_dict, f)\n",
    "                        \n",
    "                else:\n",
    "                    results = results_dict[seed][nlp_type][max_llm_iter]['results']\n",
    "                    print([year, seed, nlp_type, max_llm_iter, results])\n",
    "\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5006a39e-849b-43b1-a57e-253f3fe53578",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:base] *",
   "language": "python",
   "name": "conda-base-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
