{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "44d1bdad",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd  \n",
    "from sklearn.datasets import load_diabetes\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7cb24f38",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "col_names: Index(['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6'], dtype='object')\n"
     ]
    }
   ],
   "source": [
    "diab_dataset = load_diabetes()\n",
    "diab = pd.DataFrame(diab_dataset.data, columns=diab_dataset.feature_names)\n",
    "col_names = diab.columns\n",
    "print('col_names:',col_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c7553f9e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(442, 10)\n",
      "[151.  75. 141. 206. 135.  97. 138.  63. 110. 310.]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(array([38., 80., 68., 62., 50., 41., 38., 42., 17.,  6.]),\n",
       " array([ 25. ,  57.1,  89.2, 121.3, 153.4, 185.5, 217.6, 249.7, 281.8,\n",
       "        313.9, 346. ]),\n",
       " <BarContainer object of 10 artists>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAARM0lEQVR4nO3df6zddX3H8edrLT/8NaFw1zRgduskGmJmZXcMozEZFYdssV1CCGbZmqVJk003nVtmncnUZEtk2WQuMZpOnN3mEERIicmcXcWYJVv1AgUKyFqRKk1pr0r9tURF3/vjfAuXy7nc03vPj37o85GcnO+vw/d1vxxefM/3nO/3m6pCktSen5t0AEnS8ljgktQoC1ySGmWBS1KjLHBJatTqca7s/PPPr+np6XGuUpKad+edd36rqqYWTh9rgU9PTzM7OzvOVUpS85Ic6jfdQyiS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpUQMVeJI/SXJ/kv1JbkxydpL1SfYmOZjkpiRnjjqsJOkpSxZ4kguAPwZmquqVwCrgWuA64PqqehnwOLB1lEElSU836CGU1cDzkqwGng8cAS4Hbunm7wQ2Dz2dJGlRSxZ4VR0G/hb4Br3i/i5wJ3C8qp7oFnsUuKDf65NsSzKbZHZubm44qcdoenqaJBN5eNkBSc9myVPpk5wLbALWA8eBTwNXDrqCqtoB7ACYmZlp7vY/hw4dYlJ3LUoykfVKasMgh1DeAHy9quaq6ifArcBrgXO6QyoAFwKHR5RRktTHIAX+DeCyJM9Pb5dwI/AAcAdwdbfMFmDXaCJKkvoZ5Bj4XnpfVt4F3Ne9ZgfwLuCdSQ4C5wE3jDCnJGmBgS4nW1XvBd67YPLDwKVDTyRJGohnYkpSoyxwSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGrVkgSd5eZJ98x7fS/KOJGuS7E5yoHs+dxyBJUk9g9xS7aGq2lBVG4BfAf4PuA3YDuypqouAPd24JGlMTvYQykbga1V1CNgE7Oym7wQ2DzGXJGkJJ1vg1wI3dsNrq+pIN/wYsHZoqSRJSxq4wJOcCbwZ+PTCeVVVQC3yum1JZpPMzs3NLTuoJOnpTmYP/E3AXVV1tBs/mmQdQPd8rN+LqmpHVc1U1czU1NTK0kqSnnQyBf4Wnjp8AnA7sKUb3gLsGlYoSdLSBirwJC8ArgBunTf5A8AVSQ4Ab+jGJUljsnqQharqh8B5C6Z9m96vUiRJE+CZmKe4JGN/TE9PT/rPljSAgfbANTm9H/iMV5Kxr1PSyXMPXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuPqaxGVsvZStdHK8nKz6msRlbMFL2UonY9Bbqp2T5JYkX03yYJLXJFmTZHeSA93zuaMOK0l6yqCHUD4EfK6qXgG8CngQ2A7sqaqLgD3duCRpTJYs8CQvBl4P3ABQVT+uquPAJmBnt9hOYPNoIkqS+hlkD3w9MAf8U5K7k3ysu0v92qo60i3zGLC234uTbEsym2R2bm5uOKklSQMV+GrgEuAjVfVq4IcsOFxSvW+8+n7rVVU7qmqmqmampqZWmleS1BmkwB8FHq2qvd34LfQK/WiSdQDd87HRRJQk9bNkgVfVY8A3k7y8m7QReAC4HdjSTdsC7BpJQklSX4P+DvyPgE8mORN4GPh9euV/c5KtwCHgmtFElCT1M1CBV9U+YKbPrI1DTSNJGpin0ktSoyxwSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjBrojT5JHgO8DPwWeqKqZJGuAm4Bp4BHgmqp6fDQxJUkLncwe+K9X1YaqOnFrte3Anqq6CNjTjUuSxmQlh1A2ATu74Z3A5hWnkSQNbNACL+DzSe5Msq2btraqjnTDjwFr+70wybYks0lm5+bmVhhXp4MkY39MT09P+s+WTtpAx8CB11XV4SS/AOxO8tX5M6uqklS/F1bVDmAHwMzMTN9lpPmqxv82STL2dUorNdAeeFUd7p6PAbcBlwJHk6wD6J6PjSqkJOmZlizwJC9I8qITw8Abgf3A7cCWbrEtwK5RhZQkPdMgh1DWArd1HzFXA/9WVZ9L8hXg5iRbgUPANaOLKUlaaMkCr6qHgVf1mf5tYOMoQkmSluaZmJLUKAtckhplgUtSoyxwSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApdOU9PT0xO5gbQ3kR6eQW9qLOk55tChQxO5gTR4E+lhGXgPPMmqJHcn+Ww3vj7J3iQHk9yU5MzRxZQkLXQyh1DeDjw4b/w64PqqehnwOLB1mMGkcfNwglozUIEnuRD4TeBj3XiAy4FbukV2AptHkE8am6qayOPQoUOT/tPVqEH3wP8e+HPgZ934ecDxqnqiG38UuKDfC5NsSzKbZHZubm4lWSVJ8yxZ4El+CzhWVXcuZwVVtaOqZqpqZmpqajn/CElSH4P8CuW1wJuTXAWcDfw88CHgnCSru73wC4HDo4spSVpoyT3wqnp3VV1YVdPAtcAXqup3gDuAq7vFtgC7RpZSkvQMKzmR513AO5McpHdM/IbhRJIkDeKkTuSpqi8CX+yGHwYuHX6k/qanp/22Xs9Zntii5WjmTMxJnTXmf1gaB9/bWg6vhSJJjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJatQgd6U/O8mXk9yT5P4k7++mr0+yN8nBJDclOXP0cSVJJwyyB/4j4PKqehWwAbgyyWXAdcD1VfUy4HFg68hSSpKeYZC70ldV/aAbPaN7FHA5cEs3fSeweRQBJUn9DXQMPMmqJPuAY8Bu4GvA8ap6olvkUeCCRV67Lclsktm5ubkhRJYkwYAFXlU/raoNwIX07kT/ikFXUFU7qmqmqmampqaWl1KS9Awn9SuUqjoO3AG8BjgnyYm72l8IHB5uNEnSsxnkVyhTSc7php8HXAE8SK/Ir+4W2wLsGlFGSVIfq5dehHXAziSr6BX+zVX12SQPAJ9K8lfA3cANI8wpSVpgyQKvqnuBV/eZ/jC94+GSpAnwTExJapQFLkmNssAlqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSoyxwSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElq1CD3xHxJkjuSPJDk/iRv76avSbI7yYHu+dzRx5UknTDIHvgTwJ9W1cXAZcBbk1wMbAf2VNVFwJ5uXJI0JksWeFUdqaq7uuHv07sj/QXAJmBnt9hOYPOIMkqS+jipY+BJpund4HgvsLaqjnSzHgPWLvKabUlmk8zOzc2tJKskaZ6BCzzJC4HPAO+oqu/Nn1dVBVS/11XVjqqaqaqZqampFYWVJD1loAJPcga98v5kVd3aTT6aZF03fx1wbDQRJUn9DPIrlAA3AA9W1Qfnzbod2NINbwF2DT+eJGkxqwdY5rXA7wL3JdnXTfsL4APAzUm2AoeAa0aSUJLU15IFXlX/BWSR2RuHG0eSNCjPxJSkRlngktQoC1zSRCQZ+2N6enrSf/ZQDfIlpiQNXe/0kfHq/ajuucM9cElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEYNck/Mjyc5lmT/vGlrkuxOcqB7Pne0MSVJCw2yB/4J4MoF07YDe6rqImBPNy5JGqMlC7yqvgR8Z8HkTcDObngnsHm4sSRJS1nuMfC1VXWkG34MWLvYgkm2JZlNMjs3N7fM1UmSFlrxl5jVu63GorfWqKodVTVTVTNTU1MrXZ0kqbPcAj+aZB1A93xseJEkSYNYboHfDmzphrcAu4YTR5I0qEF+Rngj8N/Ay5M8mmQr8AHgiiQHgDd045KkMVryrvRV9ZZFZm0cchZJ0knwTExJapQFLkmNssAlqVEWuCQ1ygKXdFpJMvbH9PT0SP6WJX+FIknPJb2Tx8cryUj+ue6BS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSoyxwSWqUBS5JjVpRgSe5MslDSQ4m2T6sUJKkpS27wJOsAj4MvAm4GHhLkouHFUyS9OxWsgd+KXCwqh6uqh8DnwI2DSeWJGkpK7mc7AXAN+eNPwr82sKFkmwDtnWjP0jy0HJXOKpLMnbOB741gfU+qwHXvWj2Ea93WJ6Wf1Lbe5nrHcq2n9DffH6Sob5vTsYQ/uZlbftT5P11stl/sd/EkV8PvKp2ADtGvZ6VSjJbVTOTzrEcLWeHtvObfXJazj+s7Cs5hHIYeMm88Qu7aZKkMVhJgX8FuCjJ+iRnAtcCtw8nliRpKcs+hFJVTyR5G/AfwCrg41V1/9CSjd8pf5jnWbScHdrOb/bJaTn/ULJnEveHkyStnGdiSlKjLHBJatRpWeBJHklyX5J9SWa7aWuS7E5yoHs+d9I5T0jy8STHkuyfN61v3vT8Q3d5g3uTXDK55Itmf1+Sw93235fkqnnz3t1lfyjJb0wm9ZNZXpLkjiQPJLk/ydu76a1s+8Xyn/LbP8nZSb6c5J4u+/u76euT7O0y3tT9gIIkZ3XjB7v506dg9k8k+fq87b6hm778901VnXYP4BHg/AXT/gbY3g1vB66bdM552V4PXALsXyovcBXw70CAy4C9p2D29wF/1mfZi4F7gLOA9cDXgFUTzL4OuKQbfhHwv13GVrb9YvlP+e3fbcMXdsNnAHu7bXozcG03/aPAH3TDfwh8tBu+Frhpgtt9seyfAK7us/yy3zen5R74IjYBO7vhncDmyUV5uqr6EvCdBZMXy7sJ+Ofq+R/gnCTrxhK0j0WyL2YT8Kmq+lFVfR04SO+SDRNRVUeq6q5u+PvAg/TOQG5l2y+WfzGnzPbvtuEPutEzukcBlwO3dNMXbvsT/05uATZmQqdcPkv2xSz7fXO6FngBn09yZ3qn+gOsraoj3fBjwNrJRBvYYnn7XeLg2f6jnZS3dR8XPz7vcNUpm737SP5qentTzW37Bfmhge2fZFWSfcAxYDe9TwTHq+qJbpH5+Z7M3s3/LnDeWAPPszB7VZ3Y7n/dbffrk5zVTVv2dj9dC/x1VXUJvSspvjXJ6+fPrN7nmmZ+X9laXuAjwC8BG4AjwN9NNM0SkrwQ+Azwjqr63vx5LWz7Pvmb2P5V9dOq2kDvLO9LgVdMNtHgFmZP8krg3fT+hl8F1gDvWul6TssCr6rD3fMx4DZ6b46jJz62dM/HJpdwIIvlPeUvcVBVR7s3+M+Af+Spj+mnXPYkZ9Arv09W1a3d5Ga2fb/8LW1/gKo6DtwBvIbe4YUTJyDOz/dk9m7+i4FvjzfpM83LfmV3SKuq6kfAPzGE7X7aFXiSFyR50Ylh4I3AfnqXAdjSLbYF2DWZhANbLO/twO9132xfBnx33sf9U8KC43u/TW/7Qy/7td0vCtYDFwFfHne+E7pjqDcAD1bVB+fNamLbL5a/he2fZCrJOd3w84Ar6B3DvwO4ults4bY/8e/kauAL3aejsVsk+1fn/U8/9I7dz9/uy3vfTOqb2kk9gJfS+6b9HuB+4D3d9POAPcAB4D+BNZPOOi/zjfQ+6v6E3vGxrYvlpfdN9ofpHS+8D5g5BbP/S5ft3u7Nu27e8u/psj8EvGnC2V9H7/DIvcC+7nFVQ9t+sfyn/PYHfhm4u8u4H/jLbvpL6f1P5SDwaeCsbvrZ3fjBbv5LT8HsX+i2+37gX3nqlyrLft94Kr0kNeq0O4QiSc8VFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElq1P8D1x/TceWR0PEAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "X = diab.to_numpy()\n",
    "Y = diab_dataset.target\n",
    "print(X.shape)\n",
    "print(Y[:10])\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure()\n",
    "plt.hist(Y, fill=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "89461a65",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max: 0.111, min: -0.107->  (1.0  , 0.0  )\n",
      "max: 0.051, min: -0.045->  (1.0  , 0.0  )\n",
      "max: 0.171, min: -0.09->  (1.0  , 0.0  )\n",
      "max: 0.132, min: -0.112->  (1.0  , 0.0  )\n",
      "max: 0.154, min: -0.127->  (1.0  , 0.0  )\n",
      "max: 0.199, min: -0.116->  (1.0  , 0.0  )\n",
      "max: 0.181, min: -0.102->  (1.0  , 0.0  )\n",
      "max: 0.185, min: -0.076->  (1.0  , 0.0  )\n",
      "max: 0.134, min: -0.126->  (1.0  , 0.0  )\n",
      "max: 0.136, min: -0.138->  (1.0  , 0.0  )\n"
     ]
    }
   ],
   "source": [
    "for j in range(len(col_names)):\n",
    "    x_max, x_min = np.max(X[:,j]), np.min(X[:,j])\n",
    "    print('max: %-5s, min: %-5s'%(np.round(x_max,3), np.round(x_min,3)), end = '->')\n",
    "    X[:,j] = (X[:,j]-x_min)/(x_max-x_min)\n",
    "    x_max, x_min = np.max(X[:,j]), np.min(X[:,j])\n",
    "    print('  (%-5s, %-5s)'%(np.round(x_max,3), np.round(x_min,3)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a4eb0b3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_train = 100 # int(len(X)/2)\n",
    "X_train, Y_train = X[:n_train,:], Y[:n_train]\n",
    "X_test, Y_test = X[n_train:,:], Y[n_train:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "30647ca6",
   "metadata": {},
   "outputs": [],
   "source": [
    "MSE = {}\n",
    "def compute_mse(y_pred,y0):\n",
    "    return np.mean((y_pred-y0)**2)**0.5\n",
    "# compute_mse(np.array([1,2,3]),np.array([0,2,5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "09da9dc3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linear models...\n",
      "\n",
      "Support vector models...\n",
      "\n",
      "Decision Tree models...\n",
      "\n",
      "KNeighbors models...\n",
      "70.4799 n:2\n",
      "68.5511 n:3\n",
      "68.5127 n:4\n",
      "66.9468 n:5\n",
      "\n",
      "MLP models...\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LinearRegression, Ridge,  Lasso\n",
    "def do_regress(regression, X_train, Y_train, MSE, name):\n",
    "    model = regression.fit(X_train, Y_train)\n",
    "    Y_pred = model.predict(X_test)\n",
    "    if name is not None:\n",
    "        MSE[name] = compute_mse(Y_pred,Y_test)\n",
    "    return MSE\n",
    "\n",
    "print('linear models...')\n",
    "MSE = do_regress( LinearRegression(), X_train, Y_train, MSE, 'linear')\n",
    "MSE = do_regress(Ridge(), X_train, Y_train, MSE,'Ridge')\n",
    "MSE = do_regress(Lasso(), X_train, Y_train, MSE, 'Lasso')\n",
    "\n",
    "print('\\nSupport vector models...')\n",
    "from sklearn.svm import LinearSVR, NuSVR, SVR\n",
    "MSE = do_regress(LinearSVR(), X_train, Y_train, MSE,'LinSVR')\n",
    "MSE = do_regress(NuSVR(), X_train, Y_train, MSE,'NuSVR')\n",
    "MSE = do_regress(SVR(), X_train, Y_train, MSE,'SVR')\n",
    "\n",
    "print('\\nDecision Tree models...')\n",
    "from sklearn.tree import DecisionTreeRegressor\n",
    "MSE = do_regress(DecisionTreeRegressor(random_state=0), X_train, Y_train, MSE,'DecTree')\n",
    "\n",
    "print('\\nKNeighbors models...')\n",
    "from sklearn.neighbors import KNeighborsRegressor\n",
    "min_mse,n_min = np.inf,-1\n",
    "for n in range(2,16+1):\n",
    "    temp_name = 'KNeighborsRegressor(n_neighbors=%s)'%(str(n))\n",
    "    dummy = do_regress(KNeighborsRegressor(n_neighbors=n), X_train, Y_train, {}, temp_name)\n",
    "    this_mse = dummy[temp_name]\n",
    "    if this_mse<min_mse:\n",
    "        print(np.round(this_mse,4), 'n:%s'%(str(n)))\n",
    "        min_mse = this_mse\n",
    "        n_min = n\n",
    "MSE['kneigh'] = min_mse\n",
    "\n",
    "print('\\nMLP models...')\n",
    "from sklearn.neural_network import MLPRegressor\n",
    "MSE = do_regress(MLPRegressor(hidden_layer_sizes=(64,64),random_state=1, max_iter=12000), X_train, Y_train, MSE, 'MLP')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7d6d51df",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Exiting layer_k_sample_collection() because all data have been used.\n",
      "Final positions of indices in the layers:\n",
      "  [1] [0, 1, 3, 5, 7, 11, 15, 20, 23, 25, 28, 32, 35, 38, 40, 43, 48, 53, 58, 62, 72, 76, 80, 85, 91, 99]\n",
      "  [2] [2, 4, 9, 12, 16, 18, 21, 26, 31, 37, 42, 52, 59, 70, 73, 82, 47, 49]\n",
      "  [3] [17, 27, 34, 44, 55, 57, 61, 65, 69, 79, 83, 86, 93, 6, 14, 22, 67]\n",
      "  [4] [50, 54, 68, 75, 89, 94, 24, 10, 74, 88, 19]\n",
      "  [5] [60, 63, 77, 90, 97, 29, 45, 56]\n",
      "  [6] [64, 71, 81, 95, 13, 46, 36, 92]\n",
      "  [7] [96, 39, 30, 41, 8]\n",
      "  [8] [33, 51, 78]\n",
      "  [9] [84, 98]\n",
      "  [10] [87]\n",
      "  [11] [66]\n",
      " i     Layer     |y-y0|                           abs error\n",
      "[0]    L=1         |(151.0) - (151.0))|                 0.0      HIT\n",
      "[1]    L=1         |(75.0 ) - (75.0 ))|                 0.0      HIT\n",
      "[2]    L=2         |(141.0) - (141.0))|                 0.0      HIT\n",
      "[3]    L=1         |(206.0) - (206.0))|                 0.0      HIT\n",
      "[4]    L=2         |(135.0) - (135.0))|                 0.0      HIT\n",
      "[5]    L=1         |(97.0 ) - (97.0 ))|                 0.0      HIT\n",
      "[6]    L=3         |(138.0) - (138.0))|                 0.0      HIT\n",
      "[7]    L=1         |(63.0 ) - (63.0 ))|                 0.0      HIT\n",
      "[8]    L=7         |(110.0) - (110.0))|                 0.0      HIT\n",
      "[9]    L=2         |(310.0) - (310.0))|                 0.0      HIT\n",
      "[10]   L=4         |(101.0) - (101.0))|                 0.0      HIT\n",
      "[11]   L=1         |(69.0 ) - (69.0 ))|                 0.0      HIT\n",
      "[12]   L=2         |(179.0) - (179.0))|                 0.0      HIT\n",
      "[13]   L=6         |(185.0) - (185.0))|                 0.0      HIT\n",
      "[14]   L=3         |(118.0) - (118.0))|                 0.0      HIT\n",
      "[15]   L=1         |(171.0) - (171.0))|                 0.0      HIT\n",
      "[16]   L=2         |(166.0) - (166.0))|                 0.0      HIT\n",
      "[17]   L=3         |(144.0) - (144.0))|                 0.0      HIT\n",
      "[18]   L=2         |(97.0 ) - (97.0 ))|                 0.0      HIT\n",
      "[19]   L=4         |(168.0) - (168.0))|                 0.0      HIT\n",
      "[20]   L=1         |(68.0 ) - (68.0 ))|                 0.0      HIT\n",
      "[21]   L=2         |(49.0 ) - (49.0 ))|                 0.0      HIT\n",
      "[22]   L=3         |(68.0 ) - (68.0 ))|                 0.0      HIT\n",
      "[23]   L=1         |(245.0) - (245.0))|                 0.0      HIT\n",
      "[24]   L=4         |(184.0) - (184.0))|                 0.0      HIT\n",
      "[25]   L=1         |(202.0) - (202.0))|                 0.0      HIT\n",
      "[26]   L=2         |(137.0) - (137.0))|                 0.0      HIT\n",
      "[27]   L=3         |(85.0 ) - (85.0 ))|                 0.0      HIT\n",
      "[28]   L=1         |(131.0) - (131.0))|                 0.0      HIT\n",
      "[29]   L=5         |(283.0) - (283.0))|                 0.0      HIT\n",
      "[30]   L=7         |(129.0) - (129.0))|                 0.0      HIT\n",
      "[31]   L=2         |(59.0 ) - (59.0 ))|                 0.0      HIT\n",
      "[32]   L=1         |(341.0) - (341.0))|                 0.0      HIT\n",
      "[33]   L=8         |(87.0 ) - (87.0 ))|                 0.0      HIT\n",
      "[34]   L=3         |(65.0 ) - (65.0 ))|                 0.0      HIT\n",
      "[35]   L=1         |(102.0) - (102.0))|                 0.0      HIT\n",
      "[36]   L=6         |(265.0) - (265.0))|                 0.0      HIT\n",
      "[37]   L=2         |(276.0) - (276.0))|                 0.0      HIT\n",
      "[38]   L=1         |(252.0) - (252.0))|                 0.0      HIT\n",
      "[39]   L=7         |(90.0 ) - (90.0 ))|                 0.0      HIT\n",
      "[40]   L=1         |(100.0) - (100.0))|                 0.0      HIT\n",
      "[41]   L=7         |(55.0 ) - (55.0 ))|                 0.0      HIT\n",
      "[42]   L=2         |(61.0 ) - (61.0 ))|                 0.0      HIT\n",
      "[43]   L=1         |(92.0 ) - (92.0 ))|                 0.0      HIT\n",
      "[44]   L=3         |(259.0) - (259.0))|                 0.0      HIT\n",
      "[45]   L=5         |(53.0 ) - (53.0 ))|                 0.0      HIT\n",
      "[46]   L=6         |(190.0) - (190.0))|                 0.0      HIT\n",
      "[47]   L=2         |(142.0) - (142.0))|                 0.0      HIT\n",
      "[48]   L=1         |(75.0 ) - (75.0 ))|                 0.0      HIT\n",
      "[49]   L=2         |(142.0) - (142.0))|                 0.0      HIT\n",
      "[50]   L=4         |(155.0) - (155.0))|                 0.0      HIT\n",
      "[51]   L=8         |(225.0) - (225.0))|                 0.0      HIT\n",
      "[52]   L=2         |(59.0 ) - (59.0 ))|                 0.0      HIT\n",
      "[53]   L=1         |(104.0) - (104.0))|                 0.0      HIT\n",
      "[54]   L=4         |(182.0) - (182.0))|                 0.0      HIT\n",
      "[55]   L=3         |(128.0) - (128.0))|                 0.0      HIT\n",
      "[56]   L=5         |(52.0 ) - (52.0 ))|                 0.0      HIT\n",
      "[57]   L=3         |(37.0 ) - (37.0 ))|                 0.0      HIT\n",
      "[58]   L=1         |(170.0) - (170.0))|                 0.0      HIT\n",
      "[59]   L=2         |(170.0) - (170.0))|                 0.0      HIT\n",
      "[60]   L=5         |(61.0 ) - (61.0 ))|                 0.0      HIT\n",
      "[61]   L=3         |(144.0) - (144.0))|                 0.0      HIT\n",
      "[62]   L=1         |(52.0 ) - (52.0 ))|                 0.0      HIT\n",
      "[63]   L=5         |(128.0) - (128.0))|                 0.0      HIT\n",
      "[64]   L=6         |(71.0 ) - (71.0 ))|                 0.0      HIT\n",
      "[65]   L=3         |(163.0) - (163.0))|                 0.0      HIT\n",
      "[66]   L=11        |(150.0) - (150.0))|                 0.0      HIT\n",
      "[67]   L=3         |(97.0 ) - (97.0 ))|                 0.0      HIT\n",
      "[68]   L=4         |(160.0) - (160.0))|                 0.0      HIT\n",
      "[69]   L=3         |(178.0) - (178.0))|                 0.0      HIT\n",
      "[70]   L=2         |(48.0 ) - (48.0 ))|                 0.0      HIT\n",
      "[71]   L=6         |(270.0) - (270.0))|                 0.0      HIT\n",
      "[72]   L=1         |(202.0) - (202.0))|                 0.0      HIT\n",
      "[73]   L=2         |(111.0) - (111.0))|                 0.0      HIT\n",
      "[74]   L=4         |(85.0 ) - (85.0 ))|                 0.0      HIT\n",
      "[75]   L=4         |(42.0 ) - (42.0 ))|                 0.0      HIT\n",
      "[76]   L=1         |(170.0) - (170.0))|                 0.0      HIT\n",
      "[77]   L=5         |(200.0) - (200.0))|                 0.0      HIT\n",
      "[78]   L=8         |(252.0) - (252.0))|                 0.0      HIT\n",
      "[79]   L=3         |(113.0) - (113.0))|                 0.0      HIT\n",
      "[80]   L=1         |(143.0) - (143.0))|                 0.0      HIT\n",
      "[81]   L=6         |(51.0 ) - (51.0 ))|                 0.0      HIT\n",
      "[82]   L=2         |(52.0 ) - (52.0 ))|                 0.0      HIT\n",
      "[83]   L=3         |(210.0) - (210.0))|                 0.0      HIT\n",
      "[84]   L=9         |(65.0 ) - (65.0 ))|                 0.0      HIT\n",
      "[85]   L=1         |(141.0) - (141.0))|                 0.0      HIT\n",
      "[86]   L=3         |(55.0 ) - (55.0 ))|                 0.0      HIT\n",
      "[87]   L=10        |(134.0) - (134.0))|                 0.0      HIT\n",
      "[88]   L=4         |(42.0 ) - (42.0 ))|                 0.0      HIT\n",
      "[89]   L=4         |(111.0) - (111.0))|                 0.0      HIT\n",
      "[90]   L=5         |(98.0 ) - (98.0 ))|                 0.0      HIT\n",
      "[91]   L=1         |(164.0) - (164.0))|                 0.0      HIT\n",
      "[92]   L=6         |(48.0 ) - (48.0 ))|                 0.0      HIT\n",
      "[93]   L=3         |(96.0 ) - (96.0 ))|                 0.0      HIT\n",
      "[94]   L=4         |(90.0 ) - (90.0 ))|                 0.0      HIT\n",
      "[95]   L=6         |(162.0) - (162.0))|                 0.0      HIT\n",
      "[96]   L=7         |(150.0) - (150.0))|                 0.0      HIT\n",
      "[97]   L=5         |(279.0) - (279.0))|                 0.0      HIT\n",
      "[98]   L=9         |(92.0 ) - (92.0 ))|                 0.0      HIT\n",
      "[99]   L=1         |(83.0 ) - (83.0 ))|                 0.0      HIT\n",
      "N_INTERPOLATED:0, N_large_error (>0.1):0\n",
      "avg error          :     0.0, avg_frac_error          :     0.0 \n",
      "avg exclusive error:     0.0, avg exclusive frac error:     0.0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'N_INTERPOLATED': 0,\n",
       " 'mean_error': 0.0,\n",
       " 'mean_frac_error': 0.0,\n",
       " 'mean_ex_error': 0.0,\n",
       " 'mean_ex_frac_error': 0.0}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from SQANN.model import SQANN, double_selective_activation\n",
    "from SQANN.utils import make_layer_setting, simple_evaluation, standard_evaluation\n",
    "\n",
    "a1 = 0.01\n",
    "a2 = 0.5\n",
    "t_admission = 0.2\n",
    "t_threshold = 0.9\n",
    "\n",
    "MAX_LAYER = 16\n",
    "layer_settings = {\n",
    "    # make_layer_setting(a1, a2,  admission_threshold, activation_threshold, max_n)\n",
    "    i: make_layer_setting(a1, a2,t_admission-0.01*i, t_threshold, ) for i in range(1,1+MAX_LAYER) # BEST NOW\n",
    "}\n",
    "\n",
    "net = SQANN(layer_settings, N=len(X_train))\n",
    "net.fit_data(X_train,Y_train,verbose=20)\n",
    "\n",
    "# Show zero errors on training dataset! (Theorem 2 in the paper)\n",
    "standard_evaluation(X_train, Y_train, net, get_interp_indices=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3808b7fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " i     Layer     |y-y0|                           abs error\n",
      "[0]    L=8         |(225.0) - (128.0))|                97.0      HIT\n",
      "[1]    L=7         |(150.0) - (102.0))|                48.0      HIT\n",
      "[2]    L=5         |(128.0) - (302.0))|               174.0      HIT\n",
      "[3]    L=9         |(92.0 ) - (198.0))|               106.0      HIT\n",
      "[4]    L=[3, 6]    |(141.538) - (95.0 ))|             46.538      INTERPOLATE\n",
      "[5]    L=11        |(150.0) - (53.0 ))|                97.0      HIT\n",
      "[6]    L=[2, 5]    |(100.358) - (134.0))|             33.642      INTERPOLATE\n",
      "[7]    L=8         |(252.0) - (144.0))|               108.0      HIT\n",
      "[8]    L=8         |(225.0) - (232.0))|                 7.0      HIT\n",
      "[9]    L=7         |(90.0 ) - (81.0 ))|                 9.0      HIT\n",
      "[10]   L=9         |(92.0 ) - (104.0))|                12.0      HIT\n",
      "[11]   L=[9, 9]    |(78.58) - (59.0 ))|               19.58      INTERPOLATE\n",
      "[12]   L=5         |(52.0 ) - (246.0))|               194.0      HIT\n",
      "[13]   L=7         |(90.0 ) - (297.0))|               207.0      HIT\n",
      "[14]   L=5         |(283.0) - (258.0))|                25.0      HIT\n",
      "[15]   L=[8, 8]    |(238.501) - (229.0))|             9.501      INTERPOLATE\n",
      "[16]   L=8         |(225.0) - (275.0))|                50.0      HIT\n",
      "[17]   L=5         |(283.0) - (281.0))|                 2.0      HIT\n",
      "[18]   L=7         |(90.0 ) - (179.0))|                89.0      HIT\n",
      "[19]   L=11        |(150.0) - (200.0))|                50.0      HIT\n",
      "[20]   L=[7, 6]    |(78.726) - (200.0))|              121.274      INTERPOLATE\n",
      "[21]   L=11        |(150.0) - (173.0))|                23.0      HIT\n",
      "[22]   L=[3, 7]    |(111.448) - (180.0))|             68.552      INTERPOLATE\n",
      "[23]   L=5         |(283.0) - (84.0 ))|               199.0      HIT\n",
      "[24]   L=11        |(150.0) - (121.0))|                29.0      HIT\n",
      "[25]   L=[6, 7]    |(82.106) - (161.0))|              78.894      INTERPOLATE\n",
      "[26]   L=[4, 5]    |(80.297) - (99.0 ))|              18.703      INTERPOLATE\n",
      "[27]   L=8         |(252.0) - (109.0))|               143.0      HIT\n",
      "[28]   L=7         |(90.0 ) - (115.0))|                25.0      HIT\n",
      "[29]   L=[7, 2]    |(130.589) - (268.0))|             137.411      INTERPOLATE\n",
      "[30]   L=5         |(283.0) - (274.0))|                 9.0      HIT\n",
      "[31]   L=[2, 6]    |(168.007) - (158.0))|             10.007      INTERPOLATE\n",
      "[32]   L=7         |(90.0 ) - (107.0))|                17.0      HIT\n",
      "[33]   L=[5, 2]    |(165.784) - (83.0 ))|             82.784      INTERPOLATE\n",
      "[34]   L=[2, 7]    |(126.402) - (103.0))|             23.402      INTERPOLATE\n",
      "[35]   L=5         |(283.0) - (272.0))|                11.0      HIT\n",
      "[36]   L=2         |(137.0) - (85.0 ))|                52.0      HIT\n",
      "[37]   L=8         |(252.0) - (280.0))|                28.0      HIT\n",
      "[38]   L=[7, 5]    |(120.099) - (336.0))|             215.901      INTERPOLATE\n",
      "[39]   L=[2, 5]    |(235.671) - (281.0))|             45.329      INTERPOLATE\n",
      "[40]   L=6         |(185.0) - (118.0))|                67.0      HIT\n",
      "[41]   L=5         |(283.0) - (317.0))|                34.0      HIT\n",
      "[42]   L=5         |(128.0) - (235.0))|               107.0      HIT\n",
      "[43]   L=9         |(92.0 ) - (60.0 ))|                32.0      HIT\n",
      "[44]   L=7         |(90.0 ) - (174.0))|                84.0      HIT\n",
      "[45]   L=[2, 5]    |(232.46) - (259.0))|              26.54      INTERPOLATE\n",
      "[46]   L=6         |(51.0 ) - (178.0))|               127.0      HIT\n",
      "[47]   L=[7, 5]    |(119.397) - (128.0))|             8.603      INTERPOLATE\n",
      "[48]   L=[6, 2]    |(50.0 ) - (96.0 ))|                46.0      INTERPOLATE\n",
      "[49]   L=2         |(61.0 ) - (126.0))|                65.0      HIT\n",
      "[50]   L=10        |(134.0) - (288.0))|               154.0      HIT\n",
      "[51]   L=10        |(134.0) - (88.0 ))|                46.0      HIT\n",
      "[52]   L=[9, 9]    |(78.53) - (292.0))|               213.47      INTERPOLATE\n",
      "[53]   L=9         |(92.0 ) - (71.0 ))|                21.0      HIT\n",
      "[54]   L=3         |(259.0) - (197.0))|                62.0      HIT\n",
      "[55]   L=8         |(225.0) - (186.0))|                39.0      HIT\n",
      "[56]   L=4         |(168.0) - (25.0 ))|               143.0      HIT\n",
      "[57]   L=7         |(90.0 ) - (84.0 ))|                 6.0      HIT\n",
      "[58]   L=[3, 7]    |(77.792) - (96.0 ))|              18.208      INTERPOLATE\n",
      "[59]   L=8         |(252.0) - (195.0))|                57.0      HIT\n",
      "[60]   L=7         |(90.0 ) - (53.0 ))|                37.0      HIT\n",
      "[61]   L=6         |(71.0 ) - (217.0))|               146.0      HIT\n",
      "[62]   L=[3, 5]    |(75.557) - (172.0))|              96.443      INTERPOLATE\n",
      "[63]   L=5         |(128.0) - (131.0))|                 3.0      HIT\n",
      "[64]   L=[8, 8]    |(238.464) - (214.0))|             24.464      INTERPOLATE\n",
      "[65]   L=7         |(90.0 ) - (59.0 ))|                31.0      HIT\n",
      "[66]   L=[2, 5]    |(232.179) - (70.0 ))|             162.179      INTERPOLATE\n",
      "[67]   L=5         |(283.0) - (220.0))|                63.0      HIT\n",
      "[68]   L=4         |(85.0 ) - (268.0))|               183.0      HIT\n",
      "[69]   L=5         |(283.0) - (152.0))|               131.0      HIT\n",
      "[70]   L=[8, 8]    |(238.498) - (47.0 ))|             191.498      INTERPOLATE\n",
      "[71]   L=9         |(65.0 ) - (74.0 ))|                 9.0      HIT\n",
      "[72]   L=[11, 8]    |(118.187) - (295.0))|             176.813      INTERPOLATE\n",
      "[73]   L=[3, 5]    |(180.372) - (101.0))|             79.372      INTERPOLATE\n",
      "[74]   L=5         |(128.0) - (151.0))|                23.0      HIT\n",
      "[75]   L=7         |(90.0 ) - (127.0))|                37.0      HIT\n",
      "[76]   L=9         |(92.0 ) - (237.0))|               145.0      HIT\n",
      "[77]   L=[6, 11]    |(167.271) - (225.0))|             57.729      INTERPOLATE\n",
      "[78]   L=[3, 7]    |(102.252) - (81.0 ))|             21.252      INTERPOLATE\n",
      "[79]   L=5         |(128.0) - (151.0))|                23.0      HIT\n",
      "[80]   L=[10, 11]    |(142.107) - (107.0))|             35.107      INTERPOLATE\n",
      "[81]   L=11        |(150.0) - (64.0 ))|                86.0      HIT\n",
      "[82]   L=9         |(92.0 ) - (138.0))|                46.0      HIT\n",
      "[83]   L=4         |(85.0 ) - (185.0))|               100.0      HIT\n",
      "[84]   L=8         |(252.0) - (265.0))|                13.0      HIT\n",
      "[85]   L=9         |(92.0 ) - (101.0))|                 9.0      HIT\n",
      "[86]   L=7         |(90.0 ) - (137.0))|                47.0      HIT\n",
      "[87]   L=[7, 1]    |(93.58) - (143.0))|               49.42      INTERPOLATE\n",
      "[88]   L=8         |(225.0) - (141.0))|                84.0      HIT\n",
      "[89]   L=[7, 5]    |(119.918) - (79.0 ))|             40.918      INTERPOLATE\n",
      "[90]   L=5         |(128.0) - (292.0))|               164.0      HIT\n",
      "[91]   L=[2, 5]    |(233.011) - (178.0))|             55.011      INTERPOLATE\n",
      "[92]   L=8         |(225.0) - (91.0 ))|               134.0      HIT\n",
      "[93]   L=7         |(90.0 ) - (116.0))|                26.0      HIT\n",
      "[94]   L=9         |(65.0 ) - (86.0 ))|                21.0      HIT\n",
      "[95]   L=7         |(90.0 ) - (122.0))|                32.0      HIT\n",
      "[96]   L=5         |(128.0) - (72.0 ))|                56.0      HIT\n",
      "[97]   L=8         |(225.0) - (129.0))|                96.0      HIT\n",
      "[98]   L=8         |(252.0) - (142.0))|               110.0      HIT\n",
      "[99]   L=6         |(270.0) - (90.0 ))|               180.0      HIT\n",
      "[100]  L=8         |(252.0) - (158.0))|                94.0      HIT\n",
      "[101]  L=9         |(92.0 ) - (39.0 ))|                53.0      HIT\n",
      "[102]  L=[5, 7]    |(118.731) - (196.0))|             77.269      INTERPOLATE\n",
      "[103]  L=8         |(225.0) - (222.0))|                 3.0      HIT\n",
      "[104]  L=5         |(283.0) - (277.0))|                 6.0      HIT\n",
      "[105]  L=[7, 2]    |(128.537) - (99.0 ))|             29.537      INTERPOLATE\n",
      "[106]  L=11        |(150.0) - (196.0))|                46.0      HIT\n",
      "[107]  L=8         |(252.0) - (202.0))|                50.0      HIT\n",
      "[108]  L=8         |(252.0) - (155.0))|                97.0      HIT\n",
      "[109]  L=11        |(150.0) - (77.0 ))|                73.0      HIT\n",
      "[110]  L=[9, 9]    |(78.517) - (191.0))|              112.483      INTERPOLATE\n",
      "[111]  L=[3, 6]    |(142.084) - (70.0 ))|             72.084      INTERPOLATE\n",
      "[112]  L=9         |(65.0 ) - (73.0 ))|                 8.0      HIT\n",
      "[113]  L=9         |(92.0 ) - (49.0 ))|                43.0      HIT\n",
      "[114]  L=9         |(65.0 ) - (65.0 ))|                 0.0      HIT\n",
      "[115]  L=5         |(283.0) - (263.0))|                20.0      HIT\n",
      "[116]  L=8         |(252.0) - (248.0))|                 4.0      HIT\n",
      "[117]  L=8         |(252.0) - (296.0))|                44.0      HIT\n",
      "[118]  L=5         |(128.0) - (214.0))|                86.0      HIT\n",
      "[119]  L=4         |(101.0) - (185.0))|                84.0      HIT\n",
      "[120]  L=9         |(92.0 ) - (78.0 ))|                14.0      HIT\n",
      "[121]  L=5         |(128.0) - (93.0 ))|                35.0      HIT\n",
      "[122]  L=8         |(225.0) - (252.0))|                27.0      HIT\n",
      "[123]  L=[2, 4]    |(117.792) - (150.0))|             32.208      INTERPOLATE\n",
      "[124]  L=9         |(92.0 ) - (77.0 ))|                15.0      HIT\n",
      "[125]  L=[6, 2]    |(96.454) - (208.0))|              111.546      INTERPOLATE\n",
      "[126]  L=[5, 2]    |(165.845) - (77.0 ))|             88.845      INTERPOLATE\n",
      "[127]  L=8         |(225.0) - (108.0))|               117.0      HIT\n",
      "[128]  L=[5, 7]    |(157.907) - (160.0))|             2.093      INTERPOLATE\n",
      "[129]  L=9         |(92.0 ) - (53.0 ))|                39.0      HIT\n",
      "[130]  L=5         |(283.0) - (220.0))|                63.0      HIT\n",
      "[131]  L=[8, 8]    |(238.5) - (154.0))|                84.5      INTERPOLATE\n",
      "[132]  L=7         |(90.0 ) - (259.0))|               169.0      HIT\n",
      "[133]  L=8         |(225.0) - (90.0 ))|               135.0      HIT\n",
      "[134]  L=8         |(252.0) - (246.0))|                 6.0      HIT\n",
      "[135]  L=[7, 5]    |(119.599) - (124.0))|             4.401      INTERPOLATE\n",
      "[136]  L=8         |(225.0) - (67.0 ))|               158.0      HIT\n",
      "[137]  L=[7, 9]    |(120.997) - (72.0 ))|             48.997      INTERPOLATE\n",
      "[138]  L=7         |(110.0) - (257.0))|               147.0      HIT\n",
      "[139]  L=[8, 8]    |(238.46) - (262.0))|              23.54      INTERPOLATE\n",
      "[140]  L=8         |(225.0) - (275.0))|                50.0      HIT\n",
      "[141]  L=7         |(90.0 ) - (177.0))|                87.0      HIT\n",
      "[142]  L=[2, 5]    |(232.661) - (71.0 ))|             161.661      INTERPOLATE\n",
      "[143]  L=8         |(225.0) - (47.0 ))|               178.0      HIT\n",
      "[144]  L=2         |(49.0 ) - (187.0))|               138.0      HIT\n",
      "[145]  L=7         |(90.0 ) - (125.0))|                35.0      HIT\n",
      "[146]  L=8         |(252.0) - (78.0 ))|               174.0      HIT\n",
      "[147]  L=[4, 5]    |(80.108) - (51.0 ))|              29.108      INTERPOLATE\n",
      "[148]  L=5         |(283.0) - (258.0))|                25.0      HIT\n",
      "[149]  L=[5, 2]    |(223.487) - (215.0))|             8.487      INTERPOLATE\n",
      "[150]  L=5         |(283.0) - (303.0))|                20.0      HIT\n",
      "[151]  L=[8, 8]    |(238.462) - (243.0))|             4.538      INTERPOLATE\n",
      "[152]  L=[8, 8]    |(238.477) - (91.0 ))|             147.477      INTERPOLATE\n",
      "[153]  L=8         |(225.0) - (150.0))|                75.0      HIT\n",
      "[154]  L=8         |(225.0) - (310.0))|                85.0      HIT\n",
      "[155]  L=9         |(92.0 ) - (153.0))|                61.0      HIT\n",
      "[156]  L=5         |(283.0) - (346.0))|                63.0      HIT\n",
      "[157]  L=[4, 5]    |(167.592) - (63.0 ))|             104.592      INTERPOLATE\n",
      "[158]  L=5         |(283.0) - (89.0 ))|               194.0      HIT\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[159]  L=11        |(150.0) - (50.0 ))|               100.0      HIT\n",
      "[160]  L=[7, 7]    |(100.475) - (39.0 ))|             61.475      INTERPOLATE\n",
      "[161]  L=11        |(150.0) - (103.0))|                47.0      HIT\n",
      "[162]  L=8         |(252.0) - (308.0))|                56.0      HIT\n",
      "[163]  L=[5, 2]    |(161.135) - (116.0))|             45.135      INTERPOLATE\n",
      "[164]  L=5         |(53.0 ) - (145.0))|                92.0      HIT\n",
      "[165]  L=7         |(90.0 ) - (74.0 ))|                16.0      HIT\n",
      "[166]  L=[2, 5]    |(166.045) - (45.0 ))|             121.045      INTERPOLATE\n",
      "[167]  L=8         |(225.0) - (115.0))|               110.0      HIT\n",
      "[168]  L=8         |(252.0) - (264.0))|                12.0      HIT\n",
      "[169]  L=[8, 6]    |(204.93) - (87.0 ))|              117.93      INTERPOLATE\n",
      "[170]  L=11        |(150.0) - (202.0))|                52.0      HIT\n",
      "[171]  L=8         |(225.0) - (127.0))|                98.0      HIT\n",
      "[172]  L=7         |(90.0 ) - (182.0))|                92.0      HIT\n",
      "[173]  L=5         |(283.0) - (241.0))|                42.0      HIT\n",
      "[174]  L=11        |(150.0) - (66.0 ))|                84.0      HIT\n",
      "[175]  L=7         |(90.0 ) - (94.0 ))|                 4.0      HIT\n",
      "[176]  L=8         |(252.0) - (283.0))|                31.0      HIT\n",
      "[177]  L=9         |(92.0 ) - (64.0 ))|                28.0      HIT\n",
      "[178]  L=5         |(283.0) - (102.0))|               181.0      HIT\n",
      "[179]  L=6         |(185.0) - (200.0))|                15.0      HIT\n",
      "[180]  L=6         |(185.0) - (265.0))|                80.0      HIT\n",
      "[181]  L=7         |(90.0 ) - (94.0 ))|                 4.0      HIT\n",
      "[182]  L=8         |(225.0) - (230.0))|                 5.0      HIT\n",
      "[183]  L=7         |(90.0 ) - (181.0))|                91.0      HIT\n",
      "[184]  L=[2, 5]    |(134.317) - (156.0))|             21.683      INTERPOLATE\n",
      "[185]  L=8         |(252.0) - (233.0))|                19.0      HIT\n",
      "[186]  L=8         |(252.0) - (60.0 ))|               192.0      HIT\n",
      "[187]  L=11        |(150.0) - (219.0))|                69.0      HIT\n",
      "[188]  L=8         |(225.0) - (80.0 ))|               145.0      HIT\n",
      "[189]  L=7         |(90.0 ) - (68.0 ))|                22.0      HIT\n",
      "[190]  L=8         |(225.0) - (332.0))|               107.0      HIT\n",
      "[191]  L=9         |(65.0 ) - (248.0))|               183.0      HIT\n",
      "[192]  L=9         |(65.0 ) - (84.0 ))|                19.0      HIT\n",
      "[193]  L=5         |(283.0) - (200.0))|                83.0      HIT\n",
      "[194]  L=8         |(225.0) - (55.0 ))|               170.0      HIT\n",
      "[195]  L=5         |(128.0) - (85.0 ))|                43.0      HIT\n",
      "[196]  L=[3, 7]    |(103.51) - (89.0 ))|              14.51      INTERPOLATE\n",
      "[197]  L=7         |(90.0 ) - (31.0 ))|                59.0      HIT\n",
      "[198]  L=[2, 5]    |(232.432) - (129.0))|             103.432      INTERPOLATE\n",
      "[199]  L=3         |(259.0) - (83.0 ))|               176.0      HIT\n",
      "[200]  L=11        |(150.0) - (275.0))|               125.0      HIT\n",
      "[201]  L=7         |(90.0 ) - (65.0 ))|                25.0      HIT\n",
      "[202]  L=8         |(225.0) - (198.0))|                27.0      HIT\n",
      "[203]  L=8         |(252.0) - (236.0))|                16.0      HIT\n",
      "[204]  L=9         |(92.0 ) - (253.0))|               161.0      HIT\n",
      "[205]  L=10        |(134.0) - (124.0))|                10.0      HIT\n",
      "[206]  L=8         |(225.0) - (44.0 ))|               181.0      HIT\n",
      "[207]  L=7         |(90.0 ) - (172.0))|                82.0      HIT\n",
      "[208]  L=9         |(92.0 ) - (114.0))|                22.0      HIT\n",
      "[209]  L=7         |(90.0 ) - (142.0))|                52.0      HIT\n",
      "[210]  L=8         |(252.0) - (109.0))|               143.0      HIT\n",
      "[211]  L=8         |(225.0) - (180.0))|                45.0      HIT\n",
      "[212]  L=6         |(190.0) - (144.0))|                46.0      HIT\n",
      "[213]  L=8         |(225.0) - (163.0))|                62.0      HIT\n",
      "[214]  L=[8, 8]    |(238.499) - (147.0))|             91.499      INTERPOLATE\n",
      "[215]  L=5         |(98.0 ) - (97.0 ))|                 1.0      HIT\n",
      "[216]  L=6         |(185.0) - (220.0))|                35.0      HIT\n",
      "[217]  L=8         |(252.0) - (190.0))|                62.0      HIT\n",
      "[218]  L=8         |(252.0) - (109.0))|               143.0      HIT\n",
      "[219]  L=8         |(252.0) - (191.0))|                61.0      HIT\n",
      "[220]  L=11        |(150.0) - (122.0))|                28.0      HIT\n",
      "[221]  L=5         |(283.0) - (230.0))|                53.0      HIT\n",
      "[222]  L=9         |(65.0 ) - (242.0))|               177.0      HIT\n",
      "[223]  L=8         |(252.0) - (248.0))|                 4.0      HIT\n",
      "[224]  L=8         |(252.0) - (249.0))|                 3.0      HIT\n",
      "[225]  L=[3, 6]    |(142.143) - (192.0))|             49.857      INTERPOLATE\n",
      "[226]  L=[11, 9]    |(107.226) - (131.0))|             23.774      INTERPOLATE\n",
      "[227]  L=11        |(150.0) - (237.0))|                87.0      HIT\n",
      "[228]  L=[8, 8]    |(238.502) - (78.0 ))|             160.502      INTERPOLATE\n",
      "[229]  L=[7, 4]    |(61.299) - (135.0))|              73.701      INTERPOLATE\n",
      "[230]  L=9         |(92.0 ) - (244.0))|               152.0      HIT\n",
      "[231]  L=7         |(110.0) - (199.0))|                89.0      HIT\n",
      "[232]  L=5         |(283.0) - (270.0))|                13.0      HIT\n",
      "[233]  L=8         |(225.0) - (164.0))|                61.0      HIT\n",
      "[234]  L=[5, 2]    |(165.537) - (72.0 ))|             93.537      INTERPOLATE\n",
      "[235]  L=9         |(65.0 ) - (96.0 ))|                31.0      HIT\n",
      "[236]  L=5         |(283.0) - (306.0))|                23.0      HIT\n",
      "[237]  L=11        |(150.0) - (91.0 ))|                59.0      HIT\n",
      "[238]  L=5         |(128.0) - (214.0))|                86.0      HIT\n",
      "[239]  L=8         |(225.0) - (95.0 ))|               130.0      HIT\n",
      "[240]  L=8         |(252.0) - (216.0))|                36.0      HIT\n",
      "[241]  L=2         |(141.0) - (263.0))|               122.0      HIT\n",
      "[242]  L=8         |(225.0) - (178.0))|                47.0      HIT\n",
      "[243]  L=8         |(252.0) - (113.0))|               139.0      HIT\n",
      "[244]  L=[7, 2]    |(212.116) - (200.0))|             12.116      INTERPOLATE\n",
      "[245]  L=7         |(90.0 ) - (139.0))|                49.0      HIT\n",
      "[246]  L=8         |(225.0) - (139.0))|                86.0      HIT\n",
      "[247]  L=[7, 5]    |(200.495) - (88.0 ))|             112.495      INTERPOLATE\n",
      "[248]  L=[9, 9]    |(78.501) - (148.0))|              69.499      INTERPOLATE\n",
      "[249]  L=5         |(283.0) - (88.0 ))|               195.0      HIT\n",
      "[250]  L=5         |(283.0) - (243.0))|                40.0      HIT\n",
      "[251]  L=[3, 5]    |(181.031) - (71.0 ))|             110.031      INTERPOLATE\n",
      "[252]  L=[3, 5]    |(183.905) - (77.0 ))|             106.905      INTERPOLATE\n",
      "[253]  L=5         |(283.0) - (109.0))|               174.0      HIT\n",
      "[254]  L=11        |(150.0) - (272.0))|               122.0      HIT\n",
      "[255]  L=6         |(51.0 ) - (60.0 ))|                 9.0      HIT\n",
      "[256]  L=9         |(92.0 ) - (54.0 ))|                38.0      HIT\n",
      "[257]  L=7         |(90.0 ) - (221.0))|               131.0      HIT\n",
      "[258]  L=9         |(65.0 ) - (90.0 ))|                25.0      HIT\n",
      "[259]  L=[7, 3]    |(146.839) - (311.0))|             164.161      INTERPOLATE\n",
      "[260]  L=6         |(48.0 ) - (281.0))|               233.0      HIT\n",
      "[261]  L=[5, 2]    |(223.548) - (182.0))|             41.548      INTERPOLATE\n",
      "[262]  L=9         |(65.0 ) - (321.0))|               256.0      HIT\n",
      "[263]  L=2         |(142.0) - (58.0 ))|                84.0      HIT\n",
      "[264]  L=7         |(90.0 ) - (262.0))|               172.0      HIT\n",
      "[265]  L=[7, 2]    |(131.188) - (206.0))|             74.812      INTERPOLATE\n",
      "[266]  L=5         |(283.0) - (233.0))|                50.0      HIT\n",
      "[267]  L=5         |(283.0) - (242.0))|                41.0      HIT\n",
      "[268]  L=7         |(90.0 ) - (123.0))|                33.0      HIT\n",
      "[269]  L=[6, 5]    |(102.15) - (167.0))|              64.85      INTERPOLATE\n",
      "[270]  L=9         |(65.0 ) - (63.0 ))|                 2.0      HIT\n",
      "[271]  L=8         |(252.0) - (197.0))|                55.0      HIT\n",
      "[272]  L=11        |(150.0) - (71.0 ))|                79.0      HIT\n",
      "[273]  L=[4, 5]    |(80.502) - (168.0))|              87.498      INTERPOLATE\n",
      "[274]  L=[2, 6]    |(93.111) - (140.0))|              46.889      INTERPOLATE\n",
      "[275]  L=7         |(90.0 ) - (217.0))|               127.0      HIT\n",
      "[276]  L=8         |(252.0) - (121.0))|               131.0      HIT\n",
      "[277]  L=[10, 6]    |(101.741) - (235.0))|             133.259      INTERPOLATE\n",
      "[278]  L=11        |(150.0) - (245.0))|                95.0      HIT\n",
      "[279]  L=8         |(252.0) - (40.0 ))|               212.0      HIT\n",
      "[280]  L=8         |(252.0) - (52.0 ))|               200.0      HIT\n",
      "[281]  L=[2, 7]    |(112.141) - (104.0))|             8.141      INTERPOLATE\n",
      "[282]  L=5         |(283.0) - (132.0))|               151.0      HIT\n",
      "[283]  L=11        |(150.0) - (88.0 ))|                62.0      HIT\n",
      "[284]  L=5         |(128.0) - (69.0 ))|                59.0      HIT\n",
      "[285]  L=5         |(53.0 ) - (219.0))|               166.0      HIT\n",
      "[286]  L=[7, 5]    |(94.144) - (72.0 ))|              22.144      INTERPOLATE\n",
      "[287]  L=[2, 5]    |(236.293) - (201.0))|             35.293      INTERPOLATE\n",
      "[288]  L=8         |(225.0) - (110.0))|               115.0      HIT\n",
      "[289]  L=2         |(49.0 ) - (51.0 ))|                 2.0      HIT\n",
      "[290]  L=[8, 8]    |(238.47) - (277.0))|              38.53      INTERPOLATE\n",
      "[291]  L=9         |(65.0 ) - (63.0 ))|                 2.0      HIT\n",
      "[292]  L=7         |(90.0 ) - (118.0))|                28.0      HIT\n",
      "[293]  L=7         |(55.0 ) - (69.0 ))|                14.0      HIT\n",
      "[294]  L=5         |(283.0) - (273.0))|                10.0      HIT\n",
      "[295]  L=[6, 5]    |(102.983) - (258.0))|             155.017      INTERPOLATE\n",
      "[296]  L=2         |(49.0 ) - (43.0 ))|                 6.0      HIT\n",
      "[297]  L=[9, 9]    |(78.515) - (198.0))|              119.485      INTERPOLATE\n",
      "[298]  L=8         |(225.0) - (242.0))|                17.0      HIT\n",
      "[299]  L=8         |(225.0) - (232.0))|                 7.0      HIT\n",
      "[300]  L=8         |(252.0) - (175.0))|                77.0      HIT\n",
      "[301]  L=[7, 5]    |(94.069) - (93.0 ))|              1.069      INTERPOLATE\n",
      "[302]  L=5         |(283.0) - (168.0))|               115.0      HIT\n",
      "[303]  L=5         |(283.0) - (275.0))|                 8.0      HIT\n",
      "[304]  L=6         |(48.0 ) - (293.0))|               245.0      HIT\n",
      "[305]  L=5         |(128.0) - (281.0))|               153.0      HIT\n",
      "[306]  L=9         |(65.0 ) - (72.0 ))|                 7.0      HIT\n",
      "[307]  L=11        |(150.0) - (140.0))|                10.0      HIT\n",
      "[308]  L=[7, 2]    |(132.001) - (189.0))|             56.999      INTERPOLATE\n",
      "[309]  L=8         |(225.0) - (181.0))|                44.0      HIT\n",
      "[310]  L=7         |(90.0 ) - (209.0))|               119.0      HIT\n",
      "[311]  L=8         |(252.0) - (136.0))|               116.0      HIT\n",
      "[312]  L=[7, 2]    |(130.605) - (261.0))|             130.395      INTERPOLATE\n",
      "[313]  L=[7, 6]    |(76.767) - (113.0))|              36.233      INTERPOLATE\n",
      "[314]  L=5         |(283.0) - (131.0))|               152.0      HIT\n",
      "[315]  L=8         |(225.0) - (174.0))|                51.0      HIT\n",
      "[316]  L=8         |(252.0) - (257.0))|                 5.0      HIT\n",
      "[317]  L=[7, 2]    |(128.501) - (55.0 ))|             73.501      INTERPOLATE\n",
      "[318]  L=5         |(279.0) - (84.0 ))|               195.0      HIT\n",
      "[319]  L=9         |(92.0 ) - (42.0 ))|                50.0      HIT\n",
      "[320]  L=8         |(252.0) - (146.0))|               106.0      HIT\n",
      "[321]  L=7         |(90.0 ) - (212.0))|               122.0      HIT\n",
      "[322]  L=5         |(283.0) - (233.0))|                50.0      HIT\n",
      "[323]  L=8         |(252.0) - (91.0 ))|               161.0      HIT\n",
      "[324]  L=9         |(65.0 ) - (111.0))|                46.0      HIT\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[325]  L=[5, 2]    |(135.206) - (152.0))|             16.794      INTERPOLATE\n",
      "[326]  L=7         |(90.0 ) - (120.0))|                30.0      HIT\n",
      "[327]  L=11        |(150.0) - (67.0 ))|                83.0      HIT\n",
      "[328]  L=8         |(252.0) - (310.0))|                58.0      HIT\n",
      "[329]  L=9         |(65.0 ) - (94.0 ))|                29.0      HIT\n",
      "[330]  L=5         |(128.0) - (183.0))|                55.0      HIT\n",
      "[331]  L=8         |(225.0) - (66.0 ))|               159.0      HIT\n",
      "[332]  L=11        |(150.0) - (173.0))|                23.0      HIT\n",
      "[333]  L=5         |(128.0) - (72.0 ))|                56.0      HIT\n",
      "[334]  L=11        |(150.0) - (49.0 ))|               101.0      HIT\n",
      "[335]  L=7         |(90.0 ) - (64.0 ))|                26.0      HIT\n",
      "[336]  L=6         |(190.0) - (48.0 ))|               142.0      HIT\n",
      "[337]  L=[1, 2]    |(145.072) - (178.0))|             32.928      INTERPOLATE\n",
      "[338]  L=7         |(90.0 ) - (104.0))|                14.0      HIT\n",
      "[339]  L=3         |(259.0) - (132.0))|               127.0      HIT\n",
      "[340]  L=5         |(128.0) - (220.0))|                92.0      HIT\n",
      "[341]  L=[2, 5]    |(235.897) - (57.0 ))|             178.897      INTERPOLATE\n",
      "N_INTERPOLATED:88, N_large_error (>0.1):341\n",
      "avg error          : 73.82305, avg_frac_error          : 0.64979 \n",
      "avg exclusive error: 74.73228, avg exclusive frac error: 0.64979\n"
     ]
    }
   ],
   "source": [
    "\n",
    "standard_evaluation(X_test, Y_test, net, get_interp_indices=False)\n",
    "Y_pred = []\n",
    "for i in range(len(X_test)):\n",
    "    y, act, ACTIVATION_STATUS, info_ = net.SQANN_propagation(X_test[i,:], ALLOW_INTERPOLATION=True)\n",
    "    Y_pred.append(y)\n",
    "MSE['SQANN'] = compute_mse(np.array(Y_pred),Y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d6ccbe6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linear                   : 58.574\n",
      "Ridge                    : 57.6595\n",
      "Lasso                    : 59.2133\n",
      "LinSVR                   : 94.6494\n",
      "NuSVR                    : 84.3368\n",
      "SVR                      : 81.0384\n",
      "DecTree                  : 76.7292\n",
      "kneigh                   : 66.9468\n",
      "MLP                      : 108.8577\n",
      "SQANN                    : 93.7947\n"
     ]
    }
   ],
   "source": [
    "for model_name, mse_score in MSE.items():\n",
    "    print('%-24s : %s'%(str(model_name),str(np.round( mse_score,4))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "55f3e046",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n total: 442\n",
      "n original training: 100\n",
      "n integrated: 218\n",
      "n training new: 318\n",
      "n test: 342\n",
      "new training fraction: 0.719\n"
     ]
    }
   ],
   "source": [
    "from SQANN.utils import ood_searcher\n",
    "OOD_INDICES_COLLECTION = ood_searcher(X_test,Y_test,net, len(X_test),error_th=40)\n",
    "\n",
    "X_train_new = list(X_train)\n",
    "Y_train_new = list(Y_train)\n",
    "\n",
    "for oi in OOD_INDICES_COLLECTION:\n",
    "    X_train_new.append(X_test[oi,:])\n",
    "    Y_train_new.append(Y_test[oi])\n",
    "X_train_new = np.array(X_train_new)\n",
    "Y_train_new = np.array(Y_train_new)\n",
    "\n",
    "# print(OOD_INDICES_COLLECTION)\n",
    "print('n total:', len(X))\n",
    "print('n original training:', len(X_train))\n",
    "print('n integrated:',len(OOD_INDICES_COLLECTION))\n",
    "print('n training new:', len(X_train_new))\n",
    "print('n test:', len(X_test))\n",
    "print('new training fraction:',np.round(len(X_train_new)/len(X),3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "99a989b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linear models...\n",
      "\n",
      "Support vector models...\n",
      "\n",
      "Decision Tree models...\n",
      "\n",
      "KNeighbors models...\n",
      "53.0588 n:2\n",
      "\n",
      "MLP models...\n"
     ]
    }
   ],
   "source": [
    "MSE_new = {}\n",
    "print('linear models...')\n",
    "MSE_new = do_regress( LinearRegression(), X_train_new, Y_train_new, MSE_new, 'linear')\n",
    "MSE_new = do_regress(Ridge(), X_train_new, Y_train_new, MSE_new,'Ridge')\n",
    "MSE_new = do_regress(Lasso(), X_train_new, Y_train_new, MSE_new, 'Lasso')\n",
    "\n",
    "print('\\nSupport vector models...')\n",
    "from sklearn.svm import LinearSVR, NuSVR, SVR\n",
    "MSE_new = do_regress(LinearSVR(), X_train_new, Y_train_new, MSE_new,'LinSVR')\n",
    "MSE_new = do_regress(NuSVR(), X_train_new, Y_train_new, MSE_new,'NuSVR')\n",
    "MSE_new = do_regress(SVR(), X_train_new, Y_train_new, MSE_new,'SVR')\n",
    "\n",
    "print('\\nDecision Tree models...')\n",
    "from sklearn.tree import DecisionTreeRegressor\n",
    "MSE_new = do_regress(DecisionTreeRegressor(random_state=0), X_train_new, Y_train_new, MSE_new,'DecTree')\n",
    "\n",
    "print('\\nKNeighbors models...')\n",
    "from sklearn.neighbors import KNeighborsRegressor\n",
    "min_mse_new,n_min = np.inf,-1\n",
    "for n in range(2,16+1):\n",
    "    temp_name = 'KNeighborsRegressor(n_neighbors=%s)'%(str(n))\n",
    "    dummy = do_regress(KNeighborsRegressor(n_neighbors=n), X_train_new, Y_train_new, {}, temp_name)\n",
    "    this_mse = dummy[temp_name]\n",
    "    if this_mse<min_mse_new:\n",
    "        print(np.round(this_mse,4), 'n:%s'%(str(n)))\n",
    "        min_mse_new = this_mse\n",
    "        n_min = n\n",
    "MSE_new['kneigh'] = min_mse_new\n",
    "\n",
    "print('\\nMLP models...')\n",
    "from sklearn.neural_network import MLPRegressor\n",
    "MSE_new = do_regress(MLPRegressor(hidden_layer_sizes=(64,64),random_state=1, max_iter=12000), X_train_new, Y_train_new, MSE_new, 'MLP')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9f737ee5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Exiting layer_k_sample_collection() because all data have been used.\n",
      "Final positions of indices in the layers:\n",
      "  [1] [0, 1, 3, 5, 7, 11, 15, 20, 23, 25, 28, 32, 35, 38, 40, 43, 48, 53, 58, 62, 72, 76, 80, 85, 91, 99, 114, 117, 121, 136, 139, 153, 157, 175, 187, 192, 194, 207, 212, 217, 236, 256, 258, 264, 269, 293, 297, 305, 317]\n",
      "  [2] [2, 4, 9, 12, 14, 17, 19, 22, 26, 31, 36, 39, 46, 50, 52, 59, 63, 67, 69, 71, 74, 86, 90, 94, 97, 100, 104, 118, 125, 134, 137, 140, 143, 145, 150, 152, 158, 167, 169, 174, 180, 186, 189, 191, 197, 202, 209, 218, 221, 226, 228, 234, 237, 260, 263, 270, 283, 294, 299, 303, 308, 312, 249, 47, 290]\n",
      "  [3] [37, 75, 105, 110, 120, 135, 154, 171, 185, 199, 205, 230, 238, 253, 271, 276, 309, 314, 6, 45, 57, 92, 106, 113, 138, 155, 168, 213, 243, 29, 65, 124, 133, 160, 172, 216, 268, 279, 287, 16, 88, 123, 151, 184, 42, 112, 24, 93, 30]\n",
      "  [4] [132, 162, 273, 68, 109, 198, 248, 61, 79, 119, 142, 161, 201, 229, 250, 259, 274, 81, 115, 170, 235, 298, 55, 144, 13, 95, 173, 240, 295, 159, 265, 51, 149, 277, 315, 166, 176, 179, 196, 203, 206, 255, 280, 288, 304, 70, 190, 34, 220, 82, 178, 225, 275, 244, 214, 231, 232, 296, 98, 239]\n",
      "  [5] [301, 147, 96, 211, 122, 262, 127, 311, 246, 183, 182, 87, 64, 107, 102, 128, 77, 261, 310, 195, 224, 21, 44, 148, 284, 233, 54, 126, 188, 245, 73, 108, 242, 33, 27]\n",
      "  [6] [306, 89, 272, 267, 208, 292, 56, 210, 247, 103, 66, 204, 219, 49, 300, 41, 215, 286, 266]\n",
      "  [7] [289, 111, 146, 8, 307, 227, 181, 131, 163, 281, 241, 116, 129, 78, 177, 285, 156, 84]\n",
      "  [8] [60, 141, 254, 200, 302, 101, 223, 316, 282, 130, 252, 164, 83, 313]\n",
      "  [9] [251, 18, 10, 278]\n",
      "  [10] [257, 193, 165]\n",
      "  [11] [291, 222]\n"
     ]
    }
   ],
   "source": [
    "net_new = SQANN(layer_settings, N=len(X_train_new))\n",
    "net_new.fit_data(X_train_new,Y_train_new,verbose=20)\n",
    "\n",
    "# standard_evaluation(X_test, Y_test, net_new, get_interp_indices=False)\n",
    "Y_pred_new = []\n",
    "for i in range(len(X_test)):\n",
    "    y1, act1, ACTIVATION_STATUS1, info1_ = net.SQANN_propagation(X_test[i,:], ALLOW_INTERPOLATION=True)\n",
    "    y2, act2, ACTIVATION_STATUS2, info2_ = net_new.SQANN_propagation(X_test[i,:], ALLOW_INTERPOLATION=True)\n",
    "\n",
    "    y = y1 if np.max(act1)>np.max(act2) else y2\n",
    "    Y_pred_new.append(y)\n",
    "MSE_new['SQANN'] = compute_mse(np.array(Y_pred_new),Y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "66196e6a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linear                   : 54.1496\n",
      "Ridge                    : 54.1838\n",
      "Lasso                    : 55.6517\n",
      "LinSVR                   : 70.8149\n",
      "NuSVR                    : 74.3897\n",
      "SVR                      : 73.5676\n",
      "DecTree                  : 40.5181\n",
      "kneigh                   : 53.0588\n",
      "MLP                      : 54.1694\n",
      "SQANN                    : 53.9622\n"
     ]
    }
   ],
   "source": [
    "for model_name, mse_score in MSE_new.items():\n",
    "    print('%-24s : %s'%(str(model_name),str(np.round( mse_score,4))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "1e91a65a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "        linear   Ridge   Lasso  LinSVR   NuSVR     SVR  DecTree  kneigh      MLP   SQANN\n",
      "orig.   58.574  57.660  59.213  94.649  84.337  81.038   76.729  66.947  108.858  93.795\n",
      "eth:40  54.150  54.184  55.652  70.815  74.390  73.568   40.518  53.059   54.169  53.962\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "pd.set_option('display.expand_frame_repr', False)\n",
    "\n",
    "results = {}\n",
    "for model_name in MSE:\n",
    "    results[model_name] = [MSE[model_name],MSE_new[model_name]]\n",
    "df = pd.DataFrame(results)\n",
    "df.index = ['orig.','eth:40'] # eth: error threshold\n",
    "print(df.round(3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2c7ef61",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
