{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pairwise Adjusted Mutual Information\n",
    "\n",
    "# Experiments on synthetic data\n",
    "\n",
    "This notebook presents the experiments on synthetic data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "from scipy import sparse\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics.cluster import contingency_matrix\n",
    "from sklearn.metrics import mutual_info_score\n",
    "from sklearn.metrics.cluster._expected_mutual_info_fast import expected_mutual_information"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pairwise adjustement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_adjusted_mutual_info_pair(contingency, n_samples):\n",
    "    \"\"\"Return pairwise adjusted mutual information.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    contingency: np.ndarray\n",
    "        Contingency matrix\n",
    "    n_samples : int\n",
    "        Number of samples\n",
    "    \"\"\"\n",
    "    k, l = contingency.shape\n",
    "    a = contingency.sum(axis=1)\n",
    "    b = contingency.sum(axis=0)\n",
    "    c = contingency.ravel()\n",
    "    # first term\n",
    "    factor = c * (contingency - np.outer(a, np.ones(l)) - np.outer(np.ones(k), b) + n_samples).ravel()\n",
    "    entropy = np.zeros(len(c))\n",
    "    entropy[c > 0] = c[c > 0] / n_samples * np.log(c[c > 0] / n_samples)\n",
    "    entropy_ = np.zeros(len(c))\n",
    "    entropy_[c > 1] = (c[c > 1] - 1) / n_samples * np.log((c[c > 1] - 1) / n_samples)\n",
    "    result = np.sum(factor * (entropy - entropy_)) / n_samples ** 2\n",
    "    # second term\n",
    "    factor = ((np.outer(a, np.ones(l)) - contingency) * (np.outer(np.ones(k), b) - contingency)).ravel()\n",
    "    entropy_ = (c + 1) / n_samples * np.log((c + 1) / n_samples)\n",
    "    result += np.sum(factor * (entropy - entropy_)) / n_samples ** 2\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Full adjustement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_adjusted_mutual_info_exact(contingency, n_samples):\n",
    "    \"\"\"Return adjusted mutual information (without normalization).\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    contingency: np.ndarray\n",
    "        Contingency matrix\n",
    "    n_samples : int\n",
    "        Number of samples\n",
    "    \"\"\"\n",
    "    mi = mutual_info_score(_, _, contingency=contingency)\n",
    "    emi = expected_mutual_information(contingency, n_samples)\n",
    "    result = mi - emi\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Symmetric clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 100\n",
    "cluster_size = 10\n",
    "labels = np.arange(n) // cluster_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cluster_size_range = np.arange(1, 100, 1).astype(int)\n",
    "adjusted_mutual_info_pair = []\n",
    "adjusted_mutual_info = []\n",
    "\n",
    "for k in cluster_size_range:\n",
    "    labels_ = np.arange(n) // k\n",
    "    contingency = contingency_matrix(labels, labels_)\n",
    "    adjusted_mutual_info_pair.append(get_adjusted_mutual_info_pair(contingency, len(labels)))\n",
    "    adjusted_mutual_info.append(get_adjusted_mutual_info_exact(contingency, len(labels)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.plot(cluster_size_range, adjusted_mutual_info, lw=2, c='b')\n",
    "plt.vlines(cluster_size, ymin=1.05 *np.max(adjusted_mutual_info), ymax=1.1*np.max(adjusted_mutual_info), lw=2)\n",
    "plt.xlabel('Cluster size')\n",
    "plt.ylabel('Mutual information')\n",
    "plt.yticks([0, 1, 2])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.plot(cluster_size_range, adjusted_mutual_info_pair, lw=2, c='b')\n",
    "plt.vlines(cluster_size, ymin=1.05 *np.max(adjusted_mutual_info_pair), ymax=1.1*np.max(adjusted_mutual_info_pair), lw=2)\n",
    "plt.xlabel('Cluster size')\n",
    "plt.ylabel('Mutual information')\n",
    "plt.yticks([0, 0.01, 0.02, 0.03])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Random clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_random_labels(n, n_clusters):\n",
    "    p = np.random.rand(n_clusters)\n",
    "    p /= np.sum(p)\n",
    "    return np.random.choice(n_clusters, p=p, size=n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0)\n",
    "n = 100\n",
    "n_clusters = 2\n",
    "n_exp = 10\n",
    "n_tests = 1000\n",
    "\n",
    "results = []\n",
    "\n",
    "for t in range(n_exp):\n",
    "    count = 0\n",
    "    for i in range(n_tests):\n",
    "        labels = get_random_labels(n, n_clusters)\n",
    "        labels1 = get_random_labels(n, n_clusters)\n",
    "        labels2 = get_random_labels(n, n_clusters)\n",
    "        contingency1 = contingency_matrix(labels, labels1)\n",
    "        contingency2 = contingency_matrix(labels, labels2)\n",
    "        order = get_adjusted_mutual_info_exact(contingency1, n) > get_adjusted_mutual_info_exact(contingency2, n)\n",
    "        order_pair = get_adjusted_mutual_info_pair(contingency1, n) > get_adjusted_mutual_info_pair(contingency2, n)\n",
    "        count += int(order == order_pair)\n",
    "    results.append(count / n_tests)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.std(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Computation times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_range = [100, 300, 1000, 3000, 10000, 30000, 100000, 300000, 1000000, 3000000, 10000000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_clusters = 10\n",
    "n_runs = 100000\n",
    "\n",
    "mean_exact = []\n",
    "mean_pair = []\n",
    "std_exact = []\n",
    "std_pair = []\n",
    "\n",
    "for n in n_range:\n",
    "    print(n)\n",
    "    \n",
    "    times_exact = []\n",
    "    times_pair = []\n",
    "    \n",
    "    for t in range(min(int(n_runs / n) + 1, 5)):\n",
    "        labels = np.arange(n) % n_clusters\n",
    "        labels_ = get_random_labels(n, n_clusters)\n",
    "        contingency = contingency_matrix(labels, labels_)\n",
    "        t0 = time.time()\n",
    "        get_adjusted_mutual_info_exact(contingency, n)\n",
    "        t1 = time.time()\n",
    "        times_exact.append(t1 - t0)\n",
    "        t0 = time.time()\n",
    "        get_adjusted_mutual_info_pair(contingency, n)\n",
    "        t1 = time.time()\n",
    "        times_pair.append(t1 - t0)\n",
    "        \n",
    "    mean_exact.append(np.mean(times_exact))\n",
    "    mean_pair.append(np.mean(times_pair))\n",
    "    std_exact.append(np.std(times_exact))\n",
    "    std_pair.append(np.std(times_pair))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.xscale('log')\n",
    "plt.yscale('log')\n",
    "plt.errorbar(n_range, mean_exact, yerr=std_exact, label='Full adjustement', linestyle='none', marker='.', c='b', lw=3)\n",
    "plt.errorbar(n_range, mean_pair, yerr=std_pair, label='Pairwise adjustement', linestyle='none', marker='.', c='r', lw=3)\n",
    "plt.legend()\n",
    "plt.xlabel('Number of samples')\n",
    "plt.ylabel('Computation time (s)')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
