{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7ccb9c01",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tabpfn.model.loading import load_model_criterion_config\n",
    "from tabpfn import TabPFNClassifier\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ecb288c",
   "metadata": {},
   "source": [
    "---\n",
    "### PyTorch interface"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3869d955",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"TabPFN-Wide-8k\" \n",
    "assert model_name in [\"TabPFN-Wide-1.5k\", \"TabPFN-Wide-5k\", \"TabPFN-Wide-8k\", \"TabPFNv2\"], f\"Model name {model_name} not recognized.\"\n",
    "checkpoint_path = f\"./models/{model_name}_submission.pt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6945c61b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, _, _ = load_model_criterion_config(\n",
    "            model_path=None,\n",
    "            check_bar_distribution_criterion=False,\n",
    "            cache_trainset_representation=False,\n",
    "            which='classifier',\n",
    "            version='v2',\n",
    "            download=True,\n",
    "        )\n",
    "if model_name != \"TabPFNv2\":\n",
    "    model.features_per_group = 1\n",
    "    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)\n",
    "    model.load_state_dict(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5f9d013",
   "metadata": {},
   "source": [
    "---\n",
    "### Sklearn interface (requires PyTorch model to be loaded)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7ba5248",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tabpfnwide.patches import fit\n",
    "from sklearn.datasets import make_classification\n",
    "from sklearn.model_selection import train_test_split\n",
    "tabpfn_classifier = TabPFNClassifier(n_estimators=1, device=device, ignore_pretraining_limits=True) # Turn off ensembling \n",
    "# Patch the fit method to implement an easy way to fit TabPFN-Wide while using the sklearn interface\n",
    "setattr(TabPFNClassifier, 'fit', fit)\n",
    "\n",
    "# Example data\n",
    "X, y = make_classification(n_samples=50, n_features=10, n_informative=2, n_redundant=2, n_classes=2, random_state=42)\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "83402b56",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tabpfn_classifier.fit(X_train, y_train, model=model)\n",
    "tabpfn_classifier.score(X_test, y_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tabpfn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
