{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28e25199",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Tuple\n",
    "import numpy as np\n",
    "from numba import jit\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from matplotlib import rcParams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ccdc4de",
   "metadata": {},
   "outputs": [],
   "source": [
    "color1 = \"#1f77b4\"\n",
    "color2 = \"#fe6100\"\n",
    "\n",
    "sns.set_context(\"paper\")\n",
    "sns.set_style(\"white\")\n",
    "\n",
    "rcParams[\"font.family\"] = \"serif\"\n",
    "# set font size\n",
    "rcParams[\"font.size\"] = 25\n",
    "rcParams[\"axes.labelsize\"] = 25\n",
    "rcParams[\"axes.titlesize\"] = 25\n",
    "rcParams[\"xtick.labelsize\"] = 25\n",
    "rcParams[\"ytick.labelsize\"] = 25\n",
    "rcParams[\"legend.fontsize\"] = 23\n",
    "rcParams[\"legend.title_fontsize\"] = 23"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4a9d6e0",
   "metadata": {},
   "source": [
    "# Clustering Algorithms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67508aba",
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def _compute_loss_np(data, centroids, labels):\n",
    "    return np.sum(np.abs(data - centroids[labels]) ** 2)\n",
    "\n",
    "\n",
    "@jit\n",
    "def compute_mean(data, mask):\n",
    "    n_elements = np.sum(mask)\n",
    "    if n_elements > 0:\n",
    "        mean = np.sum(data[mask], axis=0)\n",
    "        mean /= n_elements\n",
    "    else:\n",
    "        mean = np.zeros_like(data[0])\n",
    "    return mean\n",
    "\n",
    "\n",
    "@jit\n",
    "def compute_centroids_np(data, labels, centroids):\n",
    "    for i in range(centroids.shape[0]):\n",
    "        idx_mask = labels == i\n",
    "        centroids[i] = compute_mean(data, idx_mask)\n",
    "    return centroids\n",
    "\n",
    "\n",
    "@jit\n",
    "def assign_labels_lloyd_all(centroids, data):\n",
    "    distances = np.zeros((data.shape[0], centroids.shape[0]))\n",
    "    for n in range(data.shape[0]):\n",
    "        for k in range(centroids.shape[0]):\n",
    "            distances[n, k] = np.sum((data[n] - centroids[k]) ** 2)\n",
    "    return np.argmin(distances, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61398fd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def run_lloyd_kmeans(data, init_centroids, max_iters, checks_convergence=True):\n",
    "    # Initial quantities\n",
    "    labels = assign_labels_lloyd_all(init_centroids, data)\n",
    "    centroids = compute_centroids_np(data, labels, init_centroids.copy())\n",
    "\n",
    "    # Variables to update\n",
    "    old_labels = labels.copy()\n",
    "    n_iters = 1 # one iteration already done\n",
    "    for _ in range(max_iters):\n",
    "        labels = assign_labels_lloyd_all(centroids, data)\n",
    "        centroids = compute_centroids_np(data, labels, centroids)\n",
    "        loss = _compute_loss_np(data, centroids, labels)\n",
    "        break_condition = np.array_equal(labels, old_labels)\n",
    "        if checks_convergence and break_condition:\n",
    "            break\n",
    "        else:\n",
    "            old_labels = labels.copy()\n",
    "            n_iters += 1\n",
    "\n",
    "    loss = _compute_loss_np(data, centroids, labels)\n",
    "    return centroids, labels, loss, n_iters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34f2d061",
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def _assign_label_hartigan_np(centroids, cluster_populations, data_point, label_point):\n",
    "    distances = np.sum((data_point[None, ...] - centroids) ** 2, axis=-1)\n",
    "\n",
    "    for i in range(centroids.shape[0]):\n",
    "        if label_point == i:\n",
    "            if cluster_populations[i] <= 1:\n",
    "                distances[i] = -1.0  # prevent empty clusters\n",
    "            else:\n",
    "                scale_factor = cluster_populations[i] / (cluster_populations[i] - 1)\n",
    "                distances[i] *= scale_factor\n",
    "\n",
    "        else:\n",
    "            scale_factor = cluster_populations[i] / (cluster_populations[i] + 1)\n",
    "            distances[i] *= scale_factor\n",
    "    return np.argmin(distances)\n",
    "\n",
    "\n",
    "@jit\n",
    "def run_hartigan_kmeans(data, init_centroids, max_iters, checks_convergence=True):\n",
    "    # Initial quantities\n",
    "    labels = assign_labels_lloyd_all(init_centroids, data)\n",
    "    centroids = compute_centroids_np(data, labels, init_centroids.copy())\n",
    "    cluster_populations = np.bincount(labels, minlength=init_centroids.shape[0])\n",
    "\n",
    "    # Variables to update\n",
    "    old_labels = labels.copy()\n",
    "    old_old_labels = labels.copy()\n",
    "    n_iters = 1 # one iteration already done\n",
    "    for _ in range(max_iters):\n",
    "        for j in range(data.shape[0]):\n",
    "            \n",
    "            new_label = _assign_label_hartigan_np(\n",
    "                centroids, cluster_populations, data[j], labels[j]\n",
    "            )\n",
    "            if new_label != labels[j]:\n",
    "                n_clust = cluster_populations[labels[j]]\n",
    "                centroids[labels[j]] = (centroids[labels[j]] * n_clust - data[j]) / (\n",
    "                    n_clust - 1.0\n",
    "                )\n",
    "\n",
    "                n_clust = cluster_populations[new_label]\n",
    "                centroids[new_label] = (centroids[new_label] * n_clust + data[j]) / (\n",
    "                    n_clust + 1.0\n",
    "                )\n",
    "\n",
    "                cluster_populations[labels[j]] -= 1\n",
    "                cluster_populations[new_label] += 1\n",
    "                labels[j] = new_label\n",
    "\n",
    "        break_condition = np.array_equal(labels, old_labels) or (\n",
    "            np.array_equal(labels, old_old_labels) and n_iters > 0\n",
    "        )\n",
    "\n",
    "        if checks_convergence and break_condition:\n",
    "            break\n",
    "        else:\n",
    "            old_old_labels = old_labels.copy()\n",
    "            old_labels = labels.copy()\n",
    "            n_iters += 1\n",
    "    loss = _compute_loss_np(data, centroids, labels)\n",
    "    return centroids, labels, loss, n_iters\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a1b772f",
   "metadata": {},
   "source": [
    "# Compute time experiments utilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c5e992d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_dataset(n, d, noise_variance, prior_variance, K, rng):\n",
    "    \"\"\"\n",
    "    Generate a synthetic dataset from a Gaussian mixture model.\n",
    "\n",
    "    **Arguments**:\n",
    "        n: Number of data points to generate.\n",
    "        d: Dimensionality of each data point.\n",
    "        noise_variance: Variance of the Gaussian noise added to each data point.\n",
    "        prior_variance: Variance of the Gaussian prior for the cluster centroids.\n",
    "        K: Number of clusters.\n",
    "        rng: A numpy random Generator instance for reproducibility.\n",
    "    **Returns**:\n",
    "        A numpy array of shape (n, d) containing the generated data points.\n",
    "    \"\"\"\n",
    "\n",
    "    # Generate cluster centroids\n",
    "    rng_centers, rng_labels, rng_noise = rng.spawn(3)\n",
    "    centroids = rng_centers.normal(loc=0.0, scale=np.sqrt(prior_variance), size=(K, d))\n",
    "\n",
    "    # Assign data points to clusters\n",
    "    labels = rng_labels.integers(low=0, high=K, size=n)\n",
    "    # Generate data points\n",
    "    data = centroids[labels] + rng_noise.normal(\n",
    "        loc=0.0, scale=np.sqrt(noise_variance), size=(n, d)\n",
    "    )\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "470b2e50",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_experiment(\n",
    "    rng,\n",
    "    n_samples,\n",
    "    dimension,\n",
    "    n_clusters,\n",
    "    max_iterations,\n",
    "    checks_convergence,\n",
    "    *,\n",
    "    noise_variance=10.0,\n",
    "    prior_variance=1.0,\n",
    "):\n",
    "    rng_data, rng_init = rng.spawn(2)\n",
    "    data = generate_dataset(\n",
    "        n=n_samples,\n",
    "        d=dimension,\n",
    "        noise_variance=noise_variance,\n",
    "        prior_variance=prior_variance,\n",
    "        K=n_clusters,\n",
    "        rng=rng_data,\n",
    "    )\n",
    "\n",
    "    init_centroids = data[\n",
    "        rng_init.choice(data.shape[0], size=n_clusters, replace=False)\n",
    "    ]\n",
    "\n",
    "    # Time Lloyd's algorithm\n",
    "    start_time = time.time()\n",
    "    _, _, _, n_iters_lloyd = run_lloyd_kmeans(\n",
    "        data,\n",
    "        init_centroids,\n",
    "        max_iters=max_iterations,\n",
    "        checks_convergence=checks_convergence,\n",
    "    )\n",
    "    end_time = time.time()\n",
    "    time_lloyd = end_time - start_time\n",
    "\n",
    "    # Time Hartigan's algorithm\n",
    "    start_time = time.time()\n",
    "    _, _, _, n_iters_hartigan = run_hartigan_kmeans(\n",
    "        data,\n",
    "        init_centroids,\n",
    "        max_iters=max_iterations,\n",
    "        checks_convergence=checks_convergence,\n",
    "    )\n",
    "    end_time = time.time()\n",
    "    time_hartigan = end_time - start_time\n",
    "\n",
    "    return time_lloyd, n_iters_lloyd, time_hartigan, n_iters_hartigan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "446f0878",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_timing_experiments(\n",
    "    N_to_test,\n",
    "    d_to_test,\n",
    "    K_to_test,\n",
    "    n_tries_per_setting,\n",
    "    max_iterations,\n",
    "    checks_convergence,\n",
    "    *,\n",
    "    random_seed=0,\n",
    "    noise_variance=10.0,\n",
    "    prior_variance=1.0,\n",
    "):\n",
    "    lloyd_times = np.zeros(\n",
    "        (len(N_to_test), len(d_to_test), len(K_to_test), n_tries_per_setting)\n",
    "    )\n",
    "    n_iterations_lloyd = np.zeros(\n",
    "        (len(N_to_test), len(d_to_test), len(K_to_test), n_tries_per_setting)\n",
    "    )\n",
    "    hartigan_times = np.zeros(\n",
    "        (len(N_to_test), len(d_to_test), len(K_to_test), n_tries_per_setting)\n",
    "    )\n",
    "    n_iterations_hartigan = np.zeros(\n",
    "        (len(N_to_test), len(d_to_test), len(K_to_test), n_tries_per_setting)\n",
    "    )\n",
    "\n",
    "    rng = np.random.default_rng(random_seed)\n",
    "\n",
    "    # run warmup\n",
    "    print(\"Running warmup...\")\n",
    "    rng, subrng = rng.spawn(2)\n",
    "    run_experiment(\n",
    "        subrng,\n",
    "        n_samples=10,\n",
    "        dimension=10,\n",
    "        n_clusters=2,\n",
    "        max_iterations=2,\n",
    "        checks_convergence=checks_convergence,\n",
    "    )\n",
    "    print(\"Warmup done. Running experiments...\")\n",
    "\n",
    "    with tqdm(\n",
    "        total=len(N_to_test) * len(d_to_test) * len(K_to_test) * n_tries_per_setting\n",
    "    ) as pbar:\n",
    "        for i, dimension in enumerate(d_to_test):\n",
    "            for j, n_samples in enumerate(N_to_test):\n",
    "                for k, n_clusters in enumerate(K_to_test):\n",
    "                    for t in range(n_tries_per_setting):\n",
    "                        rng, subrng = rng.spawn(2)\n",
    "\n",
    "                        time_lloyd, n_iters_l, time_hartigan, n_iters_h = (\n",
    "                            run_experiment(\n",
    "                                subrng,\n",
    "                                n_samples,\n",
    "                                dimension,\n",
    "                                n_clusters,\n",
    "                                max_iterations,\n",
    "                                checks_convergence,\n",
    "                                noise_variance=noise_variance,\n",
    "                                prior_variance=prior_variance,\n",
    "                            )\n",
    "                        )\n",
    "\n",
    "                        lloyd_times[j, i, k, t] = time_lloyd\n",
    "                        n_iterations_lloyd[j, i, k, t] = n_iters_l\n",
    "\n",
    "                        hartigan_times[j, i, k, t] = time_hartigan\n",
    "                        n_iterations_hartigan[j, i, k, t] = n_iters_h\n",
    "\n",
    "                        pbar.update(1)\n",
    "\n",
    "    return lloyd_times, n_iterations_lloyd, hartigan_times, n_iterations_hartigan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ecd4ab9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_timing_results(\n",
    "    N_to_test,\n",
    "    d_to_test,\n",
    "    K_to_test,\n",
    "    lloyd_times,\n",
    "    hartigan_times,\n",
    "    save_path=None,\n",
    "    plots_legend=True,\n",
    "):\n",
    "    markers = [\"o\", \"^\", \"X\"]\n",
    "\n",
    "    # Ensure that we have enough markers\n",
    "    if len(d_to_test) > len(markers):\n",
    "        raise ValueError(\"Not enough markers defined for the number of dimensions\")\n",
    "\n",
    "    fig, ax = plt.subplots(\n",
    "        1, 3, figsize=(18, 8), sharex=True, sharey=True, layout=\"compressed\"\n",
    "    )\n",
    "\n",
    "    lw = 3\n",
    "    ms = 200\n",
    "    # Plotting with lines to distinguish algorithms and scatter to distinguish dimensions\n",
    "    for k in range(len(K_to_test)):\n",
    "        for i in range(len(d_to_test)):\n",
    "            ax[k].plot(\n",
    "                N_to_test,\n",
    "                lloyd_times.mean(axis=-1)[:, i, k],\n",
    "                label=\"Lloyd\"\n",
    "                if i == 0\n",
    "                else None,  # Add \"Lloyd\" label only once for legend\n",
    "                color=color1,\n",
    "                linewidth=lw,\n",
    "            )\n",
    "            ax[k].plot(\n",
    "                N_to_test,\n",
    "                hartigan_times.mean(axis=-1)[:, i, k],\n",
    "                label=\"Hartigan\"\n",
    "                if i == 0\n",
    "                else None,  # Add \"Hartigan\" label only once for legend\n",
    "                color=color2,\n",
    "                linewidth=lw,\n",
    "            )\n",
    "\n",
    "            ax[k].scatter(\n",
    "                N_to_test,\n",
    "                lloyd_times.mean(axis=-1)[:, i, k],\n",
    "                label=f\"d={d_to_test[i]}\"\n",
    "                if k == 0\n",
    "                else None,  # Add dimension label only once for legend\n",
    "                color=color1,\n",
    "                marker=markers[i],\n",
    "                s=ms,\n",
    "            )\n",
    "            ax[k].scatter(\n",
    "                N_to_test,\n",
    "                hartigan_times.mean(axis=-1)[:, i, k],\n",
    "                color=color2,\n",
    "                marker=markers[i],\n",
    "                s=ms,\n",
    "            )\n",
    "\n",
    "        ax[k].set_yscale(\"log\")\n",
    "        ax[k].set_xscale(\"log\", base=10)\n",
    "        ax[k].set_title(f\"K={K_to_test[k]}\")\n",
    "        ax[k].set_xlabel(\"Number of data points [n]\")\n",
    "\n",
    "    ax[0].set_ylabel(\"Average compute time (seconds)\")\n",
    "\n",
    "    if plots_legend:\n",
    "        # Create algorithm legend\n",
    "        thick_line_width = 7\n",
    "        algorithm_handles = [\n",
    "            plt.Line2D(\n",
    "                [0], [0], color=color1, linewidth=thick_line_width, label=\"Lloyd\"\n",
    "            ),\n",
    "            plt.Line2D(\n",
    "                [0], [0], color=color2, linewidth=thick_line_width, label=\"Hartigan\"\n",
    "            ),\n",
    "        ]\n",
    "\n",
    "        algorithm_legend = fig.legend(\n",
    "            algorithm_handles,\n",
    "            [\"Lloyd\", \"Hartigan\"],\n",
    "            loc=\"upper right\",\n",
    "            title=\"Algorithm\",\n",
    "            bbox_to_anchor=(1.17, 0.97),\n",
    "        )\n",
    "\n",
    "        # Create custom black markers for dimension legend\n",
    "        dimension_handles = [\n",
    "            plt.Line2D(\n",
    "                [0],\n",
    "                [0],\n",
    "                marker=markers[i],\n",
    "                color=\"w\",\n",
    "                markerfacecolor=\"black\",\n",
    "                markersize=15,\n",
    "                linestyle=\"\",\n",
    "            )\n",
    "            for i in range(len(d_to_test))\n",
    "        ]\n",
    "        dimension_labels = [f\"d={d}\" for d in d_to_test]\n",
    "        dimension_legend = fig.legend(\n",
    "            dimension_handles,\n",
    "            dimension_labels,\n",
    "            loc=\"lower right\",\n",
    "            title=\"Dimension\",\n",
    "            bbox_to_anchor=(1.17, 0.4),\n",
    "        )\n",
    "\n",
    "        # Add legends to the plot\n",
    "        ax[0].add_artist(algorithm_legend)\n",
    "        ax[0].add_artist(dimension_legend)\n",
    "\n",
    "    if save_path is not None:\n",
    "        fig.savefig(save_path)\n",
    "    return fig, ax"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9f9ca93",
   "metadata": {},
   "source": [
    "# Setting experiment parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad17cba2",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_to_test = np.logspace(np.log10(32), np.log10(10000), num=7, base=10, dtype=int)\n",
    "d_to_test = [100, 1000, 10000]\n",
    "K_to_test = [4, 8, 16]\n",
    "n_tries_per_setting = 10\n",
    "noise_variance = 10.0\n",
    "prior_variance = 1.0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "589bd376",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_to_test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ebde7ba",
   "metadata": {},
   "source": [
    "# Experiments single iteration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "386b5ac8",
   "metadata": {},
   "outputs": [],
   "source": [
    "lloyd_it_times, _, hartigan_it_times, _ = run_timing_experiments(\n",
    "    N_to_test,\n",
    "    d_to_test,\n",
    "    K_to_test,\n",
    "    n_tries_per_setting,\n",
    "    max_iterations=1,\n",
    "    checks_convergence=False,\n",
    "    noise_variance=noise_variance,\n",
    "    prior_variance=prior_variance,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23b12e4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_timing_results(\n",
    "    N_to_test,\n",
    "    d_to_test,\n",
    "    K_to_test,\n",
    "    lloyd_it_times,\n",
    "    hartigan_it_times,\n",
    "    plots_legend=True,\n",
    "    save_path=\"figure_runtime_one_iteration.svg\",\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7686496",
   "metadata": {},
   "source": [
    "# Running until convergence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3207c3cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "lloyd_conv_times, lloyd_n_iter, hartigan_conv_times, hartigan_n_iter = (\n",
    "    run_timing_experiments(\n",
    "        N_to_test,\n",
    "        d_to_test,\n",
    "        K_to_test,\n",
    "        n_tries_per_setting,\n",
    "        max_iterations=10000,\n",
    "        checks_convergence=True,\n",
    "        noise_variance=noise_variance,\n",
    "        prior_variance=prior_variance,\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7d83c31",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_timing_results(\n",
    "    N_to_test,\n",
    "    d_to_test,\n",
    "    K_to_test,\n",
    "    lloyd_conv_times,\n",
    "    hartigan_conv_times,\n",
    "    plots_legend=True,\n",
    "    save_path=\"figure_runtime_convergence.svg\",\n",
    ");"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "obs-on-kmeans-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
