{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-18T17:08:51.199533Z",
     "start_time": "2021-01-18T17:08:51.192962Z"
    }
   },
   "source": [
    "# Pairwise Adjusted Mutual Information\n",
    "\n",
    "# Experiments on real data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook presents the experiments on real data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-08T18:28:17.670665Z",
     "start_time": "2021-01-08T18:28:17.435233Z"
    }
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:45:44.309788Z",
     "start_time": "2021-05-27T10:45:44.298140Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "import glob\n",
    "from natsort import natsorted\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import stats, sparse\n",
    "\n",
    "from sklearn.metrics.cluster import contingency_matrix\n",
    "from sklearn.metrics import mutual_info_score\n",
    "from sklearn.metrics.cluster._expected_mutual_info_fast import expected_mutual_information\n",
    "from sklearn import cluster, datasets, mixture\n",
    "from sklearn.neighbors import kneighbors_graph\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.decomposition import TruncatedSVD\n",
    "from openml.datasets import edit_dataset, fork_dataset, get_dataset, list_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:42:44.248095Z",
     "start_time": "2021-05-27T10:42:44.243178Z"
    }
   },
   "outputs": [],
   "source": [
    "warnings.filterwarnings(\"ignore\")\n",
    "# set the seed for reproducible results\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pairwise Adjusted Mutual Information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:42:44.262175Z",
     "start_time": "2021-05-27T10:42:44.251564Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_adjusted_mutual_info_pair(contingency, n_samples):\n",
    "    \"\"\"Return pairwise adjusted mutual information.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    contingency: np.ndarray\n",
    "        Contingency matrix\n",
    "    n_samples : int\n",
    "        Number of samples\n",
    "    \"\"\"\n",
    "    k, l = contingency.shape\n",
    "    a = contingency.sum(axis=1)\n",
    "    b = contingency.sum(axis=0)\n",
    "    c = contingency.ravel()\n",
    "    # first term\n",
    "    factor = c * (contingency - np.outer(a, np.ones(l)) - np.outer(np.ones(k), b) + n_samples).ravel()\n",
    "    entropy = np.zeros(len(c))\n",
    "    entropy[c > 0] = c[c > 0] / n_samples * np.log(c[c > 0] / n_samples)\n",
    "    entropy_ = np.zeros(len(c))\n",
    "    entropy_[c > 1] = (c[c > 1] - 1) / n_samples * np.log((c[c > 1] - 1) / n_samples)\n",
    "    result = np.sum(factor * (entropy - entropy_)) / n_samples ** 2\n",
    "    # second term\n",
    "    factor = ((np.outer(a, np.ones(l)) - contingency) * (np.outer(np.ones(k), b) - contingency)).ravel()\n",
    "    entropy_ = (c + 1) / n_samples * np.log((c + 1) / n_samples)\n",
    "    result += np.sum(factor * (entropy - entropy_)) / n_samples ** 2\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Full Adjusted Mutual Information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:42:44.268598Z",
     "start_time": "2021-05-27T10:42:44.264721Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_adjusted_mutual_info_exact(contingency, n_samples):\n",
    "    \"\"\"Return adjusted mutual information (without normalization).\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    contingency: np.ndarray\n",
    "        Contingency matrix\n",
    "    n_samples : int\n",
    "        Number of samples\n",
    "    \"\"\"\n",
    "    mi = mutual_info_score(_, _, contingency=contingency)\n",
    "    emi = expected_mutual_information(contingency, n_samples)\n",
    "    result = mi - emi\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clustering"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-01-18T18:59:05.500224Z",
     "start_time": "2021-01-18T18:59:05.497119Z"
    }
   },
   "source": [
    "We consider the clustering algorithms of scikit-learn.\n",
    "\n",
    "In order to evaluate similarity between results obtained with **Full Adjusted Mutual Information** and **Pairwise Adjusted Mutual Information**, we use the **Spearman correlation** between rankings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:42:44.272305Z",
     "start_time": "2021-05-27T10:42:44.269943Z"
    }
   },
   "outputs": [],
   "source": [
    "default_base = {'quantile': .3,\n",
    "                'eps': .3,\n",
    "                'damping': .9,\n",
    "                'preference': -200,\n",
    "                'n_neighbors': 10,\n",
    "                'n_clusters': 3,\n",
    "                'min_samples': 20,\n",
    "                'xi': 0.05,\n",
    "                'min_cluster_size': 0.1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:42:44.279510Z",
     "start_time": "2021-05-27T10:42:44.273599Z"
    }
   },
   "outputs": [],
   "source": [
    "def prepare_algorithms(X):\n",
    "    params = default_base.copy()\n",
    "    # connectivity matrix for structured Ward\n",
    "    connectivity = kneighbors_graph(\n",
    "        X, n_neighbors=params['n_neighbors'], include_self=False)\n",
    "    # make connectivity symmetric\n",
    "    connectivity = 0.5 * (connectivity + connectivity.T)\n",
    "    two_means = cluster.MiniBatchKMeans(n_clusters=params['n_clusters'])\n",
    "    ms = cluster.MeanShift()\n",
    "    ward = cluster.AgglomerativeClustering(\n",
    "        n_clusters=params['n_clusters'], linkage='ward',\n",
    "        connectivity=connectivity)\n",
    "    spectral = cluster.SpectralClustering(\n",
    "        n_clusters=params['n_clusters'], eigen_solver='arpack',\n",
    "        affinity=\"nearest_neighbors\")\n",
    "    dbscan = cluster.DBSCAN(eps=params['eps'])\n",
    "    optics = cluster.OPTICS(min_samples=params['min_samples'],\n",
    "                            xi=params['xi'],\n",
    "                            min_cluster_size=params['min_cluster_size'])\n",
    "    affinity_propagation = cluster.AffinityPropagation(\n",
    "        damping=params['damping'], preference=params['preference'])\n",
    "    average_linkage = cluster.AgglomerativeClustering(\n",
    "        linkage=\"average\", affinity=\"cityblock\",\n",
    "        n_clusters=params['n_clusters'], connectivity=connectivity)\n",
    "    birch = cluster.Birch(n_clusters=params['n_clusters'])\n",
    "    gmm = mixture.GaussianMixture(\n",
    "        n_components=params['n_clusters'], covariance_type='full')\n",
    "    return (\n",
    "                ('MiniBatchKMeans', two_means),\n",
    "                ('MeanShift', ms),\n",
    "                ('AffinityPropagation', affinity_propagation),\n",
    "                ('SpectralClustering', spectral),\n",
    "                ('Ward', ward),\n",
    "                ('AgglomerativeClustering', average_linkage),\n",
    "                ('DBSCAN', dbscan),\n",
    "                ('OPTICS', optics),\n",
    "                ('Birch', birch),\n",
    "                ('GaussianMixture', gmm)\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:42:44.289051Z",
     "start_time": "2021-05-27T10:42:44.280975Z"
    }
   },
   "outputs": [],
   "source": [
    "def run_experiment(dataset_iterator, n_datasets):\n",
    "    dataset_names = []\n",
    "    nb_samples = []\n",
    "    nb_labels = []\n",
    "    nb_features = []\n",
    "    results = []\n",
    "    gains = []\n",
    "    \n",
    "    for i, (name, X, y) in enumerate(dataset_iterator): \n",
    "        \n",
    "        n_samples, n_features = X.shape\n",
    "        print(i + 1, \"/\", n_datasets, n_features, len(set(y)),  n_samples)\n",
    "        dataset_names.append(name)\n",
    "        nb_samples.append(n_samples)\n",
    "        nb_labels.append(len(set(y)))\n",
    "        nb_features.append(n_features)\n",
    "\n",
    "        if n_features > 100:\n",
    "            # dimension reduction\n",
    "            svd = TruncatedSVD(n_components=10)\n",
    "            X = svd.fit_transform(X)\n",
    "\n",
    "        X = StandardScaler(with_mean=False).fit_transform(X)\n",
    "        clustering_algorithms = prepare_algorithms(X)\n",
    "\n",
    "        sim_full = []\n",
    "        sim_pair = []\n",
    "        time_full = []\n",
    "        time_pair = []\n",
    "\n",
    "        for algo_name, algorithm in clustering_algorithms:\n",
    "            try:\n",
    "                algorithm.fit(X)\n",
    "                if hasattr(algorithm, 'labels_'):\n",
    "                    y_pred = algorithm.labels_.astype(int)\n",
    "                else:\n",
    "                    y_pred = algorithm.predict(X)\n",
    "                contingency = contingency_matrix(y, y_pred)\n",
    "                t0 = time.time()\n",
    "                sim_full.append(get_adjusted_mutual_info_exact(contingency, len(y)))\n",
    "                t1 = time.time()\n",
    "                time_full.append(t1 - t0)\n",
    "                t0 = time.time()\n",
    "                sim_pair.append(get_adjusted_mutual_info_pair(contingency, len(y)))\n",
    "                t1 = time.time()        \n",
    "                time_pair.append(t1 - t0)\n",
    "            except:\n",
    "                pass\n",
    "        \n",
    "        result = stats.spearmanr(sim_full, sim_pair).correlation\n",
    "        gain = np.mean(np.array(time_full) / np.array(time_pair))\n",
    "\n",
    "        results.append(result)\n",
    "        gains.append(gain)\n",
    "    return dataset_names, nb_samples, nb_labels, nb_features, results, gains"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:48:28.953452Z",
     "start_time": "2021-05-27T10:48:28.940498Z"
    }
   },
   "outputs": [],
   "source": [
    "def evaluate_results(results, gains, nb_samples, name=\"\"):\n",
    "    # nb of datasets with correlation > 0.95\n",
    "    np.sum(np.array(results) > 0.95)\n",
    "    index = np.argsort(nb_samples)\n",
    "    # Spearman correlation visualisation\n",
    "    plt.plot(1 + np.arange(len(results)), np.array(results)[index], c='b', lw=3)\n",
    "    plt.ylim(0, 1.01)\n",
    "    plt.xlabel('Dataset')\n",
    "    plt.ylabel('Spearman correlation')\n",
    "    plt.savefig(name + '_spearman.pdf', bbox_inches='tight', transparent=True)\n",
    "    plt.show()\n",
    "    # Speedup visualisation\n",
    "    plt.plot(1 + np.arange(len(gains)), np.array(gains)[index], c='b', lw=3)\n",
    "    plt.ylim(0, np.max(gains) + 5)\n",
    "    plt.xlabel('Dataset')\n",
    "    plt.ylabel('Speed-up')\n",
    "    plt.savefig(name + '_speedup.pdf', bbox_inches='tight', transparent=True)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:42:44.300549Z",
     "start_time": "2021-05-27T10:42:44.297561Z"
    }
   },
   "outputs": [],
   "source": [
    "def save_results_csv(dataset_names, nb_samples, nb_labels, nb_features, results, gains, name=\"\"):\n",
    "    output_results = pd.DataFrame()\n",
    "    output_results['dataset'] = dataset_names\n",
    "    output_results['nb_samples'] = nb_samples\n",
    "    output_results['nb_labels'] = nb_labels\n",
    "    output_results['nb_feature'] = nb_features\n",
    "    output_results['pearson'] = results\n",
    "    output_results['gain'] = gains\n",
    "    output_results.to_csv(f'./{name}_results.csv')\n",
    "    return output_results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We consider the following [benchmark](https://github.com/gagolews/clustering_benchmarks_v1), consisting of **79** different datasets:\n",
    "\n",
    "M. Gagolewski and others (Eds.), Benchmark Suite for Clustering Algorithms -- Version 1, 2020\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:42:44.415538Z",
     "start_time": "2021-05-27T10:42:44.301901Z"
    }
   },
   "outputs": [],
   "source": [
    "!git clone https://github.com/gagolews/clustering_benchmarks_v1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:42:44.429050Z",
     "start_time": "2021-05-27T10:42:44.417162Z"
    }
   },
   "outputs": [],
   "source": [
    "path = \"./clustering_benchmarks_v1/\"\n",
    "data_files = natsorted([f for f in glob.glob(f\"{path}*/*.data.gz\")])\n",
    "dataset_names = [file.split(path)[1].split('.')[0] for file in data_files]\n",
    "dataset_names = [name for name in dataset_names if not (('g2mg' in name) or ('h2mg' in name))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:42:44.432763Z",
     "start_time": "2021-05-27T10:42:44.430593Z"
    }
   },
   "outputs": [],
   "source": [
    "n_datasets = len(dataset_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:42:44.437703Z",
     "start_time": "2021-05-27T10:42:44.434294Z"
    }
   },
   "outputs": [],
   "source": [
    "def iterate_gagolews_datasets(dataset_names):\n",
    "    for name in dataset_names:\n",
    "        X = np.loadtxt(path + name +\".data.gz\", ndmin=2)\n",
    "        y = np.loadtxt(path + name +\".labels0.gz\", dtype=np.intc)\n",
    "        n_samples, n_features = X.shape\n",
    "        yield name, X, y "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:43:52.213781Z",
     "start_time": "2021-05-27T10:42:44.439635Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "gagolews_dataset_iterator = iterate_gagolews_datasets(dataset_names)\n",
    "dataset_names, nb_samples, nb_labels, nb_features, results, gains = run_experiment(gagolews_dataset_iterator, n_datasets)\n",
    "evaluate_results(results, gains, nb_samples, \"gagolews\")\n",
    "save_results_csv(dataset_names, nb_samples, nb_labels, nb_features, results, gains, \"gagolews\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T08:47:57.669788Z",
     "start_time": "2021-05-27T08:47:57.663856Z"
    }
   },
   "source": [
    "# Experiment 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:22:43.581588Z",
     "start_time": "2021-05-27T10:22:43.575145Z"
    }
   },
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We consider data from [OpenML](https://www.openml.org)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:45:50.169958Z",
     "start_time": "2021-05-27T10:45:48.720778Z"
    }
   },
   "outputs": [],
   "source": [
    "datalist = list_datasets(output_format=\"dataframe\")\n",
    "dataset_ids = datalist[(datalist.NumberOfInstances > 1000)  & \\\n",
    "                       (datalist.NumberOfInstances < 50000) & \\\n",
    "                       (datalist.NumberOfFeatures < 100) & \\\n",
    "                       (datalist.NumberOfFeatures == datalist.NumberOfNumericFeatures) & \\\n",
    "                       (datalist.NumberOfMissingValues == 0)\n",
    "                      ]['did'].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:46:11.263907Z",
     "start_time": "2021-05-27T10:45:54.377054Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "filtered_dataset_ids = []\n",
    "for i, dataset_id in enumerate(dataset_ids):\n",
    "    dataset = get_dataset(int(dataset_id))\n",
    "    try:\n",
    "        X, y, categorical_indicator, attribute_names = dataset.get_data(dataset_format=\"array\", \n",
    "                                                                        target=dataset.default_target_attribute, \n",
    "                                                                        )\n",
    "        n_samples, n_features = X.shape\n",
    "        # we filter dataset without target or with too many labels\n",
    "        if (y is not None) and (len(np.unique(y))/n_samples < 0.2):\n",
    "            filtered_dataset_ids.append(dataset_id)\n",
    "    except Exception:\n",
    "        print(dataset_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:46:35.450778Z",
     "start_time": "2021-05-27T10:46:35.443488Z"
    }
   },
   "outputs": [],
   "source": [
    "n_datasets = len(filtered_dataset_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:46:57.571154Z",
     "start_time": "2021-05-27T10:46:57.566120Z"
    }
   },
   "outputs": [],
   "source": [
    "def iterate_openml_datasets(filtered_dataset_ids):\n",
    "    for dataset_id in filtered_dataset_ids:\n",
    "        dataset = get_dataset(int(dataset_id))\n",
    "        X, y, categorical_indicator, attribute_names = dataset.get_data(dataset_format=\"array\", \n",
    "                                                                        target=dataset.default_target_attribute\n",
    "                                                                       )\n",
    "        if sparse.issparse(X):\n",
    "            X = X.toarray()\n",
    "        yield dataset.name, X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-05-27T10:47:12.110029Z",
     "start_time": "2021-05-27T10:46:58.462104Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "openml_dataset_iterator = iterate_openml_datasets(filtered_dataset_ids)\n",
    "dataset_names, nb_samples, nb_labels, nb_features, results, gains = run_experiment(openml_dataset_iterator, n_datasets)\n",
    "evaluate_results(results, gains, nb_samples, \"openml\")\n",
    "save_results_csv(dataset_names, nb_samples, nb_labels, nb_features, results, gains, \"openml\")"
   ]
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
