{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9f614e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "import jax.random as jr\n",
    "import jax\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import equinox as eqx\n",
    "from cryojax.jax_util import filter_bmap\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rcParams\n",
    "import seaborn as sns\n",
    "from scipy import stats\n",
    "\n",
    "from kmeans_jax.kmeans import compute_centroids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913b216f",
   "metadata": {},
   "outputs": [],
   "source": [
    "color1 = \"#1f77b4\"\n",
    "color2 = \"#fe6100\"\n",
    "\n",
    "sns.set_context(\"paper\")\n",
    "sns.set_style(\"white\")\n",
    "\n",
    "rcParams[\"font.family\"] = \"serif\"\n",
    "\n",
    "# set font size\n",
    "rcParams[\"font.size\"] = 25\n",
    "rcParams[\"axes.labelsize\"] = 25\n",
    "rcParams[\"axes.titlesize\"] = 25\n",
    "rcParams[\"xtick.labelsize\"] = 25\n",
    "rcParams[\"ytick.labelsize\"] = 25\n",
    "rcParams[\"legend.fontsize\"] = 23"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cb2f134",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_conf_interval(trials, alpha: float = 0.05):\n",
    "    \"\"\"\n",
    "    Compute the confidence interval for Binomial proportions using\n",
    "    Wilson's Interval.\n",
    "\n",
    "    **Arguments:**\n",
    "        trials: The trials to compute the confidence interval for.\n",
    "            Shape (..., n_trials) where ... is any number of batch dimensions and\n",
    "            n is the number oftrials performed.\n",
    "        alpha: The significance level.\n",
    "    **Returns:**\n",
    "        A tuple containing the lower and upper bounds of the confidence interval for\n",
    "        each batch dimension.\n",
    "    \"\"\"\n",
    "    n_s = trials.sum(axis=-1)\n",
    "    n = trials.shape[-1]\n",
    "\n",
    "    crit = stats.norm.isf(alpha)\n",
    "    crit2 = crit**2\n",
    "    q = n_s / n\n",
    "    denom = 1 + crit2 / n\n",
    "    center = (q + crit2 / (2 * n)) / denom\n",
    "    width = crit * np.sqrt(q * (1.0 - 1) / n + crit2 / (4.0 * n**2))\n",
    "\n",
    "    width /= denom\n",
    "\n",
    "    return center - width, center + width"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4fd0615",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Purity coefficients (see paper for definition)\n",
    "def compute_Rij(i, j, labels, true_labels):\n",
    "    return jnp.logical_and(labels == j, true_labels == i).sum() / (labels == j).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "048811fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_cluster_statistics(point_idx, curr_labels, true_labels):\n",
    "\n",
    "    Rlj = compute_Rij(true_labels[point_idx], curr_labels[point_idx], curr_labels, true_labels)\n",
    "    Rlk = compute_Rij(true_labels[point_idx], 1 - curr_labels[point_idx], curr_labels, true_labels)  # \\line{j} -> k\n",
    "\n",
    "    cluster_populations = jnp.bincount(curr_labels, length=2)\n",
    "    C_j = cluster_populations[curr_labels[point_idx]]\n",
    "    C_k = cluster_populations[1 - curr_labels[point_idx]]\n",
    "    return C_j, C_k, Rlj, Rlk\n",
    "\n",
    "def compute_rho_lloyd(point_idx, sigma2, tau2, curr_labels, true_labels):\n",
    "    C_j, C_k, _, _ = compute_cluster_statistics(point_idx, curr_labels, true_labels)\n",
    "    # Lloyd's Theoretical Upper Bound\n",
    "    rho_lloyd_term1 = (\n",
    "        4\n",
    "        * sigma2\n",
    "        * (C_j - 1)\n",
    "        * C_j**2\n",
    "        * C_k\n",
    "        * (C_k + 1)\n",
    "        * (C_j * (sigma2 + 2 * tau2) - 2 * tau2)\n",
    "    )\n",
    "    rho_lloyd_term2 = (\n",
    "        -C_j * (sigma2 + 4 * tau2) * C_k\n",
    "        + C_j**2 * (sigma2 + 2 * (sigma2 + tau2) * C_k)\n",
    "        + 2 * tau2 * C_k\n",
    "    ) ** 2\n",
    "\n",
    "    return rho_lloyd_term1 / rho_lloyd_term2\n",
    "\n",
    "def compute_rho_hartigan(point_idx, sigma2, tau2, curr_labels, true_labels):\n",
    "    # Hartigan's Theoretical Upper Bound\n",
    "    C_j, C_k, Rlj, Rlk = compute_cluster_statistics(point_idx, curr_labels, true_labels)\n",
    "    rho_hart_term1 = C_j / (C_j - 1) * (1 - Rlj) ** 2\n",
    "    rho_hart_term2 = C_k / (C_k + 1) * (1 - Rlk) ** 2\n",
    "\n",
    "    rho_hart = (\n",
    "        1\n",
    "        - (tau2 * (rho_hart_term1 - rho_hart_term2)) ** 2\n",
    "        / (tau2 * (rho_hart_term1 + rho_hart_term2) + sigma2) ** 2\n",
    "    )\n",
    "    return rho_hart\n",
    "\n",
    "@eqx.filter_vmap(in_axes=(0, None, None, None, None, None, None))\n",
    "def _run_numerical_experiment(key, n, d, tau2, sigma2, true_labels, curr_labels):\n",
    "    key_centers, key_noise = jr.split(key, 2)\n",
    "\n",
    "    true_centers = jr.normal(key_centers, (2, d)) * jnp.sqrt(tau2)\n",
    "    noise = jr.normal(key_noise, (n, d)) * jnp.sqrt(sigma2)\n",
    "\n",
    "    samples = true_centers[true_labels] + noise\n",
    "    label_xj = curr_labels[0]\n",
    "\n",
    "    cluster_centers = compute_centroids(samples, curr_labels, 2)\n",
    "    C_j, C_k, _, _ = compute_cluster_statistics(0, curr_labels, true_labels)\n",
    "\n",
    "    dist_Cj = jnp.sum((samples[0] - cluster_centers[label_xj]) ** 2)\n",
    "    hartigan_dist_Cj = dist_Cj * (C_j / (C_j - 1))\n",
    "\n",
    "    dist_Ck = jnp.sum((samples[0] - cluster_centers[1 - label_xj]) ** 2)\n",
    "    hartigan_dist_Ck = dist_Ck * (C_k / (C_k + 1))\n",
    "\n",
    "    return dist_Cj, dist_Ck, hartigan_dist_Cj, hartigan_dist_Ck\n",
    "\n",
    "def run_numerical_experiments(keys, n, d, tau2, sigma2, true_labels, curr_labels, batch_size=None):\n",
    "    return filter_bmap(\n",
    "        lambda x: _run_numerical_experiment(x, n, d, tau2, sigma2, true_labels, curr_labels),\n",
    "        keys,\n",
    "        batch_size=batch_size,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c86d254",
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jr.key(0)\n",
    "\n",
    "# Fixed parameters\n",
    "tau2 = 1.0\n",
    "n = 40\n",
    "true_prop = 0.5\n",
    "\n",
    "# Cluster labels\n",
    "n_c1 = int(n * true_prop)\n",
    "n_c2 = n - n_c1\n",
    "true_labels = jnp.concatenate([jnp.zeros(n_c1, dtype=int), jnp.ones(n_c2, dtype=int)])\n",
    "\n",
    "curr_prop = 0.25\n",
    "Cj = n // 2\n",
    "Ck = n - Cj\n",
    "\n",
    "Cj_0 = int(Cj * curr_prop)\n",
    "Cj_1 = Cj - Cj_0\n",
    "Ck_1 = int(Ck * curr_prop)\n",
    "Ck_0 = Ck - Ck_1\n",
    "curr_labels = jnp.concatenate([\n",
    "    jnp.zeros(Cj_0, dtype=int),\n",
    "    jnp.ones(Cj_1, dtype=int),\n",
    "    jnp.ones(Ck_1, dtype=int),\n",
    "    jnp.zeros(Ck_0, dtype=int),\n",
    "    \n",
    "])\n",
    "\n",
    "# Validating purity coefficients\n",
    "R00 = compute_Rij(0, 0, curr_labels, true_labels)\n",
    "R01 = compute_Rij(0, 1, curr_labels, true_labels)\n",
    "print(\"Purity R_0^0:\", R00)\n",
    "print(\"Purity R_0^1:\", R01)\n",
    "\n",
    "# Varying parameters\n",
    "\n",
    "# dimension\n",
    "d_values = jnp.logspace(0.8, 5, 20, dtype=int)\n",
    "\n",
    "# Noise variance\n",
    "base_sigma = 2 * tau2 * (n_c1 - 1) ** 2 / (n_c1 ** 2 / n_c2 + n_c1)\n",
    "print(\"Base sigma:\", base_sigma)\n",
    "sigma2_multiplier = jnp.array([0.01, 0.1, 0.5, 0.9, 1.0, 1.1, 1.5, 2.0])\n",
    "\n",
    "sigma2_values = base_sigma * sigma2_multiplier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e14db41d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "n_tries = 10000\n",
    "\n",
    "dist_Cj = np.zeros((len(d_values), len(sigma2_values), n_tries))\n",
    "dist_Ck = np.zeros((len(d_values), len(sigma2_values), n_tries))\n",
    "dist_hart_Cj = np.zeros((len(d_values), len(sigma2_values), n_tries))\n",
    "dist_hart_Ck = np.zeros((len(d_values), len(sigma2_values), n_tries))\n",
    "\n",
    "for i in tqdm(range(len(d_values))):\n",
    "    d = d_values[i]\n",
    "    for j, sigma2 in enumerate(sigma2_values):\n",
    "        key, subkey = jr.split(key)\n",
    "        keys = jr.split(subkey, n_tries)\n",
    "        (\n",
    "            dist_Cj[i, j],\n",
    "            dist_Ck[i, j],\n",
    "            dist_hart_Cj[i, j],\n",
    "            dist_hart_Ck[i, j],\n",
    "        ) = run_numerical_experiments(\n",
    "            keys, n, d, tau2, sigma2, true_labels, curr_labels, batch_size=20\n",
    "        )\n",
    "    jax.clear_caches()\n",
    "\n",
    "num_exps_lloyd = (dist_Cj < dist_Ck)\n",
    "num_exps_hartigan = (dist_hart_Cj < dist_hart_Ck)\n",
    "\n",
    "jnp.savez(\n",
    "    \"numerical_experiments_results.npz\",\n",
    "    # Save values to guarantee reproducibility\n",
    "    d_values=d_values,\n",
    "    sigma2_values=sigma2_values,\n",
    "    true_labels=true_labels,\n",
    "    curr_labels=curr_labels,\n",
    "    tau2=tau2,\n",
    "    # Save numerical experiment results\n",
    "    num_exps_lloyd=num_exps_lloyd,\n",
    "    num_exps_hartigan=num_exps_hartigan,\n",
    ")\n",
    "\"\"\"\n",
    "\n",
    "num_experiments_results = jnp.load(\"numerical_experiments_results.npz\")\n",
    "\n",
    "num_exps_lloyd = num_experiments_results[\"num_exps_lloyd\"]\n",
    "num_exps_hartigan = num_experiments_results[\"num_exps_hartigan\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "163c605f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute theoretical bounds\n",
    "\n",
    "# compute rho for i = 0\n",
    "rho_lloyd = compute_rho_lloyd(0, sigma2_values, tau2, curr_labels, true_labels)\n",
    "rho_hartigan = compute_rho_hartigan(0, sigma2_values, tau2, curr_labels, true_labels)\n",
    "\n",
    "th_bound_hartigan = rho_hartigan ** (d_values[:, None] / 4)\n",
    "\n",
    "# flip bound so that Lloyd and Hartigan are comparable\n",
    "th_bound_lloyd = 1 - rho_lloyd ** (d_values[:, None] / 4)\n",
    "\n",
    "# Compute numerical quantities\n",
    "ratios_lloyd = num_exps_lloyd.mean(-1)\n",
    "ratios_hartigan = num_exps_hartigan.mean(-1)\n",
    "\n",
    "## Confidence interval (Wilson's Interval)\n",
    "ci_lloyd = compute_conf_interval(num_exps_lloyd, alpha=0.05)\n",
    "ci_hart = compute_conf_interval(num_exps_hartigan, alpha=0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a79660c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(\n",
    "    2,\n",
    "    4,\n",
    "    figsize=(20, 10),\n",
    "    sharex=True,\n",
    "    sharey=True,\n",
    ")\n",
    "\n",
    "ms = 10\n",
    "lw = 2\n",
    "\n",
    "for j in range(len(sigma2_values)):\n",
    "    ax.flatten()[j].plot(\n",
    "        d_values, th_bound_hartigan[:, j], label=\"Th. Bound Hartigan\", ms=ms, linewidth=3, color=\"orange\"\n",
    "    )\n",
    "    if sigma2_multiplier[j] > 1.0:\n",
    "        ax.flatten()[j].plot(\n",
    "            d_values, th_bound_lloyd[:, j], label=\"Th. Bound Lloyd\", ms=ms, linewidth=3, color=\"blue\"\n",
    "        )\n",
    "    ax.flatten()[j].plot(\n",
    "        d_values, ratios_lloyd[:, j], label=\"Lloyd\", ls=\"--\", marker=\"o\", ms=ms, linewidth=lw, color=\"blue\"\n",
    "    )\n",
    "    ax.flatten()[j].plot(\n",
    "        d_values, ratios_hartigan[:, j], label=\"Hartigan\", ls=\"--\", marker=\"X\", ms=ms, linewidth=lw, color=\"orange\"\n",
    "    )\n",
    "    ax.flatten()[j].fill_between(\n",
    "        d_values, y1=ci_hart[0][:, j], y2=ci_hart[1][:, j], alpha=0.3, color=\"orange\"\n",
    "    )\n",
    "    ax.flatten()[j].fill_between(\n",
    "        d_values, y1=ci_lloyd[0][:, j], y2=ci_lloyd[1][:, j], alpha=0.3, color=\"blue\"\n",
    "    )\n",
    "    ax.flatten()[j].set_title(r\"$\\sigma^2=$\" + f\"{sigma2_values[j]:.2f} \" + r\"($\\beta =$\" + f\"{sigma2_multiplier[j]})\")\n",
    "    ax.flatten()[j].set_xscale(\"log\")\n",
    "    ax.flatten()[j].set_xticks([1e1, 1e2, 1e3, 1e4, 1e5])\n",
    "\n",
    "\n",
    "# set legend on the lower right panel\n",
    "ax.flatten()[-2].legend(bbox_to_anchor=(2.1, 2.35), loc=\"upper left\")\n",
    "\n",
    "# increase vertical distance between subplots\n",
    "plt.subplots_adjust(hspace=0.3, wspace=0.1)\n",
    "\n",
    "# fig.suptitle(\"Probability that partition is a fixed point\")\n",
    "fig.savefig(\"figure_numerical_experiments_raw.svg\", dpi=600, bbox_inches=\"tight\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "obs-on-kmeans-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
