{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "myfindinterval (generic function with 1 method)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "using LinearAlgebra\n",
    "using Printf\n",
    "include(\"QuantBnB-2D.jl\")\n",
    "include(\"QuantBnB-3D.jl\")\n",
    "include(\"gen_data.jl\")\n",
    "include(\"lowerbound_middle.jl\")\n",
    "include(\"Algorithms.jl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16-element Vector{String}:\n",
       " \"avila\"\n",
       " \"bank\"\n",
       " \"bean\"\n",
       " \"bidding\"\n",
       " \"eeg\"\n",
       " \"fault\"\n",
       " \"HTRU\"\n",
       " \"magic\"\n",
       " \"occupancy\"\n",
       " \"page\"\n",
       " \"raisin\"\n",
       " \"rice\"\n",
       " \"room\"\n",
       " \"segment\"\n",
       " \"skin\"\n",
       " \"wilt\""
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "regress_data = [\"carbon\",\"casp\",\"concrete\",\"energy\",\"fish\",\"gas\",\"grid\",\"news\",\"qsar\",\"query1\",\"query2\"]\n",
    "\n",
    "class_data = [\"avila\", \"bank\", \"bean\", \"bidding\", \"eeg\", \"fault\", \"HTRU\",\n",
    "\"magic\", \"occupancy\", \"page\",\"raisin\", \"rice\", \"room\", \"segment\",\"skin\",\"wilt\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 19688743840\n",
      "Total number of intervals = 4000\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 9458024641\n",
      "Total number of intervals = 30357\n",
      "Current objective = 4458.0\n",
      "time = 12.472540855407715\n",
      "--------------------------------------\n",
      "Obj = 4458.0\n",
      "Tree is Any"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 0.2164144973582967, Any[9, 0.11948085917578102, Any[1, 0.1310165830815444, "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[9, 0.13838020022463626, [0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0]]], Any[3, 0.06735744379007103, Any[5, 0.9103656541984284, [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[5, 0.8805209279172044, [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 12.472540855407715\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset: avila,  CART train/test acc: 0.532 / 0.531,  Quant-BnB train/test acc: 0.573 / 0.571"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"avila\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 5.645 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"avila\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 266603584\n",
      "Total number of intervals = 256\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 163305884\n",
      "Total number of intervals = 1406\n",
      "Current objective = 32.0\n",
      "time = 0.04915308952331543\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 30523655\n",
      "Total number of intervals = 2328\n",
      "Current objective = 31.0\n",
      "time = 1.7883870601654053\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 3\n",
      "Number of remaining trees = 2814014\n",
      "Total number of intervals = 1941\n",
      "Current objective = 21.0\n",
      "time = 0.4944300651550293\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4\n",
      "Number of remaining trees = 21068\n",
      "Total number of intervals = 130\n",
      "Current objective = 19.0\n",
      "time = 0.5039830207824707\n",
      "--------------------------------------\n",
      "Loop 5\n",
      "Number of remaining trees = 5746\n",
      "Total number of intervals = 320\n",
      "Current objective = 19.0\n",
      "time = 0.06704401969909668\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6\n",
      "Number of remaining trees = 1216\n",
      "Total number of intervals = 1088\n",
      "Current objective = 19.0\n",
      "time = 0.4007701873779297\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Obj = 19.0\n",
      "Tree is Any[3, 0.2365969434165464, Any[2, 0.7094440839373315, Any[1, 0.6853720579221023, [0.0 1.0], [1.0 0.0]], Any[1, 0.26284590283336573, [0.0 1.0], [1.0 0.0]]], Any[2, 0.5380485690017099, Any[1, 0.47447777059039864, [0.0 1.0], [1.0 0.0]], Any[1, 0.299825863747485, [0.0 1.0], [1.0 0.0]]]]\n",
      "total time = 3.7923636436462402\n",
      "Dataset: bank,  CART train/test acc: 0.933 / 0.927,  Quant-BnB train/test acc: 0.983 / 0.978"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"bank\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 0.158 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"bank\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 19688743840\n",
      "Total number of intervals = 4000\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 16983563641\n",
      "Total number of intervals = 33837\n",
      "Current objective = 4458.0\n",
      "time = 6.470751047134399\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 14852901000\n",
      "Total number of intervals = 242802\n",
      "Current objective = 4458.0\n",
      "time = 69.37562584877014\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 3\n",
      "Number of remaining trees = 412937953\n",
      "Total number of intervals = 469185\n",
      "Current objective = 4409.0\n",
      "time = 230.6009738445282\n",
      "--------------------------------------\n",
      "Obj = 4409.0\n",
      "Tree is Any[1, 0.22527652290025632, Any[9, 0.1145460276434946, Any[5, 0.8879821094875104, [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[9, 0.1371202428078056, [0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0]]], Any[3, 0.06508545324714689, Any[5, 0.9103656541984284, [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[4, 0.5644186238036573, [0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 306.44735193252563\n",
      "Dataset: occupancy,  CART train/test acc: 0.532 / 0.531,  Quant-BnB train/test acc: 0.577 / 0.571"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"avila\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 300 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"occupancy\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 266603584\n",
      "Total number of intervals = 256\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 163305884\n",
      "Total number of intervals = 1406\n",
      "Current objective = 32.0\n",
      "time = 0.08709406852722168\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 30523655\n",
      "Total number of intervals = 2328\n",
      "Current objective = 31.0\n",
      "time = 0.47351694107055664\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 3\n",
      "Number of remaining trees = 2814014\n",
      "Total number of intervals = 1941\n",
      "Current objective = 21.0\n",
      "time = 0.9503560066223145\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4\n",
      "Number of remaining trees = 21068\n",
      "Total number of intervals = 130\n",
      "Current objective = 19.0\n",
      "time = 0.9280400276184082\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 5\n",
      "Number of remaining trees = 5746\n",
      "Total number of intervals = 320\n",
      "Current objective = 19.0\n",
      "time = 0.12368011474609375\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6\n",
      "Number of remaining trees = 1216\n",
      "Total number of intervals = 1088\n",
      "Current objective = 19.0\n",
      "time = 0.5627219676971436\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Obj = 19.0\n",
      "Tree is Any[3, 0.2365969434165464, Any[2, 0.7094440839373315, Any[1, 0.6853720579221023, [0.0 1.0], [1.0 0.0]], Any[1, 0.26284590283336573, [0.0 1.0], [1.0 0.0]]], Any[2, 0.5380485690017099, Any[1, 0.47447777059039864, [0.0 1.0], [1.0 0.0]], Any[1, 0.299825863747485, [0.0 1.0], [1.0 0.0]]]]\n",
      "total time = 3.270132064819336\n",
      "Dataset: bank,  CART train/test acc: 0.933 / 0.927,  Quant-BnB train/test acc: 0.983 / 0.978"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"bank\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 0.158 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"bank\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 1860430528576\n",
      "Total number of intervals = 16384\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 1400911750897\n",
      "Total number of intervals = 111045\n",
      "Current objective = 1604.0\n",
      "time = 21.231276035308838\n",
      "--------------------------------------\n",
      "Obj = 1604.0\n",
      "Tree is Any[2, 0.26303354440605925, Any[13, 0.5534573765897823, Any[12, 0.5960115116580839, [0.0 0.0 0.0 0.0 0.0 1.0 0.0], [1.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[14, 0.2881152965985623, [0.0 0.0 0.0 0.0 1.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 1.0]]], Any[12, 0.39754863446556105, Any[13, 0.43133676314757663, [0.0 0.0 0.0 1.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 1.0 0.0 0.0]], Any[1, 0.3659638909289655, [0.0 1.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 1.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 21.231276988983154\n",
      "Dataset: bean,  CART train/test acc: 0.777 / 0.776,  Quant-BnB train/test acc: 0.853 / 0.856"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"bean\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 16.194+ 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"bean\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 3781512036\n",
      "Total number of intervals = 2916\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 2038097934\n",
      "Total number of intervals = 14984\n",
      "Current objective = 64.0\n",
      "time = 1.3418939113616943\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 65585123\n",
      "Total number of intervals = 6421\n",
      "Current objective = 64.0\n",
      "time = 4.2167840003967285\n",
      "--------------------------------------\n",
      "Obj = 64.0\n",
      "Tree is Any[9, 0.5555444444444444, Any[4, 0.7754442691068941, Any[3, 0.74995, [1.0 0.0], [0.0 1.0]], Any[3, 0.25005, [1.0 0.0], [0.0 1.0]]], Any[2, 0.08813407384031359, Any[1, 0.0016102718851999998, [1.0 0.0], [1.0 0.0]], Any[3, 0.25005, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 5.558677911758423\n",
      "Dataset: bidding,  CART train/test acc: 0.981 / 0.986,  Quant-BnB train/test acc: 0.987 / 0.987"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"bidding\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 0.545 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"bidding\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 1545264504\n",
      "Total number of intervals = 10976\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 1332824808\n",
      "Total number of intervals = 86061\n",
      "Current objective = 3799.0\n",
      "time = 9.5902578830719\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 28228767\n",
      "Total number of intervals = 11621\n",
      "Current objective = 3799.0\n",
      "time = 5.40713906288147\n",
      "--------------------------------------\n",
      "Obj = 3799.0\n",
      "Tree is Any[1, 0.5121581787115892, Any[2, 0.23414849945213462, Any[6, 0.34654278413252937, [0.0 1.0], [1.0 0.0]], Any[7, 0.4606819566099373, [0.0 1.0], [1.0 0.0]]], Any[6, 0.3480845203884224, Any[2, 0.2393034339596088, [0.0 1.0], [1.0 0.0]], Any[7, 0.4701802805274723, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 14.99739694595337\n",
      "Dataset: eeg,  CART train/test acc: 0.666 / 0.666,  Quant-BnB train/test acc: 0.683 / 0.698"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"eeg\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 8.927 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"eeg\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 28884958128\n",
      "Total number of intervals = 78732\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 13080360712\n",
      "Total number of intervals = 200707\n",
      "Current objective = 549.0\n",
      "time = 7.810006856918335\n",
      "--------------------------------------\n",
      "Obj = 549.0\n",
      "Tree is Any[1, 0.24081431085043986, Any[12, 0.5, Any[18, 0.0467512351326624, [0.0 0.0 0.0 0.0 0.0 0.0 1.0], [0.0 0.0 1.0 0.0 0.0 0.0 0.0]], Any[11, 0.23285767195767193, [0.0 1.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 1.0 0.0]]], Any[12, 0.5, Any[5, 0.0002408141340163639, [0.0 0.0 0.0 1.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 1.0]], Any[17, 0.5860975294948069, [0.0 0.0 0.0 0.0 0.0 0.0 1.0], [0.0 0.0 0.0 0.0 0.0 1.0 0.0]]]]\n",
      "total time = 7.810007810592651\n",
      "Dataset: fault,  CART train/test acc: 0.553 / 0.548,  Quant-BnB train/test acc: 0.646 / 0.632"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"fault\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 2.46 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"fault\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 329152524800\n",
      "Total number of intervals = 2048\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 277160952330\n",
      "Total number of intervals = 15561\n",
      "Current objective = 286.0\n",
      "time = 2.6682960987091064\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 101985369682\n",
      "Total number of intervals = 52318\n",
      "Current objective = 284.0\n",
      "time = 14.105509996414185\n",
      "--------------------------------------\n",
      "Obj = 284.0\n",
      "Tree is Any[1, 0.44940579649533685, Any[1, 0.3388093764376229, Any[3, 0.3578980679765418, [1.0 0.0], [0.0 1.0]], Any[3, 0.3165147641018484, [1.0 0.0], [0.0 1.0]]], Any[5, 0.016406746164436145, Any[3, 0.2999132835165079, [1.0 0.0], [0.0 1.0]], Any[3, 0.27999403468969986, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 16.773807287216187\n",
      "Dataset: htru,  CART train/test acc: 0.979 / 0.980,  Quant-BnB train/test acc: 0.980 / 0.981"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"htru\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 11.316 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"htru\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 580279921000\n",
      "Total number of intervals = 4000\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 547419390492\n",
      "Total number of intervals = 33772\n",
      "Current objective = 2721.0\n",
      "time = 5.221412897109985\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 85261881174\n",
      "Total number of intervals = 38614\n",
      "Current objective = 2640.0\n",
      "time = 15.24050498008728\n",
      "--------------------------------------\n",
      "Obj = 2640.0\n",
      "Tree is Any[2, 0.0451341301261399, Any[9, 0.25790231811111114, Any[3, 0.19519698698994678, [1.0 0.0], [0.0 1.0]], Any[3, 0.12857814311058546, [1.0 0.0], [0.0 1.0]]], Any[9, 0.25790231811111114, Any[1, 0.33922344620309275, [1.0 0.0], [0.0 1.0]], Any[1, 0.08910819864592663, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 20.461918830871582\n",
      "Dataset: magic,  CART train/test acc: 0.801 / 0.792,  Quant-BnB train/test acc: 0.826 / 0.822"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"magic\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 14.838 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"magic\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 1392446720\n",
      "Total number of intervals = 500\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 771573916\n",
      "Total number of intervals = 2810\n",
      "Current objective = 56.0\n",
      "time = 0.5248799324035645\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 346770440\n",
      "Total number of intervals = 12349\n",
      "Current objective = 56.0\n",
      "time = 3.155334949493408\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 3\n",
      "Number of remaining trees = 7489364\n",
      "Total number of intervals = 8932\n",
      "Current objective = 56.0\n",
      "time = 2.809142827987671\n",
      "--------------------------------------\n",
      "Obj = 56.0\n",
      "Tree is Any[3, 0.26297273530711446, Any[3, 0.1006837973191928, Any[1, 0.004721571648690387, [0.0 1.0], [0.0 1.0]], Any[4, 0.03317627848735276, [0.0 1.0], [1.0 0.0]]], Any[4, 0.36531619834710743, Any[1, 0.5965138212634861, [1.0 0.0], [0.0 1.0]], Any[1, 0.6763899845916868, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 6.48935866355896\n",
      "Dataset: occupancy,  CART train/test acc: 0.989 / 0.977,  Quant-BnB train/test acc: 0.993 / 0.896"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"occupancy\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 1.458 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"occupancy\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 2679769000\n",
      "Total number of intervals = 4000\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 2029740626\n",
      "Total number of intervals = 27149\n",
      "Current objective = 152.0\n",
      "time = 2.0712649822235107\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 276899467\n",
      "Total number of intervals = 41997\n",
      "Current objective = 145.0\n",
      "time = 6.081144094467163\n",
      "--------------------------------------\n",
      "Obj = 145.0\n",
      "Tree is Any[4, 0.0072057475609551705, Any[5, 0.18097099156118146, Any[1, 0.014418430884184309, [1.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 1.0 0.0]], Any[4, 0.0005263634721495438, [0.0 0.0 1.0 0.0 0.0], [1.0 0.0 0.0 0.0 0.0]]], Any[5, 0.18097099156118146, Any[7, 0.0012130191764230926, [1.0 0.0 0.0 0.0 0.0], [0.0 1.0 0.0 0.0 0.0]], Any[1, 0.0032127023661270237, [0.0 1.0 0.0 0.0 0.0], [1.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 8.15241026878357\n",
      "Dataset: page,  CART train/test acc: 0.964 / 0.958,  Quant-BnB train/test acc: 0.967 / 0.959"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"page\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 2.859 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"page\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 710962588\n",
      "Total number of intervals = 1372\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 626304028\n",
      "Total number of intervals = 10878\n",
      "Current objective = 87.0\n",
      "time = 0.23666095733642578\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 473914668\n",
      "Total number of intervals = 74084\n",
      "Current objective = 84.0\n",
      "time = 2.32940411567688\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 3\n",
      "Number of remaining trees = 35437927\n",
      "Total number of intervals = 49771\n",
      "Current objective = 82.0\n",
      "time = 3.019047975540161\n",
      "--------------------------------------\n",
      "Obj = 82.0\n",
      "Tree is Any[1, 0.29037389964704763, Any[6, 0.6548782591691547, Any[2, 0.2187553452848061, [1.0 0.0], [0.0 1.0]], Any[2, 0.25564772741067743, [1.0 0.0], [0.0 1.0]]], Any[2, 0.27326125001414414, Any[7, 0.27027430608574005, [1.0 0.0], [0.0 1.0]], Any[7, 0.27660566715688184, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 5.585114002227783\n",
      "Dataset: raisin,  CART train/test acc: 0.869 / 0.883,  Quant-BnB train/test acc: 0.886 / 0.883"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"raisin\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 0.501 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"raisin\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 11187683388\n",
      "Total number of intervals = 1372\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 9609460287\n",
      "Total number of intervals = 10587\n",
      "Current objective = 198.0\n",
      "time = 0.5208170413970947\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 7460611876\n",
      "Total number of intervals = 73585\n",
      "Current objective = 193.0\n",
      "time = 5.375838994979858\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 3\n",
      "Number of remaining trees = 95543265\n",
      "Total number of intervals = 9756\n",
      "Current objective = 193.0\n",
      "time = 1.2736380100250244\n",
      "--------------------------------------\n",
      "Obj = 193.0\n",
      "Tree is Any[1, 0.41328085724344304, Any[6, 0.4184850123066104, Any[3, 0.4981964921204348, [0.0 1.0], [1.0 0.0]], Any[3, 0.45220154889367065, [0.0 1.0], [1.0 0.0]]], Any[4, 0.6481065480358618, Any[3, 0.5088920709046933, [0.0 1.0], [1.0 0.0]], Any[3, 0.47889078038784194, [0.0 1.0], [1.0 0.0]]]]\n",
      "total time = 7.1702940464019775\n",
      "Dataset: rice,  CART train/test acc: 0.933 / 0.917,  Quant-BnB train/test acc: 0.937 / 0.920"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"rice\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 2.004 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"rice\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 536385600\n",
      "Total number of intervals = 16384\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 128634903\n",
      "Total number of intervals = 53080\n",
      "Current objective = 97.0\n",
      "time = 7.9541730880737305\n",
      "--------------------------------------\n",
      "Obj = 97.0\n",
      "Tree is Any[5, 0.6242175757575756, Any[1, 0.6735763888888892, Any[14, 0.43430317220527126, [0.0 0.0 0.0 1.0], [0.0 0.0 1.0 0.0]], Any[1, 0.9755993055555554, [0.0 0.0 0.0 1.0], [0.0 0.0 1.0 0.0]]], Any[5, 0.8635636363636363, Any[7, 0.4107321428571429, [1.0 0.0 0.0 0.0], [0.0 1.0 0.0 0.0]], Any[7, 0.6303310714285715, [0.0 1.0 0.0 0.0], [0.0 0.0 1.0 0.0]]]]\n",
      "total time = 7.9541730880737305\n",
      "Dataset: room,  CART train/test acc: 0.968 / 0.967,  Quant-BnB train/test acc: 0.988 / 0.986"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"room\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 2.714 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"room\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 12444739848\n",
      "Total number of intervals = 23328\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 5960526016\n",
      "Total number of intervals = 99869\n",
      "Current objective = 226.0\n",
      "time = 5.908473968505859\n",
      "--------------------------------------\n",
      "Obj = 226.0\n",
      "Tree is Any[10, 0.17551061325373268, Any[18, 0.1939733550127248, Any[18, 0.14785660397852451, [0.0 1.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 1.0]], Any[2, 0.5687362499999999, [0.0 0.0 0.0 0.0 1.0 0.0 0.0], [0.0 0.0 0.0 1.0 0.0 0.0 0.0]]], Any[10, 0.4145227869864083, Any[2, 0.60622875, [0.0 0.0 0.0 0.0 0.0 1.0 0.0], [1.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[9, 0.5317519688396428, [0.0 0.0 0.0 0.0 0.0 1.0 0.0], [0.0 0.0 1.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 5.908474922180176\n",
      "Dataset: segment,  CART train/test acc: 0.574 / 0.556,  Quant-BnB train/test acc: 0.878 / 0.842"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"segment\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 0.771 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"segment\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 7077888\n",
      "Total number of intervals = 108\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 5009320\n",
      "Total number of intervals = 687\n",
      "Current objective = 6683.006683\n",
      "time = 2.5524308681488037\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 1705850\n",
      "Total number of intervals = 2103\n",
      "Current objective = 6683.006683\n",
      "time = 21.20090413093567\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 3\n",
      "Number of remaining trees = 97369\n",
      "Total number of intervals = 1064\n",
      "Current objective = 6496.0\n",
      "time = 31.113924026489258\n",
      "--------------------------------------\n",
      "Obj = 6496.0\n",
      "Tree is Any[1, 0.48039607843137255, Any[3, 0.5940988235294118, Any[3, 0.4686337254901961, [0.0 1.0], [1.0 0.0]], Any[2, 0.3745349019607843, [0.0 1.0], [1.0 0.0]]], Any[3, 0.8175835294117647, Any[1, 0.484316862745098, [0.0 1.0], [0.0 1.0]], Any[1, 0.8215043137254902, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 54.86725902557373\n",
      "Dataset: skin,  CART train/test acc: 0.966 / 0.965,  Quant-BnB train/test acc: 0.967 / 0.966"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"skin\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 48.894 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"skin\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 8269431120\n",
      "Total number of intervals = 500\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 1\n",
      "Number of remaining trees = 6100918434\n",
      "Total number of intervals = 3339\n",
      "Current objective = 28.0\n",
      "time = 0.2578010559082031\n",
      "--------------------------------------\n",
      "Loop "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "Number of remaining trees = 1856836080\n",
      "Total number of intervals = 9376\n",
      "Current objective = 24.0\n",
      "time = 2.1480469703674316\n",
      "--------------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loop 3\n",
      "Number of remaining trees = 70234966\n",
      "Total number of intervals = 3367\n",
      "Current objective = 24.0\n",
      "time = 3.21099591255188\n",
      "--------------------------------------\n",
      "Obj = 24.0\n",
      "Tree is Any[3, 0.04580060686108907, Any[3, 0.03580042191337607, Any[2, 0.027491453606214036, [1.0 0.0], [0.0 1.0]], Any[2, 0.05385647839607505, [1.0 0.0], [0.0 1.0]]], Any[4, 0.2785741317442127, Any[2, 0.060662268608250365, [1.0 0.0], [0.0 1.0]], Any[2, 0.06832407704514065, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 5.616844892501831\n",
      "Dataset: wilt,  CART train/test acc: 0.993 / 0.768,  Quant-BnB train/test acc: 0.994 / 0.812"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"wilt\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\", 0.582 + 5)\n",
    "gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_test = sum((Y_test - tree_eval(opt_tree, X_test, 3, m)).>0)\n",
    "@printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"wilt\", \n",
    "            1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.8.1",
   "language": "julia",
   "name": "julia-1.8"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.8.1"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
