{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "myfindinterval (generic function with 1 method)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "using LinearAlgebra\n",
    "using Printf\n",
    "include(\"QuantBnB-2D.jl\")\n",
    "include(\"QuantBnB-3D.jl\")\n",
    "include(\"gen_data.jl\")\n",
    "include(\"lowerbound_middle.jl\")\n",
    "include(\"Algorithms.jl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16-element Vector{String}:\n",
       " \"avila\"\n",
       " \"bank\"\n",
       " \"bean\"\n",
       " \"bidding\"\n",
       " \"eeg\"\n",
       " \"fault\"\n",
       " \"HTRU\"\n",
       " \"magic\"\n",
       " \"occupancy\"\n",
       " \"page\"\n",
       " \"raisin\"\n",
       " \"rice\"\n",
       " \"room\"\n",
       " \"segment\"\n",
       " \"skin\"\n",
       " \"wilt\""
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "regress_data = [\"carbon\",\"casp\",\"concrete\",\"energy\",\"fish\",\"gas\",\"grid\",\"news\",\"qsar\",\"query1\",\"query2\"]\n",
    "\n",
    "class_data = [\"avila\", \"bank\", \"bean\", \"bidding\", \"eeg\", \"fault\", \"HTRU\",\n",
    "\"magic\", \"occupancy\", \"page\",\"raisin\", \"rice\", \"room\", \"segment\",\"skin\",\"wilt\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 19688743840\n",
      "Total number of intervals = 4000\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 16983563641\n",
      "Total number of intervals = 33837\n",
      "Current objective = 4458.0\n",
      "time = 13.816772937774658\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 14852901000\n",
      "Total number of intervals = 242802\n",
      "Current objective = 4458.0\n",
      "time = 89.71591997146606\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 267114821\n",
      "Total number of intervals = 406508\n",
      "Current objective = 4409.0\n",
      "time = 197.53381204605103\n",
      "--------------------------------------\n",
      "Obj = 4409.0\n",
      "Tree is Any[1, 0.22527652290025632, Any[9, 0.1145460276434946, Any[5, 0.8879821094875104, [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[9, 0.1371202428078056, [0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0]]], Any[3, 0.06508545324714689, Any[5, 0.9103656541984284, [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[4, 0.5644186238036573, [0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 301.06650614738464\n",
      "opt evl0.5727969348659003"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"avila\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)\n",
    "# @printf(\"Dataset: %s,  CART train/test acc: %.3f / %.3f,  Quant-BnB train/test acc: %.3f / %.3f\", \"avila\", \n",
    "#             1-gre_train/n_train,1-gre_test/n_test, 1-opt_train/n_train,1-opt_test/n_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 266603584\n",
      "Total number of intervals = 256\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 163305884\n",
      "Total number of intervals = 1406\n",
      "Current objective = 32.0\n",
      "time = 0.06306815147399902\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 30523655\n",
      "Total number of intervals = 2328\n",
      "Current objective = 31.0\n",
      "time = 0.3926258087158203\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 2814014\n",
      "Total number of intervals = 1941\n",
      "Current objective = 21.0\n",
      "time = 0.7641439437866211\n",
      "--------------------------------------\n",
      "Loop 4\n",
      "Number of remaining trees = 21068\n",
      "Total number of intervals = 130\n",
      "Current objective = 19.0\n",
      "time = 0.7548868656158447\n",
      "--------------------------------------\n",
      "Loop 5\n",
      "Number of remaining trees = 5746\n",
      "Total number of intervals = 320\n",
      "Current objective = 19.0\n",
      "time = 0.10826683044433594\n",
      "--------------------------------------\n",
      "Loop 6\n",
      "Number of remaining trees = 1216\n",
      "Total number of intervals = 1088\n",
      "Current objective = 19.0\n",
      "time = 0.6266069412231445\n",
      "--------------------------------------\n",
      "Obj = 19.0\n",
      "Tree is Any[3, 0.2365969434165464, Any[2, 0.7094440839373315, Any[1, 0.6853720579221023, [0.0 1.0], [1.0 0.0]], Any[1, 0.26284590283336573, [0.0 1.0], [1.0 0.0]]], Any[2, 0.5380485690017099, Any[1, 0.47447777059039864, [0.0 1.0], [1.0 0.0]], Any[1, 0.299825863747485, [0.0 1.0], [1.0 0.0]]]]\n",
      "total time = 3.53100848197937\n",
      "opt evl0.9784172661870504"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"bank\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 19688743840\n",
      "Total number of intervals = 4000\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 16983563641\n",
      "Total number of intervals = 33837\n",
      "Current objective = 4458.0\n",
      "time = 7.495133876800537\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 14852901000\n",
      "Total number of intervals = 242802\n",
      "Current objective = 4458.0\n",
      "time = 97.5443959236145\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 302619269\n",
      "Total number of intervals = 419937\n",
      "Current objective = 4409.0\n",
      "time = 198.0163938999176\n",
      "--------------------------------------\n",
      "Obj = 4409.0\n",
      "Tree is Any[1, 0.22527652290025632, Any[9, 0.1145460276434946, Any[5, 0.8879821094875104, [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[9, 0.1371202428078056, [0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0]]], Any[3, 0.06508545324714689, Any[5, 0.9103656541984284, [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]], Any[4, 0.5644186238036573, [0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], [1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 303.05592465400696\n",
      "opt evl0.5727969348659003"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"avila\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 266603584\n",
      "Total number of intervals = 256\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 163305884\n",
      "Total number of intervals = 1406\n",
      "Current objective = 32.0\n",
      "time = 0.07892513275146484\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 30523655\n",
      "Total number of intervals = 2328\n",
      "Current objective = 31.0\n",
      "time = 0.36770200729370117\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 2814014\n",
      "Total number of intervals = 1941\n",
      "Current objective = 21.0\n",
      "time = 0.7749729156494141\n",
      "--------------------------------------\n",
      "Loop 4\n",
      "Number of remaining trees = 21068\n",
      "Total number of intervals = 130\n",
      "Current objective = 19.0\n",
      "time = 0.7311749458312988\n",
      "--------------------------------------\n",
      "Loop 5\n",
      "Number of remaining trees = 5746\n",
      "Total number of intervals = 320\n",
      "Current objective = 19.0\n",
      "time = 0.10645484924316406\n",
      "--------------------------------------\n",
      "Loop 6\n",
      "Number of remaining trees = 1216\n",
      "Total number of intervals = 1088\n",
      "Current objective = 19.0\n",
      "time = 0.42316317558288574\n",
      "--------------------------------------\n",
      "Obj = 19.0\n",
      "Tree is Any[3, 0.2365969434165464, Any[2, 0.7094440839373315, Any[1, 0.6853720579221023, [0.0 1.0], [1.0 0.0]], Any[1, 0.26284590283336573, [0.0 1.0], [1.0 0.0]]], Any[2, 0.5380485690017099, Any[1, 0.47447777059039864, [0.0 1.0], [1.0 0.0]], Any[1, 0.299825863747485, [0.0 1.0], [1.0 0.0]]]]\n",
      "total time = 2.6018760204315186\n",
      "opt evl0.9784172661870504"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"bank\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 1860430528576\n",
      "Total number of intervals = 16384\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 1400911750897\n",
      "Total number of intervals = 111045\n",
      "Current objective = 1604.0\n",
      "time = 20.04662799835205\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 372032618900\n",
      "Total number of intervals = 266289\n",
      "Current objective = 1604.0\n",
      "time = 188.92322993278503\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 5718293475\n",
      "Total number of intervals = 37025\n",
      "Current objective = 1592.0\n",
      "time = 91.58445906639099\n",
      "--------------------------------------\n",
      "Obj = 1592.0\n",
      "Tree is Any[2, 0.2394556933496002, Any[3, 0.16883747536501187, Any[13, 0.5255806237209568, [1.0 0.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 1.0]], Any[12, 0.2891225881002969, [0.0 0.0 0.0 0.0 1.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 1.0 0.0]]], Any[12, 0.37878851120698664, Any[13, 0.43133676314757663, [0.0 0.0 0.0 1.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 1.0 0.0 0.0]], Any[1, 0.3659638909289655, [0.0 1.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 1.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 300.5543169975281\n",
      "opt evl0.8473954512105649"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"bean\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 3781512036\n",
      "Total number of intervals = 2916\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 2038097934\n",
      "Total number of intervals = 14984\n",
      "Current objective = 64.0\n",
      "time = 1.277008056640625\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 175114545\n",
      "Total number of intervals = 15858\n",
      "Current objective = 37.0\n",
      "time = 7.481611013412476\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 7571\n",
      "Total number of intervals = 171\n",
      "Current objective = 37.0\n",
      "time = 11.971817016601562\n",
      "--------------------------------------\n",
      "Loop 4\n",
      "Number of remaining trees = 135\n",
      "Total number of intervals = 49\n",
      "Current objective = 37.0\n",
      "time = 0.15485715866088867\n",
      "--------------------------------------\n",
      "Obj = 37.0\n",
      "Tree is Any[9, 0.5555444444444444, Any[9, 0.11118888888888888, Any[3, 0.25005, [1.0 0.0], [0.0 1.0]], Any[3, 0.74995, [1.0 0.0], [0.0 1.0]]], Any[3, 0.25005, Any[1, 0.0016102718851999998, [1.0 0.0], [1.0 0.0]], Any[2, 0.08903313378412894, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 20.901289224624634\n",
      "opt evl0.9858044164037855"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"bidding\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 1545264504\n",
      "Total number of intervals = 10976\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 1332824808\n",
      "Total number of intervals = 86061\n",
      "Current objective = 3799.0\n",
      "time = 9.30630898475647\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 875385098\n",
      "Total number of intervals = 505815\n",
      "Current objective = 3725.0\n",
      "time = 112.60021615028381\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 144647801\n",
      "Total number of intervals = 691083\n",
      "Current objective = 3553.0\n",
      "time = 178.61239981651306\n",
      "--------------------------------------\n",
      "Obj = 3553.0\n",
      "Tree is Any[2, 0.23744709470530878, Any[6, 0.34692774872748444, Any[12, 0.4227888661027988, [1.0 0.0], [0.0 1.0]], Any[14, 0.5320190256374911, [1.0 0.0], [0.0 1.0]]], Any[7, 0.4729458379357005, Any[7, 0.4609211146053924, [0.0 1.0], [1.0 0.0]], Any[6, 0.3646483262898894, [0.0 1.0], [1.0 0.0]]]]\n",
      "total time = 300.51892495155334\n",
      "opt evl0.7304869913275517"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"eeg\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 28884958128\n",
      "Total number of intervals = 78732\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 26854241960\n",
      "Total number of intervals = 544599\n",
      "Current objective = 513.0\n",
      "time = 20.705874919891357\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 18658862367\n",
      "Total number of intervals = 3120029\n",
      "Current objective = 501.0\n",
      "time = 212.19703197479248\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 371626582\n",
      "Total number of intervals = 466318\n",
      "Current objective = 501.0\n",
      "time = 67.32541704177856\n",
      "--------------------------------------\n",
      "Obj = 501.0\n",
      "Tree is Any[12, 0.5, Any[24, 0.38934282057904995, Any[5, 0.0002735616070434253, [0.0 0.0 0.0 1.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 1.0]], Any[25, 0.6637842537953296, [0.0 0.0 1.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 1.0]]], Any[11, 0.23991093474426806, Any[14, 0.13468846153846153, [0.0 1.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 1.0]], Any[17, 0.5646739639003732, [0.0 0.0 0.0 0.0 0.0 0.0 1.0], [0.0 0.0 0.0 0.0 0.0 1.0 0.0]]]]\n",
      "total time = 300.2283251285553\n",
      "opt evl0.6122448979591837"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"fault\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 329152524800\n",
      "Total number of intervals = 2048\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 277160952330\n",
      "Total number of intervals = 15561\n",
      "Current objective = 286.0\n",
      "time = 2.932851791381836\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 219737011931\n",
      "Total number of intervals = 110799\n",
      "Current objective = 284.0\n",
      "time = 32.508466958999634\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 148171493791\n",
      "Total number of intervals = 675158\n",
      "Current objective = 283.0\n",
      "time = 243.026624917984\n",
      "--------------------------------------\n",
      "Loop 4\n",
      "Number of remaining trees = 731263785\n",
      "Total number of intervals = 48380\n",
      "Current objective = 283.0\n",
      "time = 24.8520610332489\n",
      "--------------------------------------\n",
      "Obj = 283.0\n",
      "Tree is Any[5, 0.20207928718498297, Any[6, 0.1416340005579037, Any[3, 0.32580287969539157, [1.0 0.0], [0.0 1.0]], Any[3, 0.2989133599909796, [1.0 0.0], [0.0 1.0]]], Any[2, 0.29571653696583877, Any[4, 0.08163401209235668, [1.0 0.0], [0.0 1.0]], Any[1, 0.4264711513529338, [0.0 1.0], [1.0 0.0]]]]\n",
      "total time = 303.3200056552887\n",
      "opt evl0.9798994974874372"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"htru\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 580279921000\n",
      "Total number of intervals = 4000\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 547419390492\n",
      "Total number of intervals = 33772\n",
      "Current objective = 2721.0\n",
      "time = 6.346981048583984\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 348766540510\n",
      "Total number of intervals = 191247\n",
      "Current objective = 2636.0\n",
      "time = 73.48655295372009\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 41774681078\n",
      "Total number of intervals = 226621\n",
      "Current objective = 2636.0\n",
      "time = 222.55029296875\n",
      "--------------------------------------\n",
      "Obj = 2636.0\n",
      "Tree is Any[9, 0.25790231811111114, Any[2, 0.03876990174817266, Any[3, 0.16669884387936135, [1.0 0.0], [0.0 1.0]], Any[1, 0.33922344620309275, [1.0 0.0], [0.0 1.0]]], Any[2, 0.0451341301261399, Any[3, 0.12857814311058546, [1.0 0.0], [0.0 1.0]], Any[1, 0.08910819864592663, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 302.3838269710541\n",
      "opt evl0.8213347346295323"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"magic\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 1392446720\n",
      "Total number of intervals = 500\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 771573916\n",
      "Total number of intervals = 2810\n",
      "Current objective = 56.0\n",
      "time = 0.5466358661651611\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 346770440\n",
      "Total number of intervals = 12349\n",
      "Current objective = 56.0\n",
      "time = 3.6835200786590576\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 143968444\n",
      "Total number of intervals = 46985\n",
      "Current objective = 56.0\n",
      "time = 18.35895299911499\n",
      "--------------------------------------\n",
      "Loop 4\n",
      "Number of remaining trees = 1532985\n",
      "Total number of intervals = 8851\n",
      "Current objective = 49.0\n",
      "time = 74.87002682685852\n",
      "--------------------------------------\n",
      "Loop 5\n",
      "Number of remaining trees = 25748\n",
      "Total number of intervals = 1246\n",
      "Current objective = 47.0\n",
      "time = 23.580113172531128\n",
      "--------------------------------------\n",
      "Loop 6\n",
      "Number of remaining trees = 258\n",
      "Total number of intervals = 258\n",
      "Current objective = 47.0\n",
      "time = 7.331634998321533\n",
      "--------------------------------------\n",
      "Obj = 47.0\n",
      "Tree is Any[3, 0.2636305307605439, Any[3, 0.18700717827858782, Any[2, 0.9225698308064161, [0.0 1.0], [1.0 0.0]], Any[2, 0.09061144803339921, [0.0 1.0], [1.0 0.0]]], Any[1, 0.5937408782742718, Any[1, 0.3408638828967697, [0.0 1.0], [1.0 0.0]], Any[2, 0.42241352816230787, [0.0 1.0], [1.0 0.0]]]]\n",
      "total time = 128.52546286582947\n",
      "opt evl0.9628019323671497"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"occupancy\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 2679769000\n",
      "Total number of intervals = 4000\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 2029740626\n",
      "Total number of intervals = 27149\n",
      "Current objective = 152.0\n",
      "time = 2.4439449310302734\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 960260607\n",
      "Total number of intervals = 118879\n",
      "Current objective = 145.0\n",
      "time = 20.52394413948059\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 320300935\n",
      "Total number of intervals = 371323\n",
      "Current objective = 130.0\n",
      "time = 86.75394201278687\n",
      "--------------------------------------\n",
      "Loop 4\n",
      "Number of remaining trees = 36278874\n",
      "Total number of intervals = 397259\n",
      "Current objective = 126.0\n",
      "time = 190.6682391166687\n",
      "--------------------------------------\n",
      "Obj = 126.0\n",
      "Tree is Any[4, 0.006218036547962449, Any[5, 0.14616782700421943, Any[1, 0.0131733499377335, [1.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 1.0 0.0]], Any[4, 0.0005263634721495438, [0.0 0.0 1.0 0.0 0.0], [1.0 0.0 0.0 0.0 0.0]]], Any[1, 0.0032127023661270237, Any[7, 0.00015852684699232944, [1.0 0.0 0.0 0.0 0.0], [0.0 1.0 0.0 0.0 0.0]], Any[7, 0.004195870205894228, [1.0 0.0 0.0 0.0 0.0], [0.0 1.0 0.0 0.0 0.0]]]]\n",
      "total time = 300.39007019996643\n",
      "opt evl0.9581056466302368"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"page\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 710962588\n",
      "Total number of intervals = 1372\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 626304028\n",
      "Total number of intervals = 10878\n",
      "Current objective = 87.0\n",
      "time = 0.22206377983093262\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 473914668\n",
      "Total number of intervals = 74084\n",
      "Current objective = 84.0\n",
      "time = 2.58707594871521\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 201098838\n",
      "Total number of intervals = 282561\n",
      "Current objective = 79.0\n",
      "time = 19.21909999847412\n",
      "--------------------------------------\n",
      "Loop 4\n",
      "Number of remaining trees = 5009312\n",
      "Total number of intervals = 63598\n",
      "Current objective = 79.0\n",
      "time = 75.36871695518494\n",
      "--------------------------------------\n",
      "Loop 5\n",
      "Number of remaining trees = 47354\n",
      "Total number of intervals = 5356\n",
      "Current objective = 77.0\n",
      "time = 22.235552072525024\n",
      "--------------------------------------\n",
      "Obj = 76.0\n",
      "Tree is Any[7, 0.2530427687488064, Any[7, 0.24223494036356744, Any[4, 0.8514364324941668, [1.0 0.0], [0.0 1.0]], Any[6, 0.8073726717978372, [0.0 1.0], [1.0 0.0]]], Any[5, 0.25429386896119455, Any[2, 0.25828582805546096, [1.0 0.0], [0.0 1.0]], Any[7, 0.2718644211540117, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 123.00444197654724\n",
      "opt evl0.8901098901098901"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"raisin\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 11187683388\n",
      "Total number of intervals = 1372\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 9609460287\n",
      "Total number of intervals = 10587\n",
      "Current objective = 198.0\n",
      "time = 0.5812828540802002\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 7460611876\n",
      "Total number of intervals = 73585\n",
      "Current objective = 193.0\n",
      "time = 5.343680143356323\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 5185873418\n",
      "Total number of intervals = 457791\n",
      "Current objective = 193.0\n",
      "time = 39.89653205871582\n",
      "--------------------------------------\n",
      "Loop 4\n",
      "Number of remaining trees = 1951650124\n",
      "Total number of intervals = 1556682\n",
      "Current objective = 191.0\n",
      "time = 243.29432487487793\n",
      "--------------------------------------\n",
      "Loop 5\n",
      "Number of remaining trees = 0\n",
      "Total number of intervals = 0\n",
      "Current objective = 191.0\n",
      "time = 11.47425889968872\n",
      "--------------------------------------\n",
      "Obj = 191.0\n",
      "Tree is Any[2, 0.5038414496665238, Any[1, 0.3764109223728217, Any[3, 0.5002382436756764, [0.0 1.0], [1.0 0.0]], Any[4, 0.44861905919052375, [1.0 0.0], [0.0 1.0]]], Any[7, 0.27293364983770185, Any[3, 0.5088920709046933, [0.0 1.0], [1.0 0.0]], Any[3, 0.4851318208290064, [0.0 1.0], [1.0 0.0]]]]\n",
      "total time = 300.590078830719\n",
      "opt evl0.9397905759162304"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"rice\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 536385600\n",
      "Total number of intervals = 16384\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 282294811\n",
      "Total number of intervals = 67927\n",
      "Current objective = 97.0\n",
      "time = 10.774322986602783\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 14886689\n",
      "Total number of intervals = 54718\n",
      "Current objective = 84.0\n",
      "time = 50.902772188186646\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 170522\n",
      "Total number of intervals = 8118\n",
      "Current objective = 66.0\n",
      "time = 60.09825897216797\n",
      "--------------------------------------\n",
      "Loop 4\n",
      "Number of remaining trees = 1897\n",
      "Total number of intervals = 184\n",
      "Current objective = 62.0\n",
      "time = 11.761933088302612\n",
      "--------------------------------------\n",
      "Obj = 62.0\n",
      "Tree is Any[6, 0.43218410852713185, Any[5, 0.3666933333333333, Any[14, 0.4791079053370805, [0.0 0.0 0.0 1.0], [0.0 0.0 1.0 0.0]], Any[7, 0.4446539285714285, [1.0 0.0 0.0 0.0], [0.0 1.0 0.0 0.0]]], Any[3, 0.8027965714285701, Any[7, 0.6410432142857143, [0.0 1.0 0.0 0.0], [0.0 0.0 1.0 0.0]], Any[1, 0.7568930555555554, [0.0 1.0 0.0 0.0], [0.0 0.0 1.0 0.0]]]]\n",
      "total time = 134.10303616523743\n",
      "opt evl0.9861932938856016"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"room\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 12444739848\n",
      "Total number of intervals = 23328\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 8444881666\n",
      "Total number of intervals = 127774\n",
      "Current objective = 226.0\n",
      "time = 7.064289808273315\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 410246229\n",
      "Total number of intervals = 52545\n",
      "Current objective = 226.0\n",
      "time = 43.95894813537598\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 7264341\n",
      "Total number of intervals = 10711\n",
      "Current objective = 224.0\n",
      "time = 23.477982997894287\n",
      "--------------------------------------\n",
      "Loop 4\n",
      "Number of remaining trees = 916707\n",
      "Total number of intervals = 13045\n",
      "Current objective = 214.0\n",
      "time = 5.49505090713501\n",
      "--------------------------------------\n",
      "Loop 5\n",
      "Number of remaining trees = 107440\n",
      "Total number of intervals = 13918\n",
      "Current objective = 209.0\n",
      "time = 7.195729970932007\n",
      "--------------------------------------\n",
      "Obj = 208.0\n",
      "Tree is Any[12, 0.15556382173308425, Any[18, 0.19890776330778234, Any[18, 0.14785660397852451, [0.0 1.0 0.0 0.0 0.0 0.0 0.0], [0.0 0.0 0.0 0.0 0.0 0.0 1.0]], Any[2, 0.5687362499999999, [0.0 0.0 0.0 0.0 1.0 0.0 0.0], [0.0 0.0 0.0 1.0 0.0 0.0 0.0]]], Any[2, 0.5978970833333332, Any[9, 0.5317519688396428, [0.0 0.0 0.0 0.0 0.0 1.0 0.0], [0.0 0.0 1.0 0.0 0.0 0.0 0.0]], Any[14, 0.21417482679865393, [0.0 0.0 0.0 1.0 0.0 0.0 0.0], [1.0 0.0 0.0 0.0 0.0 0.0 0.0]]]]\n",
      "total time = 95.8139967918396\n",
      "opt evl0.8275862068965517"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"segment\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 7077888\n",
      "Total number of intervals = 108\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 5009320\n",
      "Total number of intervals = 687\n",
      "Current objective = 6683.006683\n",
      "time = 2.6158978939056396\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 1705850\n",
      "Total number of intervals = 2103\n",
      "Current objective = 6683.006683\n",
      "time = 24.051903009414673\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 258661\n",
      "Total number of intervals = 2842\n",
      "Current objective = 6496.0\n",
      "time = 71.68850898742676\n",
      "--------------------------------------\n",
      "Loop 4\n",
      "Number of remaining trees = 11309\n",
      "Total number of intervals = 1079\n",
      "Current objective = 6183.0\n",
      "time = 133.1830849647522\n",
      "--------------------------------------\n",
      "Loop 5\n",
      "Number of remaining trees = 10\n",
      "Total number of intervals = 10\n",
      "Current objective = 6167.0\n",
      "time = 69.74823904037476\n",
      "--------------------------------------\n",
      "Obj = 6167.0\n",
      "Tree is Any[1, 0.48823764705882355, Any[3, 0.5705741176470588, Any[3, 0.4686337254901961, [0.0 1.0], [1.0 0.0]], Any[2, 0.36277254901960787, [0.0 1.0], [1.0 0.0]]], Any[1, 0.8215043137254902, Any[3, 0.8254250980392157, [0.0 1.0], [1.0 0.0]], Any[1, 0.8254250980392157, [0.0 1.0], [0.0 1.0]]]]\n",
      "total time = 301.287633895874\n",
      "opt evl0.9679683355775901"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"skin\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trees = 8269431120\n",
      "Total number of intervals = 500\n",
      "--------------------------------------\n",
      "Loop 1\n",
      "Number of remaining trees = 6100918434\n",
      "Total number of intervals = 3339\n",
      "Current objective = 28.0\n",
      "time = 0.2998080253601074\n",
      "--------------------------------------\n",
      "Loop 2\n",
      "Number of remaining trees = 1856836080\n",
      "Total number of intervals = 9376\n",
      "Current objective = 24.0\n",
      "time = 2.3733630180358887\n",
      "--------------------------------------\n",
      "Loop 3\n",
      "Number of remaining trees = 120529549\n",
      "Total number of intervals = 5817\n",
      "Current objective = 22.0\n",
      "time = 6.406371831893921\n",
      "--------------------------------------\n",
      "Loop 4\n",
      "Number of remaining trees = 6576905\n",
      "Total number of intervals = 2984\n",
      "Current objective = 20.0\n",
      "time = 5.624925136566162\n",
      "--------------------------------------\n",
      "Loop 5\n",
      "Number of remaining trees = 411154\n",
      "Total number of intervals = 1676\n",
      "Current objective = 19.0\n",
      "time = 4.012331962585449\n",
      "--------------------------------------\n",
      "Loop 6\n",
      "Number of remaining trees = 54458\n",
      "Total number of intervals = 2010\n",
      "Current objective = 19.0\n",
      "time = 2.6736040115356445\n",
      "--------------------------------------\n",
      "Loop 7\n",
      "Number of remaining trees = 35662\n",
      "Total number of intervals = 11866\n",
      "Current objective = 18.0\n",
      "time = 4.204135179519653\n",
      "--------------------------------------\n",
      "Obj = 18.0\n",
      "Tree is Any[3, 0.04856618949450736, Any[3, 0.03747631092286498, Any[2, 0.027491453606214036, [1.0 0.0], [0.0 1.0]], Any[2, 0.05385647839607505, [1.0 0.0], [0.0 1.0]]], Any[4, 0.2296376600258726, Any[2, 0.060662268608250365, [1.0 0.0], [0.0 1.0]], Any[2, 0.06899101128405441, [1.0 0.0], [0.0 1.0]]]]\n",
      "total time = 45.36223316192627\n",
      "opt evl0.8127490039840637"
     ]
    }
   ],
   "source": [
    "# test depth-3 trees on a classification problem\n",
    "X_train, X_test, Y_train, Y_test = generate_realdata(string(\"./dataset/class/\",\"wilt\",\".json\"))\n",
    "n_train, m = size(Y_train)\n",
    "n_test, _ = size(Y_test)\n",
    "X_eval, Y_eval = X_test[n_test÷2: end, : ], Y_test[n_test÷2: end, : ]\n",
    "n_eval, _ = size(Y_eval)\n",
    "\n",
    "gre_train, gre_tree = greedy_tree(X_train, Y_train, 3, \"C\")\n",
    "opt_train, opt_tree = QuantBnB_3D(X_train, Y_train, 3, 3, gre_train*(1+1e-6), 0, 0, nothing, \"C\",300)\n",
    "# gre_test = sum((Y_test - tree_eval(gre_tree, X_test, 3, m)).>0)\n",
    "opt_eval = sum((Y_eval - tree_eval(opt_tree, X_eval, 3, m)).>0)\n",
    "print(\"opt evl\", 1-opt_eval/n_eval)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.8.1",
   "language": "julia",
   "name": "julia-1.8"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.8.1"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
