{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "737a2e08",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Find best kernel (with bandwidth) according to Rand Index\n",
    "\n",
    "def get_hyperparam(X, y_true, gammas):\n",
    "    \n",
    "    true_k = len(np.unique(y_true))\n",
    "    best_rand = 0\n",
    "    \n",
    "    for gamma in gammas:\n",
    "        \n",
    "        print('Test Gamma =', gamma)\n",
    "        rands = np.zeros(2)\n",
    "\n",
    "        Kmat = pairwise_kernels(X, metric=rbf, gamma=gamma)\n",
    "        y = kernelkmeans(Kmat, true_k, algo='kernelkmeans', n_init=10, n_iter=100)\n",
    "        rands[0] = adjusted_rand_score(y_true, y)\n",
    "        \n",
    "        Kmat = pairwise_kernels(X, metric=laplace, gamma=gamma)\n",
    "        y = kernelkmeans(Kmat, true_k, algo='kernelkmeans', n_init=10, n_iter=100)\n",
    "        rands[1] = adjusted_rand_score(y_true, y)\n",
    "        \n",
    "        best_i = np.argmax(rands)\n",
    "        \n",
    "        if rands[best_i]>best_rand:\n",
    "            best_gamma = gamma\n",
    "            best_rand = rands[best_i]\n",
    "            best_kernel = best_i\n",
    "            \n",
    "    return(best_gamma, best_kernel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e9a8e04",
   "metadata": {},
   "outputs": [],
   "source": [
    "def imm_experiments(X, y_true, gammas, n_init = 10):\n",
    "    \n",
    "    true_k = len(np.unique(y_true))\n",
    "    \n",
    "    print('--- First k-means')\n",
    "    Kmat_lin = pairwise_kernels(X, metric=linear)\n",
    "\n",
    "    y_kmeans, centers_kmeans = kernelkmeans(X, true_k, algo='kmeans', n_init = n_init, n_iter=200)\n",
    "    rand_kmeans = adjusted_rand_score(y_kmeans, y_true)\n",
    "    print('---> Rand Score of k-means:', rand_kmeans)\n",
    "\n",
    "    y_kmeans_imm, kmeans_threshold_cuts = imm(X, y_kmeans, centers_kmeans, check_all_cuts = True)\n",
    "    rand_imm = adjusted_rand_score(y_kmeans_imm, y_true)\n",
    "    print('---> Rand Score of linear IMM:', rand_kmeans)\n",
    "    \n",
    "    print('--- Now to kernels. Find kernel and gamma ---')\n",
    "    gamma, best_kernel = get_hyperparam(X, y_true, gammas)\n",
    "    \n",
    "    if best_kernel==0:\n",
    "        kernelfunc = rbf\n",
    "    elif best_kernel==1:\n",
    "        kernelfunc = laplace\n",
    "    \n",
    "    print('We choose the', str(kernelfunc), 'kernel with gamma =', gamma)\n",
    "    print('--- Run kernel k-means ---')\n",
    "    Kmat = pairwise_kernels(X, metric=kernelfunc, gamma=gamma)\n",
    "    y_kkm = kernelkmeans(Kmat, true_k, algo='kernelkmeans', n_init = n_init, n_iter=200)\n",
    "\n",
    "    rand_kkm = adjusted_rand_score(y_true, y_kkm)\n",
    "    print('---> Rand Score:', rand_kkm)\n",
    "    \n",
    "    # For the Gaussian kernel, we try both Taylor as well as kernel matrix features\n",
    "    if best_kernel==0:\n",
    "        \n",
    "        print('Run Gaussian Taylor IMM on kernel k-means')\n",
    "        y_taylor_imm_on_kkm, threshold_cuts_taylor = taylor_imm(X, y_kkm, gamma, 5, check_all_cuts = True)\n",
    "\n",
    "        rand_taylor_imm_on_kkm = adjusted_rand_score(y_taylor_imm_on_kkm, y_true)\n",
    "        price_taylor_imm_on_kkm = np.sum(kernelkmeanscost(Kmat, y_taylor_imm_on_kkm))/np.sum(kernelkmeanscost(Kmat, y_kkm))\n",
    "\n",
    "        print('Run Gaussian kernel matrix IMM on kernel k-means')\n",
    "        y_kmat_imm_on_kkm, threshold_cuts_kmat = kernelmatrix_imm(X, y_kkm, gamma, rbf, check_all_cuts = True)\n",
    "\n",
    "        rand_kmat_imm_on_kkm = adjusted_rand_score(y_kmat_imm_on_kkm, y_true)\n",
    "        price_kmat_imm_on_kkm = np.sum(kernelkmeanscost(Kmat, y_kmat_imm_on_kkm))/np.sum(kernelkmeanscost(Kmat, y_kkm))\n",
    "            \n",
    "        results = {'rand_kmeans': rand_kmeans,\n",
    "                   'rand_imm': rand_imm,\n",
    "                   'best_kernel': best_kernel,\n",
    "                   'best_gamma': gamma,\n",
    "                   'rand_kkm': rand_kkm,\n",
    "                   'rand_taylor_imm_on_kkm': rand_taylor_imm_on_kkm,\n",
    "                   'price_taylor_imm_on_kkm': price_taylor_imm_on_kkm,\n",
    "                   'rand_kmat_imm_on_kkm': rand_kmat_imm_on_kkm,\n",
    "                   'price_kmat_imm_on_kkm': price_kmat_imm_on_kkm,\n",
    "                   'threshold_cuts_taylor': threshold_cuts_taylor,\n",
    "                   'threshold_cuts_kmat': threshold_cuts_kmat\n",
    "              }\n",
    "        \n",
    "        labels = {'y_kmeans': y_kmeans,\n",
    "                  'y_kmeans_imm': y_kmeans_imm,\n",
    "                  'y_kkm': y_kkm,\n",
    "                  'y_taylor_imm_on_kkm': y_taylor_imm_on_kkm,\n",
    "                  'y_kmat_imm_on_kkm': y_kmat_imm_on_kkm\n",
    "                 }\n",
    "        \n",
    "        return(results, labels)\n",
    "    \n",
    "            \n",
    "        \n",
    "    elif best_kernel==1:\n",
    "        \n",
    "        print('Run Laplace kernel matrix IMM on kernel k-means')\n",
    "        y_kmat_imm_on_kkm, threshold_cuts_kmat = kernelmatrix_imm(X, y_kkm, gamma, laplace, check_all_cuts = True)\n",
    "\n",
    "        rand_kmat_imm_on_kkm = adjusted_rand_score(y_kmat_imm_on_kkm, y_true)\n",
    "        price_kmat_imm_on_kkm = np.sum(kernelkmeanscost(Kmat, y_kmat_imm_on_kkm))/np.sum(kernelkmeanscost(Kmat, y_kkm))\n",
    "            \n",
    "        results = {'rand_kmeans': rand_kmeans,\n",
    "                   'rand_imm': rand_imm,\n",
    "                   'best_kernel': best_kernel,\n",
    "                   'best_gamma': gamma,\n",
    "                   'rand_kkm': rand_kkm,\n",
    "                   'rand_kmat_imm_on_kkm': rand_kmat_imm_on_kkm,\n",
    "                   'price_kmat_imm_on_kkm': price_kmat_imm_on_kkm,\n",
    "                   'threshold_cuts_kmat': threshold_cuts_kmat\n",
    "              }\n",
    "        \n",
    "        labels = {'y_kmeans': y_kmeans,\n",
    "                  'y_kmeans_imm': y_kmeans_imm,\n",
    "                  'y_kkm': y_kkm,\n",
    "                  'y_kmat_imm_on_kkm': y_kmat_imm_on_kkm\n",
    "         }\n",
    "        \n",
    "        return(results, labels)\n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8aad252",
   "metadata": {},
   "outputs": [],
   "source": [
    "def refine_imm(X, y_true, y_kkm, y_imm, Kmat, max_leaves):\n",
    "    \n",
    "    print('Kernel ExKMC')\n",
    "    y_greedy = exkmc_build_on_imm(X, y_kkm, y_imm, Kmat, max_leaves)\n",
    "    y_exkmc = np.zeros(X.shape[0])\n",
    "    \n",
    "    for cluster in np.unique(y_greedy):\n",
    "        index_u = np.where(y_greedy==cluster)[0]\n",
    "        best_label = exkmc_min_cost_at_node(index_u, Kmat, y_kkm)[1]\n",
    "        y_exkmc[index_u] = best_label\n",
    "        \n",
    "    rand_exkmc = adjusted_rand_score(y_exkmc, y_true)\n",
    "    price_exkmc = np.sum(kernelkmeanscost(Kmat, y_exkmc))/np.sum(kernelkmeanscost(Kmat, y_kkm))\n",
    "   \n",
    "    y_greedy2 = expand_build_on_imm(X, y_kkm, y_imm, max_leaves)\n",
    "    y_expand = np.zeros(X.shape[0])\n",
    "    \n",
    "    for cluster in np.unique(y_greedy2):\n",
    "        index_u = np.where(y_greedy2==cluster)[0]\n",
    "        best_label = expand_min_cost_at_node(index_u, y_kkm)[1]\n",
    "        y_expand[index_u] = best_label\n",
    "\n",
    "    rand_expand = adjusted_rand_score(y_expand, y_true)\n",
    "    price_expand = np.sum(kernelkmeanscost(Kmat, y_expand))/np.sum(kernelkmeanscost(Kmat, y_kkm))\n",
    "    \n",
    "    results = {'rand_exkmc': rand_exkmc,\n",
    "               'price_exkmc': price_exkmc,\n",
    "               'rand_expand': rand_expand,\n",
    "               'price_expand': price_expand\n",
    "              }\n",
    "    \n",
    "    labels = {'y_exkmc': y_exkmc,\n",
    "              'y_expand': y_expand\n",
    "             }\n",
    "    \n",
    "    return(results, labels)\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
