{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2d27024-cde1-4150-b42e-ec915c92c28e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import random\n",
    "from typing import Dict, List, Tuple\n",
    "\n",
    "import gymnasium as gym\n",
    "import numpy as np\n",
    "import tqdm\n",
    "from tabpfn import TabPFNRegressor\n",
    "\n",
    "\n",
    "def build_next_state_matrix(next_states: np.ndarray, n_actions: int) -> np.ndarray:\n",
    "    \"\"\"Return [(s', a)] for every action a.\"\"\"\n",
    "    b, obs_dim = next_states.shape\n",
    "    repeated   = np.repeat(next_states, n_actions, axis=0)          # [B*A, obs]\n",
    "    onehots    = np.tile(np.eye(n_actions, dtype=np.float32), (b, 1))\n",
    "    return np.hstack([repeated, onehots]).astype(np.float32)\n",
    "\n",
    "\n",
    "def make_row(state: np.ndarray, action: int, n_actions: int) -> np.ndarray:\n",
    "    \"\"\"Concatenate state with a one-hot action vector.\"\"\"\n",
    "    one_hot = np.zeros(n_actions, dtype=np.float32)\n",
    "    one_hot[action] = 1.0\n",
    "    return np.concatenate([state.astype(np.float32), one_hot])\n",
    "\n",
    "\n",
    "def act_greedy(state: np.ndarray,\n",
    "               model: TabPFNRegressor,\n",
    "               n_actions: int) -> Tuple[int, float]:\n",
    "    \"\"\"Return argmax-Q action and its value from *model*.\"\"\"\n",
    "    rows   = np.vstack([make_row(state, a, n_actions) for a in range(n_actions)])\n",
    "    q_vals = model.predict(rows)\n",
    "    a_best = int(q_vals.argmax())\n",
    "    return a_best, float(q_vals[a_best])\n",
    "\n",
    "\n",
    "def q_value(state: np.ndarray,\n",
    "            action: int,\n",
    "            model: TabPFNRegressor,\n",
    "            n_actions: int) -> float:\n",
    "    if model is None:\n",
    "        return 0.0\n",
    "    return float(model.predict(make_row(state, action, n_actions)[None])[0])\n",
    "\n",
    "\n",
    "def get_action(env,\n",
    "               state,\n",
    "               model: TabPFNRegressor,\n",
    "               epsilon: float,\n",
    "               n_actions: int) -> Tuple[int, float]:\n",
    "    \"\"\"ε-greedy policy that falls back to random actions before *model* exists.\"\"\"\n",
    "    if model is not None and random.random() >= epsilon:\n",
    "        return act_greedy(state, model, n_actions)\n",
    "    action = env.action_space.sample()\n",
    "    return action, q_value(state, action, model, n_actions)\n",
    "\n",
    "\n",
    "def fit_model(data: List[Dict],\n",
    "              env,\n",
    "              gamma: float = 0.99,\n",
    "              n_iter: int = 60,\n",
    "              device: str = \"cuda\") -> TabPFNRegressor:\n",
    "    \"\"\"Fitted-Q iteration with TabPFN (one pass per iteration).\"\"\"\n",
    "    obs_dim   = env.observation_space.shape[0]\n",
    "    n_actions = env.action_space.n\n",
    "\n",
    "    states      = np.vstack([d[\"state\"]      for d in data]).astype(np.float32)\n",
    "    actions     = np.array([d[\"action\"]     for d in data], dtype=np.int64)\n",
    "    rewards     = np.array([d[\"reward\"]     for d in data], dtype=np.float32)\n",
    "    next_states = np.vstack([d[\"next_state\"] for d in data]).astype(np.float32)\n",
    "    dones       = np.array([d[\"done\"]       for d in data], dtype=np.float32)\n",
    "\n",
    "    X = np.hstack([states, np.eye(n_actions, dtype=np.float32)[actions]])\n",
    "\n",
    "    model = TabPFNRegressor(device=device, n_estimators=1, fit_mode=\"fit_with_cache\")\n",
    "\n",
    "    for _ in range(n_iter):\n",
    "        sa_next   = build_next_state_matrix(next_states, n_actions)\n",
    "        q_next    = model.predict(sa_next) if _ else np.zeros(len(sa_next))\n",
    "        q_next    = q_next.reshape(len(states), n_actions).max(axis=1)\n",
    "        q_next[dones == 1.0] = 0.0\n",
    "        targets = rewards + gamma * q_max\n",
    "        if y_full.std() == 0:\n",
    "            targets += np.random.uniform(0, 1e-4, len(rewards))\n",
    "        model.fit(X, targets)\n",
    "\n",
    "    return model\n",
    "\n",
    "def run_tabpfn_fqi(env_id: str = \"CartPole-v1\",\n",
    "                             n_episodes: int = 10_000,\n",
    "                             max_steps: int = 200,\n",
    "                             initial_size: int = 200,\n",
    "                             max_context: int = 2048,\n",
    "                             eps_offline: float = 0.95,\n",
    "                             eps_online: float = 0.8,\n",
    "                             eps_decay: float = 0.99,\n",
    "                             eps_min: float = 0.1,\n",
    "                             percentile_refit: float = 0.95,\n",
    "                             reward_shaping=None,\n",
    "                             seed: int = 42,\n",
    "                             title: str = \"\") -> Dict:\n",
    "    \"\"\"Collect data offline, then run fitted-Q with TabPFN and periodic refits.\"\"\"\n",
    "    env = gym.make(env_id)\n",
    "    env.reset(seed=seed)\n",
    "    n_actions = env.action_space.n\n",
    "\n",
    "    training_data = []\n",
    "    rewards_all   = []\n",
    "    pbar          = tqdm.tqdm(range(n_episodes))\n",
    "\n",
    "    # Offline data collection\n",
    "    epsilon = eps_offline\n",
    "    ep_idx  = 0\n",
    "    while len(training_data) < initial_size:\n",
    "        state, _   = env.reset()\n",
    "        episode    = []\n",
    "        total_r    = 0\n",
    "\n",
    "        for _ in range(max_steps):\n",
    "            action, q_est = get_action(env, state, None, epsilon, n_actions)\n",
    "            nxt, r, term, trunc, _ = env.step(action)\n",
    "            if reward_shaping:\n",
    "                r = reward_shaping(nxt)\n",
    "\n",
    "            done = term or trunc\n",
    "            episode.append({\n",
    "                \"episode\": ep_idx,\n",
    "                \"epsilon\": epsilon,\n",
    "                \"reward_until_now\": float(total_r),\n",
    "                \"expected_reward\":  float(q_est),\n",
    "                \"state\":           state.astype(float).tolist(),\n",
    "                \"action\":          float(action),\n",
    "                \"reward\":          float(r),\n",
    "                \"next_state\":      nxt.astype(float).tolist(),\n",
    "                \"done\":            float(done)\n",
    "            })\n",
    "\n",
    "            total_r += r\n",
    "            state    = nxt\n",
    "            if done:\n",
    "                break\n",
    "\n",
    "        training_data.extend(episode)\n",
    "        rewards_all.append(total_r)\n",
    "        ep_idx += 1\n",
    "        pbar.set_description(f\"bootstrap r={total_r:.1f}\")\n",
    "\n",
    "    print(f\"Collected {len(training_data)} transitions.\")\n",
    "\n",
    "    # Initial fit\n",
    "    model   = fit_model(training_data, env)\n",
    "    epsilon = eps_online\n",
    "    refits  = 0\n",
    "    rewards_history = []\n",
    "\n",
    "    # Online phase\n",
    "    for ep_idx in pbar:\n",
    "        state, _ = env.reset()\n",
    "        total_r  = 0\n",
    "        episode  = []\n",
    "\n",
    "        for _ in range(max_steps):\n",
    "            action, q_est = get_action(env, state, model, epsilon, n_actions)\n",
    "            nxt, r, term, trunc, _ = env.step(action)\n",
    "            if reward_shaping:\n",
    "                r = reward_shaping(nxt)\n",
    "\n",
    "            done = term or trunc\n",
    "            episode.append({\n",
    "                \"episode\": ep_idx,\n",
    "                \"epsilon\": epsilon,\n",
    "                \"reward_until_now\": float(total_r),\n",
    "                \"expected_reward\":  float(q_est),\n",
    "                \"state\":           state.astype(float).tolist(),\n",
    "                \"action\":          float(action),\n",
    "                \"reward\":          float(r),\n",
    "                \"next_state\":      nxt.astype(float).tolist(),\n",
    "                \"done\":            float(done)\n",
    "            })\n",
    "\n",
    "            total_r += r\n",
    "            state    = nxt\n",
    "            if done:\n",
    "                break\n",
    "\n",
    "        rewards_history.append(total_r)\n",
    "        rewards_all.append(total_r)\n",
    "\n",
    "        # Refitting criterion\n",
    "        if (total_r > np.quantile(rewards_history, percentile_refit) and\n",
    "            len(training_data) < max_context):\n",
    "            training_data.extend(episode)\n",
    "            model  = fit_model(training_data, env)\n",
    "            refits += 1\n",
    "\n",
    "        epsilon = max(epsilon * eps_decay, eps_min)\n",
    "        pbar.set_description(f\"r={total_r:.1f} σ={np.std(rewards_history):.2f} refits={refits}\")\n",
    "\n",
    "    fname = f\"./{env_id}{title}_{seed}_results.json\"\n",
    "    try:\n",
    "        with open(fname, \"w\") as f:\n",
    "            json.dump({\"rewards\": rewards_all, \"training_data\": training_data}, f)\n",
    "    except Exception as e:\n",
    "        print(f\"Error saving: {e}\")\n",
    "\n",
    "    return {\"rewards\": rewards_all, \"training_data\": training_data}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b81e67ca-7257-4de9-b724-524c99515644",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\u001b[A%|          | 0/10000 [00:00<?, ?it/s]\n",
      "\u001b[Atstrap r=20.0:   0%|          | 0/10000 [00:00<?, ?it/s]\n",
      "\u001b[Atstrap r=37.0:   0%|          | 0/10000 [00:00<?, ?it/s]\n",
      "\u001b[Atstrap r=11.0:   0%|          | 0/10000 [00:00<?, ?it/s]\n",
      "\u001b[Atstrap r=16.0:   0%|          | 0/10000 [00:00<?, ?it/s]\n",
      "\u001b[Atstrap r=20.0:   0%|          | 0/10000 [00:00<?, ?it/s]\n",
      "\u001b[Atstrap r=30.0:   0%|          | 0/10000 [00:00<?, ?it/s]\n",
      "\u001b[Atstrap r=12.0:   0%|          | 0/10000 [00:00<?, ?it/s]\n",
      "\u001b[Atstrap r=29.0:   0%|          | 0/10000 [00:00<?, ?it/s]\n",
      "\u001b[Atstrap r=29.0:   0%|          | 0/10000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collected 204 transitions.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\u001b[A1.0 σ=0.00 refits=0:   0%|          | 0/10000 [00:13<?, ?it/s]\n",
      "\u001b[A1.0 σ=0.00 refits=0:   0%|          | 1/10000 [00:13<38:24:59, 13.83s/it]\n",
      "\u001b[A6.0 σ=12.50 refits=0:   0%|          | 1/10000 [00:14<38:24:59, 13.83s/it]\n",
      "\u001b[A6.0 σ=12.50 refits=0:   0%|          | 2/10000 [00:14<16:39:21,  6.00s/it]\n",
      "\u001b[A6.0 σ=11.79 refits=0:   0%|          | 2/10000 [00:14<16:39:21,  6.00s/it]\n",
      "\u001b[A6.0 σ=11.79 refits=0:   0%|          | 3/10000 [00:14<9:41:59,  3.49s/it] \n",
      "\u001b[A7.0 σ=10.27 refits=0:   0%|          | 3/10000 [00:15<9:41:59,  3.49s/it]\n",
      "\u001b[A7.0 σ=10.27 refits=0:   0%|          | 4/10000 [00:15<6:49:10,  2.46s/it]\n",
      "\u001b[A3.0 σ=10.37 refits=0:   0%|          | 4/10000 [00:16<6:49:10,  2.46s/it]\n",
      "\u001b[A3.0 σ=10.37 refits=0:   0%|          | 5/10000 [00:16<4:46:52,  1.72s/it]\n",
      "\u001b[A7.0 σ=9.69 refits=0:   0%|          | 5/10000 [00:16<4:46:52,  1.72s/it] \n",
      "\u001b[A7.0 σ=9.69 refits=0:   0%|          | 6/10000 [00:16<3:39:32,  1.32s/it]\n",
      "\u001b[A9.0 σ=9.33 refits=0:   0%|          | 6/10000 [00:17<3:39:32,  1.32s/it]\n",
      "\u001b[A9.0 σ=9.33 refits=0:   0%|          | 7/10000 [00:17<3:17:36,  1.19s/it]\n",
      "\u001b[A5.0 σ=9.09 refits=0:   0%|          | 7/10000 [00:18<3:17:36,  1.19s/it]\n",
      "\u001b[A5.0 σ=9.09 refits=0:   0%|          | 8/10000 [00:18<2:40:06,  1.04it/s]\n",
      "\u001b[A3.0 σ=8.58 refits=0:   0%|          | 8/10000 [00:18<2:40:06,  1.04it/s]\n",
      "\u001b[A3.0 σ=8.58 refits=0:   0%|          | 9/10000 [00:18<2:27:56,  1.13it/s]\n",
      "\u001b[A0.0 σ=8.16 refits=0:   0%|          | 9/10000 [00:19<2:27:56,  1.13it/s]\n",
      "\u001b[A0.0 σ=8.16 refits=0:   0%|          | 10/10000 [00:19<2:14:34,  1.24it/s]\n",
      "\u001b[A4.0 σ=14.44 refits=1:   0%|          | 10/10000 [00:31<2:14:34,  1.24it/s]\n",
      "\u001b[A4.0 σ=14.44 refits=1:   0%|          | 11/10000 [00:31<11:59:48,  4.32s/it]\n",
      "\u001b[A8.0 σ=16.48 refits=1:   0%|          | 11/10000 [00:33<11:59:48,  4.32s/it]\n",
      "\u001b[A8.0 σ=16.48 refits=1:   0%|          | 12/10000 [00:33<9:57:40,  3.59s/it] \n",
      "\u001b[A5.0 σ=15.93 refits=1:   0%|          | 12/10000 [00:34<9:57:40,  3.59s/it]\n",
      "\u001b[A5.0 σ=15.93 refits=1:   0%|          | 13/10000 [00:34<7:55:20,  2.86s/it]\n",
      "\u001b[A3.0 σ=15.39 refits=1:   0%|          | 13/10000 [00:35<7:55:20,  2.86s/it]\n",
      "\u001b[A3.0 σ=15.39 refits=1:   0%|          | 14/10000 [00:35<6:26:49,  2.32s/it]\n",
      "\u001b[A4.0 σ=14.92 refits=1:   0%|          | 14/10000 [00:37<6:26:49,  2.32s/it]\n",
      "\u001b[A4.0 σ=14.92 refits=1:   0%|          | 15/10000 [00:37<5:26:50,  1.96s/it]\n",
      "\u001b[A6.0 σ=15.00 refits=1:   0%|          | 15/10000 [00:38<5:26:50,  1.96s/it]\n",
      "\u001b[A6.0 σ=15.00 refits=1:   0%|          | 16/10000 [00:38<5:04:42,  1.83s/it]\n",
      "\u001b[A5.0 σ=14.99 refits=1:   0%|          | 16/10000 [00:39<5:04:42,  1.83s/it]\n",
      "\u001b[A5.0 σ=14.99 refits=1:   0%|          | 17/10000 [00:39<3:57:59,  1.43s/it]\n",
      "\u001b[A2.0 σ=17.52 refits=2:   0%|          | 17/10000 [00:51<3:57:59,  1.43s/it]\n",
      "r=72.0 σ=17.52 refits=2:   0%|          | 18/10000 [00:51<7:58:14,  2.87s/it] \n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m      1\u001b[39m \u001b[38;5;66;03m# Execute:\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m results = run_tabpfn_fqi()\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 161\u001b[39m, in \u001b[36mrun_tabpfn_fqi\u001b[39m\u001b[34m(env_id, n_episodes, max_steps, initial_size, max_context, eps_offline, eps_online, eps_decay, eps_min, percentile_refit, reward_shaping, seed, title)\u001b[39m\n\u001b[32m    158\u001b[39m episode  = []\n\u001b[32m    160\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_steps):\n\u001b[32m--> \u001b[39m\u001b[32m161\u001b[39m     action, q_est = get_action(env, state, model, epsilon, n_actions)\n\u001b[32m    162\u001b[39m     nxt, r, term, trunc, _ = env.step(action)\n\u001b[32m    163\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m reward_shaping:\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 54\u001b[39m, in \u001b[36mget_action\u001b[39m\u001b[34m(env, state, model, epsilon, n_actions)\u001b[39m\n\u001b[32m     52\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m act_greedy(state, model, n_actions)\n\u001b[32m     53\u001b[39m action = env.action_space.sample()\n\u001b[32m---> \u001b[39m\u001b[32m54\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m action, q_value(state, action, model, n_actions)\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 42\u001b[39m, in \u001b[36mq_value\u001b[39m\u001b[34m(state, action, model, n_actions)\u001b[39m\n\u001b[32m     40\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m model \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m     41\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[32m0.0\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m42\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mfloat\u001b[39m(model.predict(make_row(state, action, n_actions)[\u001b[38;5;28;01mNone\u001b[39;00m])[\u001b[32m0\u001b[39m])\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/contextlib.py:85\u001b[39m, in \u001b[36mContextDecorator.__call__.<locals>.inner\u001b[39m\u001b[34m(*args, **kwds)\u001b[39m\n\u001b[32m     82\u001b[39m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[32m     83\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34minner\u001b[39m(*args, **kwds):\n\u001b[32m     84\u001b[39m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m._recreate_cm():\n\u001b[32m---> \u001b[39m\u001b[32m85\u001b[39m         \u001b[38;5;28;01mreturn\u001b[39;00m func(*args, **kwds)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/tabpfn/regressor.py:649\u001b[39m, in \u001b[36mTabPFNRegressor.predict\u001b[39m\u001b[34m(self, X, output_type, quantiles)\u001b[39m\n\u001b[32m    646\u001b[39m outputs: \u001b[38;5;28mlist\u001b[39m[torch.Tensor] = []\n\u001b[32m    647\u001b[39m borders: \u001b[38;5;28mlist\u001b[39m[np.ndarray] = []\n\u001b[32m--> \u001b[39m\u001b[32m649\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m output, config \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.executor_.iter_outputs(\n\u001b[32m    650\u001b[39m     X,\n\u001b[32m    651\u001b[39m     device=\u001b[38;5;28mself\u001b[39m.device_,\n\u001b[32m    652\u001b[39m     autocast=\u001b[38;5;28mself\u001b[39m.use_autocast_,\n\u001b[32m    653\u001b[39m ):\n\u001b[32m    654\u001b[39m     \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(config, RegressorEnsembleConfig)\n\u001b[32m    656\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.softmax_temperature != \u001b[32m1\u001b[39m:\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/tabpfn/inference.py:486\u001b[39m, in \u001b[36mInferenceEngineCacheKV.iter_outputs\u001b[39m\u001b[34m(self, X, device, autocast, only_return_standard_out)\u001b[39m\n\u001b[32m    473\u001b[39m X_test = X_test.unsqueeze(\u001b[32m1\u001b[39m)\n\u001b[32m    475\u001b[39m MemoryUsageEstimator.reset_peak_memory_if_required(\n\u001b[32m    476\u001b[39m     save_peak_mem=\u001b[38;5;28mself\u001b[39m.save_peak_mem,\n\u001b[32m    477\u001b[39m     model=model,\n\u001b[32m   (...)\u001b[39m\u001b[32m    483\u001b[39m     n_train_samples=X_train_len,\n\u001b[32m    484\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m486\u001b[39m model = model.to(device)  \u001b[38;5;66;03m# noqa: PLW2901\u001b[39;00m\n\u001b[32m    487\u001b[39m style = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m    489\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.force_inference_dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/torch/nn/modules/module.py:1355\u001b[39m, in \u001b[36mModule.to\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1352\u001b[39m         \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m   1353\u001b[39m             \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1355\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._apply(convert)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/torch/nn/modules/module.py:915\u001b[39m, in \u001b[36mModule._apply\u001b[39m\u001b[34m(self, fn, recurse)\u001b[39m\n\u001b[32m    913\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[32m    914\u001b[39m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.children():\n\u001b[32m--> \u001b[39m\u001b[32m915\u001b[39m         module._apply(fn)\n\u001b[32m    917\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[32m    918\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m torch._has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[32m    919\u001b[39m         \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[32m    920\u001b[39m         \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[32m   (...)\u001b[39m\u001b[32m    925\u001b[39m         \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[32m    926\u001b[39m         \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/torch/nn/modules/module.py:915\u001b[39m, in \u001b[36mModule._apply\u001b[39m\u001b[34m(self, fn, recurse)\u001b[39m\n\u001b[32m    913\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[32m    914\u001b[39m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.children():\n\u001b[32m--> \u001b[39m\u001b[32m915\u001b[39m         module._apply(fn)\n\u001b[32m    917\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[32m    918\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m torch._has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[32m    919\u001b[39m         \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[32m    920\u001b[39m         \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[32m   (...)\u001b[39m\u001b[32m    925\u001b[39m         \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[32m    926\u001b[39m         \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n",
      "    \u001b[31m[... skipping similar frames: Module._apply at line 915 (1 times)]\u001b[39m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/torch/nn/modules/module.py:915\u001b[39m, in \u001b[36mModule._apply\u001b[39m\u001b[34m(self, fn, recurse)\u001b[39m\n\u001b[32m    913\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[32m    914\u001b[39m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.children():\n\u001b[32m--> \u001b[39m\u001b[32m915\u001b[39m         module._apply(fn)\n\u001b[32m    917\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[32m    918\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m torch._has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[32m    919\u001b[39m         \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[32m    920\u001b[39m         \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[32m   (...)\u001b[39m\u001b[32m    925\u001b[39m         \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[32m    926\u001b[39m         \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/torch/nn/modules/module.py:942\u001b[39m, in \u001b[36mModule._apply\u001b[39m\u001b[34m(self, fn, recurse)\u001b[39m\n\u001b[32m    938\u001b[39m \u001b[38;5;66;03m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[32m    939\u001b[39m \u001b[38;5;66;03m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[32m    940\u001b[39m \u001b[38;5;66;03m# `with torch.no_grad():`\u001b[39;00m\n\u001b[32m    941\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m torch.no_grad():\n\u001b[32m--> \u001b[39m\u001b[32m942\u001b[39m     param_applied = fn(param)\n\u001b[32m    943\u001b[39m p_should_use_set_data = compute_should_use_set_data(param, param_applied)\n\u001b[32m    945\u001b[39m \u001b[38;5;66;03m# subclasses may have multiple child tensors so we need to use swap_tensors\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/torch/nn/modules/module.py:1341\u001b[39m, in \u001b[36mModule.to.<locals>.convert\u001b[39m\u001b[34m(t)\u001b[39m\n\u001b[32m   1334\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m convert_to_format \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m t.dim() \u001b[38;5;129;01min\u001b[39;00m (\u001b[32m4\u001b[39m, \u001b[32m5\u001b[39m):\n\u001b[32m   1335\u001b[39m         \u001b[38;5;28;01mreturn\u001b[39;00m t.to(\n\u001b[32m   1336\u001b[39m             device,\n\u001b[32m   1337\u001b[39m             dtype \u001b[38;5;28;01mif\u001b[39;00m t.is_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t.is_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m   1338\u001b[39m             non_blocking,\n\u001b[32m   1339\u001b[39m             memory_format=convert_to_format,\n\u001b[32m   1340\u001b[39m         )\n\u001b[32m-> \u001b[39m\u001b[32m1341\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m t.to(\n\u001b[32m   1342\u001b[39m         device,\n\u001b[32m   1343\u001b[39m         dtype \u001b[38;5;28;01mif\u001b[39;00m t.is_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t.is_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m   1344\u001b[39m         non_blocking,\n\u001b[32m   1345\u001b[39m     )\n\u001b[32m   1346\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m   1347\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(e) == \u001b[33m\"\u001b[39m\u001b[33mCannot copy out of meta tensor; no data!\u001b[39m\u001b[33m\"\u001b[39m:\n",
      "\u001b[31mKeyboardInterrupt\u001b[39m: "
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/IPython/core/async_helpers.py:128\u001b[39m, in \u001b[36m_pseudo_sync_runner\u001b[39m\u001b[34m(coro)\u001b[39m\n\u001b[32m    120\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m    121\u001b[39m \u001b[33;03mA runner that does not really allow async execution, and just advance the coroutine.\u001b[39;00m\n\u001b[32m    122\u001b[39m \n\u001b[32m   (...)\u001b[39m\u001b[32m    125\u001b[39m \u001b[33;03mCredit to Nathaniel Smith\u001b[39;00m\n\u001b[32m    126\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m    127\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m128\u001b[39m     coro.send(\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[32m    129\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[32m    130\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m exc.value\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3381\u001b[39m, in \u001b[36mInteractiveShell.run_cell_async\u001b[39m\u001b[34m(self, raw_cell, store_history, silent, shell_futures, transformed_cell, preprocessing_exc_tuple, cell_id)\u001b[39m\n\u001b[32m   3377\u001b[39m exec_count = \u001b[38;5;28mself\u001b[39m.execution_count\n\u001b[32m   3378\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m result.error_in_exec:\n\u001b[32m   3379\u001b[39m     \u001b[38;5;66;03m# Store formatted traceback and error details\u001b[39;00m\n\u001b[32m   3380\u001b[39m     \u001b[38;5;28mself\u001b[39m.history_manager.exceptions[exec_count] = (\n\u001b[32m-> \u001b[39m\u001b[32m3381\u001b[39m         \u001b[38;5;28mself\u001b[39m._format_exception_for_storage(result.error_in_exec)\n\u001b[32m   3382\u001b[39m     )\n\u001b[32m   3384\u001b[39m \u001b[38;5;66;03m# Each cell is a *single* input, regardless of how many lines it has\u001b[39;00m\n\u001b[32m   3385\u001b[39m \u001b[38;5;28mself\u001b[39m.execution_count += \u001b[32m1\u001b[39m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3435\u001b[39m, in \u001b[36mInteractiveShell._format_exception_for_storage\u001b[39m\u001b[34m(self, exception, filename, running_compiled_code)\u001b[39m\n\u001b[32m   3432\u001b[39m         stb = evalue._render_traceback_()\n\u001b[32m   3433\u001b[39m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m   3434\u001b[39m         \u001b[38;5;66;03m# Otherwise, use InteractiveTB to format the traceback.\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m3435\u001b[39m         stb = \u001b[38;5;28mself\u001b[39m.InteractiveTB.structured_traceback(\n\u001b[32m   3436\u001b[39m             etype, evalue, tb, tb_offset=\u001b[32m1\u001b[39m\n\u001b[32m   3437\u001b[39m         )\n\u001b[32m   3438\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[32m   3439\u001b[39m     \u001b[38;5;66;03m# In case formatting fails, fallback to Python's built-in formatting.\u001b[39;00m\n\u001b[32m   3440\u001b[39m     stb = traceback.format_exception(etype, evalue, tb)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/IPython/core/ultratb.py:1182\u001b[39m, in \u001b[36mAutoFormattedTB.structured_traceback\u001b[39m\u001b[34m(self, etype, evalue, etb, tb_offset, context)\u001b[39m\n\u001b[32m   1180\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m   1181\u001b[39m     \u001b[38;5;28mself\u001b[39m.tb = etb\n\u001b[32m-> \u001b[39m\u001b[32m1182\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m FormattedTB.structured_traceback(\n\u001b[32m   1183\u001b[39m     \u001b[38;5;28mself\u001b[39m, etype, evalue, etb, tb_offset, context\n\u001b[32m   1184\u001b[39m )\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/IPython/core/ultratb.py:1053\u001b[39m, in \u001b[36mFormattedTB.structured_traceback\u001b[39m\u001b[34m(self, etype, evalue, etb, tb_offset, context)\u001b[39m\n\u001b[32m   1050\u001b[39m mode = \u001b[38;5;28mself\u001b[39m.mode\n\u001b[32m   1051\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m mode \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.verbose_modes:\n\u001b[32m   1052\u001b[39m     \u001b[38;5;66;03m# Verbose modes need a full traceback\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1053\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m VerboseTB.structured_traceback(\n\u001b[32m   1054\u001b[39m         \u001b[38;5;28mself\u001b[39m, etype, evalue, etb, tb_offset, context\n\u001b[32m   1055\u001b[39m     )\n\u001b[32m   1056\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m mode == \u001b[33m\"\u001b[39m\u001b[33mDocs\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m   1057\u001b[39m     \u001b[38;5;66;03m# return DocTB\u001b[39;00m\n\u001b[32m   1058\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m DocTB(\n\u001b[32m   1059\u001b[39m         theme_name=\u001b[38;5;28mself\u001b[39m._theme_name,\n\u001b[32m   1060\u001b[39m         call_pdb=\u001b[38;5;28mself\u001b[39m.call_pdb,\n\u001b[32m   (...)\u001b[39m\u001b[32m   1068\u001b[39m         etype, evalue, etb, tb_offset, \u001b[32m1\u001b[39m\n\u001b[32m   1069\u001b[39m     )  \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/IPython/core/ultratb.py:861\u001b[39m, in \u001b[36mVerboseTB.structured_traceback\u001b[39m\u001b[34m(self, etype, evalue, etb, tb_offset, context)\u001b[39m\n\u001b[32m    852\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstructured_traceback\u001b[39m(\n\u001b[32m    853\u001b[39m     \u001b[38;5;28mself\u001b[39m,\n\u001b[32m    854\u001b[39m     etype: \u001b[38;5;28mtype\u001b[39m,\n\u001b[32m   (...)\u001b[39m\u001b[32m    858\u001b[39m     context: \u001b[38;5;28mint\u001b[39m = \u001b[32m5\u001b[39m,\n\u001b[32m    859\u001b[39m ) -> \u001b[38;5;28mlist\u001b[39m[\u001b[38;5;28mstr\u001b[39m]:\n\u001b[32m    860\u001b[39m \u001b[38;5;250m    \u001b[39m\u001b[33;03m\"\"\"Return a nice text document describing the traceback.\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m861\u001b[39m     formatted_exceptions: \u001b[38;5;28mlist\u001b[39m[\u001b[38;5;28mlist\u001b[39m[\u001b[38;5;28mstr\u001b[39m]] = \u001b[38;5;28mself\u001b[39m.format_exception_as_a_whole(\n\u001b[32m    862\u001b[39m         etype, evalue, etb, context, tb_offset\n\u001b[32m    863\u001b[39m     )\n\u001b[32m    865\u001b[39m     termsize = \u001b[38;5;28mmin\u001b[39m(\u001b[32m75\u001b[39m, get_terminal_size()[\u001b[32m0\u001b[39m])\n\u001b[32m    866\u001b[39m     theme = theme_table[\u001b[38;5;28mself\u001b[39m._theme_name]\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/IPython/core/ultratb.py:773\u001b[39m, in \u001b[36mVerboseTB.format_exception_as_a_whole\u001b[39m\u001b[34m(self, etype, evalue, etb, context, tb_offset)\u001b[39m\n\u001b[32m    763\u001b[39m         frames.append(\n\u001b[32m    764\u001b[39m             theme_table[\u001b[38;5;28mself\u001b[39m._theme_name].format(\n\u001b[32m    765\u001b[39m                 [\n\u001b[32m   (...)\u001b[39m\u001b[32m    770\u001b[39m             )\n\u001b[32m    771\u001b[39m         )\n\u001b[32m    772\u001b[39m         skipped = \u001b[32m0\u001b[39m\n\u001b[32m--> \u001b[39m\u001b[32m773\u001b[39m     frames.append(\u001b[38;5;28mself\u001b[39m.format_record(record))\n\u001b[32m    774\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m skipped:\n\u001b[32m    775\u001b[39m     frames.append(\n\u001b[32m    776\u001b[39m         theme_table[\u001b[38;5;28mself\u001b[39m._theme_name].format(\n\u001b[32m    777\u001b[39m             [\n\u001b[32m   (...)\u001b[39m\u001b[32m    782\u001b[39m         )\n\u001b[32m    783\u001b[39m     )\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/IPython/core/ultratb.py:651\u001b[39m, in \u001b[36mVerboseTB.format_record\u001b[39m\u001b[34m(self, frame_info)\u001b[39m\n\u001b[32m    648\u001b[39m result += \u001b[33m\"\u001b[39m\u001b[33m, \u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m call \u001b[38;5;28;01melse\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m    649\u001b[39m result += \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcall\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m    650\u001b[39m result += theme_table[\u001b[38;5;28mself\u001b[39m._theme_name].format(\n\u001b[32m--> \u001b[39m\u001b[32m651\u001b[39m     _format_traceback_lines(\n\u001b[32m    652\u001b[39m         frame_info.lines,\n\u001b[32m    653\u001b[39m         theme_table[\u001b[38;5;28mself\u001b[39m._theme_name],\n\u001b[32m    654\u001b[39m         \u001b[38;5;28mself\u001b[39m.has_colors,\n\u001b[32m    655\u001b[39m         lvals_toks,\n\u001b[32m    656\u001b[39m     )\n\u001b[32m    657\u001b[39m )\n\u001b[32m    658\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m result\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/IPython/core/tbtools.py:99\u001b[39m, in \u001b[36m_format_traceback_lines\u001b[39m\u001b[34m(lines, theme, has_colors, lvals_toks)\u001b[39m\n\u001b[32m     96\u001b[39m     \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[32m     98\u001b[39m lineno = stack_line.lineno\n\u001b[32m---> \u001b[39m\u001b[32m99\u001b[39m line = stack_line.render(pygmented=has_colors).rstrip(\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m) + \u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m    100\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m stack_line.is_current:\n\u001b[32m    101\u001b[39m     \u001b[38;5;66;03m# This is the line with the error\u001b[39;00m\n\u001b[32m    102\u001b[39m     pad = numbers_width - \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mstr\u001b[39m(lineno))\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/stack_data/core.py:360\u001b[39m, in \u001b[36mLine.render\u001b[39m\u001b[34m(self, markers, strip_leading_indent, pygmented, escape_html)\u001b[39m\n\u001b[32m    358\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m pygmented \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m.frame_info.scope:\n\u001b[32m    359\u001b[39m     assert_(\u001b[38;5;129;01mnot\u001b[39;00m markers, \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mCannot use pygmented with markers\u001b[39m\u001b[33m\"\u001b[39m))\n\u001b[32m--> \u001b[39m\u001b[32m360\u001b[39m     start_line, lines = \u001b[38;5;28mself\u001b[39m.frame_info._pygmented_scope_lines\n\u001b[32m    361\u001b[39m     result = lines[\u001b[38;5;28mself\u001b[39m.lineno - start_line]\n\u001b[32m    362\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m strip_leading_indent:\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/stack_data/utils.py:145\u001b[39m, in \u001b[36mcached_property.cached_property_wrapper\u001b[39m\u001b[34m(self, obj, _cls)\u001b[39m\n\u001b[32m    142\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m obj \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m    143\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n\u001b[32m--> \u001b[39m\u001b[32m145\u001b[39m value = obj.\u001b[34m__dict__\u001b[39m[\u001b[38;5;28mself\u001b[39m.func.\u001b[34m__name__\u001b[39m] = \u001b[38;5;28mself\u001b[39m.func(obj)\n\u001b[32m    146\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m value\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/stack_data/core.py:780\u001b[39m, in \u001b[36mFrameInfo._pygmented_scope_lines\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m    777\u001b[39m     ranges = []\n\u001b[32m    779\u001b[39m code = atok.get_text(scope)\n\u001b[32m--> \u001b[39m\u001b[32m780\u001b[39m lines = _pygmented_with_ranges(formatter, code, ranges)\n\u001b[32m    782\u001b[39m start_line = line_range(scope)[\u001b[32m0\u001b[39m]\n\u001b[32m    784\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m start_line, lines\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/stack_data/utils.py:165\u001b[39m, in \u001b[36m_pygmented_with_ranges\u001b[39m\u001b[34m(formatter, code, ranges)\u001b[39m\n\u001b[32m    162\u001b[39m             \u001b[38;5;28;01myield\u001b[39;00m ttype, value\n\u001b[32m    164\u001b[39m lexer = MyLexer(stripnl=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m--> \u001b[39m\u001b[32m165\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m pygments.highlight(code, lexer, formatter).splitlines()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/pygments/__init__.py:82\u001b[39m, in \u001b[36mhighlight\u001b[39m\u001b[34m(code, lexer, formatter, outfile)\u001b[39m\n\u001b[32m     77\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mhighlight\u001b[39m(code, lexer, formatter, outfile=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m     78\u001b[39m \u001b[38;5;250m    \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m     79\u001b[39m \u001b[33;03m    This is the most high-level highlighting function. It combines `lex` and\u001b[39;00m\n\u001b[32m     80\u001b[39m \u001b[33;03m    `format` in one function.\u001b[39;00m\n\u001b[32m     81\u001b[39m \u001b[33;03m    \"\"\"\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m82\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mformat\u001b[39m(lex(code, lexer), formatter, outfile)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/pygments/__init__.py:64\u001b[39m, in \u001b[36mformat\u001b[39m\u001b[34m(tokens, formatter, outfile)\u001b[39m\n\u001b[32m     62\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m outfile:\n\u001b[32m     63\u001b[39m     realoutfile = \u001b[38;5;28mgetattr\u001b[39m(formatter, \u001b[33m'\u001b[39m\u001b[33mencoding\u001b[39m\u001b[33m'\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;129;01mand\u001b[39;00m BytesIO() \u001b[38;5;129;01mor\u001b[39;00m StringIO()\n\u001b[32m---> \u001b[39m\u001b[32m64\u001b[39m     formatter.format(tokens, realoutfile)\n\u001b[32m     65\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m realoutfile.getvalue()\n\u001b[32m     66\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/pygments/formatters/terminal256.py:250\u001b[39m, in \u001b[36mTerminal256Formatter.format\u001b[39m\u001b[34m(self, tokensource, outfile)\u001b[39m\n\u001b[32m    249\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mformat\u001b[39m(\u001b[38;5;28mself\u001b[39m, tokensource, outfile):\n\u001b[32m--> \u001b[39m\u001b[32m250\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m Formatter.format(\u001b[38;5;28mself\u001b[39m, tokensource, outfile)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/pygments/formatter.py:124\u001b[39m, in \u001b[36mFormatter.format\u001b[39m\u001b[34m(self, tokensource, outfile)\u001b[39m\n\u001b[32m    121\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.encoding:\n\u001b[32m    122\u001b[39m     \u001b[38;5;66;03m# wrap the outfile in a StreamWriter\u001b[39;00m\n\u001b[32m    123\u001b[39m     outfile = codecs.lookup(\u001b[38;5;28mself\u001b[39m.encoding)[\u001b[32m3\u001b[39m](outfile)\n\u001b[32m--> \u001b[39m\u001b[32m124\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.format_unencoded(tokensource, outfile)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/pygments/formatters/terminal256.py:256\u001b[39m, in \u001b[36mTerminal256Formatter.format_unencoded\u001b[39m\u001b[34m(self, tokensource, outfile)\u001b[39m\n\u001b[32m    253\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.linenos:\n\u001b[32m    254\u001b[39m     \u001b[38;5;28mself\u001b[39m._write_lineno(outfile)\n\u001b[32m--> \u001b[39m\u001b[32m256\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m ttype, value \u001b[38;5;129;01min\u001b[39;00m tokensource:\n\u001b[32m    257\u001b[39m     not_found = \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m    258\u001b[39m     \u001b[38;5;28;01mwhile\u001b[39;00m ttype \u001b[38;5;129;01mand\u001b[39;00m not_found:\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/stack_data/utils.py:158\u001b[39m, in \u001b[36m_pygmented_with_ranges.<locals>.MyLexer.get_tokens\u001b[39m\u001b[34m(self, text)\u001b[39m\n\u001b[32m    156\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mget_tokens\u001b[39m(\u001b[38;5;28mself\u001b[39m, text):\n\u001b[32m    157\u001b[39m     length = \u001b[32m0\u001b[39m\n\u001b[32m--> \u001b[39m\u001b[32m158\u001b[39m     \u001b[38;5;28;01mfor\u001b[39;00m ttype, value \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28msuper\u001b[39m().get_tokens(text):\n\u001b[32m    159\u001b[39m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28many\u001b[39m(start <= length < end \u001b[38;5;28;01mfor\u001b[39;00m start, end \u001b[38;5;129;01min\u001b[39;00m ranges):\n\u001b[32m    160\u001b[39m             ttype = ttype.ExecutingNode\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/pygments/lexer.py:270\u001b[39m, in \u001b[36mLexer.get_tokens.<locals>.streamer\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m    269\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstreamer\u001b[39m():\n\u001b[32m--> \u001b[39m\u001b[32m270\u001b[39m     \u001b[38;5;28;01mfor\u001b[39;00m _, t, v \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.get_tokens_unprocessed(text):\n\u001b[32m    271\u001b[39m         \u001b[38;5;28;01myield\u001b[39;00m t, v\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.conda/envs/tabpfnrl/lib/python3.13/site-packages/pygments/lexer.py:712\u001b[39m, in \u001b[36mRegexLexer.get_tokens_unprocessed\u001b[39m\u001b[34m(self, text, stack)\u001b[39m\n\u001b[32m    710\u001b[39m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[32m1\u001b[39m:\n\u001b[32m    711\u001b[39m     \u001b[38;5;28;01mfor\u001b[39;00m rexmatch, action, new_state \u001b[38;5;129;01min\u001b[39;00m statetokens:\n\u001b[32m--> \u001b[39m\u001b[32m712\u001b[39m         m = rexmatch(text, pos)\n\u001b[32m    713\u001b[39m         \u001b[38;5;28;01mif\u001b[39;00m m:\n\u001b[32m    714\u001b[39m             \u001b[38;5;28;01mif\u001b[39;00m action \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "\u001b[31mKeyboardInterrupt\u001b[39m: "
     ]
    }
   ],
   "source": [
    "# Execute:\n",
    "results = run_tabpfn_fqi()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee7584b9-5f25-4c87-9240-50adf06f30b1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
