{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b89272c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cost_points2cluster(index_u, cluster, Kmat, y):\n",
    "    \n",
    "    ### cost of assigning points to a given cluster\n",
    "    cluster_index = np.where(y == cluster)[0] # get the points in cluster\n",
    "    #print(cluster_index)\n",
    "    \n",
    "    normcluster2 = np.sum(Kmat[np.ix_(cluster_index, cluster_index)])/len(cluster_index)**2\n",
    "    \n",
    "    index_dot_cluster = np.sum(Kmat[np.ix_(index_u, cluster_index)])/len(cluster_index)\n",
    "    #print(Kmat[np.ix_(index_u, cluster_index)].shape)\n",
    "    \n",
    "    sum_xx = np.sum(np.diagonal(Kmat[np.ix_(index_u, index_u)]))\n",
    "    \n",
    "    return(sum_xx + normcluster2 - 2*index_dot_cluster)\n",
    "\n",
    "#points = np.where(y_true==1)[0]\n",
    "#cost_points2cluster(points, 2, Kmat, y_true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5e8da89",
   "metadata": {},
   "outputs": [],
   "source": [
    "def exkmc_min_cost_at_node(index_u, Kmat, y, silent = True):\n",
    "    \n",
    "    ### minimal cost of a set of points\n",
    "    \n",
    "    cost_best = float('inf')\n",
    "    \n",
    "    for cluster in np.unique(y):\n",
    "        \n",
    "        cost_u = cost_points2cluster(index_u, cluster, Kmat, y)\n",
    "    \n",
    "        if cost_u < cost_best:\n",
    "            cluster_best = cluster\n",
    "            cost_best = cost_u\n",
    "    \n",
    "    return(cost_best, cluster_best)\n",
    "\n",
    "#points = np.where(y_true==1)[0]\n",
    "#exkmc_min_cost_at_node(points, Kmat, y_true, silent = False)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bfd72f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def exkmc_cost_delta_of_split(i, theta1, theta2, index_u, X, y, Kmat, silent = True):\n",
    "    \n",
    "    # Given a split (i,theta) we determine the cost of it, over all partitions\n",
    "    \n",
    "    # i: Axis we use to split\n",
    "    # theta1 and theta2: Thresholds we use to split\n",
    "    # index_u: Index of points at this node\n",
    "    # X: Data\n",
    "    # y: Cluster Labels\n",
    "    # Kmat: Kernel Matrix\n",
    "    \n",
    "    # Current cost of a given node\n",
    "    cost_u = exkmc_min_cost_at_node(index_u, Kmat, y)[0]\n",
    "    \n",
    "    #print('Current cost at this node:', cost_u)\n",
    "    \n",
    "    # Row IDs of the interval\n",
    "    index_L = index_u[np.where((X[np.ix_(index_u,[i])]>=theta1) & (X[np.ix_(index_u,[i])]<=theta2))[0]]\n",
    "    \n",
    "    # Row IDs of not the interval \n",
    "    index_R = index_u[np.where((X[np.ix_(index_u,[i])]<theta1) | (X[np.ix_(index_u,[i])]>theta2))[0]]\n",
    "    \n",
    "    cost_new = exkmc_min_cost_at_node(index_L, Kmat, y)[0] + exkmc_min_cost_at_node(index_R, Kmat, y)[0]\n",
    "    \n",
    "    cost_delta = cost_u - cost_new # we will choose the largest cost_delta (i.e. lowest cost_new)\n",
    "    \n",
    "    if silent == False:\n",
    "        print('Old cost:', cost_u)\n",
    "        print('New cost:', cost_new)\n",
    "        print('Delta:', cost_u - cost_new)\n",
    "    \n",
    "    return(cost_delta)\n",
    "\n",
    "#exkmc_cost_delta_of_split(0, 0, 10, np.arange(X.shape[0]), X, y_true, Kmat, silent = False) # bad\n",
    "#exkmc_cost_delta_of_split(0, 0, 1, np.arange(X.shape[0]), X, y_true, Kmat, silent = False) # good"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "493a41ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "def exkmc_split_node(index_u, X, y, Kmat, silent = True):\n",
    "    \n",
    "    # Splits a given node at the best possible threshold cut\n",
    "    \n",
    "    # index_u: points at this node\n",
    "    \n",
    "    best_delta = (-1)*float('inf')\n",
    "    \n",
    "    for i in np.arange(np.shape(X)[1]):\n",
    "        \n",
    "        print('Check Coordinate',i)\n",
    "        thetas = np.unique(np.sort(X[np.ix_(index_u,[i])]))\n",
    "        \n",
    "        #print('Thresholds:', thetas)\n",
    "        \n",
    "        for theta1 in thetas[:-1]:\n",
    "            for theta2 in thetas[thetas>theta1]:\n",
    "                \n",
    "                #print(theta1, theta2)\n",
    "                cost_delta = exkmc_cost_delta_of_split(i, theta1, theta2, index_u, X, y, Kmat)\n",
    "            \n",
    "                if cost_delta > best_delta:\n",
    "\n",
    "                    if silent == False:\n",
    "                        print('Improvement at theta1 =', theta1, 'and theta2=', theta2)\n",
    "                    \n",
    "                    best_delta = cost_delta\n",
    "                    best_i = i\n",
    "                    best_theta1 = theta1\n",
    "                    best_theta2 = theta2\n",
    "                    \n",
    "    info_u = {'best_delta': best_delta,\n",
    "              'best_i': best_i,\n",
    "              'best_theta1': best_theta1,\n",
    "              'best_theta2': best_theta2\n",
    "             }\n",
    "    \n",
    "    return(info_u)\n",
    "\n",
    "#points = np.arange(X.shape[0])\n",
    "#points = np.where((y_true==1) | (y_true==2))[0]\n",
    "#plt.scatter(X[points, 0], X[points, 1], c=y_true[points], s=50, cmap='viridis')\n",
    "#exkmc_split_node(points, X, y_true, Kmat, silent = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec526368",
   "metadata": {},
   "outputs": [],
   "source": [
    "def exkmc_new_cut(X, y, Kmat, index_nodes, info_nodes, silent=True):\n",
    "    \n",
    "    ### This function uses ExKMC to add a new cut no matter what\n",
    "    ### Make sure to only call it as long as you want more leaves\n",
    "    \n",
    "    ### X,y are simply the data and the labels\n",
    "    ### index_nodes is a list of arrays, each containing the index of points at a node\n",
    "    \n",
    "    print('+++ Currently, we have', len(index_nodes), 'node(s). Let us find a new cut! +++')\n",
    "    \n",
    "    best_delta = (-1)*float('inf')\n",
    "    n_nodes = len(index_nodes)\n",
    "    \n",
    "    for u in np.arange(n_nodes):\n",
    "        \n",
    "        print('--- Check node', u, '---')\n",
    "        \n",
    "        index_u = index_nodes[u]\n",
    "        print('Number of points at this node:', len(index_u))\n",
    "        \n",
    "        info_u = info_nodes[u]\n",
    "        print('Available info at this node:', info_u)\n",
    "        \n",
    "        if len(index_u)<=3:\n",
    "            best_delta_u = (-1)*float('inf') # don't further cut small leafs\n",
    "            \n",
    "        else:\n",
    "            \n",
    "            if info_u == 'Empty':\n",
    "                print('Fetch info on this new node')\n",
    "                info_u = expand_split_node(index_u, X, y)\n",
    "                info_nodes[u] = info_u # write this info to the big list\n",
    "            \n",
    "            # anyways this is what we are interested in!\n",
    "            best_delta_u = info_u['best_delta']\n",
    "            best_i_u = info_u['best_i']\n",
    "            best_theta1_u = info_u['best_theta1']\n",
    "            best_theta2_u = info_u['best_theta2']\n",
    "        \n",
    "            print('---> Delta = ', best_delta_u)\n",
    "\n",
    "            if best_delta_u > best_delta:\n",
    "\n",
    "                best_delta = best_delta_u\n",
    "                best_u = u\n",
    "                best_cut_i = best_i_u\n",
    "                best_cut_theta1 = best_theta1_u\n",
    "                best_cut_theta2 = best_theta2_u\n",
    "    \n",
    "    print('--- NEW CUT --- ... at node', best_u, \n",
    "          '... at Coordinate', best_cut_i, \n",
    "          '... at Thresholds', best_cut_theta1, best_cut_theta2)\n",
    "\n",
    "    #go_left_u = np.where(X[np.ix_(index_nodes[best_u],[best_cut_i])]<=best_cut_theta)[0]\n",
    "    go_left_u = np.where((X[np.ix_(index_nodes[best_u],[best_cut_i])]>=best_cut_theta1) & \n",
    "                         (X[np.ix_(index_nodes[best_u],[best_cut_i])]<=best_cut_theta2))[0]\n",
    "    #print('Send Left:', (index_nodes[best_u])[go_left_u])\n",
    "    \n",
    "    #go_right_u = np.where(X[np.ix_(index_nodes[best_u],[best_cut_i])]>best_cut_theta)[0]\n",
    "    go_right_u = np.where((X[np.ix_(index_nodes[best_u],[best_cut_i])]<best_cut_theta1) |\n",
    "                          (X[np.ix_(index_nodes[best_u],[best_cut_i])]>best_cut_theta2))[0]\n",
    "    #print('Send Right:', (index_nodes[best_u])[go_right_u])\n",
    "\n",
    "    index_update = index_nodes.copy()\n",
    "    index_update[best_u] = (index_nodes[best_u])[go_left_u]\n",
    "    index_update.append((index_nodes[best_u])[go_right_u])\n",
    "    \n",
    "    print('Node left:', len(go_left_u))\n",
    "    print('Node right:', len(go_right_u))\n",
    "    \n",
    "    info_update = info_nodes.copy()\n",
    "    info_update[best_u] = 'Empty' \n",
    "    info_update.append('Empty')\n",
    "\n",
    "    return(index_update, info_update)\n",
    "\n",
    "#points = np.where((y_true==1) | (y_true==2))[0]\n",
    "#exkmc_new_cut(X, y_true, Kmat, [points], silent = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88db4f2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def exkmc_build_on_imm(X, y, y_imm, Kmat, max_leaves, silent = True):\n",
    "    \n",
    "    n_data = np.shape(X)[0]\n",
    "    index_nodes = []\n",
    "    info_nodes = []\n",
    "    \n",
    "    for i in np.unique(y_imm):\n",
    "        \n",
    "        indices_i = np.where(y_imm == i)[0]\n",
    "        index_nodes.append(indices_i)\n",
    "        info_nodes.append('Empty')\n",
    "    \n",
    "    # index_nodes is a partition that contains the IMM clusters\n",
    "    # we will further partition it using 'exkmc_new_cut'\n",
    "    \n",
    "    converged = False\n",
    "    \n",
    "    if(len(index_nodes)>=max_leaves):\n",
    "        converged = True\n",
    "        print('Converged')\n",
    "    \n",
    "    while converged == False:\n",
    "        \n",
    "        index_nodes, info_nodes = exkmc_new_cut(X, y, Kmat, index_nodes, info_nodes)\n",
    "        \n",
    "        if(len(index_nodes)>=max_leaves):\n",
    "            converged = True\n",
    "            print('Converged')\n",
    "\n",
    "    y_greedy = np.zeros(n_data)\n",
    "    \n",
    "    for i in np.arange(len(index_nodes)):\n",
    "        index_u = index_nodes[i]\n",
    "        y_greedy[index_u] = i\n",
    "    \n",
    "    return(y_greedy)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
