{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "myfindinterval (generic function with 1 method)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "using LinearAlgebra\n",
    "using Printf\n",
    "include(\"QuantBnB-2D.jl\")\n",
    "include(\"QuantBnB-3D.jl\")\n",
    "include(\"gen_data.jl\")\n",
    "include(\"lowerbound_middle.jl\")\n",
    "include(\"Algorithms.jl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16-element Vector{String}:\n",
       " \"avila\"\n",
       " \"bank\"\n",
       " \"bean\"\n",
       " \"bidding\"\n",
       " \"eeg\"\n",
       " \"fault\"\n",
       " \"HTRU\"\n",
       " \"magic\"\n",
       " \"occupancy\"\n",
       " \"page\"\n",
       " \"raisin\"\n",
       " \"rice\"\n",
       " \"room\"\n",
       " \"segment\"\n",
       " \"skin\"\n",
       " \"wilt\""
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "regress_data = [\"carbon\",\"casp\",\"concrete\",\"energy\",\"fish\",\"gas\",\"grid\",\"news\",\"qsar\",\"query1\",\"query2\"]\n",
    "\n",
    "class_data = [\"avila\", \"bank\", \"bean\", \"bidding\", \"eeg\", \"fault\", \"HTRU\",\n",
    "\"magic\", \"occupancy\", \"page\",\"raisin\", \"rice\", \"room\", \"segment\",\"skin\",\"wilt\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 19688743840\n",
      "Total number of intervals = 4000\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 42504960\n",
      "Total number of intervals = 1160\n",
      "Current objective = 4458.0\n",
      "time = 7.3914618492126465\n",
      "--------------------------------------\n",
      "Obj = 4458.0\n",
      "Tree is Any"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 0.2164144973582967, Any[9, 0.11948085917578102, Any[1, 0.1310165830815444, "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[9, 0.13838020022463626, [0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0]]], Any[3, 0.06735744379007103, Any[5, 0.9103656541984284, [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[5, 0.8805209279172044, [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 7.391462802886963\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset: avila,  CART train/test acc: 0.532 / 0.531,  Quant-BnB train/test acc: 0.573 / 0.571"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"avila\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 5.645)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"avila\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 266603584\n",
      "Total number of intervals = 256\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 163305884\n",
      "Total number of intervals = 1406\n",
      "Current objective = 32.0\n",
      "time = 0.04653596878051758\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 10608418\n",
      "Total number of intervals = 794\n",
      "Current objective = 32.0\n",
      "time = 0.1176910400390625\n",
      "--------------------------------------\n",
      "Obj = 32.0\n",
      "Tree is Any[3, 0.19969326094729362, Any[1, 0.657224316177372, Any[2, 0.7300744262798087, [0.0 1.0], [1.0 0.0]], Any[3, 0.014472380726732303, [0.0 1.0], [1.0 0.0]]], Any[2, 0.6725941660710877, Any[1, 0.5279098125752693, [0.0 1.0], [1.0 0.0]], Any[1, 0.3058173528330052, [0.0 1.0], [1.0 0.0]]]]\n",
      "total time = 0.16422700881958008\n",
      "Dataset: bank,  CART train/test acc: 0.933 / 0.927,  Quant-BnB train/test acc: 0.971 / 0.978"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"bank\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 0.158)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"bank\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 19688743840\n",
      "Total number of intervals = 4000\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 16983563641\n",
      "Total number of intervals = 33837\n",
      "Current objective = 4458.0\n",
      "time = 4.957025051116943\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 14852901000\n",
      "Total number of intervals = 242802\n",
      "Current objective = 4458.0\n",
      "time = 65.5099310874939\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 3\n",
      "Number of remaining trees = 612231217\n",
      "Total number of intervals = 687844\n",
      "Current objective = 4409.0\n",
      "time = 229.96273398399353\n",
      "--------------------------------------\n",
      "Obj = 4409.0\n",
      "Tree is Any[1, 0.22527652290025632, Any[9, 0.1145460276434946, Any[5, 0.8879821094875104, [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[9, 0.1371202428078056, [0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0]]], Any[3, 0.06508545324714689, Any[5, 0.9103656541984284, [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[4, 0.5644186238036573, [0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 300.42969131469727\n",
      "Dataset: occupancy,  CART train/test acc: 0.532 / 0.531,  Quant-BnB train/test acc: 0.577 / 0.571"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"avila\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 300)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"occupancy\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 266603584\n",
      "Total number of intervals = 256\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 163305884\n",
      "Total number of intervals = 1406\n",
      "Current objective = 32.0\n",
      "time = 0.1144568920135498\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 2248029\n",
      "Total number of intervals = 168\n",
      "Current objective = 32.0\n",
      "time = 1.4677660465240479\n",
      "--------------------------------------\n",
      "Obj = 32.0\n",
      "Tree is Any[3, 0.19969326094729362, Any[1, 0.657224316177372, Any[2, 0.7300744262798087, [0.0 1.0], [1.0 0.0]], Any[3, 0.014472380726732303, [0.0 1.0], [1.0 0.0]]], Any[2, 0.6725941660710877, Any[1, 0.5279098125752693, [0.0 1.0], [1.0 0.0]], Any[1, 0.3058173528330052, [0.0 1.0], [1.0 0.0]]]]\n",
      "total time = 1.5822229385375977\n",
      "Dataset: bank,  CART train/test acc: 0.933 / 0.927,  Quant-BnB train/test acc: 0.971 / 0.978"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"bank\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 0.158)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"bank\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 1860430528576\n",
      "Total number of intervals = 16384\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 1400911750897\n",
      "Total number of intervals = 111045\n",
      "Current objective = 1604.0\n",
      "time = 15.901433944702148\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 0\n",
      "Total number of intervals = 0\n",
      "Current objective = 1604.0\n",
      "time = 0.6354100704193115\n",
      "--------------------------------------\n",
      "Obj = 1604.0\n",
      "Tree is Any[2, 0.26303354440605925, Any[13, 0.5534573765897823, Any[12, 0.5960115116580839, [0.0 0.0 0.0 0.0 0.0 1.0 0.0], [1.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[14, 0.2881152965985623, [0.0 0.0 0.0 0.0 1.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 1.0]]], Any[12, 0.39754863446556105, Any[13, 0.43133676314757663, [0.0 0.0 0.0 1.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 1.0 0.0 0.0]], Any[1, 0.3659638909289655, [0.0 1.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 1.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 16.536844968795776\n",
      "Dataset: bean,  CART train/test acc: 0.777 / 0.776,  Quant-BnB train/test acc: 0.853 / 0.856"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"bean\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 16.194)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"bean\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 3781512036\n",
      "Total number of intervals = 2916\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 1128574067\n",
      "Total number of intervals = 8265\n",
      "Current objective = 71.0\n",
      "time = 0.6934599876403809\n",
      "--------------------------------------\n",
      "Obj = 71.0\n",
      "Tree is Any[4, 0.7754442691068941, Any[9, 0.5555444444444444, Any[3, 0.74995, [1.0 0.0], [0.0 1.0]], Any[3, 0.25005, [1.0 0.0], [0.0 1.0]]], Any[2, 0.08813407384031359, Any[1, 0.0016102718851999998, [1.0 0.0], [1.0 0.0]], Any[3, 0.25005, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 0.6934609413146973\n",
      "Dataset: bidding,  CART train/test acc: 0.981 / 0.986,  Quant-BnB train/test acc: 0.986 / 0.986"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"bidding\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 0.545)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"bidding\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 1545264504\n",
      "Total number of intervals = 10976\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 1332824808\n",
      "Total number of intervals = 86061\n",
      "Current objective = 3799.0\n",
      "time = 7.299355983734131\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 0\n",
      "Total number of intervals = 0\n",
      "Current objective = 3799.0\n",
      "time = 1.7830519676208496\n",
      "--------------------------------------\n",
      "Obj = 3799.0\n",
      "Tree is Any[1, 0.5121581787115892, Any[2, 0.23414849945213462, Any[6, 0.34654278413252937, [0.0 1.0], [1.0 0.0]], Any[7, 0.4606819566099373, [0.0 1.0], [1.0 0.0]]], Any[6, 0.3480845203884224, Any[2, 0.2393034339596088, [0.0 1.0], [1.0 0.0]], Any[7, 0.4701802805274723, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 9.082408905029297\n",
      "Dataset: eeg,  CART train/test acc: 0.666 / 0.666,  Quant-BnB train/test acc: 0.683 / 0.698"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"eeg\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 8.927)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"eeg\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 28884958128\n",
      "Total number of intervals = 78732\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 8017358949\n",
      "Total number of intervals = 94955\n",
      "Current objective = 549.0\n",
      "time = 2.601346015930176\n",
      "--------------------------------------\n",
      "Obj = 549.0\n",
      "Tree is Any[1, 0.24081431085043986, Any[12, 0.5, Any[18, 0.0467512351326624, [0.0 0.0 0.0 0.0 0.0 0.0 1.0], [0.0 0.0 1.0 0.0 0.0 0.0 0.0]], Any[11, 0.23285767195767193, [0.0 1.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 1.0 0.0]]], Any[12, 0.5, Any[5, 0.0002408141340163639, [0.0 0.0 0.0 1.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 1.0]], Any[17, 0.5860975294948069, [0.0 0.0 0.0 0.0 0.0 0.0 1.0], [0.0 0.0 0.0 0.0 0.0 1.0 0.0]]]]\n",
      "total time = 2.601346969604492\n",
      "Dataset: fault,  CART train/test acc: 0.553 / 0.548,  Quant-BnB train/test acc: 0.646 / 0.632"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"fault\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 2.46)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"fault\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 329152524800\n",
      "Total number of intervals = 2048\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 277160952330\n",
      "Total number of intervals = 15561\n",
      "Current objective = 286.0\n",
      "time = 2.2196130752563477\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 90744084693\n",
      "Total number of intervals = 47258\n",
      "Current objective = 284.0\n",
      "time = 9.344864130020142\n",
      "--------------------------------------\n",
      "Obj = 284.0\n",
      "Tree is Any[1, 0.44940579649533685, Any[1, 0.3388093764376229, Any[3, 0.3578980679765418, [1.0 0.0], [0.0 1.0]], Any[3, 0.3165147641018484, [1.0 0.0], [0.0 1.0]]], Any[5, 0.016406746164436145, Any[3, 0.2999132835165079, [1.0 0.0], [0.0 1.0]], Any[3, 0.27999403468969986, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 11.56447720527649\n",
      "Dataset: htru,  CART train/test acc: 0.979 / 0.980,  Quant-BnB train/test acc: 0.980 / 0.981"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"htru\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 11.316)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"htru\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 580279921000\n",
      "Total number of intervals = 4000\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 547419390492\n",
      "Total number of intervals = 33772\n",
      "Current objective = 2721.0\n",
      "time = 3.78177809715271\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 85261881174\n",
      "Total number of intervals = 38614\n",
      "Current objective = 2640.0\n",
      "time = 11.369930028915405\n",
      "--------------------------------------\n",
      "Obj = 2640.0\n",
      "Tree is Any[2, 0.0451341301261399, Any[9, 0.25790231811111114, Any[3, 0.19519698698994678, [1.0 0.0], [0.0 1.0]], Any[3, 0.12857814311058546, [1.0 0.0], [0.0 1.0]]], Any[9, 0.25790231811111114, Any[1, 0.33922344620309275, [1.0 0.0], [0.0 1.0]], Any[1, 0.08910819864592663, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 15.151708126068115\n",
      "Dataset: magic,  CART train/test acc: 0.801 / 0.792,  Quant-BnB train/test acc: 0.826 / 0.822"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"magic\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 14.838)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"magic\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 1392446720\n",
      "Total number of intervals = 500\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 771573916\n",
      "Total number of intervals = 2810\n",
      "Current objective = 56.0\n",
      "time = 0.26821184158325195\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 105936085\n",
      "Total number of intervals = 6639\n",
      "Current objective = 56.0\n",
      "time = 1.2394092082977295\n",
      "--------------------------------------\n",
      "Obj = 56.0\n",
      "Tree is Any[3, 0.26297273530711446, Any[3, 0.1006837973191928, Any[1, 0.004721571648690387, [0.0 1.0], [0.0 1.0]], Any[4, 0.03317627848735276, [0.0 1.0], [1.0 0.0]]], Any[4, 0.36531619834710743, Any[1, 0.5965138212634861, [1.0 0.0], [0.0 1.0]], Any[1, 0.6763899845916868, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 1.5076210498809814\n",
      "Dataset: occupancy,  CART train/test acc: 0.989 / 0.977,  Quant-BnB train/test acc: 0.993 / 0.896"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"occupancy\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 1.458)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"occupancy\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 2679769000\n",
      "Total number of intervals = 4000\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 2029740626\n",
      "Total number of intervals = 27149\n",
      "Current objective = 152.0\n",
      "time = 1.4649300575256348\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 12369703\n",
      "Total number of intervals = 12906\n",
      "Current objective = 152.0\n",
      "time = 1.5338411331176758\n",
      "--------------------------------------\n",
      "Obj = 152.0\n",
      "Tree is Any[4, 0.01238541210034395, Any[5, 0.2600690928270042, Any[1, 0.034339726027397266, [1.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 1.0 0.0]], Any[4, 0.0005263634721495438, [0.0 0.0 1.0 0.0 0.0], [1.0 0.0 0.0 0.0 0.0]]], Any[7, 0.0006297688736374647, Any[3, 0.00027706513133221285, [0.0 1.0 0.0 0.0 0.0], [1.0 0.0 0.0 0.0 0.0]], Any[10, 0.014267206477732794, [0.0 1.0 0.0 0.0 0.0], [1.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 2.9987711906433105\n",
      "Dataset: page,  CART train/test acc: 0.964 / 0.958,  Quant-BnB train/test acc: 0.965 / 0.967"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"page\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 2.859)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"page\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 710962588\n",
      "Total number of intervals = 1372\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 626304028\n",
      "Total number of intervals = 10878\n",
      "Current objective = 87.0\n",
      "time = 0.16036391258239746\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 111193920\n",
      "Total number of intervals = 17378\n",
      "Current objective = 86.0\n",
      "time = 0.35332489013671875\n",
      "--------------------------------------\n",
      "Obj = 86.0\n",
      "Tree is Any[1, 0.2759748912525041, Any[3, 0.20832866328560215, Any[4, 0.8514364324941668, [1.0 0.0], [0.0 1.0]], Any[7, 0.2421449973757372, [1.0 0.0], [0.0 1.0]]], Any[3, 0.31656429018273025, Any[6, 0.8018273594230778, [0.0 1.0], [1.0 0.0]], Any[7, 0.27088683226222043, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 0.5136888027191162\n",
      "Dataset: raisin,  CART train/test acc: 0.869 / 0.883,  Quant-BnB train/test acc: 0.881 / 0.889"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"raisin\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 0.501)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"raisin\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 11187683388\n",
      "Total number of intervals = 1372\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 9609460287\n",
      "Total number of intervals = 10587\n",
      "Current objective = 198.0\n",
      "time = 0.33312416076660156\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 2653824644\n",
      "Total number of intervals = 26715\n",
      "Current objective = 193.0\n",
      "time = 1.7608180046081543\n",
      "--------------------------------------\n",
      "Obj = 193.0\n",
      "Tree is Any[1, 0.41328085724344304, Any[6, 0.4184850123066104, Any[3, 0.4981964921204348, [0.0 1.0], [1.0 0.0]], Any[3, 0.45220154889367065, [0.0 1.0], [1.0 0.0]]], Any[4, 0.6481065480358618, Any[3, 0.5088920709046933, [0.0 1.0], [1.0 0.0]], Any[3, 0.47889078038784194, [0.0 1.0], [1.0 0.0]]]]\n",
      "total time = 2.093942165374756\n",
      "Dataset: rice,  CART train/test acc: 0.933 / 0.917,  Quant-BnB train/test acc: 0.937 / 0.920"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"rice\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 2.004)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"rice\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 536385600\n",
      "Total number of intervals = 16384\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 24224978\n",
      "Total number of intervals = 24528\n",
      "Current objective = 97.0\n",
      "time = 2.8179709911346436\n",
      "--------------------------------------\n",
      "Obj = 97.0\n",
      "Tree is Any[5, 0.6242175757575756, Any[1, 0.6735763888888892, Any[14, 0.43430317220527126, [0.0 0.0 0.0 1.0], [0.0 0.0 1.0 0.0]], Any[1, 0.9755993055555554, [0.0 0.0 0.0 1.0], [0.0 0.0 1.0 0.0]]], Any[5, 0.8635636363636363, Any[7, 0.4107321428571429, [1.0 0.0 0.0 0.0], [0.0 1.0 0.0 0.0]], Any[7, 0.6303310714285715, [0.0 1.0 0.0 0.0], [0.0 0.0 1.0 0.0]]]]\n",
      "total time = 2.8179709911346436\n",
      "Dataset: room,  CART train/test acc: 0.968 / 0.967,  Quant-BnB train/test acc: 0.988 / 0.986"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"room\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 2.714)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"room\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 12444739848\n",
      "Total number of intervals = 23328\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 405952654\n",
      "Total number of intervals = 19080\n",
      "Current objective = 386.0\n",
      "time = 0.8659160137176514\n",
      "--------------------------------------\n",
      "Obj = 386.0\n",
      "Tree is Any[2, 0.6541358333333334, Any[12, 0.17816258641963875, Any[18, 0.16948923376884895, [0.0 1.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 1.0 0.0 0.0]], Any[9, 0.5317519688396428, [0.0 0.0 0.0 0.0 0.0 1.0 0.0], [0.0 0.0 1.0 0.0 0.0 0.0 0.0]]], Any[1, 0.33007351778656124, Any[14, 0.21829165306800202, [0.0 0.0 0.0 1.0 0.0 0.0 0.0], [1.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[9, 0.20145400313877626, [0.0 0.0 0.0 1.0 0.0 0.0 0.0], [1.0 0.0 0.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 0.8659169673919678\n",
      "Dataset: segment,  CART train/test acc: 0.574 / 0.556,  Quant-BnB train/test acc: 0.791 / 0.768"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"segment\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 0.771)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"segment\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 7077888\n",
      "Total number of intervals = 108\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 5009320\n",
      "Total number of intervals = 687\n",
      "Current objective = 6683.006683\n",
      "time = 1.9062819480895996\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 1705850\n",
      "Total number of intervals = 2103\n",
      "Current objective = 6683.006683\n",
      "time = 16.248024940490723\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 3\n",
      "Number of remaining trees = 154239\n",
      "Total number of intervals = 1695\n",
      "Current objective = 6496.0\n",
      "time = 31.789759874343872\n",
      "--------------------------------------\n",
      "Obj = 6496.0\n",
      "Tree is Any[1, 0.48039607843137255, Any[3, 0.5940988235294118, Any[3, 0.4686337254901961, [0.0 1.0], [1.0 0.0]], Any[2, 0.3745349019607843, [0.0 1.0], [1.0 0.0]]], Any[3, 0.8175835294117647, Any[1, 0.484316862745098, [0.0 1.0], [0.0 1.0]], Any[1, 0.8215043137254902, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 49.94406771659851\n",
      "Dataset: skin,  CART train/test acc: 0.966 / 0.965,  Quant-BnB train/test acc: 0.967 / 0.966"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"skin\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 48.894)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"skin\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 8269431120\n",
      "Total number of intervals = 500\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 6100918434\n",
      "Total number of intervals = 3339\n",
      "Current objective = 28.0\n",
      "time = 0.1693120002746582\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 527945707\n",
      "Total number of intervals = 2596\n",
      "Current objective = 28.0\n",
      "time = 0.43268394470214844\n",
      "--------------------------------------\n",
      "Obj = 28.0\n",
      "Tree is Any[3, 0.040439758076541966, Any[3, 0.02915658119231141, Any[2, 0.027491453606214036, [1.0 0.0], [0.0 1.0]], Any[2, 0.042869054061459674, [1.0 0.0], [0.0 1.0]]], Any[4, 0.24530530872990078, Any[2, 0.054556523870534085, [1.0 0.0], [0.0 1.0]], Any[2, 0.062044492200904744, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 0.601996898651123\n",
      "Dataset: wilt,  CART train/test acc: 0.993 / 0.768,  Quant-BnB train/test acc: 0.994 / 0.776"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"wilt\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 0.582)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"wilt\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.8.1",
   "language": "julia",
   "name": "julia-1.8"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.8.1"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
