{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We have to download the TabPFN, as there is no checkpoint at  /home/felix/miniforge3/envs/tab/lib/python3.11/site-packages/tabpfn/models_diff/prior_diff_real_checkpoint_n_0_epoch_100.cpkt\n",
      "It has about 100MB, so this might take a moment.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/felix/miniforge3/envs/tab/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 0.9840425531914894\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.datasets import load_breast_cancer\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from tabpfn import TabPFNClassifier\n",
    "\n",
    "X, y = load_breast_cancer(return_X_y=True)\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)\n",
    "\n",
    "# N_ensemble_configurations controls the number of model predictions that are ensembled with feature and class rotations (See our work for details).\n",
    "# When N_ensemble_configurations > #features * #classes, no further averaging is applied.\n",
    "\n",
    "classifier = TabPFNClassifier(device='cpu', N_ensemble_configurations=32)\n",
    "\n",
    "classifier.fit(X_train, y_train)\n",
    "y_eval, p_eval = classifier.predict(X_test, return_winning_probability=True)\n",
    "\n",
    "print('Accuracy', accuracy_score(y_test, y_eval))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tab",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
