{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the models\n",
    "\n",
    "In this notebook, you will find the procedure that we used to train the models corresponding to the starting example. Considering the size of the models, this notebook can be run without GPUs. Please note that it may take some time if you run it until the end."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import the libraries\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "\n",
    "data = torch.load(\"data.pt\").reshape(2 * 200, 2).float()\n",
    "labels = torch.cat([torch.zeros(200), torch.ones(200)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup\n",
    "\n",
    "loss = nn.BCEWithLogitsLoss()\n",
    "dataset = TensorDataset(data, labels)\n",
    "dataloader = DataLoader(dataset, batch_size=10, shuffle=True)\n",
    "num_epochs = 10\n",
    "num_models = 10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the models (takes around 30 minutes)\n",
    "\n",
    "models = []\n",
    "for k in tqdm(range(num_models)):\n",
    "    model = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1)).float()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=2)\n",
    "    for i in range(num_epochs):\n",
    "        mean_loss = 0\n",
    "        for batch in dataloader:\n",
    "            optimizer.zero_grad()\n",
    "            x, y = batch\n",
    "            y_hat = model(x)\n",
    "\n",
    "            l = loss(y_hat, y.unsqueeze(1))\n",
    "            l.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            # Compute some arbitrary metric to check if the model learned correctly\n",
    "            mean_loss += l.item()\n",
    "\n",
    "    if mean_loss / len(dataloader) <= 0.01:\n",
    "        models.append(model.state_dict())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the models to disk if you want to replace the pre-trained models in other notebooks (commented to avoid writing on your disk without your acknowledgment)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(models, \"your_models.pt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "uncertainty",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
