{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a01d7de2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from time import time\n",
    "from sklearn import datasets\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import random\n",
    "from sklearn.metrics.pairwise import pairwise_kernels\n",
    "from sklearn.metrics.cluster import adjusted_rand_score\n",
    "from sklearn.cluster import KMeans\n",
    "import os\n",
    "\n",
    "%run KernelkmeansFunctions.ipynb\n",
    "%run ExplainabilityFunctions.ipynb\n",
    "%run ExpandingIMM.ipynb\n",
    "%run KernelExKMC.ipynb\n",
    "%run RunExperiments.ipynb\n",
    "\n",
    "rng = np.random.default_rng()\n",
    "\n",
    "### all clustering benchmark datasets available under https://cs.joensuu.fi/sipu/datasets/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c218f0a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Define kernel functions\n",
    "\n",
    "def rbf(x,y,gamma):\n",
    "    return(np.exp(-gamma*np.sum((x-y)**2)))\n",
    "\n",
    "def laplace(x,y,gamma):\n",
    "    return(np.exp(-gamma*np.sum(np.abs(x-y))))\n",
    "\n",
    "def linear(x,y):\n",
    "    return(np.dot(x,y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f448ad66",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Pathbased\n",
    "\n",
    "rng = np.random.default_rng()\n",
    "df = pd.read_csv('your path', sep=\";\", header=None)\n",
    "\n",
    "gammas = np.array([0.01, 0.05, 0.1, 0.5, 1, 5, 10])\n",
    "\n",
    "X = np.array(df)[:,0:3]\n",
    "y_true = X[:,2]\n",
    "true_k = len(np.unique(y_true))\n",
    "X = X[:,[0,1]]\n",
    "y_true = y_true.astype(int) - 1\n",
    "\n",
    "gammas = [0.05] # set optimum for a quick run\n",
    "imm_path1, imm_path2 = imm_experiments(X, y_true, gammas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a0c8597",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_kkm = imm_path2['y_kkm']\n",
    "gamma = imm_path1['best_gamma']\n",
    "\n",
    "if imm_path1['best_kernel'] == 0:\n",
    "    Kmat = pairwise_kernels(X, metric=rbf, gamma=gamma)\n",
    "    if imm_path1['price_taylor_imm_on_kkm'] < imm_path1['price_kmat_imm_on_kkm']:\n",
    "        print('Gaussian Taylor')\n",
    "        y_imm = imm_path2['y_taylor_imm_on_kkm']\n",
    "    else:\n",
    "        print('Gaussian Kernel Matrix')\n",
    "        y_imm = imm_path2['y_kmat_imm_on_kkm']\n",
    "else:\n",
    "    print('Laplace Kernel Matrix')\n",
    "    Kmat = pairwise_kernels(X, metric=laplace, gamma=gamma)\n",
    "    y_imm = imm_path2['y_kmat_imm_on_kkm']\n",
    "\n",
    "refine_path1, refine_path2 = refine_imm(X, y_true, y_kkm, y_imm, Kmat, max_leaves = 6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13f5c520",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Make a plot\n",
    "\n",
    "y_kmeans = imm_path2['y_kmeans']\n",
    "y_kmeans_imm = imm_path2['y_kmeans_imm']\n",
    "y_exkmc = refine_path2['y_exkmc']\n",
    "y_expand = refine_path2['y_expand']\n",
    "\n",
    "plt.subplot(2, 3, 1)\n",
    "plt.scatter(X[y_kmeans==0, 0], X[y_kmeans==0, 1], s=50, c='green')\n",
    "plt.scatter(X[y_kmeans==1, 0], X[y_kmeans==1, 1], s=50, c='red')\n",
    "plt.scatter(X[y_kmeans==2, 0], X[y_kmeans==2, 1], s=50, c='blue')\n",
    "plt.title('K-means', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 2)\n",
    "plt.scatter(X[y_kkm==0, 0], X[y_kkm==0, 1], s=50, c='green')\n",
    "plt.scatter(X[y_kkm==1, 0], X[y_kkm==1, 1], s=50, c='blue')\n",
    "plt.scatter(X[y_kkm==2, 0], X[y_kkm==2, 1], s=50, c='red')\n",
    "plt.title('Kernel k-means', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 3)\n",
    "plt.scatter(X[y_expand==0, 0], X[y_expand==0, 1], s=50, c='green')\n",
    "plt.scatter(X[y_expand==1, 0], X[y_expand==1, 1], s=50, c='blue')\n",
    "plt.scatter(X[y_expand==2, 0], X[y_expand==2, 1], s=50, c='red')\n",
    "plt.title('Kernel IMM expanded', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 4)\n",
    "plt.scatter(X[y_kmeans_imm==0, 0], X[y_kmeans_imm==0, 1], s=50, c='green')\n",
    "plt.scatter(X[y_kmeans_imm==1, 0], X[y_kmeans_imm==1, 1], s=50, c='blue')\n",
    "plt.scatter(X[y_kmeans_imm==2, 0], X[y_kmeans_imm==2, 1], s=50, c='red')\n",
    "plt.title('IMM on k-means', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 5)\n",
    "plt.scatter(X[y_imm==0, 0], X[y_imm==0, 1], s=50, c='green')\n",
    "plt.scatter(X[y_imm==1, 0], X[y_imm==1, 1], s=50, c='blue')\n",
    "plt.scatter(X[y_imm==2, 0], X[y_imm==2, 1], s=50, c='red')\n",
    "plt.title('Kernel IMM', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 6)\n",
    "plt.scatter(X[y_exkmc==0, 0], X[y_exkmc==0, 1], s=50, c='green')\n",
    "plt.scatter(X[y_exkmc==1, 0], X[y_exkmc==1, 1], s=50, c='blue')\n",
    "plt.scatter(X[y_exkmc==2, 0], X[y_exkmc==2, 1], s=50, c='red')\n",
    "plt.title('Kernel ExKMC', fontsize=10)\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4219477",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(imm_path1, refine_path1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9275daee",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Aggregation\n",
    "\n",
    "df = pd.read_csv('your path', sep=\";\", header=None)\n",
    "\n",
    "X = np.array(df)[:,0:3]\n",
    "y_true = X[:,2]\n",
    "y_true = y_true.astype(int) - 1\n",
    "X = X[:,[0,1]]\n",
    "\n",
    "gammas = [0.1]\n",
    "imm_agg1, imm_agg2 = imm_experiments(X, y_true, gammas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "722f0e44",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_kkm = imm_agg2['y_kkm']\n",
    "gamma = imm_agg1['best_gamma']\n",
    "\n",
    "if imm_agg1['best_kernel'] == 0:\n",
    "    Kmat = pairwise_kernels(X, metric=rbf, gamma=gamma)\n",
    "    if imm_agg1['price_taylor_imm_on_kkm'] < imm_agg1['price_kmat_imm_on_kkm']:\n",
    "        print('Gaussian Taylor')\n",
    "        y_imm = imm_agg2['y_taylor_imm_on_kkm']\n",
    "    else:\n",
    "        print('Gaussian Kernel Matrix')\n",
    "        y_imm = imm_agg2['y_kmat_imm_on_kkm']\n",
    "else:\n",
    "    print('Laplace Kernel Matrix')\n",
    "    Kmat = pairwise_kernels(X, metric=laplace, gamma=gamma)\n",
    "    y_imm = imm_agg2['y_kmat_imm_on_kkm']\n",
    "\n",
    "refine_agg1, refine_agg2 = refine_imm(X, y_true, y_kkm, y_imm, Kmat, max_leaves=len(np.unique(y_true))+3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97c3e28e",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Make a plot\n",
    "\n",
    "y_kmeans = imm_agg2['y_kmeans']\n",
    "y_kmeans_imm = imm_agg2['y_kmeans_imm']\n",
    "y_exkmc = refine_agg2['y_exkmc']\n",
    "y_expand = refine_agg2['y_expand']\n",
    "\n",
    "plt.subplot(2, 3, 1)\n",
    "\n",
    "plt.scatter(X[:, 0], X[:, 1], s=50, c=y_kmeans)\n",
    "plt.title('K-means', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 2)\n",
    "plt.scatter(X[:, 0], X[:, 1], s=50, c=y_kkm)\n",
    "plt.title('Kernel k-means', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 3)\n",
    "plt.scatter(X[:, 0], X[:, 1], s=50, c=y_expand)\n",
    "plt.title('Kernel IMM expanded', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 4)\n",
    "plt.scatter(X[:, 0], X[:, 1], s=50, c=y_kmeans_imm)\n",
    "plt.title('IMM on k-means', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 5)\n",
    "plt.scatter(X[:, 0], X[:, 1], s=50, c=y_imm)\n",
    "plt.title('Kernel IMM', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 6)\n",
    "plt.scatter(X[:, 0], X[:, 1], s=50, c=y_exkmc)\n",
    "plt.title('Kernel ExKMC', fontsize=10)\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b11f8681",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(imm_agg1, refine_agg1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "940fcda1",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Flame\n",
    "\n",
    "df = pd.read_csv('your path', sep=\";\", header=None)\n",
    "\n",
    "X = np.array(df)[:,0:3]\n",
    "y_true = X[:,2]\n",
    "X = X[:,[0,1]]\n",
    "y_true = y_true.astype(int) - 1\n",
    "\n",
    "gamma = [0.05]\n",
    "imm_flame1, imm_flame2 = imm_experiments(X, y_true, gammas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8bd9976",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_kkm = imm_flame2['y_kkm']\n",
    "gamma = imm_flame1['best_gamma']\n",
    "\n",
    "if imm_flame1['best_kernel'] == 0:\n",
    "    Kmat = pairwise_kernels(X, metric=rbf, gamma=gamma)\n",
    "    if imm_flame1['price_taylor_imm_on_kkm'] < imm_flame1['price_kmat_imm_on_kkm']:\n",
    "        print('Gaussian Taylor')\n",
    "        y_imm = imm_flame2['y_taylor_imm_on_kkm']\n",
    "    else:\n",
    "        print('Gaussian Kernel Matrix')\n",
    "        y_imm = imm_flame2['y_kmat_imm_on_kkm']\n",
    "else:\n",
    "    print('Laplace Kernel Matrix')\n",
    "    Kmat = pairwise_kernels(X, metric=laplace, gamma=gamma)\n",
    "    y_imm = imm_flame2['y_kmat_imm_on_kkm']\n",
    "\n",
    "y_imm = imm_flame2['y_kmat_imm_on_kkm']\n",
    "refine_flame1, refine_flame2 = refine_imm(X, y_true, y_kkm, y_imm, Kmat, max_leaves = 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc96bc78",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Make a plot\n",
    "\n",
    "y_kmeans = imm_flame2['y_kmeans']\n",
    "y_kmeans_imm = imm_flame2['y_kmeans_imm']\n",
    "y_exkmc = refine_flame2['y_exkmc']\n",
    "y_expand = refine_flame2['y_expand']\n",
    "\n",
    "plt.subplot(2, 3, 1)\n",
    "\n",
    "plt.scatter(X[:, 0], X[:, 1], s=50, c=y_kmeans)\n",
    "plt.title('K-means', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 2)\n",
    "plt.scatter(X[:, 0], X[:, 1], s=50, c=y_kkm)\n",
    "plt.title('Kernel k-means', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 3)\n",
    "plt.scatter(X[:, 0], X[:, 1], s=50, c=y_expand)\n",
    "plt.title('Kernel IMM expanded', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 4)\n",
    "plt.scatter(X[:, 0], X[:, 1], s=50, c=y_kmeans_imm)\n",
    "plt.title('IMM on k-means', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 5)\n",
    "plt.scatter(X[:, 0], X[:, 1], s=50, c=y_imm)\n",
    "plt.title('Kernel IMM', fontsize=10)\n",
    "\n",
    "plt.subplot(2, 3, 6)\n",
    "plt.scatter(X[:, 0], X[:, 1], s=50, c=y_exkmc)\n",
    "plt.title('Kernel ExKMC', fontsize=10)\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86d70157",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(imm_flame1, refine_flame1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c15ae29a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import datasets\n",
    "\n",
    "iris = datasets.load_iris()\n",
    "X = iris.data\n",
    "y_true = iris.target\n",
    "\n",
    "gammas = [1]\n",
    "imm_iris1, imm_iris2 = imm_experiments(X, y_true, gammas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "259018bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_kkm = imm_iris2['y_kkm']\n",
    "gamma = imm_iris1['best_gamma']\n",
    "\n",
    "if imm_iris1['best_kernel'] == 0:\n",
    "    Kmat = pairwise_kernels(X, metric=rbf, gamma=gamma)\n",
    "    if imm_iris1['price_taylor_imm_on_kkm'] < imm_iris1['price_kmat_imm_on_kkm']:\n",
    "        print('Gaussian Taylor')\n",
    "        y_imm = imm_iris2['y_taylor_imm_on_kkm']\n",
    "    else:\n",
    "        print('Gaussian Kernel Matrix')\n",
    "        y_imm = imm_iris2['y_kmat_imm_on_kkm']\n",
    "else:\n",
    "    print('Laplace Kernel Matrix')\n",
    "    Kmat = pairwise_kernels(X, metric=laplace, gamma=gamma)\n",
    "    y_imm = imm_iris2['y_kmat_imm_on_kkm']\n",
    "\n",
    "y_imm = imm_iris2['y_kmat_imm_on_kkm']\n",
    "refine_iris1, refine_iris2 = refine_imm(X, y_true, y_kkm, y_imm, Kmat, max_leaves = 6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cec4888",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(imm_iris1, refine_iris1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcca0487",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import datasets\n",
    "\n",
    "X, y_true = datasets.load_breast_cancer(return_X_y=True)\n",
    "\n",
    "gammas = 10**(-6)*np.array([1, 5, 10])\n",
    "imm_wisc1, imm_wisc2 = imm_experiments(X, y_true, gammas)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d16528d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_kkm = imm_wisc2['y_kkm']\n",
    "gamma = imm_wisc1['best_gamma']\n",
    "\n",
    "if imm_wisc1['best_kernel'] == 0:\n",
    "    Kmat = pairwise_kernels(X, metric=rbf, gamma=gamma)\n",
    "    if imm_wisc1['price_taylor_imm_on_kkm'] < imm_wisc1['price_kmat_imm_on_kkm']:\n",
    "        print('Gaussian Taylor')\n",
    "        y_imm = imm_wisc2['y_taylor_imm_on_kkm']\n",
    "    else:\n",
    "        print('Gaussian Kernel Matrix')\n",
    "        y_imm = imm_wisc2['y_kmat_imm_on_kkm']\n",
    "else:\n",
    "    print('Laplace Kernel Matrix')\n",
    "    Kmat = pairwise_kernels(X, metric=laplace, gamma=gamma)\n",
    "    y_imm = imm_wisc2['y_kmat_imm_on_kkm']\n",
    "\n",
    "y_imm = imm_wisc2['y_kmat_imm_on_kkm']\n",
    "refine_wisc1, refine_wisc2 = refine_imm(X, y_true, y_kkm, y_imm, Kmat, max_leaves = 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a0b2550",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(imm_wisc1, refine_wisc1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
