{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os\n",
    "import math\n",
    "import pickle\n",
    "import numpy as np\n",
    "import rfm\n",
    "import wagop_rfm_laplace\n",
    "import fact_rfm_laplace\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import utils\n",
    "import classic_kernel\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "from tabular_datasets_utils import get_tabular_datasets, print_dataset_info, \\\n",
    "    hyperparam_select_dataset, test_dataset, get_dataset_train_val_data\n",
    "\n",
    "\n",
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('-dir', default = \"data\", type = str, help = \"data directory\")\n",
    "parser.add_argument('-file', default = \"result.log\", type = str, help = \"Output File\")\n",
    "\n",
    "args = parser.parse_args([])\n",
    "datadir = args.dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "if not os.path.exists('data'):\n",
    "    os.makedirs('data')\n",
    "    import tarfile\n",
    "    import urllib.request\n",
    "    url = 'http://persoal.citius.usc.es/manuel.fernandez.delgado/papers/jmlr/data.tar.gz'\n",
    "    urllib.request.urlretrieve(url, 'data.tar.gz')\n",
    "    with tarfile.open('data.tar.gz', 'r:gz') as tar:\n",
    "        tar.extractall(path='data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cpu'\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_laplacian_M_sol(X_train, y_train, c, M, iters=5, reg=0, L=10, normalize=False, device='cpu'):\n",
    "    \n",
    "    y_train = utils.convert_one_hot(y_train, c)\n",
    "    if normalize:\n",
    "        X_train /= np.linalg.norm(X_train, axis=-1).reshape(-1, 1)\n",
    "\n",
    "    X_train = torch.from_numpy(X_train).float().to(device)\n",
    "    y_train = torch.from_numpy(y_train).float().to(device)\n",
    "\n",
    "    K_train = classic_kernel.clamped_laplacian_M(X_train, X_train, L, M)\n",
    "    sol = torch.linalg.solve(K_train + reg * torch.eye(K_train.shape[0]).to(device), y_train).T\n",
    "    return sol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_gaussian_M_sol(X_train, y_train, c, M, iters=5, reg=0, L=10, normalize=False, device='cpu'):\n",
    "    \n",
    "    y_train = utils.convert_one_hot(y_train, c)\n",
    "    if normalize:\n",
    "        X_train /= np.linalg.norm(X_train, axis=-1).reshape(-1, 1)\n",
    "\n",
    "    X_train = torch.from_numpy(X_train).float().to(device)\n",
    "    y_train = torch.from_numpy(y_train).float().to(device)\n",
    "\n",
    "    K_train = classic_kernel.gaussian_M(X_train, X_train, L, M)\n",
    "    sol = torch.linalg.solve(K_train + reg * torch.eye(K_train.shape[0]).to(device), y_train).T\n",
    "    return sol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_closeness_to_id(X_train, sol):\n",
    "    # sol is c x n\n",
    "    # X_train is n x d\n",
    "    # print(sol.shape)\n",
    "    # print(X_train.shape)\n",
    "    K = torch.from_numpy(X_train @ X_train.T).to(device)\n",
    "    out_mat = torch.einsum('ij,ci,cj->cij', K, sol,sol)\n",
    "    for i in range(out_mat.shape[0]):\n",
    "        plt.imshow(out_mat[i,:,:].cpu().numpy())\n",
    "        plt.colorbar()\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_matrix_alignment(A,B):\n",
    "    Anorm = torch.linalg.norm(A, ord='fro')\n",
    "    Bnorm = torch.linalg.norm(B, ord='fro')\n",
    "    return torch.sum(A*B)/(Anorm*Bnorm)\n",
    "    \n",
    "def get_alignments_laplacian_M(dataset, hyperparams):\n",
    "    M = hyperparams['M']\n",
    "    L = hyperparams['L']\n",
    "    normalize = hyperparams['normalize']\n",
    "    reg = hyperparams['reg']\n",
    "    train_val_data = get_dataset_train_val_data(dataset)\n",
    "    X = train_val_data['X']\n",
    "    y = train_val_data['y']\n",
    "    train_fold = train_val_data['train_fold']\n",
    "    val_fold = train_val_data['val_fold']\n",
    "    c = train_val_data['c']\n",
    "    \n",
    "    sol = get_laplacian_M_sol(X[train_fold], y[train_fold], c, M, reg=reg, L=L, normalize=normalize, device=device)\n",
    "    agop = classic_kernel.laplacian_M_agop(torch.from_numpy(X[train_fold]).float().to(device), L, sol, M)\n",
    "    wagop = classic_kernel.clamped_laplacian_M_wagop(torch.from_numpy(X[train_fold]).float().to(device), L, sol, M)\n",
    "    print('AGOP alignment', get_matrix_alignment(M, agop))\n",
    "    print('sqrt(AGOP) alignment', get_matrix_alignment(M, utils.matrix_sqrt(agop)))\n",
    "    print('WAGOP alignment', get_matrix_alignment(M, wagop))\n",
    "    print('sqrt(WAGOP) alignment', get_matrix_alignment(M, utils.matrix_sqrt(wagop)))\n",
    "    plot_closeness_to_id(X[train_fold], sol)\n",
    "\n",
    "def get_alignments_laplacian_M_train_traj(dataset, hyperparams, train_traj):\n",
    "    M = hyperparams['M']\n",
    "    L = hyperparams['L']\n",
    "    normalize = hyperparams['normalize']\n",
    "    reg = hyperparams['reg']\n",
    "    train_val_data = get_dataset_train_val_data(dataset)\n",
    "    X = train_val_data['X']\n",
    "    y = train_val_data['y']\n",
    "    train_fold = train_val_data['train_fold']\n",
    "    val_fold = train_val_data['val_fold']\n",
    "    c = train_val_data['c']\n",
    "    \n",
    "    sol = get_laplacian_M_sol(X[train_fold], y[train_fold], c, M, reg=reg, L=L, normalize=normalize, device=device)\n",
    "    agop = classic_kernel.laplacian_M_agop(torch.from_numpy(X[train_fold]).float().to(device), L, sol, M)\n",
    "    wagop = classic_kernel.clamped_laplacian_M_wagop(torch.from_numpy(X[train_fold]).float().to(device), L, sol, M)\n",
    "    print('AGOP alignment', get_matrix_alignment(M, agop))\n",
    "    print('sqrt(AGOP) alignment', get_matrix_alignment(M, utils.matrix_sqrt(agop)))\n",
    "    print('WAGOP alignment', get_matrix_alignment(M, wagop))\n",
    "    print('sqrt(WAGOP) alignment', get_matrix_alignment(M, utils.matrix_sqrt(wagop)))\n",
    "    print()\n",
    "\n",
    "    alignments = {'agop_M' : [], 'sqrtagop_M' : [], 'wagop_M' : [], 'sqrtwagop_M' : [], 'agop_wagop' : [], 'sqrtagop_sqrtwagop' : []}\n",
    "    for sol, M in train_traj:\n",
    "        agop = classic_kernel.laplacian_M_agop(torch.from_numpy(X[train_fold]).float().to(device), L, sol, M)\n",
    "        wagop = classic_kernel.clamped_laplacian_M_wagop(torch.from_numpy(X[train_fold]).float().to(device), L, sol, M)\n",
    "        alignments['agop_M'].append(get_matrix_alignment(M, agop).item())\n",
    "        alignments['sqrtagop_M'].append(get_matrix_alignment(M, utils.matrix_sqrt(agop)).item())\n",
    "        alignments['wagop_M'].append(get_matrix_alignment(M, wagop).item())\n",
    "        alignments['sqrtwagop_M'].append(get_matrix_alignment(M, utils.matrix_sqrt(wagop)).item())\n",
    "        alignments['agop_wagop'].append(get_matrix_alignment(agop, wagop).item())\n",
    "        alignments['sqrtagop_sqrtwagop'].append(get_matrix_alignment(utils.matrix_sqrt(agop), utils.matrix_sqrt(wagop)).item())\n",
    "    for k in alignments.keys():\n",
    "        plt.plot(alignments[k], label=k)\n",
    "        plt.legend()\n",
    "    plt.show()\n",
    "        \n",
    "    # plot_closeness_to_id(X[train_fold], sol)\n",
    "\n",
    "\n",
    "# Compute sol on train\n",
    "\n",
    "# laplacian_M_agop(X_train, L, sol, M)\n",
    "\n",
    "import signal\n",
    "\n",
    "class timeout:\n",
    "    def __init__(self, seconds=1, error_message='Timeout'):\n",
    "        self.seconds = seconds\n",
    "        self.error_message = error_message\n",
    "    def handle_timeout(self, signum, frame):\n",
    "        raise TimeoutError(self.error_message)\n",
    "    def __enter__(self):\n",
    "        signal.signal(signal.SIGALRM, self.handle_timeout)\n",
    "        signal.alarm(self.seconds)\n",
    "    def __exit__(self, type, value, traceback):\n",
    "        signal.alarm(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting datasets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████| 122/122 [00:00<00:00, 4461.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found datasets: 120\n",
      "No cache found, creating new one\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "print('Collecting datasets')\n",
    "datasets = get_tabular_datasets() # n_tot_threshold=10000\n",
    "print('Found datasets:', len(datasets))\n",
    "\n",
    "\n",
    "cache_file = 'tabular_benchmarking_cache.pkl'\n",
    "if os.path.exists(cache_file):\n",
    "    print('Loading cache')\n",
    "    cached_results = pickle.load(open(cache_file, 'rb'))\n",
    "else:\n",
    "    print('No cache found, creating new one')\n",
    "    cached_results = {}\n",
    "    cached_results['fact_no_geom'] = {}\n",
    "    cached_results['fact_geom'] = {}\n",
    "    cached_results['agop'] = {}\n",
    "    cached_results['kernel'] = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Laplace sqrt(FACT FACT^T) updates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def wagop_rfm_laplace_hyperparam_select(X, y, train_fold, val_fold, c, device='cpu'):\n",
    "    \"\"\" Select hyperparameters for WAGOP_RFM via cross validation.\n",
    "    Note: pass train_fold and val_fold instead of X_train, y_train, X_val, y_val because do not want to overwrite accidentally\"\"\"\n",
    "    max_iter = 5\n",
    "    regs = [10, 1, .1, 1e-2, 1e-3]\n",
    "    normalize = [True, False]\n",
    "    L = 10\n",
    "\n",
    "    best_acc, best_reg, best_iter, best_M = -1, 0, 0, 0\n",
    "    best_normalize = False\n",
    "    best_train_traj = None\n",
    "    for reg in regs:\n",
    "        for n in normalize:\n",
    "            if dataset == 'balance-scale':\n",
    "                n = False\n",
    "            try:\n",
    "                acc, iter_v, M, train_traj = wagop_rfm_laplace.hyperparam_train(X[train_fold], y[train_fold], X[val_fold], y[val_fold], c,\n",
    "                                                                                iters=max_iter, reg=reg, L=L, normalize=n, device=device, return_train_traj=True)\n",
    "                if acc > best_acc:\n",
    "                    best_acc = acc\n",
    "                    best_reg = reg\n",
    "                    best_iter = iter_v\n",
    "                    best_M = M\n",
    "                    best_normalize = n\n",
    "                    best_train_traj = train_traj\n",
    "            except Exception as e:\n",
    "                print(e,'with reg', reg, 'normalize', n, 'dataset', dataset)\n",
    "                continue\n",
    "    return {'reg' : best_reg, 'iters' : best_iter, 'normalize' : best_normalize, 'M' : best_M, 'L' : L}, best_train_traj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing datasets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                      | 0/120 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "abalone \tN: 4177 \td: 8 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|█                                                             | 2/120 [00:01<01:02,  1.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.6597222222222222\n",
      "acute-inflammation \tN: 120 \td: 6 \tc: 2\n",
      "Training\n",
      "iters 0\n",
      "acc 1.0\n",
      "acute-nephritis \tN: 120 \td: 6 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|█▌                                                            | 3/120 [00:01<00:39,  2.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 1.0\n",
      "adult \tN: 48842 \td: 14 \tc: 2\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|██▌                                                           | 5/120 [00:31<14:55,  7.78s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.8524262899262899\n",
      "annealing \tN: 898 \td: 31 \tc: 5\n",
      "Training\n",
      "iters 1\n",
      "acc 0.9396984924623115\n",
      "arrhythmia \tN: 452 \td: 262 \tc: 13\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|███▌                                                          | 7/120 [00:31<06:39,  3.54s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 2\n",
      "acc 0.7676991150442478\n",
      "audiology-std \tN: 196 \td: 59 \tc: 18\n",
      "Training\n",
      "iters 4\n",
      "acc 0.9186046511627907\n",
      "balance-scale \tN: 625 \td: 4 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|████▋                                                         | 9/120 [00:32<03:11,  1.72s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.953525641025641\n",
      "balloons \tN: 16 \td: 4 \tc: 2\n",
      "Training\n",
      "iters 0\n",
      "acc 0.8125\n",
      "bank \tN: 4521 \td: 16 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|█████                                                        | 10/120 [00:32<02:38,  1.44s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 2\n",
      "acc 0.893141592920354\n",
      "blood \tN: 748 \td: 4 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|██████                                                       | 12/120 [00:33<01:22,  1.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.7847593582887701\n",
      "breast-cancer \tN: 286 \td: 9 \tc: 2\n",
      "Training\n",
      "iters 1\n",
      "acc 0.7535211267605635\n",
      "breast-cancer-wisc \tN: 699 \td: 9 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|███████                                                      | 14/120 [00:33<00:47,  2.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.9757142857142858\n",
      "breast-cancer-wisc-diag \tN: 569 \td: 30 \tc: 2\n",
      "Training\n",
      "iters 3\n",
      "acc 0.9841549295774648\n",
      "breast-cancer-wisc-prog \tN: 198 \td: 33 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|████████▏                                                    | 16/120 [00:33<00:28,  3.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.8163265306122449\n",
      "breast-tissue \tN: 106 \td: 9 \tc: 6\n",
      "Training\n",
      "iters 0\n",
      "acc 0.7307692307692308\n",
      "car \tN: 1728 \td: 6 \tc: 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|████████▋                                                    | 17/120 [00:34<00:28,  3.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 2\n",
      "acc 0.9878472222222221\n",
      "cardiotocography-10clases \tN: 2126 \td: 21 \tc: 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█████████▏                                                   | 18/120 [00:34<00:29,  3.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.8540489642184558\n",
      "cardiotocography-3clases \tN: 2126 \td: 21 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█████████▋                                                   | 19/120 [00:34<00:30,  3.28it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 2\n",
      "acc 0.9430320150659134\n",
      "chess-krvk \tN: 28056 \td: 6 \tc: 18\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|██████████▏                                                  | 20/120 [00:56<11:18,  6.79s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.7918448816652409\n",
      "chess-krvkp \tN: 3196 \td: 36 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|███████████▏                                                 | 22/120 [00:57<05:40,  3.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 3\n",
      "acc 0.9893617021276596\n",
      "congressional-voting \tN: 435 \td: 16 \tc: 2\n",
      "Training\n",
      "iters 0\n",
      "acc 0.6146788990825688\n",
      "conn-bench-sonar-mines-rocks \tN: 208 \td: 60 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|████████████▏                                                | 24/120 [00:57<02:49,  1.77s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 3\n",
      "acc 0.8413461538461539\n",
      "conn-bench-vowel-deterding \tN: 990 \td: 11 \tc: 11\n",
      "Training\n",
      "iters 0\n",
      "acc 0.9791666666666666\n",
      "connect-4 \tN: 67557 \td: 42 \tc: 2\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 21%|████████████▎                                              | 25/120 [03:50<1:24:22, 53.29s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.8725057729883356\n",
      "contrac \tN: 1473 \td: 9 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|█████████████▋                                               | 27/120 [03:51<40:37, 26.21s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 3\n",
      "acc 0.563858695652174\n",
      "credit-approval \tN: 690 \td: 15 \tc: 2\n",
      "Training\n",
      "iters 0\n",
      "acc 0.8735465116279069\n",
      "cylinder-bands \tN: 512 \td: 35 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██████████████▋                                              | 29/120 [03:51<19:34, 12.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.7734375\n",
      "dermatology \tN: 366 \td: 34 \tc: 6\n",
      "Training\n",
      "iters 2\n",
      "acc 0.9835164835164836\n",
      "echocardiogram \tN: 131 \td: 10 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|███████████████▊                                             | 31/120 [03:51<09:28,  6.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.8484848484848485\n",
      "ecoli \tN: 336 \td: 7 \tc: 8\n",
      "Training\n",
      "iters 0\n",
      "acc 0.875\n",
      "energy-y1 \tN: 768 \td: 8 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|████████████████▊                                            | 33/120 [03:52<04:39,  3.21s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.9609375\n",
      "energy-y2 \tN: 768 \td: 8 \tc: 3\n",
      "Training\n",
      "iters 1\n",
      "acc 0.9388020833333334\n",
      "fertility \tN: 100 \td: 9 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 29%|█████████████████▊                                           | 35/120 [03:52<02:18,  1.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.9199999999999999\n",
      "flags \tN: 194 \td: 28 \tc: 8\n",
      "Training\n",
      "iters 1\n",
      "acc 0.6145833333333334\n",
      "glass \tN: 214 \td: 9 \tc: 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 31%|██████████████████▊                                          | 37/120 [03:52<01:11,  1.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.7122641509433962\n",
      "haberman-survival \tN: 306 \td: 3 \tc: 2\n",
      "Training\n",
      "iters 1\n",
      "acc 0.7434210526315789\n",
      "hayes-roth \tN: 160 \td: 3 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███████████████████▊                                         | 39/120 [03:52<00:38,  2.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 2\n",
      "acc 0.8181818181818182\n",
      "heart-cleveland \tN: 303 \td: 13 \tc: 5\n",
      "Training\n",
      "iters 3\n",
      "acc 0.6348684210526316\n",
      "heart-hungarian \tN: 294 \td: 12 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 34%|████████████████████▊                                        | 41/120 [03:53<00:22,  3.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.8493150684931506\n",
      "heart-switzerland \tN: 123 \td: 12 \tc: 5\n",
      "Training\n",
      "iters 4\n",
      "acc 0.5\n",
      "heart-va \tN: 200 \td: 12 \tc: 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|█████████████████████▊                                       | 43/120 [03:53<00:15,  5.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 4\n",
      "acc 0.37000000000000005\n",
      "hepatitis \tN: 155 \td: 19 \tc: 2\n",
      "Training\n",
      "iters 1\n",
      "acc 0.8397435897435898\n",
      "hill-valley \tN: 1212 \td: 100 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|██████████████████████▉                                      | 45/120 [03:53<00:12,  5.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.8642384105960265\n",
      "horse-colic \tN: 368 \td: 25 \tc: 2\n",
      "Training\n",
      "iters 2\n",
      "acc 0.8266666666666667\n",
      "ilpd-indian-liver \tN: 583 \td: 9 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 39%|███████████████████████▉                                     | 47/120 [03:53<00:10,  6.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.7363013698630136\n",
      "image-segmentation \tN: 2310 \td: 18 \tc: 7\n",
      "Training\n",
      "iters 2\n",
      "acc 0.9230769230769231\n",
      "ionosphere \tN: 351 \td: 33 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 41%|████████████████████████▉                                    | 49/120 [03:54<00:09,  7.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.9517045454545455\n",
      "iris \tN: 150 \td: 4 \tc: 3\n",
      "Training\n",
      "iters 1\n",
      "acc 0.9729729729729729\n",
      "led-display \tN: 1000 \td: 7 \tc: 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|█████████████████████████▍                                   | 50/120 [03:54<00:10,  6.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.7469999999999999\n",
      "lenses \tN: 24 \td: 4 \tc: 3\n",
      "Training\n",
      "iters 0\n",
      "acc 0.8333333333333333\n",
      "letter \tN: 20000 \td: 16 \tc: 26\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|██████████████████████████▉                                  | 53/120 [04:05<02:17,  2.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.97445\n",
      "libras \tN: 360 \td: 90 \tc: 15\n",
      "Training\n",
      "iters 0\n",
      "acc 0.8305555555555556\n",
      "low-res-spect \tN: 531 \td: 100 \tc: 9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|███████████████████████████▉                                 | 55/120 [04:05<01:15,  1.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 4\n",
      "acc 0.9191729323308271\n",
      "lung-cancer \tN: 32 \td: 56 \tc: 3\n",
      "Training\n",
      "iters 1\n",
      "acc 0.84375\n",
      "lymphography \tN: 148 \td: 18 \tc: 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 47%|████████████████████████████▍                                | 56/120 [04:06<00:55,  1.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.9121621621621622\n",
      "magic \tN: 19020 \td: 10 \tc: 2\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|█████████████████████████████▍                               | 58/120 [04:16<02:36,  2.52s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.8668243953732913\n",
      "mammographic \tN: 961 \td: 5 \tc: 2\n",
      "Training\n",
      "iters 1\n",
      "acc 0.83125\n",
      "molec-biol-promoter \tN: 106 \td: 57 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 49%|█████████████████████████████▉                               | 59/120 [04:16<01:50,  1.81s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.9807692307692308\n",
      "molec-biol-splice \tN: 3190 \td: 60 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 51%|███████████████████████████████                              | 61/120 [04:17<01:02,  1.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 4\n",
      "acc 0.8826850690087829\n",
      "monks-1 \tN: 556 \td: 6 \tc: 2\n",
      "Training\n",
      "iters 2\n",
      "acc 0.9435483870967742\n",
      "monks-2 \tN: 601 \td: 6 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 52%|████████████████████████████████                             | 63/120 [04:17<00:32,  1.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 3\n",
      "acc 0.7976190476190476\n",
      "monks-3 \tN: 554 \td: 6 \tc: 2\n",
      "Training\n",
      "iters 3\n",
      "acc 0.9416666666666667\n",
      "mushroom \tN: 8124 \td: 21 \tc: 2\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|█████████████████████████████████                            | 65/120 [04:19<00:43,  1.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.9997538158542589\n",
      "musk-1 \tN: 476 \td: 166 \tc: 2\n",
      "Training\n",
      "iters 1\n",
      "acc 0.9180672268907563\n",
      "musk-2 \tN: 6598 \td: 166 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 55%|█████████████████████████████████▌                           | 66/120 [04:21<01:02,  1.17s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 4\n",
      "acc 0.992116434202547\n",
      "nursery \tN: 12960 \td: 8 \tc: 5\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|██████████████████████████████████                           | 67/120 [04:26<02:00,  2.28s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.9914351851851851\n",
      "oocytes_merluccius_nucleus_4d \tN: 1022 \td: 41 \tc: 2\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 57%|███████████████████████████████████                          | 69/120 [04:26<01:02,  1.22s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.8333333333333333\n",
      "oocytes_merluccius_states_2f \tN: 1022 \td: 25 \tc: 3\n",
      "Training\n",
      "iters 0\n",
      "acc 0.9284313725490195\n",
      "oocytes_trisopterus_nucleus_2f \tN: 912 \td: 25 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 59%|████████████████████████████████████                         | 71/120 [04:27<00:33,  1.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.8344298245614035\n",
      "oocytes_trisopterus_states_5b \tN: 912 \td: 32 \tc: 3\n",
      "Training\n",
      "iters 0\n",
      "acc 0.9385964912280702\n",
      "optical \tN: 5620 \td: 62 \tc: 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|████████████████████████████████████▌                        | 72/120 [04:28<00:34,  1.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.9877092050209205\n",
      "ozone \tN: 2536 \td: 72 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 61%|█████████████████████████████████████                        | 73/120 [04:28<00:29,  1.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.9716088328075709\n",
      "page-blocks \tN: 5473 \td: 10 \tc: 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|█████████████████████████████████████▌                       | 74/120 [04:29<00:35,  1.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.9649122807017545\n",
      "parkinsons \tN: 195 \td: 22 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████████████████████████████████████▏                      | 75/120 [04:29<00:26,  1.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 3\n",
      "acc 0.9387755102040816\n",
      "pendigits \tN: 10992 \td: 16 \tc: 10\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 64%|███████████████████████████████████████▏                     | 77/120 [04:31<00:31,  1.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.9966631073144688\n",
      "pima \tN: 768 \td: 8 \tc: 2\n",
      "Training\n",
      "iters 2\n",
      "acc 0.7708333333333334\n",
      "pittsburg-bridges-MATERIAL \tN: 106 \td: 7 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|████████████████████████████████████████▏                    | 79/120 [04:32<00:16,  2.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.9326923076923077\n",
      "pittsburg-bridges-REL-L \tN: 103 \td: 7 \tc: 3\n",
      "Training\n",
      "iters 1\n",
      "acc 0.7788461538461537\n",
      "pittsburg-bridges-SPAN \tN: 92 \td: 7 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 68%|█████████████████████████████████████████▏                   | 81/120 [04:32<00:09,  3.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.6739130434782608\n",
      "pittsburg-bridges-T-OR-D \tN: 102 \td: 7 \tc: 2\n",
      "Training\n",
      "iters 3\n",
      "acc 0.89\n",
      "pittsburg-bridges-TYPE \tN: 105 \td: 7 \tc: 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 69%|██████████████████████████████████████████▏                  | 83/120 [04:32<00:06,  5.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.7019230769230769\n",
      "planning \tN: 182 \td: 12 \tc: 2\n",
      "Training\n",
      "iters 0\n",
      "acc 0.7111111111111111\n",
      "plant-margin \tN: 1600 \td: 64 \tc: 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|██████████████████████████████████████████▋                  | 84/120 [04:32<00:07,  4.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.85\n",
      "plant-shape \tN: 1600 \td: 64 \tc: 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 71%|███████████████████████████████████████████▏                 | 85/120 [04:33<00:08,  4.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.7256250000000001\n",
      "plant-texture \tN: 1599 \td: 64 \tc: 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 72%|████████████████████████████████████████████▏                | 87/120 [04:33<00:07,  4.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.8443750000000001\n",
      "post-operative \tN: 90 \td: 8 \tc: 3\n",
      "Training\n",
      "iters 0\n",
      "acc 0.7272727272727273\n",
      "primary-tumor \tN: 330 \td: 17 \tc: 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 73%|████████████████████████████████████████████▋                | 88/120 [04:33<00:05,  5.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.5457317073170732\n",
      "ringnorm \tN: 7400 \td: 20 \tc: 2\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 75%|█████████████████████████████████████████████▊               | 90/120 [04:35<00:15,  1.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.9843243243243243\n",
      "seeds \tN: 210 \td: 7 \tc: 3\n",
      "Training\n",
      "iters 2\n",
      "acc 0.9278846153846154\n",
      "semeion \tN: 1593 \td: 256 \tc: 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 77%|██████████████████████████████████████████████▊              | 92/120 [04:36<00:10,  2.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.9667085427135678\n",
      "soybean \tN: 683 \td: 35 \tc: 18\n",
      "Training\n",
      "iters 2\n",
      "acc 0.9253246753246753\n",
      "spambase \tN: 4601 \td: 57 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|███████████████████████████████████████████████▊             | 94/120 [04:37<00:10,  2.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.9478260869565217\n",
      "spect \tN: 265 \td: 22 \tc: 2\n",
      "Training\n",
      "iters 0\n",
      "acc 0.7250000000000001\n",
      "spectf \tN: 267 \td: 44 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████████████████████████████████████████████▊            | 96/120 [04:37<00:06,  3.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.85\n",
      "statlog-australian-credit \tN: 690 \td: 14 \tc: 2\n",
      "Training\n",
      "iters 0\n",
      "acc 0.6802325581395349\n",
      "statlog-german-credit \tN: 1000 \td: 24 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 82%|█████████████████████████████████████████████████▊           | 98/120 [04:37<00:04,  4.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 3\n",
      "acc 0.804\n",
      "statlog-heart \tN: 270 \td: 13 \tc: 2\n",
      "Training\n",
      "iters 0\n",
      "acc 0.8917910447761194\n",
      "statlog-image \tN: 2310 \td: 18 \tc: 7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 82%|██████████████████████████████████████████████████▎          | 99/120 [04:38<00:05,  3.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 3\n",
      "acc 0.983102253032929\n",
      "statlog-landsat \tN: 6435 \td: 36 \tc: 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 83%|██████████████████████████████████████████████████          | 100/120 [04:38<00:08,  2.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.9206492335437331\n",
      "statlog-shuttle \tN: 58000 \td: 9 \tc: 7\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 85%|███████████████████████████████████████████████████         | 102/120 [05:37<03:46, 12.59s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 1\n",
      "acc 0.9990574712643678\n",
      "statlog-vehicle \tN: 846 \td: 18 \tc: 4\n",
      "Training\n",
      "iters 4\n",
      "acc 0.8009478672985781\n",
      "steel-plates \tN: 1941 \td: 27 \tc: 7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 87%|████████████████████████████████████████████████████        | 104/120 [05:38<01:40,  6.28s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.7860824742268041\n",
      "synthetic-control \tN: 600 \td: 60 \tc: 6\n",
      "Training\n",
      "iters 1\n",
      "acc 0.9966666666666666\n",
      "teaching \tN: 151 \td: 5 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|████████████████████████████████████████████████████▌       | 105/120 [05:38<01:06,  4.43s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.611842105263158\n",
      "thyroid \tN: 7200 \td: 21 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|█████████████████████████████████████████████████████       | 106/120 [05:38<00:46,  3.30s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.9790562036055144\n",
      "tic-tac-toe \tN: 958 \td: 9 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 89%|█████████████████████████████████████████████████████▌      | 107/120 [05:39<00:30,  2.36s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.9769874476987448\n",
      "titanic \tN: 2201 \td: 3 \tc: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|██████████████████████████████████████████████████████      | 108/120 [05:39<00:20,  1.75s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.7881818181818182\n",
      "trains \tN: 10 \td: 29 \tc: 2\n",
      "Training\n",
      "iters 0\n",
      "acc 0.875\n",
      "twonorm \tN: 7400 \td: 20 \tc: 2\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|███████████████████████████████████████████████████████▌    | 111/120 [05:41<00:09,  1.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.9787837837837838\n",
      "vertebral-column-2clases \tN: 310 \td: 6 \tc: 2\n",
      "Training\n",
      "iters 1\n",
      "acc 0.8441558441558441\n",
      "vertebral-column-3clases \tN: 310 \td: 6 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 93%|████████████████████████████████████████████████████████    | 112/120 [05:41<00:06,  1.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 2\n",
      "acc 0.8441558441558442\n",
      "wall-following \tN: 5456 \td: 24 \tc: 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 94%|████████████████████████████████████████████████████████▌   | 113/120 [05:42<00:06,  1.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.9285190615835777\n",
      "waveform \tN: 5000 \td: 21 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 95%|█████████████████████████████████████████████████████████   | 114/120 [05:43<00:05,  1.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.8696\n",
      "waveform-noise \tN: 5000 \td: 40 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████████████████████████████████████████████████████▌  | 115/120 [05:44<00:04,  1.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 1\n",
      "acc 0.8734000000000001\n",
      "wine \tN: 178 \td: 13 \tc: 3\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 97%|██████████████████████████████████████████████████████████  | 116/120 [05:44<00:02,  1.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.9886363636363636\n",
      "wine-quality-red \tN: 1599 \td: 11 \tc: 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|██████████████████████████████████████████████████████████▌ | 117/120 [05:45<00:01,  1.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.6775\n",
      "wine-quality-white \tN: 4898 \td: 11 \tc: 7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|███████████████████████████████████████████████████████████ | 118/120 [05:46<00:01,  1.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.6817810457516339\n",
      "yeast \tN: 1484 \td: 8 \tc: 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████| 120/120 [05:46<00:00,  2.89s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.6138814016172507\n",
      "zoo \tN: 101 \td: 16 \tc: 7\n",
      "Training\n",
      "iters 0\n",
      "acc 0.98\n",
      "avg_acc: 85.22111606985257\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "print('Testing datasets')\n",
    "acc_list = []\n",
    "method_type = 'fact_no_geom'\n",
    "for dataset in tqdm(datasets):\n",
    "    if dataset in cached_results[method_type]:\n",
    "        acc = cached_results[method_type][dataset]\n",
    "    else:\n",
    "        print_dataset_info(dataset)\n",
    "        hyperparams, train_traj = hyperparam_select_dataset(dataset, wagop_rfm_laplace_hyperparam_select, device=device)\n",
    "        acc = test_dataset(dataset, hyperparams, wagop_rfm_laplace.train, device=device)\n",
    "        cached_results[method_type][dataset] = acc\n",
    "        with open(cache_file, 'wb') as f:\n",
    "            pickle.dump(cached_results, f)\n",
    "        print('iters',hyperparams['iters'])\n",
    "        M = hyperparams['M']\n",
    "    acc_list.append(acc)\n",
    "    print('acc',acc)\n",
    "\n",
    "    # get_alignments_laplacian_M_train_traj(dataset, hyperparams, train_traj)\n",
    "    # get_alignments_laplacian_M(dataset, hyperparams)\n",
    "\n",
    "print (\"avg_acc:\", np.mean(acc_list) * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Laplace ((FACT M^T)(M FACT^T))^{1/4} updates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fact_rfm_laplace_hyperparam_select(X, y, train_fold, val_fold, c, device='cpu'):\n",
    "    \"\"\" Select hyperparameters for WAGOP_RFM via cross validation.\n",
    "    Note: pass train_fold and val_fold instead of X_train, y_train, X_val, y_val because do not want to overwrite accidentally\"\"\"\n",
    "    max_iter = 5\n",
    "    regs = [10, 1, .1, 1e-2, 1e-3]\n",
    "    normalize = [True, False]\n",
    "    L = 10\n",
    "\n",
    "    best_acc, best_reg, best_iter, best_M = -1, 0, 0, 0\n",
    "    best_normalize = False\n",
    "    best_train_traj = None\n",
    "    for reg in regs:\n",
    "        for n in normalize:\n",
    "            print('reg', reg, 'normalize', n)\n",
    "            if dataset == 'balance-scale':\n",
    "                n = False\n",
    "            try:\n",
    "                with timeout(seconds=240):\n",
    "                    acc, iter_v, M, train_traj = fact_rfm_laplace.hyperparam_train(X[train_fold], y[train_fold], X[val_fold], y[val_fold], c,\n",
    "                                                                                   iters=max_iter, reg=reg, L=L, normalize=n, device=device, return_train_traj=True)\n",
    "                if acc > best_acc:\n",
    "                    best_acc = acc\n",
    "                    best_reg = reg\n",
    "                    best_iter = iter_v\n",
    "                    best_M = M\n",
    "                    best_normalize = n\n",
    "                    best_train_traj = train_traj\n",
    "            except Exception as e:\n",
    "                print(e,'with reg', reg, 'normalize', n, 'dataset', dataset)\n",
    "                continue\n",
    "    return {'reg' : best_reg, 'iters' : best_iter, 'normalize' : best_normalize, 'M' : best_M, 'L' : L}, best_train_traj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing datasets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████| 120/120 [00:00<00:00, 111897.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc 0.6597222222222222\n",
      "acc 1.0\n",
      "acc 1.0\n",
      "acc 0.8527334152334153\n",
      "acc 0.9447236180904522\n",
      "acc 0.7654867256637168\n",
      "acc 0.8953488372093023\n",
      "acc 0.9374999999999999\n",
      "acc 0.8125\n",
      "acc 0.8995575221238938\n",
      "acc 0.7847593582887701\n",
      "acc 0.7253521126760563\n",
      "acc 0.9728571428571429\n",
      "acc 0.9806338028169015\n",
      "acc 0.8163265306122449\n",
      "acc 0.7307692307692308\n",
      "acc 0.9907407407407407\n",
      "acc 0.8648775894538607\n",
      "acc 0.9416195856873824\n",
      "acc 0.8218562874251497\n",
      "acc 0.9943679599499373\n",
      "acc 0.6146788990825688\n",
      "acc 0.8894230769230769\n",
      "acc 0.9791666666666666\n",
      "acc 0.8881816566996269\n",
      "acc 0.5081521739130435\n",
      "acc 0.8735465116279069\n",
      "acc 0.7734375\n",
      "acc 0.9780219780219781\n",
      "acc 0.8484848484848485\n",
      "acc 0.875\n",
      "acc 0.9609375\n",
      "acc 0.9296875\n",
      "acc 0.89\n",
      "acc 0.5729166666666666\n",
      "acc 0.7122641509433962\n",
      "acc 0.7368421052631579\n",
      "acc 0.8409090909090909\n",
      "acc 0.5723684210526315\n",
      "acc 0.8561643835616438\n",
      "acc 0.4596774193548387\n",
      "acc 0.37\n",
      "acc 0.8333333333333334\n",
      "acc 0.8642384105960265\n",
      "acc 0.83\n",
      "acc 0.7123287671232876\n",
      "acc 0.9326923076923077\n",
      "acc 0.9517045454545455\n",
      "acc 0.9527027027027027\n",
      "acc 0.742\n",
      "acc 0.8333333333333333\n",
      "acc 0.97445\n",
      "acc 0.8305555555555556\n",
      "acc 0.9210526315789473\n",
      "acc 0.75\n",
      "acc 0.8851351351351352\n",
      "acc 0.8696109358569927\n",
      "acc 0.8333333333333333\n",
      "acc 0.9230769230769231\n",
      "acc 0.9058971141781682\n",
      "acc 0.9596774193548387\n",
      "acc 0.7738095238095237\n",
      "acc 0.925\n",
      "acc 0.9997538158542589\n",
      "acc 0.9390756302521008\n",
      "acc 0.9946937537901759\n",
      "acc 0.9962962962962962\n",
      "acc 0.8333333333333333\n",
      "acc 0.934313725490196\n",
      "acc 0.8366228070175439\n",
      "acc 0.9375\n",
      "acc 0.9895397489539749\n",
      "acc 0.9716088328075709\n",
      "acc 0.9685672514619883\n",
      "acc 0.9489795918367346\n",
      "acc 0.9961292044847838\n",
      "acc 0.7721354166666666\n",
      "acc 0.9326923076923077\n",
      "acc 0.7307692307692308\n",
      "acc 0.6956521739130436\n",
      "acc 0.89\n",
      "acc 0.7019230769230769\n",
      "acc 0.7111111111111111\n",
      "acc 0.8724999999999999\n",
      "acc 0.75125\n",
      "acc 0.8525\n",
      "acc 0.7272727272727273\n",
      "acc 0.5548780487804879\n",
      "acc 0.9843243243243243\n",
      "acc 0.9086538461538461\n",
      "acc 0.9736180904522613\n",
      "acc 0.9318181818181819\n",
      "acc 0.9539130434782609\n",
      "acc 0.7250000000000001\n",
      "acc 0.725\n",
      "acc 0.6802325581395349\n",
      "acc 0.8019999999999999\n",
      "acc 0.8917910447761194\n",
      "acc 0.9748700173310225\n",
      "acc 0.9206492335437331\n",
      "acc 0.9991264367816092\n",
      "acc 0.8151658767772512\n",
      "acc 0.788659793814433\n",
      "acc 1.0\n",
      "acc 0.618421052631579\n",
      "acc 0.9779957582184517\n",
      "acc 0.9759414225941423\n",
      "acc 0.7881818181818182\n",
      "acc 1.0\n",
      "acc 0.9787837837837838\n",
      "acc 0.8538961038961039\n",
      "acc 0.8376623376623377\n",
      "acc 0.9332844574780058\n",
      "acc 0.8654000000000001\n",
      "acc 0.8672\n",
      "acc 0.9886363636363636\n",
      "acc 0.6775\n",
      "acc 0.6817810457516339\n",
      "acc 0.6206199460916442\n",
      "acc 0.98\n",
      "avg_acc: 84.98729152094214\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "print('Testing datasets')\n",
    "acc_list = []\n",
    "method_type = 'fact_geom'\n",
    "for dataset in tqdm(datasets):\n",
    "    if dataset in cached_results[method_type]:\n",
    "        acc = cached_results[method_type][dataset]\n",
    "    else:\n",
    "        print_dataset_info(dataset)\n",
    "        hyperparams, train_traj = hyperparam_select_dataset(dataset, fact_rfm_laplace_hyperparam_select, device=device)\n",
    "        acc = test_dataset(dataset, hyperparams, fact_rfm_laplace.train, device=device)\n",
    "        cached_results[method_type][dataset] = acc\n",
    "        with open(cache_file, 'wb') as f:\n",
    "            pickle.dump(cached_results, f)\n",
    "        print('iters',hyperparams['iters'])\n",
    "        M = hyperparams['M']\n",
    "    acc_list.append(acc)\n",
    "    print('acc',acc)\n",
    "\n",
    "print (\"avg_acc:\", np.mean(acc_list) * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Laplace AGOP updates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rfm_hyperparam_select(X, y, train_fold, val_fold, c, device='cpu'):\n",
    "    \"\"\" Select hyperparameters for WAGOP_RFM via cross validation.\n",
    "    Note: pass train_fold and val_fold instead of X_train, y_train, X_val, y_val because do not want to overwrite accidentally\"\"\"\n",
    "    max_iter = 5\n",
    "    regs = [10, 1, .1, 1e-2, 1e-3]\n",
    "    normalize = [True, False]\n",
    "    L = 10\n",
    "\n",
    "    best_acc, best_reg, best_iter, best_M = -1, 0, 0, 0\n",
    "    best_train_traj = None\n",
    "    best_normalize = False\n",
    "    for reg in regs:\n",
    "        for n in normalize:\n",
    "            if dataset == 'balance-scale':\n",
    "                n = False\n",
    "            try:\n",
    "                acc, iter_v, M, train_traj = rfm.hyperparam_train_gpu(X[train_fold], y[train_fold], X[val_fold], y[val_fold], c, iters=max_iter, reg=reg, L=L, normalize=n, device=device, return_train_traj=True)\n",
    "                if acc > best_acc:\n",
    "                    best_acc = acc\n",
    "                    best_reg = reg\n",
    "                    best_iter = iter_v\n",
    "                    best_M = M\n",
    "                    best_normalize = n\n",
    "                    best_train_traj = train_traj\n",
    "            except Exception as e:\n",
    "                print(e,'with reg', reg, 'normalize', n, 'dataset', dataset)\n",
    "                continue\n",
    "    return {'reg' : best_reg, 'iters' : best_iter, 'normalize' : best_normalize, 'M' : best_M, 'L' : L}, best_train_traj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing datasets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████| 120/120 [00:00<00:00, 105142.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc 0.6597222222222222\n",
      "acc 1.0\n",
      "acc 1.0\n",
      "acc 0.8516891891891891\n",
      "acc 0.9472361809045227\n",
      "acc 0.7676991150442478\n",
      "acc 0.9011627906976744\n",
      "acc 0.9727564102564104\n",
      "acc 0.8125\n",
      "acc 0.9019911504424779\n",
      "acc 0.7941176470588236\n",
      "acc 0.7253521126760563\n",
      "acc 0.9742857142857143\n",
      "acc 0.9788732394366197\n",
      "acc 0.8163265306122449\n",
      "acc 0.7307692307692308\n",
      "acc 0.9855324074074074\n",
      "acc 0.8742937853107345\n",
      "acc 0.9364406779661018\n",
      "acc 0.8089891645280867\n",
      "acc 0.9931163954943679\n",
      "acc 0.6146788990825688\n",
      "acc 0.8605769230769231\n",
      "acc 0.9791666666666666\n",
      "acc 0.8789744804310498\n",
      "acc 0.5550271739130435\n",
      "acc 0.8880813953488371\n",
      "acc 0.7734375\n",
      "acc 1.0\n",
      "acc 0.8484848484848485\n",
      "acc 0.875\n",
      "acc 0.9661458333333334\n",
      "acc 0.9270833333333334\n",
      "acc 0.89\n",
      "acc 0.6197916666666666\n",
      "acc 0.7122641509433962\n",
      "acc 0.7368421052631579\n",
      "acc 0.8409090909090909\n",
      "acc 0.6151315789473684\n",
      "acc 0.8561643835616438\n",
      "acc 0.45161290322580644\n",
      "acc 0.385\n",
      "acc 0.8333333333333334\n",
      "acc 0.8013245033112583\n",
      "acc 0.8266666666666667\n",
      "acc 0.714041095890411\n",
      "acc 0.9182692307692308\n",
      "acc 0.9517045454545455\n",
      "acc 0.9797297297297297\n",
      "acc 0.7469999999999999\n",
      "acc 0.8333333333333333\n",
      "acc 0.98175\n",
      "acc 0.8305555555555556\n",
      "acc 0.9248120300751879\n",
      "acc 0.65625\n",
      "acc 0.8783783783783785\n",
      "acc 0.87602523659306\n",
      "acc 0.8104166666666667\n",
      "acc 0.9807692307692308\n",
      "acc 0.8801756587202008\n",
      "acc 0.9838709677419355\n",
      "acc 0.7678571428571429\n",
      "acc 0.9333333333333333\n",
      "acc 0.9997538158542589\n",
      "acc 0.8886554621848739\n",
      "acc 0.9934808975136447\n",
      "acc 0.9939043209876542\n",
      "acc 0.8235294117647058\n",
      "acc 0.9333333333333332\n",
      "acc 0.8245614035087719\n",
      "acc 0.9385964912280702\n",
      "acc 0.9877092050209205\n",
      "acc 0.9716088328075709\n",
      "acc 0.970577485380117\n",
      "acc 0.9285714285714285\n",
      "acc 0.9974639615589962\n",
      "acc 0.7799479166666666\n",
      "acc 0.9326923076923077\n",
      "acc 0.7403846153846154\n",
      "acc 0.6847826086956522\n",
      "acc 0.9\n",
      "acc 0.7019230769230769\n",
      "acc 0.7111111111111111\n",
      "acc 0.890625\n",
      "acc 0.74625\n",
      "acc 0.884375\n",
      "acc 0.7272727272727273\n",
      "acc 0.5457317073170732\n",
      "acc 0.9843243243243243\n",
      "acc 0.9471153846153847\n",
      "acc 0.9673366834170855\n",
      "acc 0.9318181818181819\n",
      "acc 0.9515217391304348\n",
      "acc 0.7250000000000001\n",
      "acc 0.875\n",
      "acc 0.6802325581395349\n",
      "acc 0.7890000000000001\n",
      "acc 0.8917910447761194\n",
      "acc 0.9826689774696707\n",
      "acc 0.9206492335437331\n",
      "acc 0.9995402298850574\n",
      "acc 0.8518957345971564\n",
      "acc 0.790721649484536\n",
      "acc 0.9966666666666666\n",
      "acc 0.611842105263158\n",
      "acc 0.9848886532343585\n",
      "acc 0.9769874476987448\n",
      "acc 0.7881818181818182\n",
      "acc 0.875\n",
      "acc 0.9787837837837838\n",
      "acc 0.8474025974025974\n",
      "acc 0.8376623376623378\n",
      "acc 0.9375\n",
      "acc 0.8712\n",
      "acc 0.8728\n",
      "acc 0.9886363636363636\n",
      "acc 0.6775\n",
      "acc 0.6817810457516339\n",
      "acc 0.6138814016172507\n",
      "acc 0.98\n",
      "avg_acc: 85.10249634962214\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "print('Testing datasets')\n",
    "acc_list = []\n",
    "method_type = 'agop'\n",
    "for dataset in tqdm(datasets):\n",
    "    if dataset in cached_results[method_type]:\n",
    "        acc = cached_results[method_type][dataset]\n",
    "    else:\n",
    "        print_dataset_info(dataset)\n",
    "        hyperparams, train_traj = hyperparam_select_dataset(dataset, rfm_hyperparam_select, device=device)\n",
    "        acc = test_dataset(dataset, hyperparams, rfm.train_gpu, device=device)\n",
    "        cached_results[method_type][dataset] = acc\n",
    "        with open(cache_file, 'wb') as f:\n",
    "            pickle.dump(cached_results, f)\n",
    "        print('iters',hyperparams['iters'])\n",
    "        M = hyperparams['M']\n",
    "    acc_list.append(acc)\n",
    "    print('acc',acc)\n",
    "\n",
    "print (\"avg_acc:\", np.mean(acc_list) * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Laplace (no adaptivity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nonadaptive_rfm_laplace_hyperparam_select(X, y, train_fold, val_fold, c, device='cpu'):\n",
    "    \"\"\" Select hyperparameters for WAGOP_RFM via cross validation.\n",
    "    Note: pass train_fold and val_fold instead of X_train, y_train, X_val, y_val because do not want to overwrite accidentally\"\"\"\n",
    "    max_iter = 1\n",
    "    regs = [10, 1, .1, 1e-2, 1e-3]\n",
    "    normalize = [True, False]\n",
    "    L = 10\n",
    "\n",
    "    best_acc, best_reg, best_iter, best_M = -1, 0, 0, 0\n",
    "    best_normalize = False\n",
    "    best_train_traj = None\n",
    "    for reg in regs:\n",
    "        for n in normalize:\n",
    "            if dataset == 'balance-scale':\n",
    "                n = False\n",
    "            try:\n",
    "                acc, iter_v, M, train_traj = wagop_rfm_laplace.hyperparam_train(X[train_fold], y[train_fold], X[val_fold], y[val_fold], c,\n",
    "                                                                                iters=max_iter, reg=reg, L=L, normalize=n, device=device, return_train_traj=True)\n",
    "                if acc > best_acc:\n",
    "                    best_acc = acc\n",
    "                    best_reg = reg\n",
    "                    best_iter = iter_v\n",
    "                    best_M = M\n",
    "                    best_normalize = n\n",
    "                    best_train_traj = train_traj\n",
    "            except Exception as e:\n",
    "                print(e,'with reg', reg, 'normalize', n, 'dataset', dataset)\n",
    "                continue\n",
    "    return {'reg' : best_reg, 'iters' : best_iter, 'normalize' : best_normalize, 'M' : best_M, 'L' : L}, best_train_traj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing datasets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                      | 0/120 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc 0.6597222222222222\n",
      "acc 1.0\n",
      "acc 1.0\n",
      "acc 0.8524262899262899\n",
      "acc 0.928391959798995\n",
      "acc 0.7123893805309734\n",
      "acc 0.8023255813953489\n",
      "acc 0.9150641025641025\n",
      "acc 0.8125\n",
      "acc 0.8988938053097345\n",
      "acc 0.7847593582887701\n",
      "acc 0.7253521126760563\n",
      "acc 0.9757142857142858\n",
      "acc 0.9700704225352113\n",
      "acc 0.8163265306122449\n",
      "acc 0.7307692307692308\n",
      "acc 0.9803240740740742\n",
      "acc 0.8465160075329566\n",
      "acc 0.9293785310734464\n",
      "acc 0.7918448816652409\n",
      "acc 0.992177722152691\n",
      "acc 0.6146788990825688\n",
      "acc 0.8605769230769231\n",
      "acc 0.9791666666666666\n",
      "acc 0.8725057729883356\n",
      "acc 0.5108695652173914\n",
      "acc 0.8735465116279069\n",
      "acc 0.7734375\n",
      "acc 0.978021978021978\n",
      "acc 0.8484848484848485\n",
      "acc 0.875\n",
      "acc 0.9296875\n",
      "acc 0.8893229166666666\n",
      "acc 0.89\n",
      "acc 0.5052083333333334\n",
      "acc 0.7122641509433962\n",
      "acc 0.7368421052631579\n",
      "acc 0.8409090909090909\n",
      "acc 0.5723684210526315\n",
      "acc 0.8561643835616438\n",
      "acc 0.4838709677419355\n",
      "acc 0.37\n",
      "acc 0.8333333333333334\n",
      "acc 0.794701986754967\n",
      "acc 0.8200000000000001\n",
      "acc 0.7363013698630136\n",
      "acc 0.875\n",
      "acc 0.9517045454545455\n",
      "acc 0.9527027027027027\n",
      "acc 0.7469999999999999\n",
      "acc 0.8333333333333333\n",
      "acc 0.97445\n",
      "acc 0.8305555555555556\n",
      "acc 0.9116541353383458\n",
      "acc 0.46875\n",
      "acc 0.8783783783783785\n",
      "acc 0.8668243953732913\n",
      "acc 0.8\n",
      "acc 0.8942307692307693\n",
      "acc 0.8695106649937264\n",
      "acc 0.8467741935483871\n",
      "acc 0.6964285714285714\n",
      "acc 0.9083333333333334\n",
      "acc 0.9997538158542589\n",
      "acc 0.8886554621848739\n",
      "acc 0.9865069739235901\n",
      "acc 0.9914351851851851\n",
      "acc 0.8333333333333333\n",
      "acc 0.9284313725490195\n",
      "acc 0.8344298245614035\n",
      "acc 0.9385964912280702\n",
      "acc 0.9877092050209205\n",
      "acc 0.9716088328075709\n",
      "acc 0.9649122807017545\n",
      "acc 0.9387755102040817\n",
      "acc 0.9966631073144688\n",
      "acc 0.7630208333333334\n",
      "acc 0.9326923076923077\n",
      "acc 0.7403846153846154\n",
      "acc 0.6739130434782608\n",
      "acc 0.88\n",
      "acc 0.7019230769230769\n",
      "acc 0.7111111111111111\n",
      "acc 0.85\n",
      "acc 0.7256250000000001\n",
      "acc 0.8443750000000001\n",
      "acc 0.7272727272727273\n",
      "acc 0.5457317073170732\n",
      "acc 0.9843243243243243\n",
      "acc 0.9086538461538461\n",
      "acc 0.9528894472361809\n",
      "acc 0.9253246753246753\n",
      "acc 0.9454347826086956\n",
      "acc 0.7250000000000001\n",
      "acc 0.8\n",
      "acc 0.6802325581395349\n",
      "acc 0.774\n",
      "acc 0.8917910447761194\n",
      "acc 0.9705372616984402\n",
      "acc 0.9206492335437331\n",
      "statlog-shuttle \tN: 58000 \td: 9 \tc: 7\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 86%|███████████████████████████████████████████████████▌        | 103/120 [00:19<00:03,  5.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.9990804597701151\n",
      "statlog-vehicle \tN: 846 \td: 18 \tc: 4\n",
      "Training\n",
      "iters 0\n",
      "acc 0.7843601895734598\n",
      "steel-plates \tN: 1941 \td: 27 \tc: 7\n",
      "Training\n",
      "iters 0\n",
      "acc 0.7716494845360825\n",
      "synthetic-control \tN: 600 \td: 60 \tc: 6\n",
      "Training\n",
      "iters 0\n",
      "acc 0.9966666666666666\n",
      "teaching \tN: 151 \td: 5 \tc: 3\n",
      "Training\n",
      "iters 0\n",
      "acc 0.611842105263158\n",
      "thyroid \tN: 7200 \td: 21 \tc: 3\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|██████████████████████████████████████████████████████      | 108/120 [00:20<00:02,  5.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.965270413573701\n",
      "tic-tac-toe \tN: 958 \td: 9 \tc: 2\n",
      "Training\n",
      "iters 0\n",
      "acc 0.9769874476987448\n",
      "titanic \tN: 2201 \td: 3 \tc: 2\n",
      "Training\n",
      "iters 0\n",
      "acc 0.7881818181818182\n",
      "trains \tN: 10 \td: 29 \tc: 2\n",
      "Training\n",
      "iters 0\n",
      "acc 0.875\n",
      "twonorm \tN: 7400 \td: 20 \tc: 2\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|███████████████████████████████████████████████████████     | 110/120 [00:21<00:01,  5.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.9787837837837838\n",
      "vertebral-column-2clases \tN: 310 \td: 6 \tc: 2\n",
      "Training\n",
      "iters 0\n",
      "acc 0.8474025974025974\n",
      "vertebral-column-3clases \tN: 310 \td: 6 \tc: 3\n",
      "Training\n",
      "iters 0\n",
      "acc 0.8409090909090908\n",
      "wall-following \tN: 5456 \td: 24 \tc: 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 94%|████████████████████████████████████████████████████████▌   | 113/120 [00:21<00:01,  5.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.9178885630498533\n",
      "waveform \tN: 5000 \td: 21 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 95%|█████████████████████████████████████████████████████████   | 114/120 [00:21<00:01,  5.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.8682000000000001\n",
      "waveform-noise \tN: 5000 \td: 40 \tc: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████████████████████████████████████████████████████▌  | 115/120 [00:22<00:01,  4.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n",
      "iters 0\n",
      "acc 0.866\n",
      "wine \tN: 178 \td: 13 \tc: 3\n",
      "Training\n",
      "iters 0\n",
      "acc 0.9886363636363636\n",
      "wine-quality-red \tN: 1599 \td: 11 \tc: 6\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|██████████████████████████████████████████████████████████▌ | 117/120 [00:22<00:00,  5.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.6775\n",
      "wine-quality-white \tN: 4898 \td: 11 \tc: 7\n",
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████| 120/120 [00:22<00:00,  5.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iters 0\n",
      "acc 0.6817810457516339\n",
      "yeast \tN: 1484 \td: 8 \tc: 10\n",
      "Training\n",
      "iters 0\n",
      "acc 0.6138814016172507\n",
      "zoo \tN: 101 \td: 16 \tc: 7\n",
      "Training\n",
      "iters 0\n",
      "acc 0.98\n",
      "avg_acc: 83.71129977058625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "print('Testing datasets')\n",
    "acc_list = []\n",
    "method_type = 'kernel'\n",
    "for dataset in tqdm(datasets):\n",
    "    if dataset in cached_results[method_type]:\n",
    "        acc = cached_results[method_type][dataset]\n",
    "    else:\n",
    "        print_dataset_info(dataset)\n",
    "        hyperparams, train_traj = hyperparam_select_dataset(dataset, nonadaptive_rfm_laplace_hyperparam_select, device=device)\n",
    "        acc = test_dataset(dataset, hyperparams, wagop_rfm_laplace.train, device=device)\n",
    "        cached_results[method_type][dataset] = acc\n",
    "        with open(cache_file, 'wb') as f:\n",
    "            pickle.dump(cached_results, f)\n",
    "        print('iters',hyperparams['iters'])\n",
    "        M = hyperparams['M']\n",
    "    acc_list.append(acc)\n",
    "    print('acc',acc)\n",
    "\n",
    "print (\"avg_acc:\", np.mean(acc_list) * 100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
