{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8d59d756",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "628c26f8-8376-4d91-9c54-79aef318d3a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from sklearn import metrics as sk_metrics\n",
    "from sklearn.datasets import fetch_olivetti_faces\n",
    "\n",
    "import kmeans_jax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "72358413",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = jnp.asarray(fetch_olivetti_faces().data)\n",
    "data /= jnp.linalg.norm(data, axis=-1, keepdims=True)\n",
    "\n",
    "true_labels = jnp.repeat(jnp.arange(0, 40), 10)\n",
    "true_centroids = kmeans_jax.kmeans.compute_centroids(data, true_labels, 40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "06d872e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.key(0)\n",
    "key_lloyd, key_hartigan = jax.random.split(key, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2134e30c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Lloyd\n"
     ]
    }
   ],
   "source": [
    "# the batch size corresponds to initializations run in parallel\n",
    "# it's a batch on processes, not on data points\n",
    "# use a smaller number if your un out of memory\n",
    "lloyd = kmeans_jax.KMeans(\n",
    "    n_clusters=40, n_init=500, max_iter=100, init=\"kmeans++\", algorithm=\"Lloyd\"\n",
    ")\n",
    "\n",
    "print(\"Running Lloyd\")\n",
    "lloyd_results = lloyd.fit(key_lloyd, data, batch_size=100)\n",
    "\n",
    "nmi_lloyd = sk_metrics.normalized_mutual_info_score(\n",
    "    true_labels, lloyd_results[\"labels\"]\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e03b8c6e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Hartigan\n"
     ]
    }
   ],
   "source": [
    "# Warning: this will take a couple of minutes to run\n",
    "\n",
    "hartigan = kmeans_jax.KMeans(\n",
    "    n_clusters=40,\n",
    "    n_init=500,\n",
    "    max_iter=100,\n",
    "    init=\"random partition\",\n",
    "    algorithm=\"Hartigan\",\n",
    ")\n",
    "\n",
    "\n",
    "# Batch size not available for Hartigan (it doesn't do anything)\n",
    "# as it's implemented in numba and not in jax\n",
    "print(\"Running Hartigan\")\n",
    "hartigan_results = hartigan.fit(key_hartigan, data, output=\"best\")\n",
    "\n",
    "nmi_hartigan = sk_metrics.normalized_mutual_info_score(\n",
    "    true_labels, hartigan_results[\"labels\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "90af8e67",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, labels_spectral, loss_spectral, _ = kmeans_jax.run_spectral_clustering(\n",
    "    data, n_clusters=40, n_init=100, random_state=0\n",
    ")\n",
    "nmi_spectral = (\n",
    "    sk_metrics.normalized_mutual_info_score(true_labels, labels_spectral),\n",
    "    loss_spectral,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5bf03eea",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, labels_sdp, loss_sdp, _ = kmeans_jax.run_sdp_clustering(data, n_clusters=40, max_iters=2000)\n",
    "nmi_sdp = sk_metrics.normalized_mutual_info_score(true_labels, labels_sdp), loss_sdp\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c41d1c53",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Lloyd: Loss = 8.53, NMI = 0.74\n",
      "Hartigan: Loss = 8.11, NMI = 0.77\n",
      "SDP: Loss = 8.85, NMI = 0.72\n",
      "Spectral: Loss = 8.95, NMI = 0.83\n"
     ]
    }
   ],
   "source": [
    "print(f\"Lloyd: Loss = {lloyd_results['loss']:.2f}, NMI = {nmi_lloyd:.2f}\")\n",
    "print(f\"Hartigan: Loss = {hartigan_results['loss']:.2f}, NMI = {nmi_hartigan:.2f}\")\n",
    "print(\n",
    "    f\"SDP: Loss = {loss_sdp:.2f}, NMI = {sk_metrics.normalized_mutual_info_score(true_labels, labels_sdp):.2f}\"\n",
    ")\n",
    "print(\n",
    "    f\"Spectral: Loss = {loss_spectral:.2f}, NMI = {sk_metrics.normalized_mutual_info_score(true_labels, labels_spectral):.2f}\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "icml_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
