{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "91e698ed-63ed-4c73-8580-3c1bd3657576",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 acute-inflammation \tN: 120 \td: 6 \tc: 2\n",
      "1/1\n",
      "NTK_acc: [1.0]\n",
      "Ours_acc: [1.0]\n",
      "2 acute-nephritis \tN: 120 \td: 6 \tc: 2\n",
      "2/2\n",
      "NTK_acc: [1.0, 1.0]\n",
      "Ours_acc: [1.0, 1.0]\n",
      "8 balloons \tN: 16 \td: 4 \tc: 2\n",
      "3/3\n",
      "NTK_acc: [1.0, 1.0, 0.75]\n",
      "Ours_acc: [1.0, 1.0, 0.75]\n",
      "10 blood \tN: 748 \td: 4 \tc: 2\n",
      "4/4\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583]\n",
      "11 breast-cancer \tN: 286 \td: 9 \tc: 2\n",
      "5/5\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625]\n",
      "12 breast-cancer-wisc \tN: 699 \td: 9 \tc: 2\n",
      "6/6\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966]\n",
      "13 breast-cancer-wisc-diag \tN: 569 \td: 30 \tc: 2\n",
      "6/7\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915]\n",
      "14 breast-cancer-wisc-prog \tN: 198 \td: 33 \tc: 2\n",
      "6/8\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66]\n",
      "21 congressional-voting \tN: 435 \td: 16 \tc: 2\n",
      "7/9\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266]\n",
      "22 conn-bench-sonar-mines-rocks \tN: 208 \td: 60 \tc: 2\n",
      "8/10\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712]\n",
      "26 credit-approval \tN: 690 \td: 15 \tc: 2\n",
      "9/11\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844]\n",
      "27 cylinder-bands \tN: 512 \td: 35 \tc: 2\n",
      "10/12\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82]\n",
      "29 echocardiogram \tN: 131 \td: 10 \tc: 2\n",
      "11/13\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788]\n",
      "33 fertility \tN: 100 \td: 9 \tc: 2\n",
      "12/14\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76]\n",
      "36 haberman-survival \tN: 306 \td: 3 \tc: 2\n",
      "13/15\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532]\n",
      "39 heart-hungarian \tN: 294 \td: 12 \tc: 2\n",
      "14/16\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878]\n",
      "42 hepatitis \tN: 155 \td: 19 \tc: 2\n",
      "14/17\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897]\n",
      "45 ilpd-indian-liver \tN: 583 \td: 9 \tc: 2\n",
      "15/18\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555]\n",
      "47 ionosphere \tN: 351 \td: 33 \tc: 2\n",
      "16/19\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966]\n",
      "57 mammographic \tN: 961 \td: 5 \tc: 2\n",
      "17/20\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792]\n",
      "59 molec-biol-promoter \tN: 106 \td: 57 \tc: 2\n",
      "18/21\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783, 0.63]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792, 0.815]\n",
      "66 musk-1 \tN: 476 \td: 166 \tc: 2\n",
      "19/22\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783, 0.63, 0.782]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792, 0.815, 0.866]\n",
      "71 oocytes_trisopterus_nucleus_2f \tN: 912 \td: 25 \tc: 2\n",
      "19/23\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783, 0.63, 0.782, 0.781]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792, 0.815, 0.866, 0.759]\n",
      "76 parkinsons \tN: 195 \td: 22 \tc: 2\n",
      "20/24\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783, 0.63, 0.782, 0.781, 0.939]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792, 0.815, 0.866, 0.759, 0.959]\n",
      "78 pima \tN: 768 \td: 8 \tc: 2\n",
      "21/25\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783, 0.63, 0.782, 0.781, 0.939, 0.552]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792, 0.815, 0.866, 0.759, 0.959, 0.599]\n",
      "82 pittsburg-bridges-T-OR-D \tN: 102 \td: 7 \tc: 2\n",
      "22/26\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783, 0.63, 0.782, 0.781, 0.939, 0.552, 0.731]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792, 0.815, 0.866, 0.759, 0.959, 0.599, 0.846]\n",
      "84 planning \tN: 182 \td: 12 \tc: 2\n",
      "23/27\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783, 0.63, 0.782, 0.781, 0.939, 0.552, 0.731, 0.435]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792, 0.815, 0.866, 0.759, 0.959, 0.599, 0.846, 0.543]\n",
      "97 statlog-australian-credit \tN: 690 \td: 14 \tc: 2\n",
      "23/28\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783, 0.63, 0.782, 0.781, 0.939, 0.552, 0.731, 0.435, 1.0]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792, 0.815, 0.866, 0.759, 0.959, 0.599, 0.846, 0.543, 0.74]\n",
      "98 statlog-german-credit \tN: 1000 \td: 24 \tc: 2\n",
      "24/29\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783, 0.63, 0.782, 0.781, 0.939, 0.552, 0.731, 0.435, 1.0, 0.512]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792, 0.815, 0.866, 0.759, 0.959, 0.599, 0.846, 0.543, 0.74, 0.576]\n",
      "99 statlog-heart \tN: 270 \td: 13 \tc: 2\n",
      "24/30\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783, 0.63, 0.782, 0.781, 0.939, 0.552, 0.731, 0.435, 1.0, 0.512, 0.779]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792, 0.815, 0.866, 0.759, 0.959, 0.599, 0.846, 0.543, 0.74, 0.576, 0.765]\n",
      "108 tic-tac-toe \tN: 958 \td: 9 \tc: 2\n",
      "25/31\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783, 0.63, 0.782, 0.781, 0.939, 0.552, 0.731, 0.435, 1.0, 0.512, 0.779, 1.0]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792, 0.815, 0.866, 0.759, 0.959, 0.599, 0.846, 0.543, 0.74, 0.576, 0.765, 1.0]\n",
      "110 trains \tN: 10 \td: 29 \tc: 2\n",
      "26/32\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783, 0.63, 0.782, 0.781, 0.939, 0.552, 0.731, 0.435, 1.0, 0.512, 0.779, 1.0, 0.667]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792, 0.815, 0.866, 0.759, 0.959, 0.599, 0.846, 0.543, 0.74, 0.576, 0.765, 1.0, 0.667]\n",
      "112 vertebral-column-2clases \tN: 310 \td: 6 \tc: 2\n",
      "26/33\n",
      "NTK_acc: [1.0, 1.0, 0.75, 0.524, 0.417, 0.96, 0.965, 0.7, 0.266, 0.635, 0.838, 0.773, 0.758, 0.76, 0.481, 0.743, 0.923, 0.432, 0.955, 0.783, 0.63, 0.782, 0.781, 0.939, 0.552, 0.731, 0.435, 1.0, 0.512, 0.779, 1.0, 0.667, 0.821]\n",
      "Ours_acc: [1.0, 1.0, 0.75, 0.583, 0.625, 0.966, 0.915, 0.66, 0.266, 0.712, 0.844, 0.82, 0.788, 0.76, 0.532, 0.878, 0.897, 0.555, 0.966, 0.792, 0.815, 0.866, 0.759, 0.959, 0.599, 0.846, 0.543, 0.74, 0.576, 0.765, 1.0, 0.667, 0.731]\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "import os\n",
    "import math\n",
    "import numpy as np\n",
    "import NTK\n",
    "import tools_submission\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "seed=42\n",
    "np.random.seed(seed=seed)\n",
    "\n",
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('-dir', default = \"data\", type = str, help = \"data directory\")\n",
    "parser.add_argument('-file', default = \"result.log\", type = str, help = \"Output File\")\n",
    "parser.add_argument('-max_tot', default = 1000, type = int, help = \"Maximum number of data samples\")\n",
    "parser.add_argument('-max_dep', default = 1, type = int, help = \"Maximum number of depth\")\n",
    "\n",
    "MIN_N_TOT=0\n",
    "\n",
    "args = parser.parse_args()\n",
    "\n",
    "MAX_N_TOT = args.max_tot\n",
    "MAX_DEP = args.max_dep\n",
    "DEP_LIST = list(range(MAX_DEP))\n",
    "C_LIST = [10.0 ** i for i in range(-2, 4)]\n",
    "datadir = args.dir\n",
    "\n",
    "alg = tools_submission.svm\n",
    "our_kernel = tools_submission.our_kernel\n",
    "\n",
    "\n",
    "best_acc_list = []\n",
    "best_acc_list2 = []\n",
    "\n",
    "dataset_names=[]\n",
    "dataset_sizes=[]\n",
    "\n",
    "\n",
    "outf = open(args.file, \"w\")\n",
    "print (\"Dataset\\tValidation Acc\\tTest Acc\", file = outf)\n",
    "for idx, dataset in enumerate(sorted(os.listdir(datadir))):\n",
    "    if not os.path.isdir(datadir + \"/\" + dataset):\n",
    "        continue\n",
    "    if not os.path.isfile(datadir + \"/\" + dataset + \"/\" + dataset + \".txt\"):\n",
    "        continue\n",
    "    dic = dict()\n",
    "    for k, v in map(lambda x : x.split(), open(datadir + \"/\" + dataset + \"/\" + dataset + \".txt\", \"r\").readlines()):\n",
    "        dic[k] = v\n",
    "    c = int(dic[\"n_clases=\"])\n",
    "    d = int(dic[\"n_entradas=\"])\n",
    "    n_train = int(dic[\"n_patrons_entrena=\"])\n",
    "    n_val = int(dic[\"n_patrons_valida=\"])\n",
    "    n_train_val = int(dic[\"n_patrons1=\"])\n",
    "    n_test = 0\n",
    "    if \"n_patrons2=\" in dic:\n",
    "        n_test = int(dic[\"n_patrons2=\"])\n",
    "    n_tot = n_train_val + n_test\n",
    "    \n",
    "    if (n_tot > MAX_N_TOT or n_tot < MIN_N_TOT) or n_test > 0 or c>2:\n",
    "        print (str(dataset) + '\\t0\\t0', file = outf)\n",
    "        continue\n",
    "    \n",
    "    print (idx, dataset, \"\\tN:\", n_tot, \"\\td:\", d, \"\\tc:\", c)\n",
    "    dataset_names.append(dataset)\n",
    "    dataset_sizes.append([n_tot,d])\n",
    "    # load data\n",
    "    f = open(\"data/\" + dataset + \"/\" + dic[\"fich1=\"], \"r\").readlines()[1:]\n",
    "    X = np.asarray(list(map(lambda x: list(map(float, x.split()[1:-1])), f)))\n",
    "    y = np.asarray(list(map(lambda x: int(x.split()[-1]), f)))\n",
    "    \n",
    "    # calculate NTK\n",
    "    Ks = NTK.kernel_value_batch(X, MAX_DEP)\n",
    "        \n",
    "    # load training and validation set\n",
    "    fold = list(map(lambda x: list(map(int, x.split())), open(datadir + \"/\" + dataset + \"/\" + \"conxuntos.dat\", \"r\").readlines()))\n",
    "    a=fold[0].copy()\n",
    "    a.extend(fold[1][:len(fold[1])//2])\n",
    "    b=fold[1][len(fold[1])//2:]\n",
    "    train_fold = a.copy()\n",
    "    val_fold= b.copy()\n",
    "    best_acc = 0.0\n",
    "    best_value = 0\n",
    "    best_dep = 0\n",
    "    best_ker = 0\n",
    "    # enumerate kenerls and cost values to find the best hyperparameters\n",
    "    for dep in DEP_LIST:\n",
    "        for fix_dep in range(dep + 1):\n",
    "            K = Ks[dep][fix_dep]\n",
    "            for value in C_LIST:\n",
    "                acc = alg(K[train_fold][:, train_fold], K[val_fold][:, train_fold], y[train_fold], y[val_fold], value, c)\n",
    "\n",
    "                if acc > best_acc:\n",
    "                    best_acc = acc\n",
    "                    best_value = value\n",
    "                    best_dep = dep\n",
    "                    best_fix = fix_dep\n",
    "    \n",
    "    betavec=[10.0 ** i for i in range(-6,1)]\n",
    "    K = Ks[best_dep][best_fix]\n",
    "\n",
    "    best_acc2 = 0.0\n",
    "    for beta in betavec:\n",
    "        acc2 = our_kernel(X[train_fold,:],  y[train_fold],X[val_fold,:], y[val_fold],beta)\n",
    "        if acc2 > best_acc2:\n",
    "            best_acc2 = acc2\n",
    "\n",
    "            \n",
    "    best_acc_list.append(np.round(best_acc,3))\n",
    "    best_acc_list2.append(np.round(best_acc2,3))\n",
    "\n",
    "    print(str(np.sum((np.array(best_acc_list2)-np.array(best_acc_list))>=0))+\"/\"+str(len(best_acc_list2)))\n",
    "        \n",
    "    print (\"NTK_acc:\", best_acc_list)\n",
    "    print (\"Ours_acc:\", best_acc_list2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "da7a9ff8-22dd-4b9f-8291-5c3f3072accc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done\n"
     ]
    }
   ],
   "source": [
    "print(\"Done\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8a14838f-1e17-41d4-9b31-a80b4ffe8d66",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.   , 1.   ],\n",
       "       [1.   , 1.   ],\n",
       "       [0.75 , 0.75 ],\n",
       "       [0.524, 0.583],\n",
       "       [0.417, 0.625],\n",
       "       [0.96 , 0.966],\n",
       "       [0.965, 0.915],\n",
       "       [0.7  , 0.66 ],\n",
       "       [0.266, 0.266],\n",
       "       [0.635, 0.712],\n",
       "       [0.838, 0.844],\n",
       "       [0.773, 0.82 ],\n",
       "       [0.758, 0.788],\n",
       "       [0.76 , 0.76 ],\n",
       "       [0.481, 0.532],\n",
       "       [0.743, 0.878],\n",
       "       [0.923, 0.897],\n",
       "       [0.432, 0.555],\n",
       "       [0.955, 0.966],\n",
       "       [0.783, 0.792],\n",
       "       [0.63 , 0.815],\n",
       "       [0.782, 0.866],\n",
       "       [0.781, 0.759],\n",
       "       [0.939, 0.959],\n",
       "       [0.552, 0.599],\n",
       "       [0.731, 0.846],\n",
       "       [0.435, 0.543],\n",
       "       [1.   , 0.74 ],\n",
       "       [0.512, 0.576],\n",
       "       [0.779, 0.765],\n",
       "       [1.   , 1.   ],\n",
       "       [0.667, 0.667],\n",
       "       [0.821, 0.731]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array([np.array(best_acc_list),np.array(best_acc_list2)]).reshape(2,-1).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "bd3618f8-cf6d-44e8-b6fc-8d2190f068b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "14"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(np.array(best_acc_list)-np.array(best_acc_list2)>=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "470351e7-95df-49cf-9384-92a387e8cebf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
