{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "900c14f2-5241-4edd-9279-c412740709ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "from branches import *\n",
    "from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, LabelEncoder\n",
    "from time import time\n",
    "import svgling.utils"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91a6a89f-2d67-4a36-8ab8-3f6882b29149",
   "metadata": {},
   "source": [
    "Branches is an Optimal Decision Tree classifier that handles datasets with categorical features. In special cases, such as binary classification and binary features, Branches employs specialised techniques for micro-optimisation. This Tutorial illustrates how to use Branches in the general case and special cases where significant computational gains are made due to the specialised techniques."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0897102b-2515-43f9-9ac5-d3d0ccf56253",
   "metadata": {},
   "source": [
    "# MONK 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77f45ee5-519a-4072-af88-41b06d1ff395",
   "metadata": {},
   "source": [
    "## General case: Classification with categorical features"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "833de397-63e2-43ef-86bc-2019a2e09ec2",
   "metadata": {},
   "source": [
    "In this first example, we consider the original Monk 1 dataset from the UCI repository."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c67a1d26-21eb-4d18-b1f2-c0982d9403e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(124, 7)\n"
     ]
    }
   ],
   "source": [
    "data = np.genfromtxt('../data/monks-1.train', delimiter=' ', dtype=int)\n",
    "data = data[:, :-1] # Getting rid of the last column, it contains only ids.\n",
    "data = data[:, ::-1]\n",
    "\n",
    "print(data.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2cae477-e21b-4beb-a900-cd1399057801",
   "metadata": {},
   "source": [
    "We use an oridnal encoding of the dataset since it is the closest to the original dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d292e993-8768-4609-ad8f-72faa43ca198",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = OrdinalEncoder()\n",
    "encoder.fit(data)\n",
    "data = encoder.transform(data).astype(int)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "350e4cbd-56f7-44cc-b2a7-e0804d9cef3a",
   "metadata": {},
   "source": [
    "In this general case, we use Branches without specifying the encoding parameter, it is set to 'ordinal' by default."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0b38b624-0e14-4007-abd5-9a4aaba5e68e",
   "metadata": {},
   "outputs": [],
   "source": [
    "lambd = 0.01\n",
    "alg = Branches(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "53ecd194-6da3-42aa-bcea-e70550544a3d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The search finished after 64 iterations.\n",
      "Execution time : 0.0182\n"
     ]
    }
   ],
   "source": [
    "start_time = time()\n",
    "alg.solve(lambd)\n",
    "print('Execution time : %.4f' %(time() -start_time))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51175614-7ce3-4472-92a4-5b3909d7ba74",
   "metadata": {},
   "source": [
    "The next cell allows the visualisation of the Optimal Decision Tree solution. Please note that while using JupyterLab for instance, you should enable a Light theme instead of a Dark one to visualise the solution in a better contrast. This can be set via Settings -> Theme -> JupyterLab Light. Note that the default renderer in Jupyter is svging, thus if the code below does not render the Decision Tree, please try running ```svgling.draw_tree(tree)``` instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "85a8924e-c647-41a3-88bc-28e6805be32f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg baseProfile=\"full\" height=\"200px\" preserveAspectRatio=\"xMidYMid meet\" style=\"font-family: times, serif; font-weight: normal; font-style: normal; font-size: 16px;\" version=\"1.1\" viewBox=\"0,0,1512.0,200.0\" width=\"1512px\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:ev=\"http://www.w3.org/2001/xml-events\" xmlns:xlink=\"http://www.w3.org/1999/xlink\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /></svg><svg width=\"33.3333%\" x=\"0%\" y=\"32px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_4=0</text></svg><svg width=\"11.1111%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"5.55556%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"44.4444%\" x=\"11.1111%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=1</text></svg><svg width=\"25%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"12.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"25%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"37.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=2</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"62.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"75%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=3</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"87.5%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"33.3333%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"44.4444%\" x=\"55.5556%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=2</text></svg><svg width=\"25%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"12.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"25%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"37.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=2</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"62.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"75%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=3</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"87.5%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"77.7778%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"16.6667%\" y1=\"0px\" y2=\"32px\" /><svg width=\"33.3333%\" x=\"33.3333%\" y=\"32px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_4=1</text></svg><svg width=\"44.4444%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=0</text></svg><svg width=\"25%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"12.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"25%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"37.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=2</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"62.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"75%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=3</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"87.5%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"22.2222%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"11.1111%\" x=\"44.4444%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"44.4444%\" x=\"55.5556%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=2</text></svg><svg width=\"25%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"12.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"25%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"37.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=2</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"62.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"75%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=3</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"87.5%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"77.7778%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"0px\" y2=\"32px\" /><svg width=\"33.3333%\" x=\"66.6667%\" y=\"32px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_4=2</text></svg><svg width=\"44.4444%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=0</text></svg><svg width=\"25%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"12.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"25%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"37.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=2</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"62.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"75%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=3</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"87.5%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"22.2222%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"44.4444%\" x=\"44.4444%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=1</text></svg><svg width=\"25%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"12.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"25%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"37.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=2</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"62.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"75%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=3</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"87.5%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"66.6667%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"11.1111%\" x=\"88.8889%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=2</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"94.4444%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"83.3333%\" y1=\"0px\" y2=\"32px\" /></svg>"
      ],
      "text/plain": [
       "Tree('', [Tree('X_4=0', [Tree('X_5=0', ['Y=1']), Tree('X_5=1', [Tree('X_1=0', ['Y=1']), Tree('X_1=1', ['Y=0']), Tree('X_1=2', ['Y=0']), Tree('X_1=3', ['Y=0'])]), Tree('X_5=2', [Tree('X_1=0', ['Y=1']), Tree('X_1=1', ['Y=0']), Tree('X_1=2', ['Y=0']), Tree('X_1=3', ['Y=0'])])]), Tree('X_4=1', [Tree('X_5=0', [Tree('X_1=0', ['Y=1']), Tree('X_1=1', ['Y=0']), Tree('X_1=2', ['Y=0']), Tree('X_1=3', ['Y=0'])]), Tree('X_5=1', ['Y=1']), Tree('X_5=2', [Tree('X_1=0', ['Y=1']), Tree('X_1=1', ['Y=0']), Tree('X_1=2', ['Y=0']), Tree('X_1=3', ['Y=0'])])]), Tree('X_4=2', [Tree('X_5=0', [Tree('X_1=0', ['Y=1']), Tree('X_1=1', ['Y=0']), Tree('X_1=2', ['Y=0']), Tree('X_1=3', ['Y=0'])]), Tree('X_5=1', [Tree('X_1=0', ['Y=1']), Tree('X_1=1', ['Y=0']), Tree('X_1=2', ['Y=0']), Tree('X_1=3', ['Y=0'])]), Tree('X_5=2', ['Y=1'])])])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree = alg.plot_tree()\n",
    "tree"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0405ad8-43b3-4404-bfcf-fa525577e89f",
   "metadata": {},
   "source": [
    "To save this tree figure, we use the svgling package as shown below."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53a642b6-2a58-46cd-ba8e-ea4b48031a52",
   "metadata": {},
   "source": [
    "The figure above includes the predicted classes at the level of each branch, if we are only interested in the Tree structure, we can set show_classes=False in the plot_tree method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "672ced8d-432b-4b45-8407-7aee0068b696",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg baseProfile=\"full\" height=\"152px\" preserveAspectRatio=\"xMidYMid meet\" style=\"font-family: times, serif; font-weight: normal; font-style: normal; font-size: 16px;\" version=\"1.1\" viewBox=\"0,0,1512.0,152.0\" width=\"1512px\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:ev=\"http://www.w3.org/2001/xml-events\" xmlns:xlink=\"http://www.w3.org/1999/xlink\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /></svg><svg width=\"33.3333%\" x=\"0%\" y=\"32px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_4=0</text></svg><svg width=\"11.1111%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"5.55556%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"44.4444%\" x=\"11.1111%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=1</text></svg><svg width=\"25%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"12.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"25%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"37.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=2</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"62.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"75%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=3</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"87.5%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"33.3333%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"44.4444%\" x=\"55.5556%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=2</text></svg><svg width=\"25%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"12.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"25%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"37.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=2</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"62.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"75%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=3</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"87.5%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"77.7778%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"16.6667%\" y1=\"0px\" y2=\"32px\" /><svg width=\"33.3333%\" x=\"33.3333%\" y=\"32px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_4=1</text></svg><svg width=\"44.4444%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=0</text></svg><svg width=\"25%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"12.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"25%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"37.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=2</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"62.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"75%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=3</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"87.5%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"22.2222%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"11.1111%\" x=\"44.4444%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"44.4444%\" x=\"55.5556%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=2</text></svg><svg width=\"25%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"12.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"25%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"37.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=2</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"62.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"75%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=3</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"87.5%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"77.7778%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"0px\" y2=\"32px\" /><svg width=\"33.3333%\" x=\"66.6667%\" y=\"32px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_4=2</text></svg><svg width=\"44.4444%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=0</text></svg><svg width=\"25%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"12.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"25%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"37.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=2</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"62.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"75%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=3</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"87.5%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"22.2222%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"44.4444%\" x=\"44.4444%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=1</text></svg><svg width=\"25%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"12.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"25%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"37.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=2</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"62.5%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"25%\" x=\"75%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=3</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"87.5%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"66.6667%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"11.1111%\" x=\"88.8889%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=2</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"94.4444%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"83.3333%\" y1=\"0px\" y2=\"32px\" /></svg>"
      ],
      "text/plain": [
       "Tree('', [Tree('X_4=0', [Tree('X_5=0', []), Tree('X_5=1', [Tree('X_1=0', []), Tree('X_1=1', []), Tree('X_1=2', []), Tree('X_1=3', [])]), Tree('X_5=2', [Tree('X_1=0', []), Tree('X_1=1', []), Tree('X_1=2', []), Tree('X_1=3', [])])]), Tree('X_4=1', [Tree('X_5=0', [Tree('X_1=0', []), Tree('X_1=1', []), Tree('X_1=2', []), Tree('X_1=3', [])]), Tree('X_5=1', []), Tree('X_5=2', [Tree('X_1=0', []), Tree('X_1=1', []), Tree('X_1=2', []), Tree('X_1=3', [])])]), Tree('X_4=2', [Tree('X_5=0', [Tree('X_1=0', []), Tree('X_1=1', []), Tree('X_1=2', []), Tree('X_1=3', [])]), Tree('X_5=1', [Tree('X_1=0', []), Tree('X_1=1', []), Tree('X_1=2', []), Tree('X_1=3', [])]), Tree('X_5=2', [])])])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree = alg.plot_tree(show_classes=False)\n",
    "tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9b00681a-7e8b-45e4-9b88-9b9873f2edb2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of branches : 27\n",
      "Number of splits : 10\n",
      "Accuracy : 1.0\n"
     ]
    }
   ],
   "source": [
    "branches, splits = alg.lattice.infer()\n",
    "print('Number of branches :', len(branches))\n",
    "print('Number of splits :', splits)\n",
    "print('Accuracy :', ((alg.predict(data[:, :-1]) == data[:, -1]).sum())/alg.n_total)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62d83690-3103-4943-9dd6-8d2d2f6dd633",
   "metadata": {},
   "source": [
    "The following cell clears the memo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1b740922-b2bd-45cf-8b88-82d5591092a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "alg.reinitialise()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bec22450-2881-4437-b478-67e4a64f1f5c",
   "metadata": {},
   "source": [
    "## Special case: Binary Classification with binary featrues."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e7e6d1e-6b06-46b7-a9da-00de77a6ce99",
   "metadata": {},
   "source": [
    "In this second example, we consider a binary encoding of Monk 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "63e0ff76-4eb7-4ecc-8e3e-b2a6face5d2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(124, 7)\n"
     ]
    }
   ],
   "source": [
    "data = np.genfromtxt('../data/monks-1.train', delimiter=' ', dtype=int)\n",
    "data = data[:, :-1] # Getting rid of the last column, it contains only ids.\n",
    "data = data[:, ::-1]\n",
    "\n",
    "print(data.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9281128-7de2-460a-ad88-467329a23730",
   "metadata": {},
   "source": [
    "We use scikit-learn's OneHotEncoder, and we drop the last category of each feature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5178bc3f-950a-40fb-84e6-f59750342f96",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(124, 12)\n"
     ]
    }
   ],
   "source": [
    "encoder = OneHotEncoder(drop=np.array([3, 3, 2, 3, 4, 2][::-1]+[0]), sparse_output=False)\n",
    "encoder.fit(data)\n",
    "data = encoder.transform(data).astype(int)\n",
    "print(data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "1856a3c8-8051-4bbc-a1f3-a75bba259d25",
   "metadata": {},
   "outputs": [],
   "source": [
    "lambd = 0.01\n",
    "alg = Branches(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7e9d352d-e210-4d0f-9212-3c791454398d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 100\n",
      "Iteration 200\n",
      "Iteration 300\n",
      "Iteration 400\n",
      "Iteration 500\n",
      "Iteration 600\n",
      "The search finished after 617 iterations.\n",
      "Execution time : 0.2948\n"
     ]
    }
   ],
   "source": [
    "start_time = time()\n",
    "alg.solve(lambd)\n",
    "print('Execution time : %.4f' %(time() -start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7441c913-5834-4a25-beae-8a76a73a3c66",
   "metadata": {},
   "outputs": [],
   "source": [
    "alg.reinitialise()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1774177c-70ae-4b09-8e16-e71dee59cad2",
   "metadata": {},
   "source": [
    "In this first attempt, Branches converges in around $0.29s$. However, since we are dealing with a binary classification problem with binary features, we can benefit from a computational speedup due to specialised micro-optimisation techniques for this special case. We do this by specifying parameter ```encoding='binary'``` . We notice in the following cell that, indeed, the execution time of Branches drops from $0.29$ to around $0.07s$. This speedup can be even more significant in other cases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ba68194a-e5f0-41bb-a283-efa1901b048c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 100\n",
      "Iteration 200\n",
      "Iteration 300\n",
      "Iteration 400\n",
      "Iteration 500\n",
      "Iteration 600\n",
      "The search finished after 617 iterations.\n",
      "Execution time : 0.0758\n"
     ]
    }
   ],
   "source": [
    "alg = Branches(data, encoding='binary')\n",
    "start_time = time()\n",
    "alg.solve(lambd)\n",
    "print('Execution time : %.4f' %(time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2256020f-b5c8-4df2-bef1-aa41d27d1439",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg baseProfile=\"full\" height=\"296px\" preserveAspectRatio=\"xMidYMid meet\" style=\"font-family: times, serif; font-weight: normal; font-style: normal; font-size: 16px;\" version=\"1.1\" viewBox=\"0,0,448.0,296.0\" width=\"448px\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:ev=\"http://www.w3.org/2001/xml-events\" xmlns:xlink=\"http://www.w3.org/1999/xlink\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /></svg><svg width=\"87.5%\" x=\"0%\" y=\"32px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg><svg width=\"71.4286%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_7=0</text></svg><svg width=\"60%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_10=0</text></svg><svg width=\"66.6667%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_8=0</text></svg><svg width=\"50%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_9=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"25%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"50%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_9=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"75%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"33.3333%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"33.3333%\" x=\"66.6667%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_8=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"83.3333%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"30%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"40%\" x=\"60%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_10=1</text></svg><svg width=\"50%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_8=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"25%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"50%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_8=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"75%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"80%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"35.7143%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"28.5714%\" x=\"71.4286%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_7=1</text></svg><svg width=\"50%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_9=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"25%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"50%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_9=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"75%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"85.7143%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"43.75%\" y1=\"0px\" y2=\"32px\" /><svg width=\"12.5%\" x=\"87.5%\" y=\"32px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"93.75%\" y1=\"0px\" y2=\"32px\" /></svg>"
      ],
      "text/plain": [
       "Tree('', [Tree('X_1=0', [Tree('X_7=0', [Tree('X_10=0', [Tree('X_8=0', [Tree('X_9=0', ['Y=1']), Tree('X_9=1', ['Y=0'])]), Tree('X_8=1', ['Y=0'])]), Tree('X_10=1', [Tree('X_8=0', ['Y=0']), Tree('X_8=1', ['Y=1'])])]), Tree('X_7=1', [Tree('X_9=0', ['Y=0']), Tree('X_9=1', ['Y=1'])])]), Tree('X_1=1', ['Y=1'])])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree = alg.plot_tree()\n",
    "tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "63ea6ef4-a12a-4269-9343-edf8312f6e6b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of branches : 8\n",
      "Number of splits : 7\n",
      "Accuracy : 1.0\n"
     ]
    }
   ],
   "source": [
    "branches, splits = alg.lattice.infer()\n",
    "print('Number of branches :', len(branches))\n",
    "print('Number of splits :', splits)\n",
    "print('Accuracy :', ((alg.predict(data[:, :-1]) == data[:, -1]).sum())/alg.n_total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "2c5f79f3-1990-4e84-a16d-1b22bc68f3f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "alg.reinitialise()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f26d82a8-17b6-40d1-b381-1e02f267fa55",
   "metadata": {},
   "source": [
    "# Car Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e96d808-e095-4668-9225-cd16565ed5af",
   "metadata": {},
   "source": [
    "## Special case: Classification with Binary features"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af5c11ba-72f0-4cf6-93d9-0fe5a82cdaec",
   "metadata": {},
   "source": [
    "In this last example, we consider a multiclass classification problem with binary features. For this, we use the Car Evaluation dataset from the UCI repository that we encode with scikit-learn's OneHoteEncoder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "85ae6926-62c7-4951-a48f-cf6fd096db61",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1728, 7)\n"
     ]
    }
   ],
   "source": [
    "data = np.genfromtxt('../data/car.data', delimiter=',', dtype=str)\n",
    "print(data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "df605dae-563c-426b-8059-1a04a85011e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of classes : 4\n"
     ]
    }
   ],
   "source": [
    "X, y = data[:, :-1], data[:, -1]\n",
    "print('Number of classes :', len(set(y)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "3d69af3c-1209-4f27-800c-ba564f04af2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1728, 16)\n"
     ]
    }
   ],
   "source": [
    "encoder_X, encoder_y = OneHotEncoder(drop='first', sparse_output=False), LabelEncoder()\n",
    "encoder_X.fit(X)\n",
    "encoder_y.fit(y)\n",
    "X = encoder_X.transform(X).astype(int)\n",
    "y = encoder_y.transform(y)\n",
    "\n",
    "data = np.hstack((X, y.reshape(-1, 1)))\n",
    "print(data.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4659d90-567f-4fd6-9f03-7c07151f1ab3",
   "metadata": {},
   "source": [
    "In similar fashion to our previous example, we first consider the general case where we do not employ specialised techniques for handling binary features. Here Branches takes around $68s$ to converge."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "5af27635-b6cb-4355-805b-af70f8093aa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "lambd = 0.005\n",
    "time_limit = 300\n",
    "alg = Branches(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "9b738312-80f9-4a1f-997c-c23a284b9e4f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 100000\n",
      "Iteration 200000\n",
      "Iteration 300000\n",
      "The search finished after 302811 iterations.\n",
      "Execution time : 68.1480\n"
     ]
    }
   ],
   "source": [
    "start_time = time()\n",
    "alg.solve(lambd, n=5000000, print_iter=100000, time_limit=time_limit)\n",
    "print('Execution time : %.4f' %(time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "16ba104c-f224-4e3a-bdc8-1d2eed19ec43",
   "metadata": {},
   "outputs": [],
   "source": [
    "alg.reinitialise()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb193d31-edbe-418e-b7f8-75a62bd627f5",
   "metadata": {},
   "source": [
    "When dealing with a multiclassification problem with binary features, set paramter ```encoding='multi'``` to benefit from micro-optimisation. We observe that the execution time drops from $68s$ to around $51s$ ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d368f044-1f98-4fcc-9141-fe4a5d38faa6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration 100000\n",
      "Iteration 200000\n",
      "Iteration 300000\n",
      "The search finished after 302807 iterations.\n",
      "Execution time : 51.1470\n"
     ]
    }
   ],
   "source": [
    "alg = Branches(data, encoding='multi')\n",
    "start_time = time()\n",
    "alg.solve(lambd, n=5000000, print_iter=100000, time_limit=time_limit)\n",
    "print('Execution time : %.4f' %(time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "f906417e-72d7-43a6-bb01-ff87dc05ab00",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg baseProfile=\"full\" height=\"440px\" preserveAspectRatio=\"xMidYMid meet\" style=\"font-family: times, serif; font-weight: normal; font-style: normal; font-size: 16px;\" version=\"1.1\" viewBox=\"0,0,904.0,440.0\" width=\"904px\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:ev=\"http://www.w3.org/2001/xml-events\" xmlns:xlink=\"http://www.w3.org/1999/xlink\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /></svg><svg width=\"92.9204%\" x=\"0%\" y=\"32px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_13=0</text></svg><svg width=\"57.1429%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_9=0</text></svg><svg width=\"13.3333%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_10=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=2</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"6.66667%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"86.6667%\" x=\"13.3333%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_10=1</text></svg><svg width=\"86.5385%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg><svg width=\"82.2222%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_12=0</text></svg><svg width=\"62.1622%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=0</text></svg><svg width=\"30.4348%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_0=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"15.2174%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"69.5652%\" x=\"30.4348%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_0=1</text></svg><svg width=\"50%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_14=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=3</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"25%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"50%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_14=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=1</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"75%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"65.2174%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"31.0811%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"37.8378%\" x=\"62.1622%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=1</text></svg><svg width=\"50%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_0=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=2</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"25%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"50%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_0=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"75%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"81.0811%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"41.1111%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"17.7778%\" x=\"82.2222%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_12=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=2</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"91.1111%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"43.2692%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"13.4615%\" x=\"86.5385%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"93.2692%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"56.6667%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"28.5714%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"42.8571%\" x=\"57.1429%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_9=1</text></svg><svg width=\"84.4444%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_0=0</text></svg><svg width=\"81.5789%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=0</text></svg><svg width=\"77.4194%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=0</text></svg><svg width=\"33.3333%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_12=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"16.6667%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"66.6667%\" x=\"33.3333%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_12=1</text></svg><svg width=\"50%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_14=0</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"25%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"50%\" x=\"50%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_14=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=2</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"75%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"66.6667%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"38.7097%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"22.5806%\" x=\"77.4194%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_5=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=2</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"88.7097%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"40.7895%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"18.4211%\" x=\"81.5789%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_1=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"90.7895%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"42.2222%\" y1=\"19.2px\" y2=\"48px\" /><svg width=\"15.5556%\" x=\"84.4444%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_0=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=0</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"92.2222%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"78.5714%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"46.4602%\" y1=\"0px\" y2=\"32px\" /><svg width=\"7.07965%\" x=\"92.9204%\" y=\"32px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">X_13=1</text></svg><svg width=\"100%\" x=\"0%\" y=\"48px\"><defs /><svg width=\"100%\" x=\"0\" y=\"0px\"><defs /><text text-anchor=\"middle\" x=\"50%\" y=\"16px\">Y=2</text></svg></svg><line stroke=\"black\" x1=\"50%\" x2=\"50%\" y1=\"19.2px\" y2=\"48px\" /></svg><line stroke=\"black\" x1=\"50%\" x2=\"96.4602%\" y1=\"0px\" y2=\"32px\" /></svg>"
      ],
      "text/plain": [
       "Tree('', [Tree('X_13=0', [Tree('X_9=0', [Tree('X_10=0', ['Y=2']), Tree('X_10=1', [Tree('X_1=0', [Tree('X_12=0', [Tree('X_5=0', [Tree('X_0=0', ['Y=0']), Tree('X_0=1', [Tree('X_14=0', ['Y=3']), Tree('X_14=1', ['Y=1'])])]), Tree('X_5=1', [Tree('X_0=0', ['Y=2']), Tree('X_0=1', ['Y=0'])])]), Tree('X_12=1', ['Y=2'])]), Tree('X_1=1', ['Y=0'])])]), Tree('X_9=1', [Tree('X_0=0', [Tree('X_1=0', [Tree('X_5=0', [Tree('X_12=0', ['Y=0']), Tree('X_12=1', [Tree('X_14=0', ['Y=0']), Tree('X_14=1', ['Y=2'])])]), Tree('X_5=1', ['Y=2'])]), Tree('X_1=1', ['Y=0'])]), Tree('X_0=1', ['Y=0'])])]), Tree('X_13=1', ['Y=2'])])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree = alg.plot_tree()\n",
    "tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "d554ec28-c3ab-4998-8820-51dac77044cd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of branches : 15\n",
      "Number of splits : 14\n",
      "Accuracy : 0.8692129629629629\n"
     ]
    }
   ],
   "source": [
    "branches, splits = alg.lattice.infer()\n",
    "print('Number of branches :', len(branches))\n",
    "print('Number of splits :', splits)\n",
    "print('Accuracy :', ((alg.predict(data[:, :-1]) == data[:, -1]).sum())/alg.n_total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "17fcd245-6c12-4249-be16-c2bc653f1c15",
   "metadata": {},
   "outputs": [],
   "source": [
    "alg.reinitialise()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c857d465-a53f-49d6-87af-130fd565906c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
