{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bea2186b-5e28-4fc8-b8f3-f6146f7606f5",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a8e04f5-2f23-42ee-9682-349679ec7eea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "base = \"/research/2025_mip/\"\n",
    "sys.path.append(base)\n",
    "sys.path.append(os.path.join(base, 'forge'))\n",
    "\n",
    "import pacmap\n",
    "\n",
    "from forge.forge import Forge\n",
    "from forge.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8ef7cb5-20db-4c8d-b732-a1209f2b06cd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "5762c7c2-aa48-4473-90f2-c92adde59e05",
   "metadata": {},
   "source": [
    "# Clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a97528cc-0f1e-46a1-821f-130b439c54b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('MIPEmbed/data/large_files/d_mip_processed_no_graph.pkl', 'rb') as file: \n",
    "    mip_to_dgl = pkl.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dedb5e9-c1b8-49a9-a671-d50b848a4c2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Forge(prob_head=False, cut_head=False)\n",
    "model.load_model(os.path.join(base, 'models/unsupervised_model.pkl'), model_type='unsupervised')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56bff5d5-3782-46a5-a0b8-20c369d6b434",
   "metadata": {},
   "outputs": [],
   "source": [
    "mip_mat = []\n",
    "label_prob_mat = []\n",
    "mean_embed_mat = []\n",
    "color_vec = [\"-\".join(x.split('-')[:-1]) for x in mip_to_dgl]\n",
    "\n",
    "for inst in tqdm(mip_to_dgl):\n",
    "\n",
    "\n",
    "    # Process LP or MPS files or a Gurobi object into DGL format\n",
    "    g, features, num_cons, num_vars = mip_to_dgl[inst]\n",
    "\n",
    "    # Forward Pass Through GNN\n",
    "    h_list, logits, loss, distances, codebook_ = model.forward(g.to(device), features.to(device), num_cons, num_vars)\n",
    "\n",
    "    # Compute a Vector for Each MIP Instance\n",
    "    # This Vector is a Distribution of the Codes that Constraints and Variables\n",
    "    # in the MIP Instance Belong to.\n",
    "    assigned_codes = torch.argmin(distances, axis=1).detach().cpu().numpy()\n",
    "    mip_vec = np.zeros(model.codebook_size,)\n",
    "    for c in assigned_codes:\n",
    "        mip_vec[c] += 1\n",
    "\n",
    "    mip_mat.append(mip_vec)\n",
    "\n",
    "    # Optionally, Also Run Simple Label Propagation for Comparison\n",
    "    label_propagation = LabelPropagation(k = 2, alpha = 0.5, clamp = False, normalize = True)\n",
    "    label_prop_vec = label_propagation(g, features).mean(0)\n",
    "    label_prob_mat.append(label_prop_vec)\n",
    "\n",
    "    # Optionally, Also Compute a Simple Mean Readout for Comparison\n",
    "    mean_embed_mat.append(h_list[0].mean(0).detach().cpu().numpy())\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8c42968-17ea-4955-890a-58031c2f4b92",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "131078db-de76-4ac5-a240-46515847ba40",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "base_dir = '../data/'\n",
    "sub_dir = 'miplib_hard_filtered'\n",
    "instances = os.listdir(os.path.join(base_dir, sub_dir))\n",
    "\n",
    "mip_mat = []\n",
    "cons_mat = []\n",
    "var_mat = []\n",
    "mip_dist_mat = []\n",
    "label_prob_mat = []\n",
    "mean_embed_mat = []\n",
    "\n",
    "color_vec = []\n",
    "\n",
    "for inst in tqdm(instances): \n",
    "    \n",
    "    if 'ipynb' in inst:\n",
    "        continue\n",
    "\n",
    "    color_vec.append(os.path.join(sub_dir, inst))\n",
    "    \n",
    "    # Process LP or MPS file into DGL format\n",
    "    g, features, num_cons, num_vars = generate_mip_graph(os.path.join(base_dir, sub_dir, inst), graph_features = False, gurobi_object = False)\n",
    "\n",
    "    # Forward Pass Through GNN\n",
    "    h_list, logits, loss, distances, codebook_ = model.forward(g.to(device), features.to(device), num_cons, num_vars)\n",
    "\n",
    "    # Compute a Vector for Each MIP Instance\n",
    "    # This Vector is a Distribution of the Codes that Constraints and Variables \n",
    "    # in the MIP Instance Belong to. \n",
    "    mip_vec = torch.argmin(distances, axis = 1).detach().cpu().numpy()\n",
    "    mip_mat.append(mip_vec)\n",
    "\n",
    "    # Optionally, One Can Also Generate a Distribution of Distributions by Considering\n",
    "    # the Top k Minimum Distances Instead of Just the Minimum\n",
    "    arg_matrix = np.argpartition(distances.detach().cpu().numpy(), kth = 2, axis = 1)\n",
    "    a1 = arg_matrix[:, 0]\n",
    "    a2 = arg_matrix[:, 1]\n",
    "    a3 = arg_matrix[:, 2]\n",
    "\n",
    "    # Optionally, One Can Also Generate Two Vectors\n",
    "    # For Each MIP Instance. One for Variable Distribution\n",
    "    # And One for Constraint Distribution\n",
    "    cons_mat.append(mip_vec[:num_cons])\n",
    "    var_mat.append(mip_vec[num_cons:])\n",
    "\n",
    "    # Generate the Distributions\n",
    "    h = np.zeros(5000,)\n",
    "    for (c1, c2, c3) in zip(a1, a2, a3):\n",
    "        h[c1] += 1\n",
    "        h[c2] += 1\n",
    "        h[c3] += 1\n",
    "    mip_dist_mat.append(h)\n",
    "\n",
    "    # Optionally, Also Run Simple Label Propagation for Comparison\n",
    "    label_propagation = LabelPropagation(k = 2, alpha = 0.5, clamp = False, normalize = True)\n",
    "    label_prop_vec = label_propagation(g, features).mean(0)\n",
    "    label_prob_mat.append(label_prop_vec)\n",
    "\n",
    "    # Optionally, Also Compute a Simple Mean Readout for Comparison\n",
    "    mean_embed_mat.append(h_list[0].mean(0).detach().cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c92b14b-6b29-4f0c-b216-bb087f2ed500",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d69a6fb5-d405-440a-a7a1-fe9dd7d32f8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute Distribution from Lists Computed In Previous Cell\n",
    "hist_mat = np.zeros((len(mip_mat), 5000))\n",
    "hist_cons_mat = np.zeros((len(cons_mat), 5000))\n",
    "hist_var_mat = np.zeros((len(var_mat), 5000))\n",
    "\n",
    "for idx in range(len(mip_mat)):\n",
    "    vec = mip_mat[idx]\n",
    "    for col in vec:\n",
    "        hist_mat[idx][col] += 1\n",
    "\n",
    "for idx in range(len(cons_mat)):\n",
    "    vec = cons_mat[idx]\n",
    "    for col in vec:\n",
    "        hist_cons_mat[idx][col] += 1\n",
    "\n",
    "for idx in range(len(var_mat)):\n",
    "    vec = var_mat[idx]\n",
    "    for col in vec:\n",
    "        hist_var_mat[idx][col] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d04486f-82fa-4a99-81a1-121bf79040dc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82d164c8-7643-4726-9630-15a80edf440a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualization\n",
    "matrix_to_visualize = mean_embed_mat\n",
    "\n",
    "# Compute PCA to visualize in 2D\n",
    "pca = pacmap.PaCMAP(n_components = 2, n_neighbors=10, MN_ratio=0.5, FP_ratio=2.0).fit_transform(matrix_to_visualize, init = 'pca')\n",
    "# pca = PCA(n_components = 2).fit_transform(matrix_to_visualize)\n",
    "pca = (pca - np.min(pca)) / np.ptp(pca)\n",
    "\n",
    "# Compute an index dict to make it easier \n",
    "# to plot different labels and colors\n",
    "index_dict = {}\n",
    "for idx, c in enumerate(color_vec):\n",
    "    try:\n",
    "        index_dict[c].append(idx)\n",
    "    except:\n",
    "        index_dict[c] = [idx]\n",
    "\n",
    "colors = [cm(1.*i/21) for i in range(21)]\n",
    "cmap = plt.get_cmap('tab20')\n",
    "# Plot Instances \n",
    "plt.figure(figsize = (5.5, 5.5), dpi = 300)\n",
    "for idx, c in enumerate(sorted(index_dict)):\n",
    "    plt.scatter(pca[index_dict[c], 0] + np.random.rand()/1000, pca[index_dict[c], 1] + np.random.rand()/1000, s = 10, label = c, color = cmap(idx))\n",
    "plt.legend(loc = 'right', ncols = 1, bbox_to_anchor = (1.45, .5))\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d59da07e-535a-4389-bcc8-4000305b8efa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5d7d26e-ab86-4c9b-9f8a-654844d4b119",
   "metadata": {},
   "outputs": [],
   "source": [
    "distance = np.zeros((18, 18))\n",
    "\n",
    "for i in range(18):\n",
    "    for j in range(18):\n",
    "        distance[i][j] = np.linalg.norm(matrix_to_visualize[i] - matrix_to_visualize[j])\n",
    "\n",
    "plt.imshow(distance)\n",
    "plt.xticks(list(range(18)), color_vec, rotation = 90)\n",
    "plt.yticks(list(range(18)), color_vec, rotation = 0)\n",
    "plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cf6fac6-7d89-4d8d-ad93-7960017814e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "from sklearn.metrics import normalized_mutual_info_score, adjusted_mutual_info_score\n",
    "from sklearn.metrics.cluster import contingency_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6878a916-83c0-4414-b489-9f10294f185b",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_clusters = 21\n",
    "res_nmi = {'dist' : [], 'label' : [], 'mean' : []}\n",
    "res_acc = {'dist' : [], 'label' : [], 'mean' : []}\n",
    "\n",
    "for _ in tqdm(range(10)):\n",
    "    \n",
    "    km_dist = KMeans(n_clusters = num_clusters).fit(mip_mat)\n",
    "    km_label = KMeans(n_clusters = num_clusters).fit(label_prob_mat)\n",
    "    km_mean = KMeans(n_clusters = num_clusters).fit(mean_embed_mat)\n",
    "\n",
    "    km_type = {'dist' : km_dist, 'label' : km_label, 'mean' : km_mean}\n",
    "\n",
    "    for t in ['dist', 'label', 'mean']:\n",
    "        true = color_vec\n",
    "        pred = km_type[t].labels_\n",
    "        \n",
    "        acc = {c : 0 for c in range(num_clusters)}\n",
    "        names = {c : None for c in range(num_clusters)}\n",
    "        for c in range(num_clusters):\n",
    "            indices = np.where(pred == c)[0]\n",
    "            acc[c] = purity_score(np.array(true)[indices], pred[indices])\n",
    "            names[c] = set(np.array(true)[indices])\n",
    "        \n",
    "        nmi_score = normalized_mutual_info_score(true, pred)\n",
    "        mean_acc = np.mean(list(acc.values()))\n",
    "        res_acc[t].append(mean_acc)\n",
    "        res_nmi[t].append(nmi_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c14416a3-c32f-454a-b85f-72090708bda2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gpu",
   "language": "python",
   "name": "gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
