{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "\n",
    "from train import train\n",
    "import priors\n",
    "import utils\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from datasets import load_openml_list, valid_dids_classification, test_dids_classification\n",
    "from tabular import evaluate, get_model, get_default_spec\n",
    "\n",
    "from tabular import bayes_net_metric, gp_metric, knn_metric, ridge_metric, catboost_metric, xgb_metric, logistic_metric, tabnet_metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading test datasets...\n",
      "wine 973\n",
      "covertype 1596\n",
      "\n",
      " Loading valid datasets...\n",
      "ionosphere 59\n"
     ]
    }
   ],
   "source": [
    "### Loads small list of datasets\n",
    "print('Loading test datasets...')\n",
    "test_datasets, test_datasets_df = load_openml_list(test_dids_classification[0:2], filter_for_nan=True)\n",
    "ds = test_datasets\n",
    "\n",
    "print('\\n Loading valid datasets...')\n",
    "valid_datasets, valid_datasets_df = load_openml_list(valid_dids_classification[0:2], filter_for_nan=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading test datasets...\n",
      "kr-vs-kp 3\n",
      "credit-g 31\n",
      "vehicle 54\n",
      "wine 973\n",
      "kc1 1067\n",
      "airlines 1169\n",
      "bank-marketing 1461\n",
      "blood-transfusion-service-center 1464\n",
      "phoneme 1489\n",
      "covertype 1596\n",
      "numerai28.6 23517\n",
      "connect-4 40668\n",
      "car 40975\n",
      "Australian 40981\n",
      "segment 40984\n",
      "jungle_chess_2pcs_raw_endgame_complete 41027\n",
      "sylvine 41146\n",
      "MiniBooNE 41150\n",
      "dionis 41167\n",
      "jannis 41168\n",
      "helena 41169\n",
      "\n",
      " Loading valid datasets...\n",
      "haberman 43\n",
      "ionosphere 59\n",
      "sa-heart 1498\n",
      "cleve 40710\n"
     ]
    }
   ],
   "source": [
    "### Loads all datasets\n",
    "print('Loading test datasets...')\n",
    "test_datasets, test_datasets_df = load_openml_list(test_dids_classification, filter_for_nan=True)\n",
    "ds = test_datasets\n",
    "\n",
    "print('\\n Loading valid datasets...')\n",
    "valid_datasets, valid_datasets_df = load_openml_list(valid_dids_classification, filter_for_nan=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# After how many training samples should evaluatuion be done?\n",
    "# Trained models have not been trained to evaluate after 30 samples\n",
    "# so performance will drop\n",
    "eval_positions = [30]\n",
    "\n",
    "# What is the maximum number of features?\n",
    "# Pretrained models have to use 60\n",
    "max_features = 60\n",
    "\n",
    "# How many samples should be loaded for one dataset?\n",
    "# Samples after the training sequence are used for evaluation\n",
    "seq_len = 100\n",
    "\n",
    "# How many subsamples of datasets should be drawn for each dataset\n",
    "max_samples = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "gp_model_checkpoint_dir = \"results/tabular_model_gp.ckpt\"\n",
    "gp_model_config = {'batch_size': 512,\n",
    " 'bptt': 100,\n",
    " 'dropout': 0.5,\n",
    " 'emsize': 512,\n",
    " 'epochs': 100,\n",
    " 'eval_positions': [10, 20, 40, 80],\n",
    " 'lr': 6.271726842985807e-05,\n",
    " 'nhead': 4,\n",
    " 'nhid_factor': 2,\n",
    " 'nlayers': 5,\n",
    " 'num_features': 60,\n",
    " 'prior_lengthscale': 0.00014803074521613278,\n",
    " 'prior_noise': 0.001,\n",
    " 'prior_normalize_by_used_features': True,\n",
    " 'prior_num_features_used_sampler': {'uniform_int_sampler_f(1,max_features)': '<function <lambda>.<locals>.<lambda> at 0x7f21e832e550>'},\n",
    " 'prior_order_y': False,\n",
    " 'prior_outputscale': 2.3163584733185836,\n",
    " 'prior_type': 'gp'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "bnn_model_checkpoint_dir = \"results/tabular_model_bnn.ckpt\"\n",
    "bnn_model_config = {'batch_size': 512,\n",
    " 'bptt': 50,\n",
    " 'dropout': 0.5,\n",
    " 'emsize': 512,\n",
    " 'epochs': 100,\n",
    " 'eval_positions': [10, 20, 40],\n",
    " 'lr': 1.6421403128751275e-05,\n",
    " 'nhead': 4,\n",
    " 'nhid_factor': 2,\n",
    " 'nlayers': 5,\n",
    " 'num_features': 60,\n",
    " 'prior_activations': \"<class 'torch.nn.modules.activation.Tanh'>\",\n",
    " 'prior_dropout_sampler': {'lambda: 0.0': '<function <lambda> at 0x7f613c1364c0>'},\n",
    " 'prior_emsize_sampler': {'scaled_beta_sampler_f(2.0, 4.0, 150, 2)': '<function <lambda>.<locals>.<lambda> at 0x7f613c136310>'},\n",
    " 'prior_is_causal': False,\n",
    " 'prior_nlayers_sampler': {'lambda: 3': '<function <lambda> at 0x7f613c136790>'},\n",
    " 'prior_noise_std_gamma_k': 1.8663049257557085,\n",
    " 'prior_noise_std_gamma_theta': 0.05275478076173361,\n",
    " 'prior_normalize_by_used_features': False,\n",
    " 'prior_num_features_used_sampler': {'scaled_beta_sampler_f(1.0, 1.6, max_features, 2)': '<function <lambda>.<locals>.<lambda> at 0x7f613c136550>'},\n",
    " 'prior_order_y': True,\n",
    " 'prior_sigma_gamma_k': 3.6187797729244253,\n",
    " 'prior_sigma_gamma_theta': 0.06773738681062867,\n",
    " 'prior_type': 'mlp'}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading PFN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda device\n",
      "DataLoader.__dict__ {'num_steps': 100, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 512, 'seq_len': 50, 'num_features': 60, 'hyperparameters': ('<function <lambda> at 0x7f613c136790>', '<function <lambda>.<locals>.<lambda> at 0x7f613c136310>', \"<class 'torch.nn.modules.activation.Tanh'>\", <function <lambda>.<locals>.<lambda> at 0x7f56719e5940>, <function <lambda>.<locals>.<lambda> at 0x7f56719e5b80>, '<function <lambda> at 0x7f613c1364c0>', True, '<function <lambda>.<locals>.<lambda> at 0x7f613c136550>', None, False, None, None, None, True, False, None, 0.0), 'batch_size_per_gp_sample': 8}, 'num_features': 60, 'num_outputs': 1}\n"
     ]
    }
   ],
   "source": [
    "model_type = 'bnn'\n",
    "if model_type == 'gp':\n",
    "    raise Exception(\"Not Implemented\")\n",
    "    config = gp_model_config\n",
    "    checkpoint_dir = gp_model_checkpoint_dir\n",
    "elif model_type == 'bnn':\n",
    "    config = bnn_model_config\n",
    "    checkpoint_dir = bnn_model_checkpoint_dir\n",
    "\n",
    "model = get_model(config, device, eval_positions, should_train=False)\n",
    "model_state, _ = torch.load(checkpoint_dir)\n",
    "model[2].load_state_dict(model_state)\n",
    "model = model[2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation of PFN and Baselines on all datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating wine\n",
      "\t Eval position 30 done..\n",
      "Evaluating covertype\n",
      "\t Eval position 30 done..\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'metric': 'auc',\n",
       " 'wine_mean_metric_at_30': 0.9587367346938775,\n",
       " 'wine_time': 0.45734596252441406,\n",
       " 'covertype_mean_metric_at_30': 0.9624857142857144,\n",
       " 'covertype_time': 0.4606165885925293,\n",
       " 'mean_metric_at_30': 0.960611224489796,\n",
       " 'mean_metric': 0.960611224489796}"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = 'cuda'\n",
    "result = evaluate(ds, model.to(device), 'transformer'\n",
    "                  , max_features = max_features\n",
    "                  , bptt=seq_len\n",
    "                  , eval_position_range=eval_positions\n",
    "                  , device=device\n",
    "                  , max_samples=20\n",
    "                  , rescale_features=config[\"prior_normalize_by_used_features\"]\n",
    "                  , extend_features=True, plot=False, overwrite=True, save=False)\n",
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Bayesian NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "result = evaluate(ds, bayes_net_metric, 'bayes_net'\n",
    "                  , bptt=seq_len\n",
    "                  , evaal_position_range=eval_positions\n",
    "                  , device=device\n",
    "                  , max_samples=20\n",
    "                  , overwrite=True\n",
    "                  , save=False)\n",
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Gaussian Process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating wine\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 1/5 [00:05<00:22,  5.68s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9857142857142858\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 2/5 [00:10<00:14,  4.88s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9714285714285714\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 3/5 [00:14<00:09,  4.61s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 4/5 [00:18<00:04,  4.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:23<00:00,  4.65s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7571428571428571\n",
      "\t Eval position 30 done..\n",
      "Evaluating covertype\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 1/5 [00:06<00:24,  6.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9285714285714286\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 2/5 [00:11<00:17,  5.75s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 3/5 [00:17<00:11,  5.80s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7142857142857143\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 4/5 [00:22<00:05,  5.44s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8142857142857143\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:27<00:00,  5.50s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8714285714285714\n",
      "\t Eval position 30 done..\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "result = evaluate(ds, gp_metric, 'gp'\n",
    "                  , bptt=seq_len\n",
    "                  , evaal_position_range=eval_positions\n",
    "                  , device=device\n",
    "                  , max_samples=20\n",
    "                  , overwrite=True\n",
    "                  , save=False)\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Tabnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1403,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating blood-transfusion-service-center\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device used : cuda\n",
      "No early stopping will be performed, last training weights will be used.\n",
      "epoch 0  | loss: 0.89694 |  0:00:00s\n",
      "epoch 1  | loss: 0.77102 |  0:00:00s\n",
      "epoch 2  | loss: 0.85065 |  0:00:00s\n",
      "epoch 3  | loss: 0.55303 |  0:00:00s\n",
      "epoch 4  | loss: 0.54864 |  0:00:00s\n",
      "epoch 5  | loss: 0.57434 |  0:00:00s\n",
      "epoch 6  | loss: 0.53758 |  0:00:00s\n",
      "epoch 7  | loss: 0.50961 |  0:00:00s\n",
      "epoch 8  | loss: 0.49373 |  0:00:00s\n",
      "epoch 9  | loss: 0.48734 |  0:00:00s\n",
      "epoch 10 | loss: 0.46865 |  0:00:00s\n",
      "epoch 11 | loss: 0.45546 |  0:00:00s\n",
      "epoch 12 | loss: 0.44154 |  0:00:00s\n",
      "epoch 13 | loss: 0.42482 |  0:00:00s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2 [00:01<?, ?it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_73533/3239863466.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'cpu'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m7\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtabnet_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'tabnet'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbptt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_positions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mplot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moverwrite\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/home/anon/prior-fitting/results_'\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mselector\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m'_tabnet.npy'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(datasets, model, method, bptt, eval_position_range, device, max_features, plot, extend_features, save, rescale_features, overwrite, max_samples, path_interfix)\u001b[0m\n\u001b[1;32m    191\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    192\u001b[0m         \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 193\u001b[0;31m         ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n\u001b[0m\u001b[1;32m    194\u001b[0m                                rescale_features=rescale_features_factor, max_samples=max_samples)\n\u001b[1;32m    195\u001b[0m         \u001b[0melapsed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_dataset\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position_range, plot, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m    222\u001b[0m     \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    223\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0meval_position\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m         r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n\u001b[0m\u001b[1;32m    225\u001b[0m                               max_samples=max_samples)\n\u001b[1;32m    226\u001b[0m         \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_position\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m    310\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    311\u001b[0m         \u001b[0;31m# acc_per_position = torch.tensor(batch_pred(model,eval_xs,eval_ys, start=2))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 312\u001b[0;31m         \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_pred\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    313\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    314\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbatch_pred\u001b[0;34m(acc_function, eval_xs, eval_ys, categorical_feats, start)\u001b[0m\n\u001b[1;32m    325\u001b[0m         \u001b[0meval_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mmean\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mstd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    326\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 327\u001b[0;31m         \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    328\u001b[0m         \u001b[0maccs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    329\u001b[0m         \u001b[0moutputs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mtabnet_acc\u001b[0;34m(x, y, test_x, test_y, cat_features)\u001b[0m\n\u001b[1;32m    565\u001b[0m             \u001b[0mclf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTabNetClassifier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcat_idxs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcat_features\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_a\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'n_d'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    566\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 567\u001b[0;31m             clf.fit(\n\u001b[0m\u001b[1;32m    568\u001b[0m                 \u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    569\u001b[0m                 \u001b[0;31m#eval_set=[(X_valid, y_valid)], patience=15\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pytorch_tabnet/abstract_model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X_train, y_train, eval_set, eval_name, eval_metric, loss_fn, weights, max_epochs, patience, batch_size, virtual_batch_size, num_workers, drop_last, callbacks, pin_memory, from_unsupervised)\u001b[0m\n\u001b[1;32m    221\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callback_container\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_epoch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 223\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    224\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    225\u001b[0m             \u001b[0;31m# Apply predict epoch to all eval sets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pytorch_tabnet/abstract_model.py\u001b[0m in \u001b[0;36m_train_epoch\u001b[0;34m(self, train_loader)\u001b[0m\n\u001b[1;32m    432\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callback_container\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    433\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 434\u001b[0;31m             \u001b[0mbatch_logs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    435\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    436\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callback_container\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_logs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pytorch_tabnet/abstract_model.py\u001b[0m in \u001b[0;36m_train_batch\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m    461\u001b[0m         \u001b[0mbatch_logs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\"batch_size\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    462\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 463\u001b[0;31m         \u001b[0mX\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    464\u001b[0m         \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    465\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "result = evaluate(ds, tabnet_metric, 'tabnet'\n",
    "                  , bptt=seq_len\n",
    "                  , eval_position_range=eval_positions\n",
    "                  , device=device\n",
    "                  , max_samples=20\n",
    "                  , overwrite=True\n",
    "                  , save=False)\n",
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### KNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating wine\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:01<00:00, 11.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t Eval position 30 done..\n",
      "Evaluating covertype\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:01<00:00, 16.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t Eval position 30 done..\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'metric': 'auc',\n",
       " 'wine_mean_metric_at_30': 0.8925102040816327,\n",
       " 'wine_time': 1.7465898990631104,\n",
       " 'covertype_mean_metric_at_30': 0.792734693877551,\n",
       " 'covertype_time': 1.247727394104004,\n",
       " 'mean_metric_at_30': 0.8426224489795919,\n",
       " 'mean_metric': 0.8426224489795919}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result = evaluate(ds, knn_metric, 'knn'\n",
    "                  , bptt=seq_len\n",
    "                  , eval_position_range=eval_positions\n",
    "                  , device=device\n",
    "                  , max_samples=20\n",
    "                  , overwrite=True\n",
    "                  , save=False)\n",
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Logistic Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating wine\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|▌         | 1/20 [00:01<00:34,  1.83s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'C': 1e-05, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'none', 'solver': 'saga', 'tol': 0.01}\n",
      "[4.12326081e-04 9.19611425e-01 3.97752626e-02 9.74932808e-01\n",
      " 2.92177993e-03 9.98121211e-01 4.40905019e-03 9.33577461e-01\n",
      " 1.50555390e-01 9.94204349e-01 3.87256524e-04 9.75964168e-01\n",
      " 6.52074221e-03 5.44187545e-01 2.76466189e-04 7.75604223e-01\n",
      " 6.95195131e-04 8.77684487e-01 3.72080751e-04 9.98190223e-01\n",
      " 1.30797581e-02 9.98647814e-01 2.64468869e-01 9.99135067e-01\n",
      " 2.63580359e-05 6.23341075e-01 1.93786767e-03 9.08927121e-01\n",
      " 3.61541354e-03 9.98969131e-01 1.74570882e-02 9.08004442e-01\n",
      " 1.55259675e-02 9.57181495e-01 2.78544674e-03 9.99731021e-01\n",
      " 7.09553284e-02 2.29517949e-01 3.19418012e-02 6.85265538e-01\n",
      " 5.43416196e-03 9.67286609e-01 3.30096232e-02 8.83952036e-01\n",
      " 9.05628129e-04 9.25079429e-01 1.30549388e-03 9.87857426e-01\n",
      " 1.93246531e-02 7.14752885e-01 6.55393417e-03 9.91341790e-01\n",
      " 2.71002767e-01 9.72008984e-01 1.95572800e-02 9.70650138e-01\n",
      " 4.02880801e-04 9.36476310e-01 1.71522394e-03 9.87288208e-01\n",
      " 5.34800821e-03 9.87928080e-01 2.29681197e-03 9.12132603e-01\n",
      " 1.32190621e-01 9.99164477e-01 2.50227955e-04 9.78312822e-01\n",
      " 8.01561154e-02 9.71095930e-01]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 2/20 [00:03<00:35,  1.99s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'C': 1e-05, 'fit_intercept': False, 'max_iter': 500, 'penalty': 'none', 'solver': 'saga', 'tol': 0.01}\n",
      "[2.34966376e-03 9.97811502e-01 4.42977805e-03 9.99565614e-01\n",
      " 2.59753692e-03 9.61724168e-01 2.61071324e-03 9.99933616e-01\n",
      " 1.46586638e-03 9.85141705e-01 1.99387649e-04 9.14871930e-01\n",
      " 2.31569875e-02 9.77123157e-01 4.55479338e-04 9.98882992e-01\n",
      " 2.01825925e-03 9.12376267e-01 2.19126373e-01 9.95130404e-01\n",
      " 5.17808246e-04 9.41544348e-01 5.26481881e-03 5.45981587e-01\n",
      " 1.29234306e-04 7.59193071e-01 5.28180762e-04 8.64066925e-01\n",
      " 5.49872741e-04 9.99114519e-01 3.82477700e-03 9.97710659e-01\n",
      " 3.32963599e-01 9.98419850e-01 2.61277431e-05 5.27080955e-01\n",
      " 6.53393899e-03 7.76305159e-01 6.04063168e-03 9.98567996e-01\n",
      " 3.01862055e-02 8.02008437e-01 1.81304029e-02 9.54083311e-01\n",
      " 7.13274150e-03 9.99326682e-01 1.00578841e-01 2.94002156e-01\n",
      " 2.58978601e-02 7.49132883e-01 4.50584482e-03 9.77984332e-01\n",
      " 8.14691986e-02 8.71252875e-01 9.75149695e-04 9.42233545e-01\n",
      " 9.85537736e-04 9.96837487e-01 1.44663880e-02 7.96780562e-01\n",
      " 7.87276898e-03 9.94137661e-01 4.49145912e-01 9.83950313e-01\n",
      " 1.61118424e-02 9.59959963e-01 4.32656352e-04 9.19760156e-01\n",
      " 1.82114427e-03 9.92688659e-01]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█▌        | 3/20 [00:05<00:32,  1.92s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'C': 1.0, 'fit_intercept': False, 'max_iter': 500, 'penalty': 'l1', 'solver': 'saga', 'tol': 0.0001}\n",
      "[0.06974362 0.42336164 0.01703207 0.97388489 0.22809311 0.864173\n",
      " 0.03105408 0.83391439 0.01444774 0.91769675 0.12392376 0.80023415\n",
      " 0.09652814 0.85240131 0.08174257 0.74563011 0.32383395 0.87731599\n",
      " 0.68395194 0.94467399 0.25368484 0.99224602 0.09027845 0.88636919\n",
      " 0.09235925 0.99655498 0.9651366  0.97910275 0.5482745  0.93189756\n",
      " 0.01159998 0.96038068 0.0267969  0.9155232  0.05236642 0.96382213\n",
      " 0.43594785 0.81424108 0.10509157 0.98534026 0.01147495 0.09514501\n",
      " 0.29550364 0.96793223 0.06416272 0.02983371 0.03576493 0.09970959\n",
      " 0.06283597 0.06050325 0.04146495 0.10931557 0.00956397 0.03881107\n",
      " 0.07353775 0.23500227 0.02755604 0.28342607 0.1337751  0.08567451\n",
      " 0.01716663 0.0149318  0.01978608 0.23753092 0.01930092 0.09254563\n",
      " 0.04065118 0.13216921 0.0324691  0.09542866]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 4/20 [00:07<00:29,  1.86s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'C': 1.0, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'l1', 'solver': 'saga', 'tol': 0.01}\n",
      "[0.70131459 0.01678711 0.60492504 0.01069737 0.71376805 0.02748887\n",
      " 0.98449608 0.05130859 0.97420844 0.40381504 0.96752984 0.02519502\n",
      " 0.82707777 0.10708214 0.69000753 0.2040882  0.97747093 0.40330462\n",
      " 0.67280445 0.19142666 0.96526066 0.08645192 0.99053094 0.35871556\n",
      " 0.65781758 0.44227517 0.68420729 0.08220921 0.89025336 0.45462276\n",
      " 0.84425192 0.54211743 0.88263528 0.08692686 0.98430602 0.17645702\n",
      " 0.8538925  0.16913417 0.97691351 0.81127546 0.93985646 0.16094987\n",
      " 0.96264753 0.04228152 0.90056355 0.19237197 0.93328752 0.07411261\n",
      " 0.89027042 0.58041776 0.7955457  0.18435328 0.99249684 0.03125832\n",
      " 0.70961402 0.2673135  0.97250892 0.01717529 0.70037093 0.12556494\n",
      " 0.70165636 0.09616296 0.2272858  0.08929946 0.32233496 0.00673966\n",
      " 0.5958783  0.02302669 0.15701894 0.0500672 ]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 25%|██▌       | 5/20 [00:09<00:27,  1.85s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'C': 1e-05, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'none', 'solver': 'saga', 'tol': 0.01}\n",
      "[6.52714586e-01 9.98635531e-01 6.00991816e-03 9.55565635e-01\n",
      " 7.09951283e-03 8.14360044e-01 6.81791679e-03 9.99801606e-01\n",
      " 4.36909361e-01 8.91309699e-01 3.48875884e-01 9.96234204e-01\n",
      " 2.69088658e-02 9.99950326e-01 2.39151361e-01 9.33560167e-01\n",
      " 6.36510214e-01 9.11335046e-01 6.19797405e-02 9.89353304e-01\n",
      " 1.05815283e-01 9.75894050e-01 9.19643764e-01 9.86137300e-01\n",
      " 2.66815067e-03 9.99040711e-01 3.01042737e-02 9.62981963e-01\n",
      " 2.01891457e-01 9.99483583e-01 9.72869351e-01 9.92534311e-01\n",
      " 3.48164720e-01 9.98274099e-01 9.57299314e-03 9.95362359e-01\n",
      " 1.34150611e-01 9.96665681e-01 8.69588961e-03 9.88952714e-01\n",
      " 4.11556516e-01 8.06033390e-01 8.89376753e-02 9.99933419e-01\n",
      " 1.79285952e-02 2.46006114e-02 4.19293642e-01 9.99020227e-01\n",
      " 1.39752377e-03 8.27256545e-02 7.09215549e-04 1.88990783e-02\n",
      " 2.49824281e-02 7.11231223e-04 9.35730222e-03 2.34447294e-03\n",
      " 9.25927278e-04 2.82376339e-02 2.25198163e-03 3.71604751e-04\n",
      " 4.15443651e-03 2.59381661e-01 3.13232724e-02 2.91566784e-02\n",
      " 3.94860122e-04 7.75759927e-04 5.23700697e-05 1.50473088e-01\n",
      " 3.39148707e-03 5.03053356e-03]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 6/20 [00:11<00:25,  1.83s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'C': 1e-05, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'none', 'solver': 'saga', 'tol': 0.01}\n",
      "[9.83583576e-01 5.31016587e-04 9.29274459e-01 3.85459019e-02\n",
      " 9.71639017e-01 1.18775105e-03 9.98134567e-01 5.21240611e-03\n",
      " 8.86056657e-01 2.38149719e-01 9.92077944e-01 7.38237983e-04\n",
      " 9.58905441e-01 1.24559396e-02 5.60988697e-01 3.11511168e-04\n",
      " 7.85096097e-01 9.90275798e-04 8.82864038e-01 5.31628645e-04\n",
      " 9.98564605e-01 1.48625315e-02 9.97921985e-01 2.86616976e-01\n",
      " 9.98944054e-01 4.72361925e-05 6.31113590e-01 6.26047465e-03\n",
      " 8.84982042e-01 9.00147613e-03 9.98913206e-01 3.01658308e-02\n",
      " 8.72072354e-01 1.88641325e-02 9.61498458e-01 5.51973166e-03\n",
      " 9.99559769e-01 8.98409294e-02 2.73225895e-01 3.36784201e-02\n",
      " 7.20221648e-01 7.07506732e-03 9.77283000e-01 1.00082860e-01\n",
      " 8.83214682e-01 6.98692395e-04 9.33891209e-01 1.52199602e-03\n",
      " 9.93985713e-01 2.50527583e-02 7.55022990e-01 9.41495432e-03\n",
      " 9.90193157e-01 3.03874037e-01 9.70953038e-01 2.07865013e-02\n",
      " 9.67583165e-01 5.22901980e-04 9.28441639e-01 2.16943826e-03\n",
      " 9.90123427e-01 7.18394185e-03 9.90211745e-01 2.42021104e-03\n",
      " 9.38244511e-01 1.23170538e-01 9.99344646e-01 3.81413969e-04\n",
      " 9.78230408e-01 7.60275466e-02]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 35%|███▌      | 7/20 [00:13<00:24,  1.85s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'C': 1e-05, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'none', 'solver': 'saga', 'tol': 0.0001}\n",
      "[3.31313562e-05 9.99999464e-01 7.95159296e-01 9.99997804e-01\n",
      " 3.13777454e-04 9.95388564e-01 3.10136131e-04 9.75628399e-01\n",
      " 3.62288555e-04 9.99999677e-01 5.60118984e-01 9.94055581e-01\n",
      " 2.82618459e-01 9.99898183e-01 2.05644609e-03 9.99999978e-01\n",
      " 7.96792726e-02 9.92154817e-01 5.86060898e-01 9.88876051e-01\n",
      " 1.15897862e-02 9.99597733e-01 5.01256743e-02 9.98633899e-01\n",
      " 9.73034582e-01 9.99549765e-01 5.99983839e-05 9.99995143e-01\n",
      " 6.18802761e-03 9.95558768e-01 6.38389742e-02 9.99998705e-01\n",
      " 9.99182384e-01 9.99817107e-01 4.12569504e-01 9.99983173e-01\n",
      " 4.46397918e-04 9.99892199e-01 3.40415387e-02 9.99945259e-01\n",
      " 6.76523004e-04 9.99672836e-01 1.59490335e-01 9.25574947e-01\n",
      " 4.39227379e-02 9.99999945e-01 7.88443823e-04 1.14832115e-02\n",
      " 3.73727289e-01 9.99995097e-01 1.43936707e-05 6.01971612e-02\n",
      " 1.11791732e-05 2.47806691e-03 8.31325116e-04 5.91754942e-06\n",
      " 3.79556417e-04 2.44141540e-04 3.43753601e-06 9.67005991e-03\n",
      " 5.44429224e-05 6.99912685e-06 1.40831013e-04 1.64769326e-01\n",
      " 4.10106657e-03 3.14697787e-03 1.65498495e-06 3.17770130e-05\n",
      " 6.87332367e-08 1.85409476e-01]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 8/20 [00:14<00:22,  1.85s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'C': 1e-05, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'none', 'solver': 'saga', 'tol': 0.01}\n",
      "[9.99664268e-01 7.62876099e-01 9.98386486e-01 6.27311096e-03\n",
      " 9.37337218e-01 6.21016618e-03 8.18091840e-01 5.71440730e-03\n",
      " 9.99733629e-01 4.49018505e-01 9.11202742e-01 4.15045796e-01\n",
      " 9.93656045e-01 2.90506359e-02 9.99924234e-01 2.44542240e-01\n",
      " 9.26934168e-01 6.76509634e-01 9.09437589e-01 7.14948195e-02\n",
      " 9.85587705e-01 7.66412054e-02 9.72836935e-01 8.36812757e-01\n",
      " 9.82542331e-01 2.44985077e-03 9.98048341e-01 3.79889220e-02\n",
      " 9.58489559e-01 1.95237163e-01 9.99317980e-01 9.59677506e-01\n",
      " 9.92652818e-01 3.36662240e-01 9.97738733e-01 1.03584140e-02\n",
      " 9.94715794e-01 1.52240798e-01 9.96180878e-01 7.25841492e-03\n",
      " 9.86733845e-01 2.56595318e-01 7.50204506e-01 1.22199744e-01\n",
      " 9.99891118e-01 2.30040057e-02 3.16352395e-02 4.69433536e-01\n",
      " 9.98253315e-01 2.25872669e-03 6.35848289e-02 1.08197788e-03\n",
      " 1.72929146e-02 2.31652202e-02 5.00682185e-04 8.54084689e-03\n",
      " 3.33725153e-03 1.54900694e-03 3.73832650e-02 3.21548385e-03\n",
      " 3.95950758e-04 4.80783210e-03 2.40004379e-01 3.00124757e-02\n",
      " 2.95706753e-02 4.96784512e-04 1.01762749e-03 6.84382194e-05\n",
      " 1.55359082e-01 3.53251038e-03]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 45%|████▌     | 9/20 [00:16<00:20,  1.86s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'C': 1e-05, 'fit_intercept': True, 'max_iter': 500, 'penalty': 'l2', 'solver': 'saga', 'tol': 1e-10}\n",
      "[0.4998356  0.50015137 0.49987409 0.50029966 0.49983398 0.50031188\n",
      " 0.50009655 0.50023837 0.49986525 0.5002024  0.49988854 0.50012421\n",
      " 0.49990913 0.50033222 0.50001204 0.50011976 0.50011477 0.50028586\n",
      " 0.49994748 0.50036626 0.50007451 0.50013711 0.50013507 0.50012796\n",
      " 0.49997282 0.50019476 0.49997082 0.50019112 0.5001566  0.50017832\n",
      " 0.49986975 0.50027281 0.49993339 0.50016413 0.50001914 0.50030764\n",
      " 0.50022348 0.50024722 0.50009243 0.500288   0.49992337 0.50020947\n",
      " 0.50000242 0.50023711 0.49991418 0.50019936 0.50010521 0.50013699\n",
      " 0.50003493 0.50039762 0.49996634 0.49993259 0.50006249 0.50027367\n",
      " 0.49986551 0.49999889 0.49985123 0.4999571  0.50000606 0.49985354\n",
      " 0.49994856 0.49987396 0.4999105  0.49993825 0.49989787 0.49983024\n",
      " 0.49994391 0.50008364 0.50003337 0.50000305]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 45%|████▌     | 9/20 [00:17<00:21,  1.92s/it]\n",
      "ERROR:root:Internal Python error in the inspect module.\n",
      "Below is the traceback from this internal error.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3441, in run_code\n",
      "    exec(code_obj, self.user_global_ns, self.user_ns)\n",
      "  File \"/tmp/ipykernel_62303/3352290182.py\", line 1, in <module>\n",
      "    result = evaluate(ds, logistic_acc, 'logistic'\n",
      "  File \"/home/anon/prior-fitting/tabular.py\", line 195, in evaluate\n",
      "    ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n",
      "  File \"/home/anon/prior-fitting/tabular.py\", line 227, in evaluate_dataset\n",
      "    r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n",
      "  File \"/home/anon/prior-fitting/tabular.py\", line 314, in evaluate_position\n",
      "    acc_eval_pos, outputs = batch_pred(model, eval_xs, eval_ys, categorical_feats, start=eval_position)\n",
      "  File \"/home/anon/prior-fitting/tabular.py\", line 329, in batch_pred\n",
      "    acc, output = acc_function(eval_x[:start], eval_y[:start], eval_x[start:], eval_y[start:], categorical_feats)\n",
      "  File \"/home/anon/prior-fitting/tabular.py\", line 380, in logistic_acc\n",
      "    clf.fit(x, y.long())\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\", line 63, in inner_f\n",
      "    return f(*args, **kwargs)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\", line 841, in fit\n",
      "    self._run_search(evaluate_candidates)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\", line 1296, in _run_search\n",
      "    evaluate_candidates(ParameterGrid(self.param_grid))\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\", line 795, in evaluate_candidates\n",
      "    out = parallel(delayed(_fit_and_score)(clone(base_estimator),\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\", line 1044, in __call__\n",
      "    while self.dispatch_one_batch(iterator):\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\", line 859, in dispatch_one_batch\n",
      "    self._dispatch(tasks)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\", line 777, in _dispatch\n",
      "    job = self._backend.apply_async(batch, callback=cb)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\", line 208, in apply_async\n",
      "    result = ImmediateResult(func)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\", line 572, in __init__\n",
      "    self.results = batch()\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\", line 262, in __call__\n",
      "    return [func(*args, **kwargs)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\", line 262, in <listcomp>\n",
      "    return [func(*args, **kwargs)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/fixes.py\", line 222, in __call__\n",
      "    return self.function(*args, **kwargs)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_validation.py\", line 625, in _fit_and_score\n",
      "    test_scores = _score(estimator, X_test, y_test, scorer, error_score)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_validation.py\", line 687, in _score\n",
      "    scores = scorer(estimator, X_test, y_test)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/metrics/_scorer.py\", line 397, in _passthrough_scorer\n",
      "    return estimator.score(*args, **kwargs)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/base.py\", line 500, in score\n",
      "    return accuracy_score(y, self.predict(X), sample_weight=sample_weight)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\", line 63, in inner_f\n",
      "    return f(*args, **kwargs)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/metrics/_classification.py\", line 210, in accuracy_score\n",
      "    return _weighted_sum(score, sample_weight, normalize)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/metrics/_classification.py\", line 133, in _weighted_sum\n",
      "    return np.average(sample_score, weights=sample_weight)\n",
      "  File \"<__array_function__ internals>\", line 5, in average\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/numpy/lib/function_base.py\", line 380, in average\n",
      "    avg = a.mean(axis)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/numpy/core/_methods.py\", line 167, in _mean\n",
      "    rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/numpy/core/_methods.py\", line 71, in _count_reduce_items\n",
      "    axis = tuple(range(arr.ndim))\n",
      "KeyboardInterrupt\n",
      "\n",
      "During handling of the above exception, another exception occurred:\n",
      "\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 2061, in showtraceback\n",
      "    stb = value._render_traceback_()\n",
      "AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'\n",
      "\n",
      "During handling of the above exception, another exception occurred:\n",
      "\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 1101, in get_records\n",
      "    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 248, in wrapped\n",
      "    return f(*args, **kwargs)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 281, in _fixed_getinnerframes\n",
      "    records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/inspect.py\", line 1541, in getinnerframes\n",
      "    frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/inspect.py\", line 1503, in getframeinfo\n",
      "    lines, lnum = findsource(frame)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 182, in findsource\n",
      "    lines = linecache.getlines(file, globals_dict)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/linecache.py\", line 46, in getlines\n",
      "    return updatecache(filename, module_globals)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/linecache.py\", line 136, in updatecache\n",
      "    with tokenize.open(fullname) as fp:\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/tokenize.py\", line 394, in open\n",
      "    encoding, lines = detect_encoding(buffer.readline)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/tokenize.py\", line 363, in detect_encoding\n",
      "    first = read_or_stop()\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/tokenize.py\", line 321, in read_or_stop\n",
      "    return readline()\n",
      "KeyboardInterrupt\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "object of type 'NoneType' has no len()",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "    \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_62303/3352290182.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m result = evaluate(ds, logistic_acc, 'logistic'\n\u001b[0m\u001b[1;32m      2\u001b[0m                   \u001b[0;34m,\u001b[0m \u001b[0mbptt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mseq_len\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m                   \u001b[0;34m,\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_positions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(datasets, model, method, bptt, eval_position_range, device, max_features, plot, extend_features, save, rescale_features, overwrite, max_samples, path_interfix)\u001b[0m\n\u001b[1;32m    194\u001b[0m         \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 195\u001b[0;31m         ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n\u001b[0m\u001b[1;32m    196\u001b[0m                                rescale_features=rescale_features_factor, max_samples=max_samples)\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_dataset\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position_range, plot, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m    226\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0meval_position\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 227\u001b[0;31m         r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n\u001b[0m\u001b[1;32m    228\u001b[0m                               max_samples=max_samples)\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_position\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m    313\u001b[0m         \u001b[0;31m# acc_per_position = torch.tensor(batch_pred(model,eval_xs,eval_ys, start=2))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m         \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_pred\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    315\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbatch_pred\u001b[0;34m(acc_function, eval_xs, eval_ys, categorical_feats, start)\u001b[0m\n\u001b[1;32m    328\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 329\u001b[0;31m         \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    330\u001b[0m         \u001b[0maccs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mlogistic_acc\u001b[0;34m(x, y, test_x, test_y, cat_features)\u001b[0m\n\u001b[1;32m    379\u001b[0m     \u001b[0;31m# fit model to data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 380\u001b[0;31m     \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    381\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     62\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, groups, **fit_params)\u001b[0m\n\u001b[1;32m    840\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 841\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    842\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36m_run_search\u001b[0;34m(self, evaluate_candidates)\u001b[0m\n\u001b[1;32m   1295\u001b[0m         \u001b[0;34m\"\"\"Search all candidates in param_grid\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1296\u001b[0;31m         \u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mParameterGrid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_grid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1297\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mevaluate_candidates\u001b[0;34m(candidate_params, cv, more_results)\u001b[0m\n\u001b[1;32m    794\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 795\u001b[0;31m                 out = parallel(delayed(_fit_and_score)(clone(base_estimator),\n\u001b[0m\u001b[1;32m    796\u001b[0m                                                        \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   1043\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1044\u001b[0;31m             \u001b[0;32mwhile\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_one_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1045\u001b[0m                 \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36mdispatch_one_batch\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m    858\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 859\u001b[0;31m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtasks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    860\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_dispatch\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m    776\u001b[0m             \u001b[0mjob_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 777\u001b[0;31m             \u001b[0mjob\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    778\u001b[0m             \u001b[0;31m# A job can complete so quickly than its callback is\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36mapply_async\u001b[0;34m(self, func, callback)\u001b[0m\n\u001b[1;32m    207\u001b[0m         \u001b[0;34m\"\"\"Schedule a func to be run\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 208\u001b[0;31m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImmediateResult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    209\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m    571\u001b[0m         \u001b[0;31m# arguments in memory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 572\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    573\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    261\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m             return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m    263\u001b[0m                     for func, args, kwargs in self.items]\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    261\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m             return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m    263\u001b[0m                     for func, args, kwargs in self.items]\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/fixes.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    221\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mconfig_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_validation.py\u001b[0m in \u001b[0;36m_fit_and_score\u001b[0;34m(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters, return_n_test_samples, return_times, return_estimator, split_progress, candidate_progress, error_score)\u001b[0m\n\u001b[1;32m    624\u001b[0m         \u001b[0mfit_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 625\u001b[0;31m         \u001b[0mtest_scores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscorer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merror_score\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    626\u001b[0m         \u001b[0mscore_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mfit_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_validation.py\u001b[0m in \u001b[0;36m_score\u001b[0;34m(estimator, X_test, y_test, scorer, error_score)\u001b[0m\n\u001b[1;32m    686\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 687\u001b[0;31m             \u001b[0mscores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscorer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    688\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/metrics/_scorer.py\u001b[0m in \u001b[0;36m_passthrough_scorer\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m    396\u001b[0m     \u001b[0;34m\"\"\"Function that wraps estimator.score\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 397\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscore\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    398\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/base.py\u001b[0m in \u001b[0;36mscore\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m    499\u001b[0m         \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0maccuracy_score\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 500\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0maccuracy_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    501\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     62\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/metrics/_classification.py\u001b[0m in \u001b[0;36maccuracy_score\u001b[0;34m(y_true, y_pred, normalize, sample_weight)\u001b[0m\n\u001b[1;32m    209\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 210\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_weighted_sum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscore\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnormalize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    211\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/metrics/_classification.py\u001b[0m in \u001b[0;36m_weighted_sum\u001b[0;34m(sample_score, sample_weight, normalize)\u001b[0m\n\u001b[1;32m    132\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mnormalize\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 133\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maverage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample_score\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweights\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    134\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0msample_weight\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36maverage\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/numpy/lib/function_base.py\u001b[0m in \u001b[0;36maverage\u001b[0;34m(a, axis, weights, returned)\u001b[0m\n\u001b[1;32m    379\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mweights\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 380\u001b[0;31m         \u001b[0mavg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    381\u001b[0m         \u001b[0mscl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mavg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/numpy/core/_methods.py\u001b[0m in \u001b[0;36m_mean\u001b[0;34m(a, axis, dtype, out, keepdims, where)\u001b[0m\n\u001b[1;32m    166\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m     \u001b[0mrcount\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_count_reduce_items\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkeepdims\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwhere\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwhere\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    168\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mrcount\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mwhere\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mTrue\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mumr_any\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrcount\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/numpy/core/_methods.py\u001b[0m in \u001b[0;36m_count_reduce_items\u001b[0;34m(arr, axis, keepdims, where)\u001b[0m\n\u001b[1;32m     70\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0maxis\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 71\u001b[0;31m             \u001b[0maxis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     72\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: ",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m   2060\u001b[0m                         \u001b[0;31m# in the engines. This should return a list of strings.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2061\u001b[0;31m                         \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2062\u001b[0m                     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'KeyboardInterrupt' object has no attribute '_render_traceback_'",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "    \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m   2061\u001b[0m                         \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2062\u001b[0m                     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2063\u001b[0;31m                         stb = self.InteractiveTB.structured_traceback(etype,\n\u001b[0m\u001b[1;32m   2064\u001b[0m                                             value, tb, tb_offset=tb_offset)\n\u001b[1;32m   2065\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m   1365\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1366\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1367\u001b[0;31m         return FormattedTB.structured_traceback(\n\u001b[0m\u001b[1;32m   1368\u001b[0m             self, etype, value, tb, tb_offset, number_of_lines_of_context)\n\u001b[1;32m   1369\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m   1265\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mverbose_modes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1266\u001b[0m             \u001b[0;31m# Verbose modes need a full traceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1267\u001b[0;31m             return VerboseTB.structured_traceback(\n\u001b[0m\u001b[1;32m   1268\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb_offset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumber_of_lines_of_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1269\u001b[0m             )\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, evalue, etb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m   1122\u001b[0m         \u001b[0;34m\"\"\"Return a nice text document describing the traceback.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1124\u001b[0;31m         formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,\n\u001b[0m\u001b[1;32m   1125\u001b[0m                                                                tb_offset)\n\u001b[1;32m   1126\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mformat_exception_as_a_whole\u001b[0;34m(self, etype, evalue, etb, number_of_lines_of_context, tb_offset)\u001b[0m\n\u001b[1;32m   1080\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1081\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1082\u001b[0;31m         \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfind_recursion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morig_etype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1083\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1084\u001b[0m         \u001b[0mframes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat_records\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mfind_recursion\u001b[0;34m(etype, value, records)\u001b[0m\n\u001b[1;32m    380\u001b[0m     \u001b[0;31m# first frame (from in to out) that looks different.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    381\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_recursion_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 382\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    383\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    384\u001b[0m     \u001b[0;31m# Select filename, lineno, func_name to track frames with\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: object of type 'NoneType' has no len()"
     ]
    }
   ],
   "source": [
    "result = evaluate(ds, logistic_metric, 'logistic'\n",
    "                  , bptt=seq_len\n",
    "                  , eval_position_range=eval_positions\n",
    "                  , device=device\n",
    "                  , max_samples=20\n",
    "                  , overwrite=True\n",
    "                  , save=False)\n",
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Ridge Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1303,
   "metadata": {
    "hidden": true,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating kr-vs-kp\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|▌         | 1/20 [00:00<00:02,  6.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8428571428571429\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 2/20 [00:00<00:02,  6.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█▌        | 3/20 [00:00<00:02,  6.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7857142857142857\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 4/20 [00:00<00:02,  6.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9428571428571428\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 25%|██▌       | 5/20 [00:00<00:02,  6.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 6/20 [00:00<00:02,  6.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 35%|███▌      | 7/20 [00:01<00:01,  6.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8714285714285714\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 8/20 [00:01<00:01,  6.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9571428571428572\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 45%|████▌     | 9/20 [00:01<00:01,  6.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8285714285714286\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=8.92231e-11): result may not be accurate.\n",
      "  dual_coef = linalg.solve(K, y, sym_pos=True,\n",
      " 50%|█████     | 10/20 [00:01<00:01,  6.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 55%|█████▌    | 11/20 [00:01<00:01,  6.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6714285714285714\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 12/20 [00:01<00:01,  6.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8285714285714286\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 65%|██████▌   | 13/20 [00:01<00:01,  6.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8857142857142857\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 14/20 [00:02<00:00,  6.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 75%|███████▌  | 15/20 [00:02<00:00,  6.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 75%|███████▌  | 15/20 [00:02<00:00,  6.37it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_71014/1206641945.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_ds_with_selector\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mselector\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mridge_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'ridge'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbptt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_positions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moverwrite\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/home/anon/prior-fitting/results/tabular/results_'\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mselector\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m'_ridge.npy'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(datasets, model, method, bptt, eval_position_range, device, max_features, plot, extend_features, save, rescale_features, overwrite, max_samples, path_interfix)\u001b[0m\n\u001b[1;32m    188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    189\u001b[0m         \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 190\u001b[0;31m         ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n\u001b[0m\u001b[1;32m    191\u001b[0m                                rescale_features=rescale_features_factor, max_samples=max_samples)\n\u001b[1;32m    192\u001b[0m         \u001b[0melapsed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_dataset\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position_range, plot, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m    219\u001b[0m     \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    220\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0meval_position\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 221\u001b[0;31m         r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n\u001b[0m\u001b[1;32m    222\u001b[0m                               max_samples=max_samples)\n\u001b[1;32m    223\u001b[0m         \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_position\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m    307\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    308\u001b[0m         \u001b[0;31m# acc_per_position = torch.tensor(batch_pred(model,eval_xs,eval_ys, start=2))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 309\u001b[0;31m         \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_pred\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    311\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbatch_pred\u001b[0;34m(acc_function, eval_xs, eval_ys, categorical_feats, start)\u001b[0m\n\u001b[1;32m    322\u001b[0m         \u001b[0meval_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mmean\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mstd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    323\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 324\u001b[0;31m         \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    325\u001b[0m         \u001b[0maccs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    326\u001b[0m         \u001b[0moutputs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mridge_acc\u001b[0;34m(x, y, test_x, test_y, cat_features)\u001b[0m\n\u001b[1;32m    348\u001b[0m     \u001b[0mclf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGridSearchCV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'ridge'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCV\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m//\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    349\u001b[0m     \u001b[0;31m# fit model to data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 350\u001b[0;31m     \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    351\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    352\u001b[0m     \u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecision_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_x\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     61\u001b[0m             \u001b[0mextra_args\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     62\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     65\u001b[0m             \u001b[0;31m# extra_args > 0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, groups, **fit_params)\u001b[0m\n\u001b[1;32m    839\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    840\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 841\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    842\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    843\u001b[0m             \u001b[0;31m# multimetric is determined here because in the case of a callable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36m_run_search\u001b[0;34m(self, evaluate_candidates)\u001b[0m\n\u001b[1;32m   1294\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_run_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1295\u001b[0m         \u001b[0;34m\"\"\"Search all candidates in param_grid\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1296\u001b[0;31m         \u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mParameterGrid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_grid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1297\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1298\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mevaluate_candidates\u001b[0;34m(candidate_params, cv, more_results)\u001b[0m\n\u001b[1;32m    793\u001b[0m                               n_splits, n_candidates, n_candidates * n_splits))\n\u001b[1;32m    794\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 795\u001b[0;31m                 out = parallel(delayed(_fit_and_score)(clone(base_estimator),\n\u001b[0m\u001b[1;32m    796\u001b[0m                                                        \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    797\u001b[0m                                                        \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   1042\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_iterating\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_original_iterator\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1043\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1044\u001b[0;31m             \u001b[0;32mwhile\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_one_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1045\u001b[0m                 \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1046\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36mdispatch_one_batch\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m    857\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    858\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 859\u001b[0;31m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtasks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    860\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    861\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_dispatch\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m    775\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    776\u001b[0m             \u001b[0mjob_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 777\u001b[0;31m             \u001b[0mjob\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    778\u001b[0m             \u001b[0;31m# A job can complete so quickly than its callback is\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    779\u001b[0m             \u001b[0;31m# called before we get here, causing self._jobs to\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36mapply_async\u001b[0;34m(self, func, callback)\u001b[0m\n\u001b[1;32m    206\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    207\u001b[0m         \u001b[0;34m\"\"\"Schedule a func to be run\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 208\u001b[0;31m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImmediateResult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    209\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    210\u001b[0m             \u001b[0mcallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m    570\u001b[0m         \u001b[0;31m# Don't delay the application, to avoid keeping the input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    571\u001b[0m         \u001b[0;31m# arguments in memory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 572\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    573\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    574\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    260\u001b[0m         \u001b[0;31m# change the default number of processes to -1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    261\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m             return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m    263\u001b[0m                     for func, args, kwargs in self.items]\n\u001b[1;32m    264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    260\u001b[0m         \u001b[0;31m# change the default number of processes to -1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    261\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m             return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m    263\u001b[0m                     for func, args, kwargs in self.items]\n\u001b[1;32m    264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/fixes.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    220\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    221\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mconfig_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_validation.py\u001b[0m in \u001b[0;36m_fit_and_score\u001b[0;34m(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters, return_n_test_samples, return_times, return_estimator, split_progress, candidate_progress, error_score)\u001b[0m\n\u001b[1;32m    596\u001b[0m             \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    597\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 598\u001b[0;31m             \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    599\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    600\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m    943\u001b[0m                              compute_sample_weight(self.class_weight, y))\n\u001b[1;32m    944\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 945\u001b[0;31m         \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    946\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    947\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m    591\u001b[0m                 \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    592\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 593\u001b[0;31m             self.coef_, self.n_iter_ = _ridge_regression(\n\u001b[0m\u001b[1;32m    594\u001b[0m                 \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0malpha\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malpha\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    595\u001b[0m                 \u001b[0mmax_iter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_iter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtol\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msolver\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msolver\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py\u001b[0m in \u001b[0;36m_ridge_regression\u001b[0;34m(X, y, alpha, sample_weight, solver, max_iter, tol, verbose, random_state, return_n_iter, return_intercept, X_scale, X_offset, check_input)\u001b[0m\n\u001b[1;32m    461\u001b[0m             \u001b[0mK\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msafe_sparse_dot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdense_output\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    462\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 463\u001b[0;31m                 \u001b[0mdual_coef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_solve_cholesky_kernel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mK\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0malpha\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    464\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    465\u001b[0m                 \u001b[0mcoef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msafe_sparse_dot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdual_coef\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdense_output\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/linear_model/_ridge.py\u001b[0m in \u001b[0;36m_solve_cholesky_kernel\u001b[0;34m(K, y, alpha, sample_weight, copy)\u001b[0m\n\u001b[1;32m    185\u001b[0m             \u001b[0;31m#       use the fall-back solution below in case a LinAlgError\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    186\u001b[0m             \u001b[0;31m#       is raised\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m             dual_coef = linalg.solve(K, y, sym_pos=True,\n\u001b[0m\u001b[1;32m    188\u001b[0m                                      overwrite_a=False)\n\u001b[1;32m    189\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinAlgError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/scipy/linalg/basic.py\u001b[0m in \u001b[0;36msolve\u001b[0;34m(a, b, sym_pos, lower, overwrite_a, overwrite_b, debug, check_finite, assume_a, transposed)\u001b[0m\n\u001b[1;32m    252\u001b[0m                            overwrite_b=overwrite_b)\n\u001b[1;32m    253\u001b[0m         \u001b[0m_solve_check\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 254\u001b[0;31m         \u001b[0mrcond\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpocon\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0manorm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    255\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    256\u001b[0m     \u001b[0m_solve_check\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlamch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrcond\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "result = evaluate(ds, ridge_metric, 'ridge'\n",
    "                  , bptt=seq_len\n",
    "                  , eval_position_range=eval_positions\n",
    "                  , device=device\n",
    "                  , max_samples=20\n",
    "                  , overwrite=True\n",
    "                  , save=False)\n",
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### XG Boost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1346,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating kr-vs-kp\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 2 folds for each of 1 candidates, totalling 2 fits\n",
      "[17:30:32] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1/2 [00:04<00:04,  4.52s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{}\n",
      "1.0\n",
      "Fitting 2 folds for each of 1 candidates, totalling 2 fits\n",
      "[17:30:36] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:08<00:00,  4.20s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{}\n",
      "1.0\n",
      "\t Eval position 30 done..\n",
      "Evaluating credit-g\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 2 folds for each of 1 candidates, totalling 2 fits\n",
      "[17:30:39] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1/2 [00:02<00:02,  2.21s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{}\n",
      "0.6142857142857143\n",
      "Fitting 2 folds for each of 1 candidates, totalling 2 fits\n",
      "[17:30:41] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:04<00:00,  2.41s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{}\n",
      "0.5285714285714286\n",
      "\t Eval position 30 done..\n",
      "Evaluating vehicle\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 2 folds for each of 1 candidates, totalling 2 fits\n",
      "[17:30:44] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1/2 [00:03<00:03,  3.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{}\n",
      "0.5714285714285714\n",
      "Fitting 2 folds for each of 1 candidates, totalling 2 fits\n",
      "[17:30:47] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:06<00:00,  3.25s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{}\n",
      "0.5571428571428572\n",
      "\t Eval position 30 done..\n",
      "Evaluating wine\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 2 folds for each of 1 candidates, totalling 2 fits\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2 [00:01<?, ?it/s]\n",
      "ERROR:root:Internal Python error in the inspect module.\n",
      "Below is the traceback from this internal error.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[17:30:50] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[17:30:35] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   1.1s\n",
      "[17:30:39] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   1.2s\n",
      "[17:30:46] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   1.4s\n",
      "[17:30:31] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   1.1s\n",
      "[17:30:37] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   1.7s\n",
      "[17:30:42] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   1.9s\n",
      "[17:30:48] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   0.7s\n",
      "[17:30:35] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   0.4s\n",
      "[17:30:39] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   0.8s\n",
      "[17:30:46] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   1.2s\n",
      "[17:30:31] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   0.6s\n",
      "[17:30:37] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   1.4s\n",
      "[17:30:42] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   1.8s\n",
      "[17:30:48] WARNING: ../src/learner.cc:1095: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n",
      "[CV] END .................................................... total time=   1.1s\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3441, in run_code\n",
      "    exec(code_obj, self.user_global_ns, self.user_ns)\n",
      "  File \"/tmp/ipykernel_71014/762088562.py\", line 4, in <module>\n",
      "    result = evaluate(ds, xgb_acc, 'xgb', bptt, eval_positions, device='cpu', max_samples=2, overwrite=True)\n",
      "  File \"/home/anon/prior-fitting/tabular.py\", line 190, in evaluate\n",
      "    ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n",
      "  File \"/home/anon/prior-fitting/tabular.py\", line 221, in evaluate_dataset\n",
      "    r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n",
      "  File \"/home/anon/prior-fitting/tabular.py\", line 309, in evaluate_position\n",
      "    acc_eval_pos, outputs = batch_pred(model, eval_xs, eval_ys, categorical_feats, start=eval_position)\n",
      "  File \"/home/anon/prior-fitting/tabular.py\", line 324, in batch_pred\n",
      "    acc, output = acc_function(eval_x[:start], eval_y[:start], eval_x[start:], eval_y[start:], categorical_feats)\n",
      "  File \"/home/anon/prior-fitting/tabular.py\", line 628, in xgb_acc\n",
      "    clf.fit(x, y.astype(int))\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\", line 63, in inner_f\n",
      "    return f(*args, **kwargs)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\", line 880, in fit\n",
      "    self.best_estimator_.fit(X, y, **fit_params)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/core.py\", line 433, in inner_f\n",
      "    return f(**kwargs)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/sklearn.py\", line 1176, in fit\n",
      "    self._Booster = train(\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/training.py\", line 189, in train\n",
      "    bst = _train_internal(params, dtrain,\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/training.py\", line 81, in _train_internal\n",
      "    bst.update(dtrain, i, obj)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/core.py\", line 1496, in update\n",
      "    _check_call(_LIB.XGBoosterUpdateOneIter(self.handle,\n",
      "KeyboardInterrupt\n",
      "\n",
      "During handling of the above exception, another exception occurred:\n",
      "\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 2061, in showtraceback\n",
      "    stb = value._render_traceback_()\n",
      "AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'\n",
      "\n",
      "During handling of the above exception, another exception occurred:\n",
      "\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 1101, in get_records\n",
      "    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 248, in wrapped\n",
      "    return f(*args, **kwargs)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\", line 281, in _fixed_getinnerframes\n",
      "    records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/inspect.py\", line 1541, in getinnerframes\n",
      "    frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/inspect.py\", line 1499, in getframeinfo\n",
      "    filename = getsourcefile(frame) or getfile(frame)\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/inspect.py\", line 709, in getsourcefile\n",
      "    if getattr(getmodule(object, filename), '__loader__', None) is not None:\n",
      "  File \"/home/anon/miniconda3/envs/prior-fitting/lib/python3.9/inspect.py\", line 745, in getmodule\n",
      "    for modname, module in sys.modules.copy().items():\n",
      "KeyboardInterrupt\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "object of type 'NoneType' has no len()",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "    \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_71014/762088562.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxgb_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'xgb'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbptt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_positions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'cpu'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moverwrite\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(datasets, model, method, bptt, eval_position_range, device, max_features, plot, extend_features, save, rescale_features, overwrite, max_samples, path_interfix)\u001b[0m\n\u001b[1;32m    189\u001b[0m         \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 190\u001b[0;31m         ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n\u001b[0m\u001b[1;32m    191\u001b[0m                                rescale_features=rescale_features_factor, max_samples=max_samples)\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_dataset\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position_range, plot, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m    220\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0meval_position\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 221\u001b[0;31m         r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n\u001b[0m\u001b[1;32m    222\u001b[0m                               max_samples=max_samples)\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_position\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m    308\u001b[0m         \u001b[0;31m# acc_per_position = torch.tensor(batch_pred(model,eval_xs,eval_ys, start=2))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 309\u001b[0;31m         \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_pred\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbatch_pred\u001b[0;34m(acc_function, eval_xs, eval_ys, categorical_feats, start)\u001b[0m\n\u001b[1;32m    323\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 324\u001b[0;31m         \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    325\u001b[0m         \u001b[0maccs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mxgb_acc\u001b[0;34m(x, y, test_x, test_y, cat_features)\u001b[0m\n\u001b[1;32m    627\u001b[0m     \u001b[0;31m# fit model to data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 628\u001b[0;31m     \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    629\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     62\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, groups, **fit_params)\u001b[0m\n\u001b[1;32m    879\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0my\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 880\u001b[0;31m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbest_estimator_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    881\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/core.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    432\u001b[0m             \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 433\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    434\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/sklearn.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight, base_margin, eval_set, eval_metric, early_stopping_rounds, verbose, xgb_model, sample_weight_eval_set, base_margin_eval_set, feature_weights, callbacks)\u001b[0m\n\u001b[1;32m   1175\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1176\u001b[0;31m         self._Booster = train(\n\u001b[0m\u001b[1;32m   1177\u001b[0m             \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/training.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(params, dtrain, num_boost_round, evals, obj, feval, maximize, early_stopping_rounds, evals_result, verbose_eval, xgb_model, callbacks)\u001b[0m\n\u001b[1;32m    188\u001b[0m     \"\"\"\n\u001b[0;32m--> 189\u001b[0;31m     bst = _train_internal(params, dtrain,\n\u001b[0m\u001b[1;32m    190\u001b[0m                           \u001b[0mnum_boost_round\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_boost_round\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/training.py\u001b[0m in \u001b[0;36m_train_internal\u001b[0;34m(params, dtrain, num_boost_round, evals, obj, feval, xgb_model, callbacks, evals_result, maximize, verbose_eval, early_stopping_rounds)\u001b[0m\n\u001b[1;32m     80\u001b[0m             \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 81\u001b[0;31m         \u001b[0mbst\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     82\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mafter_iteration\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbst\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/xgboost/core.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, dtrain, iteration, fobj)\u001b[0m\n\u001b[1;32m   1495\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mfobj\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1496\u001b[0;31m             _check_call(_LIB.XGBoosterUpdateOneIter(self.handle,\n\u001b[0m\u001b[1;32m   1497\u001b[0m                                                     \u001b[0mctypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_int\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miteration\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: ",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m   2060\u001b[0m                         \u001b[0;31m# in the engines. This should return a list of strings.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2061\u001b[0;31m                         \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2062\u001b[0m                     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'KeyboardInterrupt' object has no attribute '_render_traceback_'",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "    \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mshowtraceback\u001b[0;34m(self, exc_tuple, filename, tb_offset, exception_only, running_compiled_code)\u001b[0m\n\u001b[1;32m   2061\u001b[0m                         \u001b[0mstb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render_traceback_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2062\u001b[0m                     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2063\u001b[0;31m                         stb = self.InteractiveTB.structured_traceback(etype,\n\u001b[0m\u001b[1;32m   2064\u001b[0m                                             value, tb, tb_offset=tb_offset)\n\u001b[1;32m   2065\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m   1365\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1366\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1367\u001b[0;31m         return FormattedTB.structured_traceback(\n\u001b[0m\u001b[1;32m   1368\u001b[0m             self, etype, value, tb, tb_offset, number_of_lines_of_context)\n\u001b[1;32m   1369\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, value, tb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m   1265\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mverbose_modes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1266\u001b[0m             \u001b[0;31m# Verbose modes need a full traceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1267\u001b[0;31m             return VerboseTB.structured_traceback(\n\u001b[0m\u001b[1;32m   1268\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtb_offset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumber_of_lines_of_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1269\u001b[0m             )\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mstructured_traceback\u001b[0;34m(self, etype, evalue, etb, tb_offset, number_of_lines_of_context)\u001b[0m\n\u001b[1;32m   1122\u001b[0m         \u001b[0;34m\"\"\"Return a nice text document describing the traceback.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1124\u001b[0;31m         formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,\n\u001b[0m\u001b[1;32m   1125\u001b[0m                                                                tb_offset)\n\u001b[1;32m   1126\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mformat_exception_as_a_whole\u001b[0;34m(self, etype, evalue, etb, number_of_lines_of_context, tb_offset)\u001b[0m\n\u001b[1;32m   1080\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1081\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1082\u001b[0;31m         \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfind_recursion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morig_etype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1083\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1084\u001b[0m         \u001b[0mframes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat_records\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlast_unique\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecursion_repeat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/IPython/core/ultratb.py\u001b[0m in \u001b[0;36mfind_recursion\u001b[0;34m(etype, value, records)\u001b[0m\n\u001b[1;32m    380\u001b[0m     \u001b[0;31m# first frame (from in to out) that looks different.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    381\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_recursion_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0metype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 382\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrecords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    383\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    384\u001b[0m     \u001b[0;31m# Select filename, lineno, func_name to track frames with\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: object of type 'NoneType' has no len()"
     ]
    }
   ],
   "source": [
    "result = evaluate(ds, ridge_metric, 'ridge'\n",
    "                  , bptt=seq_len\n",
    "                  , eval_position_range=eval_positions\n",
    "                  , device=device\n",
    "                  , max_samples=20\n",
    "                  , overwrite=True\n",
    "                  , save=False)\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Catboost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1338,
   "metadata": {
    "hidden": true,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating kr-vs-kp\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  3.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6428571428571429\n",
      "\t Eval position 30 done..\n",
      "Evaluating credit-g\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  5.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5571428571428572\n",
      "\t Eval position 30 done..\n",
      "Evaluating vehicle\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1/1 [00:00<00:00,  6.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5285714285714286\n",
      "\t Eval position 30 done..\n",
      "Evaluating wine\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  6.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8714285714285714\n",
      "\t Eval position 30 done..\n",
      "Evaluating kc1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  6.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6571428571428571\n",
      "\t Eval position 30 done..\n",
      "Evaluating airlines\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  6.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7714285714285715\n",
      "\t Eval position 30 done..\n",
      "Evaluating bank-marketing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  6.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6571428571428571\n",
      "\t Eval position 30 done..\n",
      "Evaluating blood-transfusion-service-center\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  7.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5285714285714286\n",
      "\t Eval position 30 done..\n",
      "Evaluating phoneme\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  6.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7428571428571429\n",
      "\t Eval position 30 done..\n",
      "Evaluating covertype\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  2.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6\n",
      "\t Eval position 30 done..\n",
      "Evaluating numerai28.6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  6.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5142857142857142\n",
      "\t Eval position 30 done..\n",
      "Evaluating connect-4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  4.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9428571428571428\n",
      "\t Eval position 30 done..\n",
      "Evaluating car\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  6.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7\n",
      "\t Eval position 30 done..\n",
      "Evaluating Australian\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  5.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8857142857142857\n",
      "\t Eval position 30 done..\n",
      "Evaluating segment\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  5.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n",
      "\t Eval position 30 done..\n",
      "Evaluating jungle_chess_2pcs_raw_endgame_complete\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  6.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7142857142857143\n",
      "\t Eval position 30 done..\n",
      "Evaluating sylvine\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  6.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8142857142857143\n",
      "\t Eval position 30 done..\n",
      "Evaluating MiniBooNE\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  5.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8142857142857143\n",
      "\t Eval position 30 done..\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating dionis\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  5.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9142857142857143\n",
      "\t Eval position 30 done..\n",
      "Evaluating jannis\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  4.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5571428571428572\n",
      "\t Eval position 30 done..\n",
      "Evaluating helena\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/catboost/core.py\u001b[0m in \u001b[0;36m_tune_hyperparams\u001b[0;34m(self, param_grid, X, y, cv, n_iter, partition_random_seed, calc_cv_statistics, search_by_train_test_split, refit, shuffle, stratified, train_size, verbose, plot, log_cout, log_cerr)\u001b[0m\n\u001b[1;32m   3628\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mlog_fixup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog_cout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_cerr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mplot_wrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0m_get_train_dir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3629\u001b[0;31m             cv_result = self._object._tune_hyperparams(\n\u001b[0m\u001b[1;32m   3630\u001b[0m                 \u001b[0mparam_grid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_params\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"train_pool\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_iter\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m_catboost.pyx\u001b[0m in \u001b[0;36m_catboost._CatBoost._tune_hyperparams\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;32m_catboost.pyx\u001b[0m in \u001b[0;36m_catboost._CatBoost._tune_hyperparams\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: ",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_71014/2224404849.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalid_datasets\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mselector\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'valid'\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mtest_datasets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcatboost_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'catboost'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbptt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_positions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mplot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'cpu'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moverwrite\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(datasets, model, method, bptt, eval_position_range, device, max_features, plot, extend_features, save, rescale_features, overwrite, max_samples, path_interfix)\u001b[0m\n\u001b[1;32m    188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    189\u001b[0m         \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 190\u001b[0;31m         ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n\u001b[0m\u001b[1;32m    191\u001b[0m                                rescale_features=rescale_features_factor, max_samples=max_samples)\n\u001b[1;32m    192\u001b[0m         \u001b[0melapsed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_dataset\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position_range, plot, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m    219\u001b[0m     \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    220\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0meval_position\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 221\u001b[0;31m         r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n\u001b[0m\u001b[1;32m    222\u001b[0m                               max_samples=max_samples)\n\u001b[1;32m    223\u001b[0m         \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_position\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m    307\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    308\u001b[0m         \u001b[0;31m# acc_per_position = torch.tensor(batch_pred(model,eval_xs,eval_ys, start=2))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 309\u001b[0;31m         \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_pred\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    311\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbatch_pred\u001b[0;34m(acc_function, eval_xs, eval_ys, categorical_feats, start)\u001b[0m\n\u001b[1;32m    322\u001b[0m         \u001b[0meval_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mmean\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mstd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    323\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 324\u001b[0;31m         \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    325\u001b[0m         \u001b[0maccs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    326\u001b[0m         \u001b[0moutputs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mcatboost_acc\u001b[0;34m(x, y, test_x, test_y, categorical_feats)\u001b[0m\n\u001b[1;32m    593\u001b[0m                                logging_level='Silent')\n\u001b[1;32m    594\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 595\u001b[0;31m     grid_search_result = model.grid_search(param_grid['catboost'],\n\u001b[0m\u001b[1;32m    596\u001b[0m                                            \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    597\u001b[0m                                            \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/catboost/core.py\u001b[0m in \u001b[0;36mgrid_search\u001b[0;34m(self, param_grid, X, y, cv, partition_random_seed, calc_cv_statistics, search_by_train_test_split, refit, shuffle, stratified, train_size, verbose, plot, log_cout, log_cerr)\u001b[0m\n\u001b[1;32m   3730\u001b[0m                     \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Parameter grid value is not iterable (key={!r}, value={!r})'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3731\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3732\u001b[0;31m         return self._tune_hyperparams(\n\u001b[0m\u001b[1;32m   3733\u001b[0m             \u001b[0mparam_grid\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparam_grid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_iter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3734\u001b[0m             \u001b[0mpartition_random_seed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpartition_random_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcalc_cv_statistics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcalc_cv_statistics\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/catboost/core.py\u001b[0m in \u001b[0;36m_tune_hyperparams\u001b[0;34m(self, param_grid, X, y, cv, n_iter, partition_random_seed, calc_cv_statistics, search_by_train_test_split, refit, shuffle, stratified, train_size, verbose, plot, log_cout, log_cerr)\u001b[0m\n\u001b[1;32m   3627\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3628\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mlog_fixup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog_cout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_cerr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mplot_wrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0m_get_train_dir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3629\u001b[0;31m             cv_result = self._object._tune_hyperparams(\n\u001b[0m\u001b[1;32m   3630\u001b[0m                 \u001b[0mparam_grid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_params\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"train_pool\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_iter\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3631\u001b[0m                 \u001b[0mfold_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpartition_random_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstratified\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "result = evaluate(ds, catboost_metric, 'catboost'\n",
    "                  , bptt=seq_len\n",
    "                  , eval_position_range=eval_positions\n",
    "                  , device=device\n",
    "                  , max_samples=20\n",
    "                  , overwrite=True\n",
    "                  , save=False)\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## Running bayesian inference on one dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['wine',\n",
       " tensor([[1.1640e+01, 2.0600e+00, 2.4600e+00,  ..., 1.0000e+00, 2.7500e+00,\n",
       "          6.8000e+02],\n",
       "         [1.3860e+01, 1.3500e+00, 2.2700e+00,  ..., 1.0100e+00, 3.5500e+00,\n",
       "          1.0450e+03],\n",
       "         [1.2340e+01, 2.4500e+00, 2.4600e+00,  ..., 8.0000e-01, 3.3800e+00,\n",
       "          4.3800e+02],\n",
       "         ...,\n",
       "         [1.3450e+01, 3.7000e+00, 2.6000e+00,  ..., 8.5000e-01, 1.5600e+00,\n",
       "          6.9500e+02],\n",
       "         [1.1560e+01, 2.0500e+00, 3.2300e+00,  ..., 9.3000e-01, 3.6900e+00,\n",
       "          4.6500e+02],\n",
       "         [1.2820e+01, 3.3700e+00, 2.3000e+00,  ..., 7.2000e-01, 1.7500e+00,\n",
       "          6.8500e+02]]),\n",
       " tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n",
       "         0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n",
       "         0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n",
       "         0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n",
       "         0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n",
       "         0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n",
       "         0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,\n",
       "         0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.]),\n",
       " []]"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# NOTE: Dataset must be ordered 1,0,1,0,...\n",
    "one_ds = ds[0]\n",
    "one_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "hidden": true,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.07891185]\n",
      " [0.9558054 ]\n",
      " [0.05336431]\n",
      " [0.98093307]\n",
      " [0.07384931]\n",
      " [0.8429412 ]\n",
      " [0.03347415]\n",
      " [0.973169  ]\n",
      " [0.04434767]\n",
      " [0.9206543 ]\n",
      " [0.03534983]\n",
      " [0.7528224 ]\n",
      " [0.23227222]\n",
      " [0.96119374]\n",
      " [0.20475379]\n",
      " [0.97402555]\n",
      " [0.04753578]\n",
      " [0.8418446 ]\n",
      " [0.20400934]\n",
      " [0.9496556 ]]\n"
     ]
    }
   ],
   "source": [
    "result_list = []\n",
    "eval_position = eval_positions[0]\n",
    "rescale_features = max_features / one_ds[1].shape[1] if config['prior_normalize_by_used_features'] else 1\n",
    "\n",
    "model = model.to(device)\n",
    "model = model.eval()\n",
    "\n",
    "# Data to run inference on\n",
    "eval_xs = one_ds[1][0:50]\n",
    "# Extending dataset to standardized feature length\n",
    "eval_xs = torch.cat([eval_xs, torch.zeros((eval_xs.shape[0], max_features - eval_xs.shape[1]))], -1).to(device)\n",
    "eval_ys = one_ds[2][0:50].to(device)\n",
    "\n",
    "for i, pos in enumerate(range(eval_position, eval_xs.shape[0])):\n",
    "    eval_x = torch.cat([eval_xs[:eval_position], eval_xs[pos].unsqueeze(0)])\n",
    "    eval_y = eval_ys[:eval_position]\n",
    "\n",
    "    # Center data using training positions so that it matches priors\n",
    "    mean = eval_x.mean(0)\n",
    "    std = eval_x.std(0) + .000001\n",
    "    eval_x = (eval_x - mean) / std\n",
    "    eval_x = eval_x / rescale_features\n",
    "\n",
    "    result_list += [torch.sigmoid(model((eval_x.unsqueeze(1), eval_y.unsqueeze(1).float()), single_eval_pos=eval_position)).squeeze(-1).detach().cpu().numpy()[0]]\n",
    "\n",
    "print(np.array(result_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1229,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating kr-vs-kp\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/5 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 2 folds for each of 8 candidates, totalling 16 fits\n",
      "5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/300 [00:00<?, ?it/s]\u001b[A\n",
      "  0%|          | 1/300 [00:02<11:40,  2.34s/it]\u001b[A\n",
      "  6%|▌         | 17/300 [00:02<00:29,  9.58it/s]\u001b[A\n",
      " 11%|█         | 33/300 [00:02<00:12, 20.89it/s]\u001b[A\n",
      " 16%|█▋        | 49/300 [00:02<00:07, 34.35it/s]\u001b[A\n",
      " 22%|██▏       | 65/300 [00:02<00:04, 49.62it/s]\u001b[A\n",
      " 28%|██▊       | 83/300 [00:02<00:03, 68.38it/s]\u001b[A\n",
      " 33%|███▎      | 100/300 [00:02<00:02, 85.84it/s]\u001b[A\n",
      " 39%|███▉      | 117/300 [00:03<00:01, 102.31it/s]\u001b[A\n",
      " 45%|████▍     | 134/300 [00:03<00:01, 117.03it/s]\u001b[A\n",
      " 51%|█████     | 152/300 [00:03<00:01, 130.51it/s]\u001b[A\n",
      " 56%|█████▋    | 169/300 [00:03<00:00, 139.07it/s]\u001b[A\n",
      " 62%|██████▏   | 187/300 [00:03<00:00, 148.43it/s]\u001b[A\n",
      " 68%|██████▊   | 204/300 [00:03<00:00, 154.00it/s]\u001b[A\n",
      " 74%|███████▍  | 222/300 [00:03<00:00, 159.37it/s]\u001b[A\n",
      " 80%|███████▉  | 239/300 [00:03<00:00, 162.31it/s]\u001b[A\n",
      " 86%|████████▌ | 257/300 [00:03<00:00, 165.46it/s]\u001b[A\n",
      " 92%|█████████▏| 275/300 [00:03<00:00, 168.22it/s]\u001b[A\n",
      "100%|██████████| 300/300 [00:04<00:00, 72.61it/s] \u001b[A\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/2] END .................embed=5, lr=0.001;, score=1.000 total time=   4.8s\n",
      "5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/300 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▌         | 17/300 [00:00<00:01, 169.72it/s]\u001b[A\n",
      " 12%|█▏        | 35/300 [00:00<00:01, 170.91it/s]\u001b[A\n",
      " 18%|█▊        | 53/300 [00:00<00:01, 171.54it/s]\u001b[A\n",
      " 24%|██▎       | 71/300 [00:00<00:01, 170.70it/s]\u001b[A\n",
      " 30%|██▉       | 89/300 [00:00<00:01, 166.52it/s]\u001b[A\n",
      " 36%|███▌      | 107/300 [00:00<00:01, 169.07it/s]\u001b[A\n",
      " 42%|████▏     | 125/300 [00:00<00:01, 171.75it/s]\u001b[A\n",
      " 48%|████▊     | 143/300 [00:00<00:00, 173.60it/s]\u001b[A\n",
      " 54%|█████▎    | 161/300 [00:00<00:00, 173.66it/s]\u001b[A\n",
      " 60%|█████▉    | 179/300 [00:01<00:00, 174.24it/s]\u001b[A\n",
      " 66%|██████▌   | 197/300 [00:01<00:00, 173.49it/s]\u001b[A\n",
      " 72%|███████▏  | 215/300 [00:01<00:00, 173.84it/s]\u001b[A\n",
      " 78%|███████▊  | 233/300 [00:01<00:00, 172.16it/s]\u001b[A\n",
      " 84%|████████▎ | 251/300 [00:01<00:00, 173.07it/s]\u001b[A\n",
      " 90%|████████▉ | 269/300 [00:01<00:00, 173.59it/s]\u001b[A\n",
      "100%|██████████| 300/300 [00:01<00:00, 172.72it/s]\u001b[A\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 2/2] END .................embed=5, lr=0.001;, score=0.933 total time=   2.1s\n",
      "5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/300 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▌         | 18/300 [00:00<00:01, 172.36it/s]\u001b[A\n",
      " 12%|█▏        | 36/300 [00:00<00:01, 174.56it/s]\u001b[A\n",
      " 18%|█▊        | 54/300 [00:00<00:01, 174.34it/s]\u001b[A\n",
      " 24%|██▍       | 72/300 [00:00<00:01, 172.24it/s]\u001b[A\n",
      " 30%|███       | 90/300 [00:00<00:01, 170.46it/s]\u001b[A\n",
      " 36%|███▌      | 108/300 [00:00<00:01, 170.08it/s]\u001b[A\n",
      " 42%|████▏     | 126/300 [00:00<00:01, 170.11it/s]\u001b[A\n",
      " 48%|████▊     | 144/300 [00:00<00:00, 169.69it/s]\u001b[A\n",
      " 54%|█████▎    | 161/300 [00:00<00:00, 165.75it/s]\u001b[A\n",
      " 60%|█████▉    | 179/300 [00:01<00:00, 167.09it/s]\u001b[A\n",
      " 65%|██████▌   | 196/300 [00:01<00:00, 167.40it/s]\u001b[A\n",
      " 71%|███████▏  | 214/300 [00:01<00:00, 168.37it/s]\u001b[A\n",
      " 77%|███████▋  | 232/300 [00:01<00:00, 170.04it/s]\u001b[A\n",
      " 83%|████████▎ | 250/300 [00:01<00:00, 171.48it/s]\u001b[A\n",
      " 89%|████████▉ | 268/300 [00:01<00:00, 171.37it/s]\u001b[A\n",
      "100%|██████████| 300/300 [00:01<00:00, 170.42it/s]\u001b[A\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/2] END ................embed=5, lr=0.0001;, score=0.667 total time=   2.1s\n",
      "5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/300 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▌         | 18/300 [00:00<00:01, 176.66it/s]\u001b[A\n",
      " 12%|█▏        | 36/300 [00:00<00:01, 175.66it/s]\u001b[A\n",
      " 18%|█▊        | 54/300 [00:00<00:01, 176.82it/s]\u001b[A\n",
      " 24%|██▍       | 72/300 [00:00<00:01, 176.85it/s]\u001b[A\n",
      " 30%|███       | 90/300 [00:00<00:01, 174.94it/s]\u001b[A\n",
      " 36%|███▌      | 108/300 [00:00<00:01, 173.28it/s]\u001b[A\n",
      " 42%|████▏     | 126/300 [00:00<00:00, 174.48it/s]\u001b[A\n",
      " 48%|████▊     | 144/300 [00:00<00:00, 175.75it/s]\u001b[A\n",
      " 54%|█████▍    | 163/300 [00:00<00:00, 177.11it/s]\u001b[A\n",
      " 60%|██████    | 181/300 [00:01<00:00, 177.97it/s]\u001b[A\n",
      " 66%|██████▋   | 199/300 [00:01<00:00, 178.49it/s]\u001b[A\n",
      " 72%|███████▏  | 217/300 [00:01<00:00, 178.84it/s]\u001b[A\n",
      " 79%|███████▊  | 236/300 [00:01<00:00, 179.42it/s]\u001b[A\n",
      " 85%|████████▍ | 254/300 [00:01<00:00, 178.96it/s]\u001b[A\n",
      " 91%|█████████ | 272/300 [00:01<00:00, 178.79it/s]\u001b[A\n",
      "100%|██████████| 300/300 [00:01<00:00, 177.22it/s]\u001b[A\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 2/2] END ................embed=5, lr=0.0001;, score=1.000 total time=   2.0s\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/300 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▌         | 17/300 [00:00<00:01, 167.33it/s]\u001b[A\n",
      " 12%|█▏        | 35/300 [00:00<00:01, 171.62it/s]\u001b[A\n",
      " 18%|█▊        | 53/300 [00:00<00:01, 172.08it/s]\u001b[A\n",
      " 24%|██▎       | 71/300 [00:00<00:01, 172.85it/s]\u001b[A\n",
      " 30%|██▉       | 89/300 [00:00<00:01, 174.53it/s]\u001b[A\n",
      " 36%|███▌      | 108/300 [00:00<00:01, 176.36it/s]\u001b[A\n",
      " 42%|████▏     | 126/300 [00:00<00:00, 175.69it/s]\u001b[A\n",
      " 48%|████▊     | 144/300 [00:00<00:00, 163.70it/s]\u001b[A\n",
      " 54%|█████▎    | 161/300 [00:00<00:00, 156.47it/s]\u001b[A\n",
      " 59%|█████▉    | 177/300 [00:01<00:00, 145.27it/s]\u001b[A\n",
      " 64%|██████▍   | 192/300 [00:01<00:00, 144.14it/s]\u001b[A\n",
      " 69%|██████▉   | 207/300 [00:01<00:00, 143.70it/s]\u001b[A\n",
      " 74%|███████▍  | 222/300 [00:01<00:00, 142.87it/s]\u001b[A\n",
      " 79%|███████▉  | 237/300 [00:01<00:00, 141.92it/s]\u001b[A\n",
      " 85%|████████▌ | 255/300 [00:01<00:00, 151.16it/s]\u001b[A\n",
      " 91%|█████████▏| 274/300 [00:01<00:00, 160.83it/s]\u001b[A\n",
      "100%|██████████| 300/300 [00:01<00:00, 160.18it/s]\u001b[A\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/2] END ................embed=10, lr=0.001;, score=1.000 total time=   2.2s\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/300 [00:00<?, ?it/s]\u001b[A\n",
      "  6%|▌         | 18/300 [00:00<00:01, 178.84it/s]\u001b[A\n",
      " 12%|█▏        | 37/300 [00:00<00:01, 181.55it/s]\u001b[A\n",
      " 19%|█▊        | 56/300 [00:00<00:01, 182.60it/s]\u001b[A\n",
      " 25%|██▌       | 75/300 [00:00<00:01, 182.45it/s]\u001b[A\n",
      " 31%|███▏      | 94/300 [00:00<00:01, 182.45it/s]\u001b[A\n",
      " 40%|████      | 121/300 [00:00<00:01, 177.93it/s]\u001b[A\n",
      "  0%|          | 0/5 [00:14<?, ?it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_73533/28407073.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalid_datasets\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mselector\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'valid'\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mtest_datasets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbayes_net_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'bayes_net'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbptt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_positions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'cpu'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moverwrite\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/home/anon/prior-fitting/results_'\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mselector\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;34m'_bayes_net.npy'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(datasets, model, method, bptt, eval_position_range, device, max_features, plot, extend_features, save, rescale_features, overwrite, max_samples, path_interfix)\u001b[0m\n\u001b[1;32m    191\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    192\u001b[0m         \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 193\u001b[0;31m         ds_result = evaluate_dataset(X.to(device), y.to(device), categorical_feats, model, bptt, eval_position_range,\n\u001b[0m\u001b[1;32m    194\u001b[0m                                rescale_features=rescale_features_factor, max_samples=max_samples)\n\u001b[1;32m    195\u001b[0m         \u001b[0melapsed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_dataset\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position_range, plot, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m    222\u001b[0m     \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    223\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0meval_position\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval_position_range\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m         r = evaluate_position(X, y, categorical_feats, model, bptt, eval_position, rescale_features=rescale_features,\n\u001b[0m\u001b[1;32m    225\u001b[0m                               max_samples=max_samples)\n\u001b[1;32m    226\u001b[0m         \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mevaluate_position\u001b[0;34m(X, y, categorical_feats, model, bptt, eval_position, rescale_features, max_samples)\u001b[0m\n\u001b[1;32m    310\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    311\u001b[0m         \u001b[0;31m# acc_per_position = torch.tensor(batch_pred(model,eval_xs,eval_ys, start=2))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 312\u001b[0;31m         \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_pred\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_xs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    313\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    314\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0macc_eval_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_ys\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0meval_position\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbatch_pred\u001b[0;34m(acc_function, eval_xs, eval_ys, categorical_feats, start)\u001b[0m\n\u001b[1;32m    325\u001b[0m         \u001b[0meval_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mmean\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mstd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    326\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 327\u001b[0;31m         \u001b[0macc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meval_y\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcategorical_feats\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    328\u001b[0m         \u001b[0maccs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0macc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    329\u001b[0m         \u001b[0moutputs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mbayes_net_acc\u001b[0;34m(x, y, test_x, test_y, cat_features)\u001b[0m\n\u001b[1;32m    514\u001b[0m     \u001b[0mclf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGridSearchCV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'bayes'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    515\u001b[0m     \u001b[0;31m# fit model to data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 516\u001b[0;31m     \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    517\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    518\u001b[0m     \u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_x\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     61\u001b[0m             \u001b[0mextra_args\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     62\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mextra_args\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     65\u001b[0m             \u001b[0;31m# extra_args > 0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, groups, **fit_params)\u001b[0m\n\u001b[1;32m    839\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    840\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 841\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    842\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    843\u001b[0m             \u001b[0;31m# multimetric is determined here because in the case of a callable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36m_run_search\u001b[0;34m(self, evaluate_candidates)\u001b[0m\n\u001b[1;32m   1294\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_run_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1295\u001b[0m         \u001b[0;34m\"\"\"Search all candidates in param_grid\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1296\u001b[0;31m         \u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mParameterGrid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_grid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1297\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1298\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mevaluate_candidates\u001b[0;34m(candidate_params, cv, more_results)\u001b[0m\n\u001b[1;32m    793\u001b[0m                               n_splits, n_candidates, n_candidates * n_splits))\n\u001b[1;32m    794\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 795\u001b[0;31m                 out = parallel(delayed(_fit_and_score)(clone(base_estimator),\n\u001b[0m\u001b[1;32m    796\u001b[0m                                                        \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    797\u001b[0m                                                        \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   1042\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_iterating\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_original_iterator\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1043\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1044\u001b[0;31m             \u001b[0;32mwhile\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_one_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1045\u001b[0m                 \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1046\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36mdispatch_one_batch\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m    857\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    858\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 859\u001b[0;31m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtasks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    860\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    861\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_dispatch\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m    775\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    776\u001b[0m             \u001b[0mjob_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 777\u001b[0;31m             \u001b[0mjob\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    778\u001b[0m             \u001b[0;31m# A job can complete so quickly than its callback is\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    779\u001b[0m             \u001b[0;31m# called before we get here, causing self._jobs to\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36mapply_async\u001b[0;34m(self, func, callback)\u001b[0m\n\u001b[1;32m    206\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    207\u001b[0m         \u001b[0;34m\"\"\"Schedule a func to be run\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 208\u001b[0;31m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImmediateResult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    209\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    210\u001b[0m             \u001b[0mcallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m    570\u001b[0m         \u001b[0;31m# Don't delay the application, to avoid keeping the input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    571\u001b[0m         \u001b[0;31m# arguments in memory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 572\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    573\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    574\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    260\u001b[0m         \u001b[0;31m# change the default number of processes to -1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    261\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m             return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m    263\u001b[0m                     for func, args, kwargs in self.items]\n\u001b[1;32m    264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    260\u001b[0m         \u001b[0;31m# change the default number of processes to -1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    261\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m             return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m    263\u001b[0m                     for func, args, kwargs in self.items]\n\u001b[1;32m    264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/utils/fixes.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    220\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    221\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mconfig_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/sklearn/model_selection/_validation.py\u001b[0m in \u001b[0;36m_fit_and_score\u001b[0;34m(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters, return_n_test_samples, return_times, return_estimator, split_progress, candidate_progress, error_score)\u001b[0m\n\u001b[1;32m    596\u001b[0m             \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    597\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 598\u001b[0;31m             \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    599\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    600\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/prior-fitting/tabular.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m    479\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    480\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 481\u001b[0;31m             \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msvi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    482\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    483\u001b[0m         \u001b[0;31m# Return the classifier\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/infer/svi.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    143\u001b[0m         \u001b[0;31m# get loss and compute gradients\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    144\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mpoutine\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam_only\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mparam_capture\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m             \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_and_grads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguide\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    146\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    147\u001b[0m         params = set(\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/infer/trace_elbo.py\u001b[0m in \u001b[0;36mloss_and_grads\u001b[0;34m(self, model, guide, *args, **kwargs)\u001b[0m\n\u001b[1;32m    138\u001b[0m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    139\u001b[0m         \u001b[0;31m# grab a trace from the generator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 140\u001b[0;31m         \u001b[0;32mfor\u001b[0m \u001b[0mmodel_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguide_trace\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_traces\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    141\u001b[0m             loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(\n\u001b[1;32m    142\u001b[0m                 \u001b[0mmodel_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguide_trace\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/infer/elbo.py\u001b[0m in \u001b[0;36m_get_traces\u001b[0;34m(self, model, guide, args, kwargs)\u001b[0m\n\u001b[1;32m    184\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    185\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_particles\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 186\u001b[0;31m                 \u001b[0;32myield\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/infer/trace_elbo.py\u001b[0m in \u001b[0;36m_get_trace\u001b[0;34m(self, model, guide, args, kwargs)\u001b[0m\n\u001b[1;32m     55\u001b[0m         \u001b[0magainst\u001b[0m \u001b[0mit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     56\u001b[0m         \"\"\"\n\u001b[0;32m---> 57\u001b[0;31m         model_trace, guide_trace = get_importance_trace(\n\u001b[0m\u001b[1;32m     58\u001b[0m             \u001b[0;34m\"flat\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_plate_nesting\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     59\u001b[0m         )\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/infer/enum.py\u001b[0m in \u001b[0;36mget_importance_trace\u001b[0;34m(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)\u001b[0m\n\u001b[1;32m     62\u001b[0m     \u001b[0mmodel_trace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprune_subsample_sites\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_trace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 64\u001b[0;31m     \u001b[0mmodel_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_log_prob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     65\u001b[0m     \u001b[0mguide_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_score_parts\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     66\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mis_validation_enabled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/poutine/trace_struct.py\u001b[0m in \u001b[0;36mcompute_log_prob\u001b[0;34m(self, site_filter)\u001b[0m\n\u001b[1;32m    248\u001b[0m                             \u001b[0;34m\"log_prob_sum at site '{}'\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    249\u001b[0m                         )\n\u001b[0;32m--> 250\u001b[0;31m                         warn_if_inf(\n\u001b[0m\u001b[1;32m    251\u001b[0m                             \u001b[0msite\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"log_prob_sum\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    252\u001b[0m                             \u001b[0;34m\"log_prob_sum at site '{}'\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/prior-fitting/lib/python3.9/site-packages/pyro/util.py\u001b[0m in \u001b[0;36mwarn_if_inf\u001b[0;34m(value, msg, allow_posinf, allow_neginf, filename, lineno)\u001b[0m\n\u001b[1;32m    130\u001b[0m         \u001b[0mvalue\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    131\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumbers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNumber\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m         \u001b[0;32melse\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    133\u001b[0m     ):\n\u001b[1;32m    134\u001b[0m         warnings.warn_explicit(\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "result = evaluate(ds, bayes_net_acc, 'bayes_net'\n",
    "                  , bptt=seq_len\n",
    "                  , eval_position_range=eval_positions\n",
    "                  , device=device\n",
    "                  , max_samples=20\n",
    "                  , overwrite=True\n",
    "                  , save=False)\n",
    "result"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:prior-fitting]",
   "language": "python",
   "name": "conda-env-prior-fitting-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
