{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "da345ce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "### IMM Function\n",
    "\n",
    "def find_cut(X, y, index_u, clusters_u, centers, check_all_cuts, silent):\n",
    "  \n",
    "    # X,y are the data and the lables\n",
    "    # index_u is the array of indices at this node (n_u)\n",
    "    # clusters_u is a vector of admissible cluster labels at the node (k_u)\n",
    "    # centers is a matrix of all centers (k,d)\n",
    "\n",
    "    # Data at this node\n",
    "    X_u = X[index_u,:]\n",
    "    y_u = y[index_u]\n",
    "    \n",
    "    #print(X_u)\n",
    "    #print('Remaining Cluster Centers:', centers[clusters_u,:])\n",
    "    \n",
    "    # Find index of all points that still have their center available (in clusters)\n",
    "    index_good = np.where(np.isin(y_u, clusters_u))[0]\n",
    "    \n",
    "    # These are these points\n",
    "    X_relevant = X_u[index_good,:]\n",
    "    \n",
    "    #... and these are their clusters\n",
    "    y_relevant = y_u[index_good]\n",
    "    y_relevant = y_relevant.astype(int)\n",
    "    \n",
    "    if silent == False:\n",
    "        print('Remaining relevant data:')\n",
    "        print(X_relevant)\n",
    "        print('Labels:')\n",
    "        print(y_relevant)\n",
    "        \n",
    "  \n",
    "    mistakes = float('inf')\n",
    "\n",
    "    for j in np.arange(np.shape(X_relevant)[1]): # iterate over all coordinates\n",
    "        \n",
    "        z = centers[clusters_u,j] # projected centers\n",
    "        \n",
    "        if len(np.unique(z))==1:\n",
    "            mistakes_temp = mistakes # if all projections are identical, don't cut\n",
    "        else:\n",
    "            sorted_z = np.sort(z)\n",
    "            \n",
    "            if check_all_cuts == False:\n",
    "                \n",
    "                thetas = (sorted_z[:-1] + sorted_z[1:])/2\n",
    "            \n",
    "            elif check_all_cuts == True:\n",
    "                \n",
    "                theta_index = np.where((X_relevant[:,j] > sorted_z[0]) & (X_relevant[:,j] < sorted_z[-1]))[0]\n",
    "                thetas = np.sort(np.unique(X_relevant[theta_index,j]))\n",
    "                thetas = (thetas[:-1] + thetas[1:])/2\n",
    "            \n",
    "            if silent == False:\n",
    "                print('Check Coordinate',j)\n",
    "                print('Projected Centers:',z)\n",
    "                print('Thetas:', thetas)\n",
    "            \n",
    "            for theta in thetas:\n",
    "                \n",
    "                pointL_centerR = np.where((X_relevant[:,j] < theta) & (centers[y_relevant,j] >= theta))[0]\n",
    "                pointR_centerL = np.where((X_relevant[:,j] >= theta) & (centers[y_relevant,j] < theta))[0]\n",
    "                \n",
    "                mistakes1 = len(pointL_centerR)\n",
    "                mistakes2 = len(pointR_centerL)\n",
    "                \n",
    "                mistakes_temp = mistakes1 + mistakes2\n",
    "                \n",
    "                if silent == False:\n",
    "                    print('Check Threshold', theta)\n",
    "                    print((X_relevant[:,j] < theta) & (centers[y_relevant,j] >= theta))\n",
    "                    print('Point left, center right:', pointL_centerR)\n",
    "                    print((X_relevant[:,j] >= theta) & (centers[y_relevant,j] < theta))\n",
    "                    print('Point right, center left:', pointR_centerL)\n",
    "                    print('--->', mistakes_temp, 'Mistakes')\n",
    "                \n",
    "                if mistakes_temp < mistakes:\n",
    "                    mistakes = mistakes_temp\n",
    "                    best_cut = {'coordinate': j, 'threshold': theta, 'mistakes': mistakes}\n",
    "    \n",
    "    index_go_L = X_u[:,best_cut['coordinate']] <= best_cut['threshold']\n",
    "    index_go_R = X_u[:,best_cut['coordinate']] > best_cut['threshold']\n",
    "    \n",
    "    best_cut['index_u_L'] = index_u[index_go_L]\n",
    "    best_cut['index_u_R'] = index_u[index_go_R]\n",
    "    \n",
    "    best_cut['clusters_u_L'] = clusters_u[np.where(centers[clusters_u,best_cut['coordinate']] <= best_cut['threshold'])[0]]\n",
    "    best_cut['clusters_u_R'] = clusters_u[np.where(centers[clusters_u,best_cut['coordinate']] > best_cut['threshold'])[0]]\n",
    "    \n",
    "    print('Cluster Label Partioning:', best_cut['clusters_u_L'], 'and', best_cut['clusters_u_R'])\n",
    "    \n",
    "    return(best_cut)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "30c1b2a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def do_cut(X, y, index_nodes, clusters_nodes, centers, threshold_cuts, check_all_cuts, silent):\n",
    "    \n",
    "    ### X,y are simply the data and the labels\n",
    "    ### index_nodes is a list of arrays, each containing the index of points at a node\n",
    "    \n",
    "    print('+++ New Cut +++')\n",
    "    mistakes = float('inf')\n",
    "    n_nodes = len(index_nodes)\n",
    "    n_leaves = 0\n",
    "    \n",
    "    for u in np.arange(n_nodes):\n",
    "        \n",
    "        print('--- Check node', u, '---')\n",
    "        index_u = index_nodes[u]\n",
    "        clusters_u = clusters_nodes[u]\n",
    "        print('Node', u ,'contains cluster labels', clusters_u)\n",
    "    \n",
    "        if len(clusters_u)==1:\n",
    "            print('... This is already a Leaf')\n",
    "            n_leaves = n_leaves + 1\n",
    "        else:\n",
    "            best_cut_u = find_cut(X, y, index_u, clusters_u, centers, check_all_cuts, silent)\n",
    "            mistakes_u = best_cut_u['mistakes']\n",
    "            print('Mistakes:', mistakes_u)\n",
    "        \n",
    "            if mistakes_u < mistakes:\n",
    "                best_u = u\n",
    "                best_cut = best_cut_u\n",
    "                mistakes = mistakes_u\n",
    "    \n",
    "    if n_leaves == n_nodes:\n",
    "        \n",
    "        print('+++ IMM has finished +++')\n",
    "        \n",
    "        return('onlyleaves')\n",
    "    \n",
    "    elif n_leaves < n_nodes:    \n",
    "        \n",
    "        print('Cut node', best_u, \n",
    "              'at Coordinate', best_cut['coordinate'], \n",
    "              'Threshold', best_cut['threshold'],\n",
    "              '---> Mistakes =', best_cut['mistakes'])\n",
    "\n",
    "        index_update = index_nodes.copy()\n",
    "        index_update[best_u] = best_cut['index_u_L']\n",
    "        index_update.append(best_cut['index_u_R'])\n",
    "\n",
    "        clusters_update = clusters_nodes.copy()\n",
    "        clusters_update[best_u] = best_cut['clusters_u_L']\n",
    "        clusters_update.append(best_cut['clusters_u_R'])\n",
    "        \n",
    "        threshold_cuts_update = threshold_cuts.copy()\n",
    "        newrow = np.array([best_u, best_cut['coordinate'], best_cut['threshold']])\n",
    "        threshold_cuts_update = np.vstack([threshold_cuts_update, newrow])\n",
    "        \n",
    "        return(index_update, clusters_update, threshold_cuts_update)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "0f2eb402",
   "metadata": {},
   "outputs": [],
   "source": [
    "def imm(X, y, centers, check_all_cuts = True, silent = True):\n",
    "    \n",
    "    n_data = np.shape(X)[0]\n",
    "    index_nodes = [np.arange(n_data)]\n",
    "    clusters_nodes = [np.unique(y.astype(int))]\n",
    "    threshold_cuts = np.zeros((0,3))\n",
    "    \n",
    "    converged = False\n",
    "    \n",
    "    while converged == False:\n",
    "        \n",
    "        cut = do_cut(X, y, index_nodes, clusters_nodes, centers, threshold_cuts, check_all_cuts, silent)\n",
    "        \n",
    "        if cut=='onlyleaves':\n",
    "            converged = True\n",
    "        else:\n",
    "            index_nodes, clusters_nodes, threshold_cuts = cut\n",
    "        \n",
    "        #print('Threshold Cuts so far:', threshold_cuts)\n",
    "    \n",
    "    y_imm = np.zeros(n_data)\n",
    "    \n",
    "    for obs in np.arange(n_data):\n",
    "        y_imm[obs] = np.where([obs in index_u for index_u in index_nodes])[0][0]\n",
    "    \n",
    "    return(y_imm, threshold_cuts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "561bbfaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "run = False\n",
    "\n",
    "if run == True:\n",
    "\n",
    "    import numpy as np\n",
    "    import matplotlib.pyplot as plt\n",
    "    from sklearn.cluster import KMeans\n",
    "    rng = np.random.default_rng()\n",
    "\n",
    "    n_clusters = 2\n",
    "    n_data = 6\n",
    "    sigma = 0.2\n",
    "    cov = np.array([[sigma, 0], [0, sigma]])\n",
    "    X = np.zeros((n_data,2))\n",
    "\n",
    "    for k in range(0,n_clusters):\n",
    "        mean = np.array([k, (1-k)**2])\n",
    "        n_k = int(n_data/n_clusters)\n",
    "        X[k*n_k + np.array(range(0,n_k)),] = np.random.default_rng().multivariate_normal(mean, cov, n_k)\n",
    "\n",
    "    kmeans = KMeans(n_clusters=n_clusters, n_init=3)\n",
    "    kmeans.fit(X)\n",
    "    y_true = kmeans.predict(X)\n",
    "\n",
    "    plt.scatter(X[:, 0], X[:, 1], c=y_true, s=50, cmap='viridis')\n",
    "\n",
    "    centers = kmeans.cluster_centers_\n",
    "\n",
    "    y_imm, threshold_cuts = imm(X, y_true, centers, check_all_cuts = False, silent = False)\n",
    "\n",
    "    print(y_imm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1d5f5f3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def taylor_imm(X, y, gamma, features_per_dim, check_all_cuts):\n",
    "    \n",
    "    # We use d Taylor approximations to the Gaussian kernel as features\n",
    "\n",
    "    N,d = X.shape\n",
    "    true_k = len(np.unique(y))\n",
    "\n",
    "    Phi = np.zeros((N, features_per_dim*d))\n",
    "\n",
    "    for i in np.arange(d):\n",
    "        X_i = X[:,i] # project data to ith dimension\n",
    "        for n in np.arange(N):\n",
    "            for j in np.arange(features_per_dim):\n",
    "                Phi[n,i*features_per_dim+j] = (np.sqrt((2*gamma)**j)/np.sqrt(np.math.factorial(j)))*(X_i[n]**j)*np.exp(-gamma*X_i[n]**2)\n",
    "\n",
    "    #print(Phi.shape)\n",
    "\n",
    "    taylor_centers = np.zeros((true_k,features_per_dim*d))\n",
    "    \n",
    "    for i in np.arange(true_k):\n",
    "        Phi_i = Phi[np.where(y==i)[0],:]\n",
    "        taylor_centers[i,:] = np.mean(Phi_i, axis=0)\n",
    "    \n",
    "    y_imm, threshold_cuts = imm(Phi, y, taylor_centers, check_all_cuts = check_all_cuts)\n",
    "    \n",
    "    return(y_imm, threshold_cuts)\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7c1fc489",
   "metadata": {},
   "outputs": [],
   "source": [
    "def kernelmatrix_imm(X, y, gamma, kernel, check_all_cuts):\n",
    "\n",
    "    # We use the d univariate kernel matrices as features\n",
    "    \n",
    "    N,d = X.shape\n",
    "    true_k = len(np.unique(y))\n",
    "\n",
    "    Phi = np.zeros((N, N*d))\n",
    "\n",
    "    for i in np.arange(d):\n",
    "        X_i = X[:,i] # project data to ith dimension\n",
    "        X_i = np.reshape(X_i, (-1, 1))\n",
    "        Phi[:,i*N:(i+1)*N] = pairwise_kernels(X_i, metric=rbf, gamma=gamma)\n",
    "\n",
    "    Kmat_centers = np.zeros((true_k,N*d))\n",
    "\n",
    "    for i in np.arange(true_k):\n",
    "        Phi_i = Phi[np.where(y==i)[0],:]\n",
    "        Kmat_centers[i,:] = np.mean(Phi_i, axis=0)\n",
    "\n",
    "    y_imm, threshold_cuts = imm(Phi, y, Kmat_centers, check_all_cuts = check_all_cuts)\n",
    "    \n",
    "    return(y_imm, threshold_cuts)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
