{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "163d3d2c-c982-4edb-8dac-14b35da72885",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.69 $\\pm$ 0.06\n",
      "0.68 $\\pm$ 0.16\n"
     ]
    }
   ],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "from sklearn.metrics import f1_score, adjusted_rand_score, normalized_mutual_info_score\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "import numpy as np\n",
    "\n",
    "# latent\n",
    "\n",
    "def hungarian_match(true_labels, pred_labels):\n",
    "    n_classes = max(true_labels.max(), pred_labels.max()) + 1\n",
    "    cm = np.zeros((n_classes, n_classes), dtype=np.int64)\n",
    "    for t, p in zip(true_labels, pred_labels):\n",
    "        cm[t, p] += 1\n",
    "    row_ind, col_ind = linear_sum_assignment(-cm)\n",
    "    mapping = {p: t for t, p in zip(row_ind, col_ind)}\n",
    "    remapped = np.array([mapping[p] for p in pred_labels])\n",
    "    return remapped\n",
    "\n",
    "file_list = ['aeqint8']\n",
    "\n",
    "for choice in range(0, 2):\n",
    "    row_parts = []\n",
    "    for j in range(len(file_list)):\n",
    "        f1_list = []\n",
    "        ari_list = []\n",
    "        nmi_list = []\n",
    "        \n",
    "        for seed in range(0, 5):\n",
    "            features = np.load(\"./\" + str(choice) + str(seed) + file_list[j] + \"_recon.npy\")\n",
    "    \n",
    "            gt_class = np.load(\"./\" + str(choice) + str(seed) + file_list[j] + \"_cl.npy\")\n",
    "            \n",
    "            kmeans_latent = KMeans(n_clusters=len(np.unique(gt_class)), random_state=42)\n",
    "            latent_clusters = kmeans_latent.fit_predict(features)\n",
    "    \n",
    "            aligned = hungarian_match(gt_class, latent_clusters)\n",
    "    \n",
    "            f1 = f1_score(gt_class, aligned, average=\"weighted\")\n",
    "            ari = adjusted_rand_score(gt_class, aligned)\n",
    "            nmi = normalized_mutual_info_score(gt_class, aligned)\n",
    "    \n",
    "            f1_list.append(f1)\n",
    "            ari_list.append(ari)\n",
    "            nmi_list.append(nmi)\n",
    "        f1_list = np.array(f1_list)\n",
    "        ari_list = np.array(ari_list)\n",
    "        nmi_list = np.array(nmi_list)\n",
    "    \n",
    "        row_parts.append(f\"{f1_list.mean():.2f} $\\\\pm$ {f1_list.std():.2f}\")\n",
    "        #row_parts.append(f\"{ari_list.mean():.2f} $\\\\pm$ {ari_list.std():.2f}\")\n",
    "        #row_parts.append(f\"{nmi_list.mean():.2f} $\\\\pm$ {nmi_list.std():.2f}\")\n",
    "    \n",
    "    print(\" & \".join(row_parts))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87819e4f-762c-4f16-bc51-405df3016453",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.55 $\\pm$ 0.12\n",
      "0.38 $\\pm$ 0.13\n"
     ]
    }
   ],
   "source": [
    "# recon full wave\n",
    "\n",
    "\n",
    "def hungarian_match(true_labels, pred_labels):\n",
    "    n_classes = max(true_labels.max(), pred_labels.max()) + 1\n",
    "    cm = np.zeros((n_classes, n_classes), dtype=np.int64)\n",
    "    for t, p in zip(true_labels, pred_labels):\n",
    "        cm[t, p] += 1\n",
    "    row_ind, col_ind = linear_sum_assignment(-cm)\n",
    "    mapping = {p: t for t, p in zip(row_ind, col_ind)}\n",
    "    remapped = np.array([mapping[p] for p in pred_labels])\n",
    "    return remapped\n",
    "\n",
    "file_list = ['aeqint8']\n",
    "\n",
    "for choice in range(0, 2):\n",
    "    row_parts = []\n",
    "    for j in range(len(file_list)):\n",
    "        f1_list = []\n",
    "        ari_list = []\n",
    "        nmi_list = []\n",
    "        \n",
    "        for seed in range(0, 5):\n",
    "            features = np.load(\"./\" + str(choice) + str(seed) + file_list[j] + \"_recon.npy\")\n",
    "            features = features[:, 32:96]\n",
    "    \n",
    "            gt_class = np.load(\"./\" + str(choice) + str(seed) + file_list[j] + \"_cl.npy\")\n",
    "            \n",
    "            kmeans_latent = KMeans(n_clusters=len(np.unique(gt_class)), random_state=42)\n",
    "            latent_clusters = kmeans_latent.fit_predict(features)\n",
    "    \n",
    "            aligned = hungarian_match(gt_class, latent_clusters)\n",
    "    \n",
    "            f1 = f1_score(gt_class, aligned, average=\"weighted\")\n",
    "            ari = adjusted_rand_score(gt_class, aligned)\n",
    "            nmi = normalized_mutual_info_score(gt_class, aligned)\n",
    "    \n",
    "            f1_list.append(f1)\n",
    "            ari_list.append(ari)\n",
    "            nmi_list.append(nmi)\n",
    "        f1_list = np.array(f1_list)\n",
    "        ari_list = np.array(ari_list)\n",
    "        nmi_list = np.array(nmi_list)\n",
    "    \n",
    "        #row_parts.append(f\"{f1_list.mean():.2f} $\\\\pm$ {f1_list.std():.2f}\")\n",
    "        #row_parts.append(f\"{ari_list.mean():.2f} $\\\\pm$ {ari_list.std():.2f}\")\n",
    "        row_parts.append(f\"{nmi_list.mean():.2f} $\\\\pm$ {nmi_list.std():.2f}\")\n",
    "    \n",
    "    print(\" & \".join(row_parts))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a10c2d6-bb7c-43ef-a279-59c077ef42c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.72 $\\pm$ 0.09\n",
      "0.74 $\\pm$ 0.07\n"
     ]
    }
   ],
   "source": [
    "# recon roi wave\n",
    "\n",
    "def hungarian_match(true_labels, pred_labels):\n",
    "    n_classes = max(true_labels.max(), pred_labels.max()) + 1\n",
    "    cm = np.zeros((n_classes, n_classes), dtype=np.int64)\n",
    "    for t, p in zip(true_labels, pred_labels):\n",
    "        cm[t, p] += 1\n",
    "    row_ind, col_ind = linear_sum_assignment(-cm)\n",
    "    mapping = {p: t for t, p in zip(row_ind, col_ind)}\n",
    "    remapped = np.array([mapping[p] for p in pred_labels])\n",
    "    return remapped\n",
    "\n",
    "file_list = ['aeqint8']\n",
    "\n",
    "for choice in range(0, 2):\n",
    "    row_parts = []\n",
    "    for j in range(len(file_list)):\n",
    "        f1_list = []\n",
    "        ari_list = []\n",
    "        nmi_list = []\n",
    "        \n",
    "        for seed in range(0, 5):\n",
    "            features = np.load(\"./\" + str(choice) + str(seed) + file_list[j] + \"_latent.npy\")\n",
    "    \n",
    "            gt_class = np.load(\"./\" + str(choice) + str(seed) + file_list[j] + \"_cl.npy\")\n",
    "            \n",
    "            kmeans_latent = KMeans(n_clusters=len(np.unique(gt_class)), random_state=42)\n",
    "            latent_clusters = kmeans_latent.fit_predict(features)\n",
    "    \n",
    "            aligned = hungarian_match(gt_class, latent_clusters)\n",
    "    \n",
    "            f1 = f1_score(gt_class, aligned, average=\"weighted\")\n",
    "            ari = adjusted_rand_score(gt_class, aligned)\n",
    "            nmi = normalized_mutual_info_score(gt_class, aligned)\n",
    "    \n",
    "            f1_list.append(f1)\n",
    "            ari_list.append(ari)\n",
    "            nmi_list.append(nmi)\n",
    "        f1_list = np.array(f1_list)\n",
    "        ari_list = np.array(ari_list)\n",
    "        nmi_list = np.array(nmi_list)\n",
    "    \n",
    "        row_parts.append(f\"{f1_list.mean():.2f} $\\\\pm$ {f1_list.std():.2f}\")\n",
    "        #row_parts.append(f\"{ari_list.mean():.2f} $\\\\pm$ {ari_list.std():.2f}\")\n",
    "        #row_parts.append(f\"{nmi_list.mean():.2f} $\\\\pm$ {nmi_list.std():.2f}\")\n",
    "    \n",
    "    print(\" & \".join(row_parts))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a86f1834-6d80-4a16-8b37-c6af24203c47",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
