{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "60d50d30",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Tuple, Optional\n",
    "\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "import numpy as np\n",
    "import jax\n",
    "import scipy.io as sio\n",
    "from sklearn import metrics as sk_metrics\n",
    "from sklearn.datasets import (\n",
    "    fetch_20newsgroups,\n",
    ")\n",
    "\n",
    "import kmeans_jax\n",
    "\n",
    "DATA_PER_CLUSTER = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "172ba6c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_20newsgroups(\n",
    "    n_samples: Optional[int] = None,\n",
    "    n_classes: Optional[int] = None,\n",
    "    max_features: int = 5000,\n",
    "    random_state: int = 42,\n",
    ") -> Tuple[np.ndarray, np.ndarray]:\n",
    "    \"\"\"\n",
    "    Load 20 Newsgroups text dataset with TF-IDF features.\n",
    "    \"\"\"\n",
    "    print(\"Loading 20 Newsgroups...\")\n",
    "\n",
    "    # Select subset of categories if needed\n",
    "    categories = None\n",
    "    if n_classes is not None and n_classes < 20:\n",
    "        all_categories = [\n",
    "            \"alt.atheism\",\n",
    "            \"comp.graphics\",\n",
    "            \"comp.windows.x\",\n",
    "            \"misc.forsale\",\n",
    "            \"rec.autos\",\n",
    "            \"rec.motorcycles\",\n",
    "            \"rec.sport.baseball\",\n",
    "            \"sci.crypt\",\n",
    "            \"sci.med\",\n",
    "            \"sci.space\",\n",
    "            \"soc.religion.christian\",\n",
    "            \"talk.politics.guns\",\n",
    "        ]\n",
    "        categories = all_categories[:n_classes]\n",
    "\n",
    "    newsgroups = fetch_20newsgroups(\n",
    "        subset=\"all\", categories=categories, remove=(\"headers\", \"footers\", \"quotes\")\n",
    "    )\n",
    "\n",
    "    # TF-IDF vectorization\n",
    "    vectorizer = TfidfVectorizer(max_features=max_features, stop_words=\"english\")\n",
    "    docs_vect = vectorizer.fit_transform(newsgroups.data).toarray()\n",
    "    y = newsgroups.target\n",
    "\n",
    "    # filter data with zero norm\n",
    "    norms = np.linalg.norm(docs_vect, axis=1)\n",
    "    nonzero_indices = np.where(norms > 0)[0]\n",
    "    docs_vect = docs_vect[nonzero_indices]\n",
    "    y = y[nonzero_indices]\n",
    "\n",
    "    rng = np.random.default_rng(random_state)\n",
    "\n",
    "    if n_samples is not None and n_samples < len(docs_vect):\n",
    "        indices = rng.choice(len(docs_vect), size=n_samples, replace=False)\n",
    "        docs_vect, y = docs_vect[indices], y[indices]\n",
    "\n",
    "\n",
    "    return docs_vect.astype(np.float64), y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "543774e3",
   "metadata": {},
   "source": [
    "# 20NG-A, K = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "abb3e8b0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading 20 Newsgroups...\n"
     ]
    }
   ],
   "source": [
    "ng20_a_data, ng20_a_labels = load_20newsgroups(\n",
    "    n_samples=2 * DATA_PER_CLUSTER, n_classes=2, random_state=0\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "892ac62b",
   "metadata": {},
   "outputs": [],
   "source": [
    "lloyd = kmeans_jax.KMeans(\n",
    "    n_clusters=2, n_init=500, max_iter=100, init=\"kmeans++\", algorithm=\"Lloyd\"\n",
    ")\n",
    "\n",
    "hartigan = kmeans_jax.KMeans(\n",
    "    n_clusters=2,\n",
    "    n_init=500,\n",
    "    max_iter=100,\n",
    "    init=\"random partition\",\n",
    "    algorithm=\"Hartigan\",\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "110e0148",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Lloyd\n",
      "Running Hartigan\n"
     ]
    }
   ],
   "source": [
    "# the batch size corresponds to initializations run in parallel\n",
    "# it's a batch on processes, not on data points\n",
    "\n",
    "key_init = jax.random.key(0)\n",
    "\n",
    "print(\"Running Lloyd\")\n",
    "lloyd_results = lloyd.fit(key_init, ng20_a_data, batch_size=100)\n",
    "\n",
    "# Batch size not available for Hartigan (it doesn't do anythin)\n",
    "# as it's implemented in numba and not in jax\n",
    "print(\"Running Hartigan\")\n",
    "hartigan_results = hartigan.fit(key_init, np.asarray(ng20_a_data), batch_size=100)\n",
    "\n",
    "nmi_lloyd = sk_metrics.normalized_mutual_info_score(\n",
    "    ng20_a_labels, lloyd_results[\"labels\"]\n",
    ")\n",
    "nmi_hartigan = sk_metrics.normalized_mutual_info_score(\n",
    "    ng20_a_labels, hartigan_results[\"labels\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f37b2813",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Spectral Clustering\n"
     ]
    }
   ],
   "source": [
    "print(\"Running Spectral Clustering\")\n",
    "_, labels_spectral, loss_spectral, _ = kmeans_jax.run_spectral_clustering(\n",
    "    ng20_a_data, n_clusters=2, n_init=500, normalizes_data=False, random_state=0\n",
    ")\n",
    "nmi_spectral = sk_metrics.normalized_mutual_info_score(ng20_a_labels, labels_spectral)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1bd82d4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "centroids_sdp, labels_sdp, loss_sdp, _ = kmeans_jax.run_sdp_clustering(\n",
    "    ng20_a_data, n_clusters=2, max_iters=2000\n",
    ")\n",
    "\n",
    "# centroids_sdp = kmeans_jax.kmeans.compute_centroids(data, labels_sdp, n_clusters)\n",
    "loss_sdp = kmeans_jax.kmeans.compute_loss(ng20_a_data, centroids_sdp, labels_sdp)\n",
    "nmi_sdp = sk_metrics.normalized_mutual_info_score(ng20_a_labels, labels_sdp)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2ca778cc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Lloyd: Loss = 193.72, NMI = 0.27\n",
      "Hartigan: Loss = 193.46, NMI = 0.54\n",
      "SDP: Loss = 193.54, NMI = 0.48\n",
      "Spectral: Loss = 193.48, NMI = 0.52\n"
     ]
    }
   ],
   "source": [
    "print(f\"Lloyd: Loss = {lloyd_results['loss']:.2f}, NMI = {nmi_lloyd:.2f}\")\n",
    "print(f\"Hartigan: Loss = {hartigan_results['loss']:.2f}, NMI = {nmi_hartigan:.2f}\")\n",
    "print(\n",
    "    f\"SDP: Loss = {loss_sdp:.2f}, NMI = {nmi_sdp:.2f}\"\n",
    ")\n",
    "print(\n",
    "    f\"Spectral: Loss = {loss_spectral:.2f}, NMI = {nmi_spectral:.2f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2242fd64",
   "metadata": {},
   "source": [
    "# 20NG-B, K = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "acf0da7c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading 20 Newsgroups...\n"
     ]
    }
   ],
   "source": [
    "ng20_b_data, ng20_b_labels = load_20newsgroups(\n",
    "    n_samples=5 * DATA_PER_CLUSTER, n_classes=5, random_state=1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7e30315a",
   "metadata": {},
   "outputs": [],
   "source": [
    "lloyd = kmeans_jax.KMeans(\n",
    "    n_clusters=5, n_init=500, max_iter=100, init=\"kmeans++\", algorithm=\"Lloyd\"\n",
    ")\n",
    "\n",
    "hartigan = kmeans_jax.KMeans(\n",
    "    n_clusters=5,\n",
    "    n_init=500,\n",
    "    max_iter=100,\n",
    "    init=\"random partition\",\n",
    "    algorithm=\"Hartigan\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "98dc97fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Lloyd\n",
      "Running Hartigan\n"
     ]
    }
   ],
   "source": [
    "key_init = jax.random.key(0)\n",
    "\n",
    "print(\"Running Lloyd\")\n",
    "lloyd_results = lloyd.fit(key_init, ng20_b_data, batch_size=100)\n",
    "\n",
    "print(\"Running Hartigan\")\n",
    "hartigan_results = hartigan.fit(key_init, np.asarray(ng20_b_data), batch_size=100)\n",
    "\n",
    "nmi_lloyd = sk_metrics.normalized_mutual_info_score(\n",
    "    ng20_b_labels, lloyd_results[\"labels\"]\n",
    ")\n",
    "nmi_hartigan = sk_metrics.normalized_mutual_info_score(\n",
    "    ng20_b_labels, hartigan_results[\"labels\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ba950d56",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Spectral Clustering\n"
     ]
    }
   ],
   "source": [
    "print(\"Running Spectral Clustering\")\n",
    "_, labels_spectral, loss_spectral, _ = kmeans_jax.run_spectral_clustering(\n",
    "    ng20_b_data, n_clusters=5, n_init=500, normalizes_data=False, random_state=0\n",
    ")\n",
    "nmi_spectral = sk_metrics.normalized_mutual_info_score(ng20_b_labels, labels_spectral)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "fecf3e50",
   "metadata": {},
   "outputs": [],
   "source": [
    "centroids_sdp, labels_sdp, loss_sdp, _ = kmeans_jax.run_sdp_clustering(\n",
    "    ng20_b_data, n_clusters=5, max_iters=2000\n",
    ")\n",
    "\n",
    "# centroids_sdp = kmeans_jax.kmeans.compute_centroids(data, labels_sdp, n_clusters)\n",
    "loss_sdp = kmeans_jax.kmeans.compute_loss(ng20_b_data, centroids_sdp, labels_sdp)\n",
    "nmi_sdp = sk_metrics.normalized_mutual_info_score(ng20_b_labels, labels_sdp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1940c6b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Lloyd: Loss = 484.04, NMI = 0.24\n",
      "Hartigan: Loss = 481.72, NMI = 0.44\n",
      "SDP: Loss = 484.29, NMI = 0.34\n",
      "Spectral: Loss = 482.89, NMI = 0.37\n"
     ]
    }
   ],
   "source": [
    "print(f\"Lloyd: Loss = {lloyd_results['loss']:.2f}, NMI = {nmi_lloyd:.2f}\")\n",
    "print(f\"Hartigan: Loss = {hartigan_results['loss']:.2f}, NMI = {nmi_hartigan:.2f}\")\n",
    "print(\n",
    "    f\"SDP: Loss = {loss_sdp:.2f}, NMI = {nmi_sdp:.2f}\"\n",
    ")\n",
    "print(\n",
    "    f\"Spectral: Loss = {loss_spectral:.2f}, NMI = {nmi_spectral:.2f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a56e7aaf",
   "metadata": {},
   "source": [
    "# 20NG-C, K = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "79a21591",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading 20 Newsgroups...\n"
     ]
    }
   ],
   "source": [
    "ng20_c_data, ng20_c_labels = load_20newsgroups(\n",
    "    n_samples=10 * DATA_PER_CLUSTER, n_classes=10, random_state=2\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "8b49b8ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "lloyd = kmeans_jax.KMeans(\n",
    "    n_clusters=10, n_init=500, max_iter=100, init=\"kmeans++\", algorithm=\"Lloyd\"\n",
    ")\n",
    "\n",
    "hartigan = kmeans_jax.KMeans(\n",
    "    n_clusters=10,\n",
    "    n_init=500,\n",
    "    max_iter=100,\n",
    "    init=\"random partition\",\n",
    "    algorithm=\"Hartigan\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "836550ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Lloyd\n",
      "Running Hartigan\n"
     ]
    }
   ],
   "source": [
    "key_init = jax.random.key(0)\n",
    "\n",
    "print(\"Running Lloyd\")\n",
    "lloyd_results = lloyd.fit(key_init, ng20_c_data, batch_size=100)\n",
    "\n",
    "print(\"Running Hartigan\")\n",
    "hartigan_results = hartigan.fit(key_init, np.asarray(ng20_c_data), batch_size=100)\n",
    "\n",
    "nmi_lloyd = sk_metrics.normalized_mutual_info_score(\n",
    "    ng20_c_labels, lloyd_results[\"labels\"]\n",
    ")\n",
    "nmi_hartigan = sk_metrics.normalized_mutual_info_score(\n",
    "    ng20_c_labels, hartigan_results[\"labels\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "6fd7d692",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running Spectral Clustering\n"
     ]
    }
   ],
   "source": [
    "print(\"Running Spectral Clustering\")\n",
    "_, labels_spectral, loss_spectral, _ = kmeans_jax.run_spectral_clustering(\n",
    "    ng20_c_data, n_clusters=10, n_init=500, normalizes_data=False, random_state=0\n",
    ")\n",
    "nmi_spectral = sk_metrics.normalized_mutual_info_score(ng20_c_labels, labels_spectral)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "ada0ab78",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This one dies on CVXPY because the dataset is too big\n",
    "# the results were computed in matlab using\n",
    "# the implementation from https://github.com/solevillar/kmeans_sdp\n",
    "# our CVXPY implementation is based on this repo, but SDPNAL implements\n",
    "# the sdp using sparse matrices, which makes it more robust to\n",
    "# larger datasets\n",
    "\n",
    "labels_sdp = (\n",
    "    sio.loadmat(\n",
    "        \"./newsgroups_sdp_results_k10.mat\"\n",
    "    )[\"sdp_labels\"].flatten()\n",
    "    - 1 # matlab has 1-based indexing\n",
    ")\n",
    "\n",
    "centroids_sdp = kmeans_jax.kmeans.compute_centroids(ng20_c_data, labels_sdp, 10)\n",
    "loss_sdp = kmeans_jax.kmeans.compute_loss(ng20_c_data, centroids_sdp, labels_sdp)\n",
    "nmi_sdp = sk_metrics.normalized_mutual_info_score(ng20_c_labels, labels_sdp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "462d9533",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Lloyd: Loss = 957.43, NMI = 0.23\n",
      "Hartigan: Loss = 951.96, NMI = 0.31\n",
      "SDP: Loss = 956.88, NMI = 0.27\n",
      "Spectral: Loss = 953.67, NMI = 0.27\n"
     ]
    }
   ],
   "source": [
    "print(f\"Lloyd: Loss = {lloyd_results['loss']:.2f}, NMI = {nmi_lloyd:.2f}\")\n",
    "print(f\"Hartigan: Loss = {hartigan_results['loss']:.2f}, NMI = {nmi_hartigan:.2f}\")\n",
    "print(\n",
    "    f\"SDP: Loss = {loss_sdp:.2f}, NMI = {nmi_sdp:.2f}\"\n",
    ")\n",
    "print(\n",
    "    f\"Spectral: Loss = {loss_spectral:.2f}, NMI = {nmi_spectral:.2f}\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "obs-on-kmeans-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
