{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79bc4303",
   "metadata": {},
   "outputs": [],
   "source": [
    "def expand_min_cost_at_node(index_u, y, silent = True):\n",
    "    \n",
    "    ### minimal error of a given set of points index_u\n",
    "    \n",
    "    # y is the array of reference cluster assignments\n",
    "    \n",
    "    if len(index_u)==0:\n",
    "        error_best = 0\n",
    "        cluster_best = 0\n",
    "    \n",
    "    else:\n",
    "        \n",
    "        counts = np.zeros(len(np.unique(y))) # counts occurences of each cluster label in index_u\n",
    "        \n",
    "        for i in np.arange(len(counts)):\n",
    "            index_i = np.where(y==i)[0]\n",
    "            counts[i] = np.sum(np.in1d(index_u, index_i))\n",
    "        \n",
    "        cluster_best = np.argmax(counts)\n",
    "        error_best = len(index_u) - np.max(counts)\n",
    "\n",
    "        if silent == False:\n",
    "            print(counts)\n",
    "            print(cluster_best)\n",
    "            print(error_best)\n",
    "\n",
    "    return(error_best, cluster_best)\n",
    "\n",
    "#points = np.where((y_true==1) | (y_true==2))[0]\n",
    "#points = np.random.choice(len(y_true), 20)\n",
    "#print(points)\n",
    "#expand_min_cost_at_node(points, y_true, silent = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bfd72f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def expand_cost_delta_of_split(i, theta1, theta2, index_u, X, y, silent = True):\n",
    "    \n",
    "    # Given a split (i,theta) we determine the cost of it, over all partitions\n",
    "    \n",
    "    # i: Axis we use to split\n",
    "    # theta1 and theta2: Thresholds we use to split\n",
    "    # index_u: Index of points at this node\n",
    "    # X: Data\n",
    "    # y: Cluster Labels\n",
    "    # Kmat: Kernel Matrix\n",
    "    \n",
    "    cost_old =  expand_min_cost_at_node(index_u, y, silent)[0]\n",
    "    \n",
    "    # Row IDs of the interval\n",
    "    index_L = index_u[np.where((X[np.ix_(index_u,[i])]>=theta1) & (X[np.ix_(index_u,[i])]<=theta2))[0]]\n",
    "    \n",
    "    # Row IDs of not the interval \n",
    "    index_R = index_u[np.where((X[np.ix_(index_u,[i])]<theta1) | (X[np.ix_(index_u,[i])]>theta2))[0]]\n",
    "    \n",
    "    cost_new = expand_min_cost_at_node(index_L, y, silent)[0] + expand_min_cost_at_node(index_R, y, silent)[0]\n",
    "    \n",
    "    cost_delta = cost_old - cost_new # the larger the better\n",
    "    \n",
    "    if silent == False:\n",
    "        print('Cost old:', cost_old)\n",
    "        print('Cost new:', cost_new)\n",
    "        print('Cost delta:', cost_delta)\n",
    "    \n",
    "    return(cost_delta)\n",
    "\n",
    "#print(expand_cost_delta_of_split(0, 0, 10, np.arange(X.shape[0]), X, y_true, silent = False)) # bad\n",
    "#print(expand_cost_delta_of_split(0, 0, 1.5, np.arange(X.shape[0]), X, y_true, silent = False)) # good"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "493a41ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "def expand_split_node(index_u, X, y, silent = True):\n",
    "    \n",
    "    # Splits a given node at the best possible threshold cut\n",
    "    \n",
    "    # index_u: points at this node\n",
    "    \n",
    "    best_delta = (-1)*float('inf')\n",
    "    \n",
    "    for i in np.arange(np.shape(X)[1]):\n",
    "        \n",
    "        print('Check Coordinate',i)\n",
    "        thetas = np.unique(np.sort(X[np.ix_(index_u,[i])]))\n",
    "        \n",
    "        #print('Thresholds:', thetas)\n",
    "        \n",
    "        for theta1 in thetas[:-1]:\n",
    "            for theta2 in thetas[thetas>theta1]:\n",
    "                \n",
    "                #print(theta1, theta2)\n",
    "                delta_new = expand_cost_delta_of_split(i, theta1, theta2, index_u, X, y)\n",
    "            \n",
    "                if delta_new > best_delta:\n",
    "\n",
    "                    if silent == False:\n",
    "                        print('Improvement at theta1 =', theta1, 'and theta2=', theta2)\n",
    "                    \n",
    "                    best_delta = delta_new\n",
    "                    best_i = i\n",
    "                    best_theta1 = theta1\n",
    "                    best_theta2 = theta2\n",
    "    \n",
    "    info_u = {'best_delta': best_delta,\n",
    "              'best_i': best_i,\n",
    "              'best_theta1': best_theta1,\n",
    "              'best_theta2': best_theta2\n",
    "             }\n",
    "\n",
    "    return(info_u)\n",
    "\n",
    "#points = np.arange(X.shape[0])\n",
    "#points = np.where((y_true==1) | (y_true==2))[0]\n",
    "#print(points)\n",
    "#plt.scatter(X[points, 0], X[points, 1], c=y_true[points], s=50, cmap='viridis')\n",
    "#expand_split_node(points, X, y_true, silent = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f931e8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def expand_new_cut(X, y, index_nodes, info_nodes, silent=True):\n",
    "    \n",
    "    ### This function adds a new cut no matter what\n",
    "    ### Make sure to only call it as long as you want more leaves\n",
    "    \n",
    "    ### X,y are simply the data and the labels\n",
    "    ### index_nodes is a list of arrays, each containing the index of points at a node\n",
    "    \n",
    "    print('+++ Currently, we have', len(index_nodes), 'node(s). Let us find a new cut! +++')\n",
    "    \n",
    "    best_delta = (-1)*float('inf')\n",
    "    n_nodes = len(index_nodes)\n",
    "    \n",
    "    for u in np.arange(n_nodes):\n",
    "        \n",
    "        print('--- Check node', u, '---')\n",
    "        \n",
    "        index_u = index_nodes[u]\n",
    "        print('Number of points at this node:', len(index_u))\n",
    "        \n",
    "        info_u = info_nodes[u]\n",
    "        print('Available info at this node:', info_u)\n",
    "        \n",
    "        if len(index_u)<=3:\n",
    "            best_delta_u = (-1)*float('inf') # don't further cut small leafs\n",
    "            \n",
    "        else:\n",
    "            \n",
    "            if info_u == 'Empty':\n",
    "                print('Fetch info on this new node')\n",
    "                info_u = expand_split_node(index_u, X, y)\n",
    "                info_nodes[u] = info_u # write this info to the big list\n",
    "            \n",
    "            # anyways this is what we are interested in!\n",
    "            best_delta_u = info_u['best_delta']\n",
    "            best_i_u = info_u['best_i']\n",
    "            best_theta1_u = info_u['best_theta1']\n",
    "            best_theta2_u = info_u['best_theta2']\n",
    "                \n",
    "            print('---> Best delta = ', best_delta_u)\n",
    "\n",
    "            if best_delta_u > best_delta:\n",
    "\n",
    "                best_delta = best_delta_u\n",
    "                best_u = u\n",
    "                best_cut_i = best_i_u\n",
    "                best_cut_theta1 = best_theta1_u\n",
    "                best_cut_theta2 = best_theta2_u\n",
    "    \n",
    "    print('--- NEW CUT --- ... at node', best_u, \n",
    "          '... at Coordinate', best_cut_i, \n",
    "          '... at Thresholds', best_cut_theta1, best_cut_theta2,\n",
    "          '... for delta =', best_delta)\n",
    "\n",
    "    #go_left_u = np.where(X[np.ix_(index_nodes[best_u],[best_cut_i])]<=best_cut_theta)[0]\n",
    "    go_left_u = np.where((X[np.ix_(index_nodes[best_u],[best_cut_i])]>=best_cut_theta1) & \n",
    "                         (X[np.ix_(index_nodes[best_u],[best_cut_i])]<=best_cut_theta2))[0]\n",
    "    #print('Send Left:', (index_nodes[best_u])[go_left_u])\n",
    "    \n",
    "    #go_right_u = np.where(X[np.ix_(index_nodes[best_u],[best_cut_i])]>best_cut_theta)[0]\n",
    "    go_right_u = np.where((X[np.ix_(index_nodes[best_u],[best_cut_i])]<best_cut_theta1) |\n",
    "                          (X[np.ix_(index_nodes[best_u],[best_cut_i])]>best_cut_theta2))[0]\n",
    "    #print('Send Right:', (index_nodes[best_u])[go_right_u])\n",
    "\n",
    "    index_update = index_nodes.copy()\n",
    "    index_update[best_u] = (index_nodes[best_u])[go_left_u]\n",
    "    index_update.append((index_nodes[best_u])[go_right_u])\n",
    "    \n",
    "    info_update = info_nodes.copy()\n",
    "    info_update[best_u] = 'Empty' \n",
    "    info_update.append('Empty')\n",
    "\n",
    "    return(index_update, info_update)\n",
    "\n",
    "#points = np.where((y_true==1) | (y_true==2))[0]\n",
    "#expand_new_cut(X, y_true, [points], silent = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4f0f207",
   "metadata": {},
   "outputs": [],
   "source": [
    "def expand_build_on_imm(X, y, y_imm, max_leaves, silent = True):\n",
    "    \n",
    "    n_data = np.shape(X)[0]\n",
    "    index_nodes = []\n",
    "    info_nodes = []\n",
    "    \n",
    "    for i in np.unique(y_imm):\n",
    "        \n",
    "        indices_i = np.where(y_imm == i)[0]\n",
    "        index_nodes.append(indices_i)\n",
    "        info_nodes.append('Empty')\n",
    "    \n",
    "    # index_nodes is a partition that contains the IMM clusters\n",
    "    # we will further partition it using 'expand_new_cut'\n",
    "    \n",
    "    converged = False\n",
    "    \n",
    "    if(len(index_nodes)>=max_leaves):\n",
    "        converged = True\n",
    "        print('Converged')\n",
    "    \n",
    "    while converged == False:\n",
    "        \n",
    "        index_nodes, info_nodes = expand_new_cut(X, y, index_nodes, info_nodes)\n",
    "        \n",
    "        if(len(index_nodes)>=max_leaves):\n",
    "            converged = True\n",
    "            print('Converged')\n",
    "\n",
    "    y_greedy = np.zeros(n_data)\n",
    "    \n",
    "    for i in np.arange(len(index_nodes)):\n",
    "        index_u = index_nodes[i]\n",
    "        y_greedy[index_u] = i\n",
    "    \n",
    "    return(y_greedy)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
