{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f0cb4ad7-7e3d-4b3b-a783-7df6d98a4abb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import pyarrow.parquet as pq\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.cluster import AgglomerativeClustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "aa13cbfe-4ae0-4841-899f-0d20bd113850",
   "metadata": {},
   "outputs": [],
   "source": [
    "infile = \"./dataset/finekb_cases_train.parquet\"\n",
    "outfile = \"./dataset/finekb_cases_train_clustered.parquet\"\n",
    "\n",
    "# --- config: tune these if you want ---\n",
    "CASE_EMB_COL = \"embed_summary\"\n",
    "CLUSTER_COL = \"cluster_id\"\n",
    "MIN_CLUSTER_SIZE = 10      # target points per cluster\n",
    "MAX_CLUSTERS_PER_KB = 50  # upper bound on clusters per kb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0780f84d-babd-4dd3-a9be-547eb2e6e915",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load parquet file\n",
    "table = pq.read_table(infile)\n",
    "\n",
    "# Convert to pandas DataFrame\n",
    "df_cases = table.to_pandas()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e04d4a5d-3d58-4c22-815f-91bf19f54609",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>case_id</th>\n",
       "      <th>issue_type</th>\n",
       "      <th>kb_id</th>\n",
       "      <th>embed_summary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>remote_access</td>\n",
       "      <td>11</td>\n",
       "      <td>[0.005332942120730877, -0.003665205556899309, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>info</td>\n",
       "      <td>102</td>\n",
       "      <td>[0.00189633306581527, 0.012942066416144371, 0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>memory</td>\n",
       "      <td>None</td>\n",
       "      <td>[0.00967031717300415, 0.013244512490928173, 0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>contract</td>\n",
       "      <td>35</td>\n",
       "      <td>[0.0068122390657663345, 0.0030071106739342213,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>fan</td>\n",
       "      <td>158</td>\n",
       "      <td>[0.013492287136614323, -0.013519312255084515, ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   case_id     issue_type kb_id  \\\n",
       "0        0  remote_access    11   \n",
       "1        1           info   102   \n",
       "2        2         memory  None   \n",
       "3        3       contract    35   \n",
       "4        4            fan   158   \n",
       "\n",
       "                                       embed_summary  \n",
       "0  [0.005332942120730877, -0.003665205556899309, ...  \n",
       "1  [0.00189633306581527, 0.012942066416144371, 0....  \n",
       "2  [0.00967031717300415, 0.013244512490928173, 0....  \n",
       "3  [0.0068122390657663345, 0.0030071106739342213,...  \n",
       "4  [0.013492287136614323, -0.013519312255084515, ...  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load parquet file\n",
    "table = pq.read_table(infile)\n",
    "\n",
    "# Convert to pandas DataFrame\n",
    "df_cases = table.to_pandas()\n",
    "df_cases.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1af8eb6c-232e-4e13-9b2d-ed1ea1e09a32",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cluster_one_kb(group: pd.DataFrame) -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    group: rows of df for a single kb_id\n",
    "    returns: same rows with a new 'cluster_id' column\n",
    "    \"\"\"\n",
    "    kb_id = group.name\n",
    "    # turn list-of-floats embeddings into a 2D array\n",
    "    X = np.vstack(group[CASE_EMB_COL].values)  # (n_samples, dim)\n",
    "    n_samples = X.shape[0]\n",
    "\n",
    "    # if not enough samples, just put them all into one cluster\n",
    "    if n_samples <= MIN_CLUSTER_SIZE:\n",
    "        group[CLUSTER_COL] = f\"{kb_id}_c0\"\n",
    "        return group\n",
    "\n",
    "    # heuristic: about MIN_CLUSTER_SIZE points per cluster, capped by MAX_CLUSTERS_PER_KB\n",
    "    n_clusters = max(1, min(MAX_CLUSTERS_PER_KB, n_samples // MIN_CLUSTER_SIZE))\n",
    "\n",
    "    # Agglomerative clustering with cosine distance\n",
    "    agglom = AgglomerativeClustering(\n",
    "        n_clusters=n_clusters,\n",
    "        metric=\"cosine\",      # for sklearn >= 1.2; use affinity=\"cosine\" on older versions\n",
    "        linkage=\"average\",\n",
    "    )\n",
    "    labels = agglom.fit_predict(X)\n",
    "\n",
    "    # build cluster ids like \"12345_c0\", \"12345_c1\", ...\n",
    "    group[CLUSTER_COL] = [f\"{kb_id}_c{int(lbl)}\" for lbl in labels]\n",
    "\n",
    "    return group"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6a9d2534-928b-42b1-b0c4-a4087313d131",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   case_id  kb_id cluster_id\n",
      "0        0     11      11_c0\n",
      "1        1    102     102_c5\n",
      "3        3     35      35_c3\n",
      "4        4    158     158_c1\n",
      "6        6     63      63_c1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_5503/3408193478.py:13: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
      "  .apply(cluster_one_kb)\n"
     ]
    }
   ],
   "source": [
    "# ---- main call ----\n",
    "\n",
    "# (optional) drop rows with missing kb_id before clustering\n",
    "df_cases = df_cases.dropna(subset=[\"kb_id\"])\n",
    "\n",
    "# ensure kb_id is a simple type for building 'kb_id_cX' strings\n",
    "df_cases[\"kb_id\"] = df_cases[\"kb_id\"].astype(int)\n",
    "\n",
    "# apply clustering per kb_id\n",
    "df_cases = (\n",
    "    df_cases\n",
    "    .groupby(\"kb_id\", group_keys=False)\n",
    "    .apply(cluster_one_kb)\n",
    ")\n",
    "\n",
    "print(df_cases[[\"case_id\", \"kb_id\", CLUSTER_COL]].head())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dd9cb059-0ef1-44fb-aecc-13027cadd2a0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved case clusters to: ./dataset/finekb_cases_train_clustered.parquet\n"
     ]
    }
   ],
   "source": [
    "# ----------------------------------------------------\n",
    "# Save to parquet\n",
    "# ----------------------------------------------------\n",
    "\n",
    "df_cases.to_parquet(\n",
    "    outfile,\n",
    "    engine=\"pyarrow\",\n",
    "    compression=\"snappy\",\n",
    "    index=False\n",
    ")\n",
    "\n",
    "print(\"Saved case clusters to:\", outfile)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "(gpu-env) Python",
   "language": "python",
   "name": "conda-env-gpu-env-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
