{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Comparison with a simple dataset and neural network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: cuda\n"
     ]
    }
   ],
   "source": [
    "import traceback\n",
    "import pennylane as qml\n",
    "import torch\n",
    "import quantum_transformers.qmlperfcomp.torch_backend as qpctorch\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using device: {device}\")\n",
    "train_dataloader, valid_dataloader = qpctorch.data.get_swiss_roll_dataloaders(batch_size=4, num_workers=4, pin_memory=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classical neural network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.\n",
      "  warnings.warn('Lazy modules are a new feature under heavy development '\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/50: 100%|██████████| 100/100 [00:01<00:00, 66.24batch/s, Loss = 0.7178, AUC = 41.63%]                                                                                                                                           \n",
      "Epoch   2/50: 100%|██████████| 100/100 [00:00<00:00, 283.42batch/s, Loss = 0.6937, AUC = 53.10%]                                                                                                                                          \n",
      "Epoch   3/50: 100%|██████████| 100/100 [00:00<00:00, 290.22batch/s, Loss = 0.6697, AUC = 67.23%]                                                                                                                                          \n",
      "Epoch   4/50: 100%|██████████| 100/100 [00:00<00:00, 277.74batch/s, Loss = 0.6433, AUC = 69.32%]                                                                                                                                          \n",
      "Epoch   5/50: 100%|██████████| 100/100 [00:00<00:00, 293.03batch/s, Loss = 0.6097, AUC = 79.35%]                                                                                                                                          \n",
      "Epoch   6/50: 100%|██████████| 100/100 [00:00<00:00, 273.20batch/s, Loss = 0.5650, AUC = 86.76%]                                                                                                                                          \n",
      "Epoch   7/50: 100%|██████████| 100/100 [00:00<00:00, 288.43batch/s, Loss = 0.5254, AUC = 89.29%]                                                                                                                                          \n",
      "Epoch   8/50: 100%|██████████| 100/100 [00:00<00:00, 296.27batch/s, Loss = 0.4885, AUC = 91.02%]                                                                                                                                          \n",
      "Epoch   9/50: 100%|██████████| 100/100 [00:00<00:00, 298.95batch/s, Loss = 0.4605, AUC = 92.35%]                                                                                                                                          \n",
      "Epoch  10/50: 100%|██████████| 100/100 [00:00<00:00, 280.91batch/s, Loss = 0.4378, AUC = 92.75%]                                                                                                                                          \n",
      "Epoch  11/50: 100%|██████████| 100/100 [00:00<00:00, 272.67batch/s, Loss = 0.4208, AUC = 92.11%]                                                                                                                                          \n",
      "Epoch  12/50: 100%|██████████| 100/100 [00:00<00:00, 284.98batch/s, Loss = 0.4074, AUC = 93.64%]                                                                                                                                          \n",
      "Epoch  13/50: 100%|██████████| 100/100 [00:00<00:00, 285.75batch/s, Loss = 0.3942, AUC = 93.64%]                                                                                                                                          \n",
      "Epoch  14/50: 100%|██████████| 100/100 [00:00<00:00, 284.84batch/s, Loss = 0.3832, AUC = 93.72%]                                                                                                                                          \n",
      "Epoch  15/50: 100%|██████████| 100/100 [00:00<00:00, 294.68batch/s, Loss = 0.3716, AUC = 94.40%]                                                                                                                                          \n",
      "Epoch  16/50: 100%|██████████| 100/100 [00:00<00:00, 284.73batch/s, Loss = 0.3618, AUC = 95.53%]                                                                                                                                          \n",
      "Epoch  17/50: 100%|██████████| 100/100 [00:00<00:00, 285.23batch/s, Loss = 0.3562, AUC = 92.79%]                                                                                                                                          \n",
      "Epoch  18/50: 100%|██████████| 100/100 [00:00<00:00, 282.94batch/s, Loss = 0.3378, AUC = 94.89%]                                                                                                                                          \n",
      "Epoch  19/50: 100%|██████████| 100/100 [00:00<00:00, 283.02batch/s, Loss = 0.3299, AUC = 95.73%]                                                                                                                                          \n",
      "Epoch  20/50: 100%|██████████| 100/100 [00:00<00:00, 274.56batch/s, Loss = 0.3185, AUC = 95.69%]                                                                                                                                          \n",
      "Epoch  21/50: 100%|██████████| 100/100 [00:00<00:00, 293.90batch/s, Loss = 0.3102, AUC = 97.30%]                                                                                                                                          \n",
      "Epoch  22/50: 100%|██████████| 100/100 [00:00<00:00, 280.38batch/s, Loss = 0.2975, AUC = 96.94%]                                                                                                                                          \n",
      "Epoch  23/50: 100%|██████████| 100/100 [00:00<00:00, 281.30batch/s, Loss = 0.2887, AUC = 96.78%]                                                                                                                                          \n",
      "Epoch  24/50: 100%|██████████| 100/100 [00:00<00:00, 288.38batch/s, Loss = 0.2780, AUC = 98.11%]                                                                                                                                          \n",
      "Epoch  25/50: 100%|██████████| 100/100 [00:00<00:00, 294.07batch/s, Loss = 0.2675, AUC = 97.95%]                                                                                                                                          \n",
      "Epoch  26/50: 100%|██████████| 100/100 [00:00<00:00, 267.38batch/s, Loss = 0.2569, AUC = 98.19%]                                                                                                                                          \n",
      "Epoch  27/50: 100%|██████████| 100/100 [00:00<00:00, 270.73batch/s, Loss = 0.2505, AUC = 97.62%]                                                                                                                                          \n",
      "Epoch  28/50: 100%|██████████| 100/100 [00:00<00:00, 287.64batch/s, Loss = 0.2369, AUC = 98.59%]                                                                                                                                          \n",
      "Epoch  29/50: 100%|██████████| 100/100 [00:00<00:00, 290.63batch/s, Loss = 0.2268, AUC = 98.63%]                                                                                                                                          \n",
      "Epoch  30/50: 100%|██████████| 100/100 [00:00<00:00, 280.91batch/s, Loss = 0.2183, AUC = 98.99%]                                                                                                                                          \n",
      "Epoch  31/50: 100%|██████████| 100/100 [00:00<00:00, 272.07batch/s, Loss = 0.2071, AUC = 99.03%]                                                                                                                                          \n",
      "Epoch  32/50: 100%|██████████| 100/100 [00:00<00:00, 281.54batch/s, Loss = 0.2003, AUC = 98.91%]                                                                                                                                          \n",
      "Epoch  33/50: 100%|██████████| 100/100 [00:00<00:00, 290.38batch/s, Loss = 0.1871, AUC = 99.28%]                                                                                                                                          \n",
      "Epoch  34/50: 100%|██████████| 100/100 [00:00<00:00, 278.41batch/s, Loss = 0.1773, AUC = 99.28%]                                                                                                                                          \n",
      "Epoch  35/50: 100%|██████████| 100/100 [00:00<00:00, 278.37batch/s, Loss = 0.1655, AUC = 99.68%]                                                                                                                                          \n",
      "Epoch  36/50: 100%|██████████| 100/100 [00:00<00:00, 294.52batch/s, Loss = 0.1565, AUC = 99.84%]                                                                                                                                          \n",
      "Epoch  37/50: 100%|██████████| 100/100 [00:00<00:00, 290.95batch/s, Loss = 0.1473, AUC = 99.84%]                                                                                                                                          \n",
      "Epoch  38/50: 100%|██████████| 100/100 [00:00<00:00, 291.34batch/s, Loss = 0.1356, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  39/50: 100%|██████████| 100/100 [00:00<00:00, 278.14batch/s, Loss = 0.1280, AUC = 99.96%]                                                                                                                                          \n",
      "Epoch  40/50: 100%|██████████| 100/100 [00:00<00:00, 278.83batch/s, Loss = 0.1188, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  41/50: 100%|██████████| 100/100 [00:00<00:00, 289.70batch/s, Loss = 0.1098, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  42/50: 100%|██████████| 100/100 [00:00<00:00, 307.97batch/s, Loss = 0.1035, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  43/50: 100%|██████████| 100/100 [00:00<00:00, 300.29batch/s, Loss = 0.0954, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  44/50: 100%|██████████| 100/100 [00:00<00:00, 276.77batch/s, Loss = 0.0922, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  45/50: 100%|██████████| 100/100 [00:00<00:00, 297.09batch/s, Loss = 0.0836, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  46/50: 100%|██████████| 100/100 [00:00<00:00, 282.62batch/s, Loss = 0.0779, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  47/50: 100%|██████████| 100/100 [00:00<00:00, 290.15batch/s, Loss = 0.0717, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  48/50: 100%|██████████| 100/100 [00:00<00:00, 285.16batch/s, Loss = 0.0674, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  49/50: 100%|██████████| 100/100 [00:00<00:00, 284.59batch/s, Loss = 0.0632, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  50/50: 100%|██████████| 100/100 [00:00<00:00, 286.74batch/s, Loss = 0.0584, AUC = 100.00%]                                                                                                                                         "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 18.75s\n",
      "BEST AUC = 100.00% AT EPOCH 38\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpctorch.classical.MLP(5)\n",
    "qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.\n",
      "  warnings.warn('Lazy modules are a new feature under heavy development '\n",
      "Epoch   1/50: 100%|██████████| 100/100 [00:00<00:00, 285.68batch/s, Loss = 0.5284, AUC = 86.96%]                                                                                                                                          \n",
      "Epoch   2/50: 100%|██████████| 100/100 [00:00<00:00, 294.12batch/s, Loss = 0.4702, AUC = 87.60%]                                                                                                                                          \n",
      "Epoch   3/50: 100%|██████████| 100/100 [00:00<00:00, 291.14batch/s, Loss = 0.4372, AUC = 89.86%]                                                                                                                                          \n",
      "Epoch   4/50: 100%|██████████| 100/100 [00:00<00:00, 291.71batch/s, Loss = 0.4148, AUC = 89.61%]                                                                                                                                          \n",
      "Epoch   5/50: 100%|██████████| 100/100 [00:00<00:00, 297.22batch/s, Loss = 0.4066, AUC = 95.01%]                                                                                                                                          \n",
      "Epoch   6/50: 100%|██████████| 100/100 [00:00<00:00, 295.18batch/s, Loss = 0.3658, AUC = 93.32%]                                                                                                                                          \n",
      "Epoch   7/50: 100%|██████████| 100/100 [00:00<00:00, 284.64batch/s, Loss = 0.3433, AUC = 95.41%]                                                                                                                                          \n",
      "Epoch   8/50: 100%|██████████| 100/100 [00:00<00:00, 275.98batch/s, Loss = 0.3173, AUC = 96.46%]                                                                                                                                          \n",
      "Epoch   9/50: 100%|██████████| 100/100 [00:00<00:00, 271.22batch/s, Loss = 0.3019, AUC = 97.99%]                                                                                                                                          \n",
      "Epoch  10/50: 100%|██████████| 100/100 [00:00<00:00, 280.19batch/s, Loss = 0.2617, AUC = 98.55%]                                                                                                                                          \n",
      "Epoch  11/50: 100%|██████████| 100/100 [00:00<00:00, 286.91batch/s, Loss = 0.2597, AUC = 99.72%]                                                                                                                                          \n",
      "Epoch  12/50: 100%|██████████| 100/100 [00:00<00:00, 275.62batch/s, Loss = 0.2018, AUC = 99.48%]                                                                                                                                          \n",
      "Epoch  13/50: 100%|██████████| 100/100 [00:00<00:00, 288.46batch/s, Loss = 0.1796, AUC = 99.92%]                                                                                                                                          \n",
      "Epoch  14/50: 100%|██████████| 100/100 [00:00<00:00, 286.73batch/s, Loss = 0.1511, AUC = 99.96%]                                                                                                                                          \n",
      "Epoch  15/50: 100%|██████████| 100/100 [00:00<00:00, 285.51batch/s, Loss = 0.1288, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  16/50: 100%|██████████| 100/100 [00:00<00:00, 284.32batch/s, Loss = 0.1064, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  17/50: 100%|██████████| 100/100 [00:00<00:00, 285.20batch/s, Loss = 0.0858, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  18/50: 100%|██████████| 100/100 [00:00<00:00, 285.23batch/s, Loss = 0.0720, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  19/50: 100%|██████████| 100/100 [00:00<00:00, 283.04batch/s, Loss = 0.0613, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  20/50: 100%|██████████| 100/100 [00:00<00:00, 296.30batch/s, Loss = 0.0515, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  21/50: 100%|██████████| 100/100 [00:00<00:00, 288.34batch/s, Loss = 0.0417, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  22/50: 100%|██████████| 100/100 [00:00<00:00, 285.02batch/s, Loss = 0.0356, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  23/50: 100%|██████████| 100/100 [00:00<00:00, 288.69batch/s, Loss = 0.0300, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  24/50: 100%|██████████| 100/100 [00:00<00:00, 299.89batch/s, Loss = 0.0248, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  25/50: 100%|██████████| 100/100 [00:00<00:00, 299.35batch/s, Loss = 0.0221, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  26/50: 100%|██████████| 100/100 [00:00<00:00, 285.68batch/s, Loss = 0.0202, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  27/50: 100%|██████████| 100/100 [00:00<00:00, 288.04batch/s, Loss = 0.0164, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  28/50: 100%|██████████| 100/100 [00:00<00:00, 278.63batch/s, Loss = 0.0138, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  29/50: 100%|██████████| 100/100 [00:00<00:00, 288.32batch/s, Loss = 0.0123, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  30/50: 100%|██████████| 100/100 [00:00<00:00, 294.42batch/s, Loss = 0.0106, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  31/50: 100%|██████████| 100/100 [00:00<00:00, 267.75batch/s, Loss = 0.0094, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  32/50: 100%|██████████| 100/100 [00:00<00:00, 268.81batch/s, Loss = 0.0085, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  33/50: 100%|██████████| 100/100 [00:00<00:00, 278.90batch/s, Loss = 0.0074, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  34/50: 100%|██████████| 100/100 [00:00<00:00, 287.20batch/s, Loss = 0.0065, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  35/50: 100%|██████████| 100/100 [00:00<00:00, 277.95batch/s, Loss = 0.0058, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  36/50: 100%|██████████| 100/100 [00:00<00:00, 287.41batch/s, Loss = 0.0054, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  37/50: 100%|██████████| 100/100 [00:00<00:00, 283.84batch/s, Loss = 0.0053, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  38/50: 100%|██████████| 100/100 [00:00<00:00, 298.89batch/s, Loss = 0.0042, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  39/50: 100%|██████████| 100/100 [00:00<00:00, 300.18batch/s, Loss = 0.0040, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  40/50: 100%|██████████| 100/100 [00:00<00:00, 297.66batch/s, Loss = 0.0035, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  41/50: 100%|██████████| 100/100 [00:00<00:00, 296.38batch/s, Loss = 0.0031, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  42/50: 100%|██████████| 100/100 [00:00<00:00, 278.80batch/s, Loss = 0.0026, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  43/50: 100%|██████████| 100/100 [00:00<00:00, 293.95batch/s, Loss = 0.0022, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  44/50: 100%|██████████| 100/100 [00:00<00:00, 285.90batch/s, Loss = 0.0021, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  45/50: 100%|██████████| 100/100 [00:00<00:00, 289.23batch/s, Loss = 0.0018, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  46/50: 100%|██████████| 100/100 [00:00<00:00, 291.88batch/s, Loss = 0.0016, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  47/50: 100%|██████████| 100/100 [00:00<00:00, 274.75batch/s, Loss = 0.0016, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  48/50: 100%|██████████| 100/100 [00:00<00:00, 292.19batch/s, Loss = 0.0014, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  49/50: 100%|██████████| 100/100 [00:00<00:00, 275.58batch/s, Loss = 0.0013, AUC = 100.00%]                                                                                                                                         \n",
      "Epoch  50/50: 100%|██████████| 100/100 [00:00<00:00, 282.79batch/s, Loss = 0.0011, AUC = 100.00%]                                                                                                                                         "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 17.52s\n",
      "BEST AUC = 100.00% AT EPOCH 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpctorch.classical.MLP(20)\n",
    "qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum neural network with PennyLane"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### With `default.qubit` quantum device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.\n",
      "  warnings.warn('Lazy modules are a new feature under heavy development '\n",
      "Epoch   1/50:   1%|          | 1/100 [00:00<00:10,  9.86batch/s]                                                                                                                                                                          "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/50: 100%|██████████| 100/100 [00:01<00:00, 54.63batch/s, Loss = 0.7204, AUC = 62.44%]                                                                                                                                           \n",
      "Epoch   2/50: 100%|██████████| 100/100 [00:01<00:00, 54.41batch/s, Loss = 0.7078, AUC = 65.00%]                                                                                                                                           \n",
      "Epoch   3/50: 100%|██████████| 100/100 [00:01<00:00, 53.91batch/s, Loss = 0.6989, AUC = 67.39%]                                                                                                                                           \n",
      "Epoch   4/50: 100%|██████████| 100/100 [00:01<00:00, 53.73batch/s, Loss = 0.6920, AUC = 68.36%]                                                                                                                                           \n",
      "Epoch   5/50: 100%|██████████| 100/100 [00:01<00:00, 53.75batch/s, Loss = 0.6858, AUC = 68.72%]                                                                                                                                           \n",
      "Epoch   6/50: 100%|██████████| 100/100 [00:01<00:00, 52.96batch/s, Loss = 0.6793, AUC = 69.81%]                                                                                                                                           \n",
      "Epoch   7/50: 100%|██████████| 100/100 [00:01<00:00, 52.68batch/s, Loss = 0.6752, AUC = 70.37%]                                                                                                                                           \n",
      "Epoch   8/50: 100%|██████████| 100/100 [00:01<00:00, 52.94batch/s, Loss = 0.6700, AUC = 72.16%]                                                                                                                                           \n",
      "Epoch   9/50: 100%|██████████| 100/100 [00:01<00:00, 52.46batch/s, Loss = 0.6656, AUC = 71.80%]                                                                                                                                           \n",
      "Epoch  10/50: 100%|██████████| 100/100 [00:01<00:00, 52.76batch/s, Loss = 0.6604, AUC = 72.44%]                                                                                                                                           \n",
      "Epoch  11/50: 100%|██████████| 100/100 [00:01<00:00, 52.90batch/s, Loss = 0.6566, AUC = 73.17%]                                                                                                                                           \n",
      "Epoch  12/50: 100%|██████████| 100/100 [00:01<00:00, 53.01batch/s, Loss = 0.6505, AUC = 73.29%]                                                                                                                                           \n",
      "Epoch  13/50: 100%|██████████| 100/100 [00:01<00:00, 52.48batch/s, Loss = 0.6467, AUC = 73.01%]                                                                                                                                           \n",
      "Epoch  14/50: 100%|██████████| 100/100 [00:01<00:00, 52.78batch/s, Loss = 0.6435, AUC = 73.65%]                                                                                                                                           \n",
      "Epoch  15/50: 100%|██████████| 100/100 [00:01<00:00, 52.27batch/s, Loss = 0.6386, AUC = 74.05%]                                                                                                                                           \n",
      "Epoch  16/50: 100%|██████████| 100/100 [00:01<00:00, 52.25batch/s, Loss = 0.6320, AUC = 74.19%]                                                                                                                                           \n",
      "Epoch  17/50: 100%|██████████| 100/100 [00:01<00:00, 52.97batch/s, Loss = 0.6269, AUC = 75.46%]                                                                                                                                           \n",
      "Epoch  18/50: 100%|██████████| 100/100 [00:01<00:00, 52.33batch/s, Loss = 0.6216, AUC = 75.79%]                                                                                                                                           \n",
      "Epoch  19/50: 100%|██████████| 100/100 [00:01<00:00, 52.44batch/s, Loss = 0.6175, AUC = 76.47%]                                                                                                                                           \n",
      "Epoch  20/50: 100%|██████████| 100/100 [00:01<00:00, 51.60batch/s, Loss = 0.6126, AUC = 76.71%]                                                                                                                                           \n",
      "Epoch  21/50: 100%|██████████| 100/100 [00:01<00:00, 52.41batch/s, Loss = 0.6060, AUC = 77.74%]                                                                                                                                           \n",
      "Epoch  22/50: 100%|██████████| 100/100 [00:01<00:00, 52.94batch/s, Loss = 0.6038, AUC = 78.26%]                                                                                                                                           \n",
      "Epoch  23/50: 100%|██████████| 100/100 [00:01<00:00, 52.54batch/s, Loss = 0.6000, AUC = 78.74%]                                                                                                                                           \n",
      "Epoch  24/50: 100%|██████████| 100/100 [00:01<00:00, 52.38batch/s, Loss = 0.5964, AUC = 79.99%]                                                                                                                                           \n",
      "Epoch  25/50: 100%|██████████| 100/100 [00:01<00:00, 52.35batch/s, Loss = 0.5920, AUC = 80.88%]                                                                                                                                           \n",
      "Epoch  26/50: 100%|██████████| 100/100 [00:01<00:00, 51.95batch/s, Loss = 0.5887, AUC = 80.92%]                                                                                                                                           \n",
      "Epoch  27/50: 100%|██████████| 100/100 [00:01<00:00, 52.43batch/s, Loss = 0.5841, AUC = 81.60%]                                                                                                                                           \n",
      "Epoch  28/50: 100%|██████████| 100/100 [00:02<00:00, 47.34batch/s, Loss = 0.5826, AUC = 81.92%]                                                                                                                                           \n",
      "Epoch  29/50: 100%|██████████| 100/100 [00:02<00:00, 43.03batch/s, Loss = 0.5791, AUC = 83.17%]                                                                                                                                           \n",
      "Epoch  30/50: 100%|██████████| 100/100 [00:02<00:00, 43.79batch/s, Loss = 0.5742, AUC = 83.70%]                                                                                                                                           \n",
      "Epoch  31/50: 100%|██████████| 100/100 [00:02<00:00, 44.38batch/s, Loss = 0.5724, AUC = 83.29%]                                                                                                                                           \n",
      "Epoch  32/50: 100%|██████████| 100/100 [00:02<00:00, 47.24batch/s, Loss = 0.5694, AUC = 83.49%]                                                                                                                                           \n",
      "Epoch  33/50: 100%|██████████| 100/100 [00:02<00:00, 47.00batch/s, Loss = 0.5658, AUC = 83.82%]                                                                                                                                           \n",
      "Epoch  34/50: 100%|██████████| 100/100 [00:01<00:00, 50.59batch/s, Loss = 0.5640, AUC = 84.50%]                                                                                                                                           \n",
      "Epoch  35/50: 100%|██████████| 100/100 [00:01<00:00, 51.98batch/s, Loss = 0.5578, AUC = 85.23%]                                                                                                                                           \n",
      "Epoch  36/50: 100%|██████████| 100/100 [00:01<00:00, 52.53batch/s, Loss = 0.5578, AUC = 85.55%]                                                                                                                                           \n",
      "Epoch  37/50: 100%|██████████| 100/100 [00:02<00:00, 46.52batch/s, Loss = 0.5524, AUC = 85.59%]                                                                                                                                           \n",
      "Epoch  38/50: 100%|██████████| 100/100 [00:02<00:00, 42.42batch/s, Loss = 0.5512, AUC = 85.10%]                                                                                                                                           \n",
      "Epoch  39/50: 100%|██████████| 100/100 [00:02<00:00, 43.22batch/s, Loss = 0.5474, AUC = 85.63%]                                                                                                                                           \n",
      "Epoch  40/50: 100%|██████████| 100/100 [00:02<00:00, 43.13batch/s, Loss = 0.5467, AUC = 85.14%]                                                                                                                                           \n",
      "Epoch  41/50: 100%|██████████| 100/100 [00:02<00:00, 48.12batch/s, Loss = 0.5429, AUC = 84.62%]                                                                                                                                           \n",
      "Epoch  42/50: 100%|██████████| 100/100 [00:01<00:00, 52.83batch/s, Loss = 0.5412, AUC = 84.54%]                                                                                                                                           \n",
      "Epoch  43/50: 100%|██████████| 100/100 [00:01<00:00, 52.81batch/s, Loss = 0.5394, AUC = 83.45%]                                                                                                                                           \n",
      "Epoch  44/50: 100%|██████████| 100/100 [00:01<00:00, 52.86batch/s, Loss = 0.5383, AUC = 83.70%]                                                                                                                                           \n",
      "Epoch  45/50: 100%|██████████| 100/100 [00:01<00:00, 52.90batch/s, Loss = 0.5309, AUC = 84.26%]                                                                                                                                           \n",
      "Epoch  46/50: 100%|██████████| 100/100 [00:01<00:00, 53.10batch/s, Loss = 0.5294, AUC = 83.70%]                                                                                                                                           \n",
      "Epoch  47/50: 100%|██████████| 100/100 [00:01<00:00, 52.15batch/s, Loss = 0.5260, AUC = 84.02%]                                                                                                                                           \n",
      "Epoch  48/50: 100%|██████████| 100/100 [00:02<00:00, 49.13batch/s, Loss = 0.5228, AUC = 84.18%]                                                                                                                                           \n",
      "Epoch  49/50: 100%|██████████| 100/100 [00:02<00:00, 45.77batch/s, Loss = 0.5153, AUC = 85.19%]                                                                                                                                           \n",
      "Epoch  50/50: 100%|██████████| 100/100 [00:02<00:00, 47.64batch/s, Loss = 0.5084, AUC = 87.72%]                                                                                                                                           "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 99.11s\n",
      "BEST AUC = 87.72% AT EPOCH 50\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpctorch.quantum.MLP(5)\n",
    "qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.\n",
      "  warnings.warn('Lazy modules are a new feature under heavy development '\n",
      "Epoch   1/50:   1%|          | 1/100 [00:00<00:25,  3.86batch/s]                                                                                                                                                                          "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.6836, AUC = 69.44%]                                                                                                                                           \n",
      "Epoch   2/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.6760, AUC = 78.10%]                                                                                                                                           \n",
      "Epoch   3/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.6710, AUC = 79.19%]                                                                                                                                           \n",
      "Epoch   4/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.6662, AUC = 80.56%]                                                                                                                                           \n",
      "Epoch   5/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.6641, AUC = 82.53%]                                                                                                                                           \n",
      "Epoch   6/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.6569, AUC = 84.58%]                                                                                                                                           \n",
      "Epoch   7/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.6519, AUC = 85.14%]                                                                                                                                           \n",
      "Epoch   8/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.6467, AUC = 84.78%]                                                                                                                                           \n",
      "Epoch   9/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.6393, AUC = 82.25%]                                                                                                                                           \n",
      "Epoch  10/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.6323, AUC = 81.64%]                                                                                                                                           \n",
      "Epoch  11/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.6266, AUC = 80.48%]                                                                                                                                           \n",
      "Epoch  12/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.6227, AUC = 81.60%]                                                                                                                                           \n",
      "Epoch  13/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.6164, AUC = 80.60%]                                                                                                                                           \n",
      "Epoch  14/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.6060, AUC = 80.84%]                                                                                                                                           \n",
      "Epoch  15/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5988, AUC = 80.03%]                                                                                                                                           \n",
      "Epoch  16/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5945, AUC = 80.03%]                                                                                                                                           \n",
      "Epoch  17/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5892, AUC = 80.68%]                                                                                                                                           \n",
      "Epoch  18/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5818, AUC = 80.19%]                                                                                                                                           \n",
      "Epoch  19/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5720, AUC = 82.69%]                                                                                                                                           \n",
      "Epoch  20/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5702, AUC = 81.64%]                                                                                                                                           \n",
      "Epoch  21/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5668, AUC = 81.40%]                                                                                                                                           \n",
      "Epoch  22/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5646, AUC = 81.28%]                                                                                                                                           \n",
      "Epoch  23/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5638, AUC = 79.39%]                                                                                                                                           \n",
      "Epoch  24/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5642, AUC = 78.66%]                                                                                                                                           \n",
      "Epoch  25/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5622, AUC = 79.79%]                                                                                                                                           \n",
      "Epoch  26/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5572, AUC = 79.91%]                                                                                                                                           \n",
      "Epoch  27/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5591, AUC = 79.23%]                                                                                                                                           \n",
      "Epoch  28/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5528, AUC = 79.67%]                                                                                                                                           \n",
      "Epoch  29/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5534, AUC = 79.71%]                                                                                                                                           \n",
      "Epoch  30/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5529, AUC = 78.78%]                                                                                                                                           \n",
      "Epoch  31/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5465, AUC = 80.19%]                                                                                                                                           \n",
      "Epoch  32/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5424, AUC = 80.52%]                                                                                                                                           \n",
      "Epoch  33/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5410, AUC = 80.76%]                                                                                                                                           \n",
      "Epoch  34/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5401, AUC = 81.08%]                                                                                                                                           \n",
      "Epoch  35/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5338, AUC = 81.56%]                                                                                                                                           \n",
      "Epoch  36/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5392, AUC = 79.71%]                                                                                                                                           \n",
      "Epoch  37/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5355, AUC = 80.96%]                                                                                                                                           \n",
      "Epoch  38/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5305, AUC = 81.84%]                                                                                                                                           \n",
      "Epoch  39/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5294, AUC = 81.80%]                                                                                                                                           \n",
      "Epoch  40/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5267, AUC = 82.29%]                                                                                                                                           \n",
      "Epoch  41/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5229, AUC = 83.09%]                                                                                                                                           \n",
      "Epoch  42/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5190, AUC = 82.81%]                                                                                                                                           \n",
      "Epoch  43/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5169, AUC = 83.29%]                                                                                                                                           \n",
      "Epoch  44/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5166, AUC = 82.57%]                                                                                                                                           \n",
      "Epoch  45/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5119, AUC = 83.57%]                                                                                                                                           \n",
      "Epoch  46/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5166, AUC = 82.93%]                                                                                                                                           \n",
      "Epoch  47/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5141, AUC = 82.97%]                                                                                                                                           \n",
      "Epoch  48/50: 100%|██████████| 100/100 [01:08<00:00,  1.46batch/s, Loss = 0.5047, AUC = 83.86%]                                                                                                                                           \n",
      "Epoch  49/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5077, AUC = 83.94%]                                                                                                                                           \n",
      "Epoch  50/50: 100%|██████████| 100/100 [01:08<00:00,  1.47batch/s, Loss = 0.5034, AUC = 84.02%]                                                                                                                                           "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 3411.80s\n",
      "BEST AUC = 85.14% AT EPOCH 7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpctorch.quantum.MLP(20)\n",
    "qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### With `lightning.gpu` quantum device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.\n",
      "  warnings.warn('Lazy modules are a new feature under heavy development '\n",
      "Epoch   1/50:   0%|          | 0/100 [00:00<?, ?batch/s]                                                                                                                                                                                  "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/50: 100%|██████████| 100/100 [00:14<00:00,  6.87batch/s, Loss = 0.6969, AUC = 43.20%]                                                                                                                                           \n",
      "Epoch   2/50: 100%|██████████| 100/100 [00:14<00:00,  6.89batch/s, Loss = 0.6929, AUC = 46.42%]                                                                                                                                           \n",
      "Epoch   3/50: 100%|██████████| 100/100 [00:14<00:00,  6.75batch/s, Loss = 0.6911, AUC = 48.99%]                                                                                                                                           \n",
      "Epoch   4/50: 100%|██████████| 100/100 [00:15<00:00,  6.62batch/s, Loss = 0.6891, AUC = 51.43%]                                                                                                                                           \n",
      "Epoch   5/50: 100%|██████████| 100/100 [00:14<00:00,  6.67batch/s, Loss = 0.6868, AUC = 56.34%]                                                                                                                                           \n",
      "Epoch   6/50: 100%|██████████| 100/100 [00:15<00:00,  6.64batch/s, Loss = 0.6839, AUC = 59.76%]                                                                                                                                           \n",
      "Epoch   7/50: 100%|██████████| 100/100 [00:15<00:00,  6.63batch/s, Loss = 0.6807, AUC = 63.39%]                                                                                                                                           \n",
      "Epoch   8/50: 100%|██████████| 100/100 [00:15<00:00,  6.61batch/s, Loss = 0.6774, AUC = 65.24%]                                                                                                                                           \n",
      "Epoch   9/50: 100%|██████████| 100/100 [00:14<00:00,  6.68batch/s, Loss = 0.6722, AUC = 67.35%]                                                                                                                                           \n",
      "Epoch  10/50: 100%|██████████| 100/100 [00:15<00:00,  6.58batch/s, Loss = 0.6651, AUC = 69.73%]                                                                                                                                           \n",
      "Epoch  11/50: 100%|██████████| 100/100 [00:15<00:00,  6.64batch/s, Loss = 0.6557, AUC = 72.34%]                                                                                                                                           \n",
      "Epoch  12/50: 100%|██████████| 100/100 [00:14<00:00,  6.69batch/s, Loss = 0.6420, AUC = 77.03%]                                                                                                                                           \n",
      "Epoch  13/50: 100%|██████████| 100/100 [00:14<00:00,  6.69batch/s, Loss = 0.6210, AUC = 79.51%]                                                                                                                                           \n",
      "Epoch  14/50: 100%|██████████| 100/100 [00:15<00:00,  6.66batch/s, Loss = 0.5979, AUC = 84.90%]                                                                                                                                           \n",
      "Epoch  15/50: 100%|██████████| 100/100 [00:14<00:00,  6.85batch/s, Loss = 0.5651, AUC = 87.60%]                                                                                                                                           \n",
      "Epoch  16/50: 100%|██████████| 100/100 [00:14<00:00,  6.88batch/s, Loss = 0.5381, AUC = 89.86%]                                                                                                                                           \n",
      "Epoch  17/50: 100%|██████████| 100/100 [00:14<00:00,  6.83batch/s, Loss = 0.5165, AUC = 91.06%]                                                                                                                                           \n",
      "Epoch  18/50: 100%|██████████| 100/100 [00:14<00:00,  6.76batch/s, Loss = 0.5135, AUC = 87.16%]                                                                                                                                           \n",
      "Epoch  19/50: 100%|██████████| 100/100 [00:14<00:00,  6.77batch/s, Loss = 0.4960, AUC = 84.14%]                                                                                                                                           \n",
      "Epoch  20/50: 100%|██████████| 100/100 [00:14<00:00,  6.77batch/s, Loss = 0.4720, AUC = 90.34%]                                                                                                                                           \n",
      "Epoch  21/50: 100%|██████████| 100/100 [00:14<00:00,  6.89batch/s, Loss = 0.4386, AUC = 92.83%]                                                                                                                                           \n",
      "Epoch  22/50: 100%|██████████| 100/100 [00:14<00:00,  6.90batch/s, Loss = 0.4186, AUC = 93.16%]                                                                                                                                           \n",
      "Epoch  23/50: 100%|██████████| 100/100 [00:15<00:00,  6.56batch/s, Loss = 0.4022, AUC = 92.31%]                                                                                                                                           \n",
      "Epoch  24/50: 100%|██████████| 100/100 [00:15<00:00,  6.55batch/s, Loss = 0.3756, AUC = 94.69%]                                                                                                                                           \n",
      "Epoch  25/50: 100%|██████████| 100/100 [00:15<00:00,  6.60batch/s, Loss = 0.3642, AUC = 95.05%]                                                                                                                                           \n",
      "Epoch  26/50: 100%|██████████| 100/100 [00:15<00:00,  6.57batch/s, Loss = 0.3573, AUC = 94.12%]                                                                                                                                           \n",
      "Epoch  27/50: 100%|██████████| 100/100 [00:15<00:00,  6.36batch/s, Loss = 0.3317, AUC = 95.85%]                                                                                                                                           \n",
      "Epoch  28/50: 100%|██████████| 100/100 [00:15<00:00,  6.45batch/s, Loss = 0.3317, AUC = 95.13%]                                                                                                                                           \n",
      "Epoch  29/50: 100%|██████████| 100/100 [00:15<00:00,  6.56batch/s, Loss = 0.3079, AUC = 96.05%]                                                                                                                                           \n",
      "Epoch  30/50: 100%|██████████| 100/100 [00:15<00:00,  6.55batch/s, Loss = 0.2991, AUC = 96.90%]                                                                                                                                           \n",
      "Epoch  31/50: 100%|██████████| 100/100 [00:15<00:00,  6.42batch/s, Loss = 0.2892, AUC = 97.22%]                                                                                                                                           \n",
      "Epoch  32/50: 100%|██████████| 100/100 [00:15<00:00,  6.65batch/s, Loss = 0.2765, AUC = 97.46%]                                                                                                                                           \n",
      "Epoch  33/50: 100%|██████████| 100/100 [00:14<00:00,  6.84batch/s, Loss = 0.2624, AUC = 97.71%]                                                                                                                                           \n",
      "Epoch  34/50: 100%|██████████| 100/100 [00:14<00:00,  6.92batch/s, Loss = 0.2418, AUC = 98.35%]                                                                                                                                           \n",
      "Epoch  35/50: 100%|██████████| 100/100 [00:14<00:00,  6.89batch/s, Loss = 0.2313, AUC = 98.51%]                                                                                                                                           \n",
      "Epoch  36/50: 100%|██████████| 100/100 [00:14<00:00,  7.00batch/s, Loss = 0.2214, AUC = 98.63%]                                                                                                                                           \n",
      "Epoch  37/50: 100%|██████████| 100/100 [00:14<00:00,  6.91batch/s, Loss = 0.2141, AUC = 98.75%]                                                                                                                                           \n",
      "Epoch  38/50: 100%|██████████| 100/100 [00:14<00:00,  7.01batch/s, Loss = 0.2050, AUC = 98.87%]                                                                                                                                           \n",
      "Epoch  39/50: 100%|██████████| 100/100 [00:14<00:00,  7.06batch/s, Loss = 0.2015, AUC = 98.95%]                                                                                                                                           \n",
      "Epoch  40/50: 100%|██████████| 100/100 [00:14<00:00,  6.91batch/s, Loss = 0.1932, AUC = 98.95%]                                                                                                                                           \n",
      "Epoch  41/50: 100%|██████████| 100/100 [00:15<00:00,  6.55batch/s, Loss = 0.1829, AUC = 99.19%]                                                                                                                                           \n",
      "Epoch  42/50: 100%|██████████| 100/100 [00:14<00:00,  6.67batch/s, Loss = 0.1768, AUC = 99.32%]                                                                                                                                           \n",
      "Epoch  43/50: 100%|██████████| 100/100 [00:14<00:00,  6.77batch/s, Loss = 0.1726, AUC = 99.32%]                                                                                                                                           \n",
      "Epoch  44/50: 100%|██████████| 100/100 [00:15<00:00,  6.46batch/s, Loss = 0.1685, AUC = 99.24%]                                                                                                                                           \n",
      "Epoch  45/50: 100%|██████████| 100/100 [00:14<00:00,  6.75batch/s, Loss = 0.1625, AUC = 99.44%]                                                                                                                                           \n",
      "Epoch  46/50: 100%|██████████| 100/100 [00:15<00:00,  6.57batch/s, Loss = 0.1570, AUC = 99.44%]                                                                                                                                           \n",
      "Epoch  47/50: 100%|██████████| 100/100 [00:14<00:00,  6.73batch/s, Loss = 0.1523, AUC = 99.48%]                                                                                                                                           \n",
      "Epoch  48/50: 100%|██████████| 100/100 [00:14<00:00,  6.72batch/s, Loss = 0.1443, AUC = 99.56%]                                                                                                                                           \n",
      "Epoch  49/50: 100%|██████████| 100/100 [00:14<00:00,  6.68batch/s, Loss = 0.1402, AUC = 99.56%]                                                                                                                                           \n",
      "Epoch  50/50: 100%|██████████| 100/100 [00:14<00:00,  6.70batch/s, Loss = 0.1357, AUC = 99.60%]                                                                                                                                           "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 745.31s\n",
      "BEST AUC = 99.60% AT EPOCH 50\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpctorch.quantum.MLP(5, qdevice=\"lightning.gpu\")\n",
    "qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.\n",
      "  warnings.warn('Lazy modules are a new feature under heavy development '\n",
      "Epoch   1/50:   0%|          | 0/100 [00:00<?, ?batch/s]                                                                                                                                                                                  "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/50: 100%|██████████| 100/100 [03:30<00:00,  2.11s/batch, Loss = 0.6915, AUC = 52.29%]                                                                                                                                           \n",
      "Epoch   2/50: 100%|██████████| 100/100 [03:30<00:00,  2.10s/batch, Loss = 0.6864, AUC = 60.67%]                                                                                                                                           \n",
      "Epoch   3/50: 100%|██████████| 100/100 [03:39<00:00,  2.20s/batch, Loss = 0.6796, AUC = 62.80%]                                                                                                                                           \n",
      "Epoch   4/50: 100%|██████████| 100/100 [03:42<00:00,  2.22s/batch, Loss = 0.6730, AUC = 66.06%]                                                                                                                                           \n",
      "Epoch   5/50: 100%|██████████| 100/100 [03:39<00:00,  2.20s/batch, Loss = 0.6622, AUC = 69.36%]                                                                                                                                           \n",
      "Epoch   6/50:  47%|████▋     | 47/100 [01:41<01:54,  2.16s/batch]                                                                                                                                                                         \n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m model \u001b[39m=\u001b[39m qpctorch\u001b[39m.\u001b[39mquantum\u001b[39m.\u001b[39mMLP(\u001b[39m20\u001b[39m, qdevice\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mlightning.gpu\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m qpctorch\u001b[39m.\u001b[39;49mtraining\u001b[39m.\u001b[39;49mtrain_and_evaluate(model, train_dataloader, valid_dataloader, device\u001b[39m=\u001b[39;49mdevice, num_classes\u001b[39m=\u001b[39;49m\u001b[39m2\u001b[39;49m, num_epochs\u001b[39m=\u001b[39;49m\u001b[39m50\u001b[39;49m)\n",
      "File \u001b[0;32m/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/torch_backend/training.py:56\u001b[0m, in \u001b[0;36mtrain_and_evaluate\u001b[0;34m(model, train_dataloader, valid_dataloader, num_classes, num_epochs, device, learning_rate, verbose)\u001b[0m\n\u001b[1;32m     53\u001b[0m     \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m Loss (\u001b[39m\u001b[39m{\u001b[39;00mtime\u001b[39m.\u001b[39mtime()\u001b[39m-\u001b[39moperation_start_time\u001b[39m:\u001b[39;00m\u001b[39m.2f\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39ms)\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m     54\u001b[0m     operation_start_time \u001b[39m=\u001b[39m time\u001b[39m.\u001b[39mtime()\n\u001b[0;32m---> 56\u001b[0m loss\u001b[39m.\u001b[39;49mbackward()\n\u001b[1;32m     58\u001b[0m \u001b[39mif\u001b[39;00m verbose:\n\u001b[1;32m     59\u001b[0m     \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m Backward (\u001b[39m\u001b[39m{\u001b[39;00mtime\u001b[39m.\u001b[39mtime()\u001b[39m-\u001b[39moperation_start_time\u001b[39m:\u001b[39;00m\u001b[39m.2f\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39ms)\u001b[39m\u001b[39m\"\u001b[39m)\n",
      "File \u001b[0;32m/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    477\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_unary(\u001b[39mself\u001b[39m):\n\u001b[1;32m    478\u001b[0m     \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    479\u001b[0m         Tensor\u001b[39m.\u001b[39mbackward,\n\u001b[1;32m    480\u001b[0m         (\u001b[39mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    485\u001b[0m         inputs\u001b[39m=\u001b[39minputs,\n\u001b[1;32m    486\u001b[0m     )\n\u001b[0;32m--> 487\u001b[0m torch\u001b[39m.\u001b[39;49mautograd\u001b[39m.\u001b[39;49mbackward(\n\u001b[1;32m    488\u001b[0m     \u001b[39mself\u001b[39;49m, gradient, retain_graph, create_graph, inputs\u001b[39m=\u001b[39;49minputs\n\u001b[1;32m    489\u001b[0m )\n",
      "File \u001b[0;32m/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/autograd/__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    195\u001b[0m     retain_graph \u001b[39m=\u001b[39m create_graph\n\u001b[1;32m    197\u001b[0m \u001b[39m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m    198\u001b[0m \u001b[39m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    199\u001b[0m \u001b[39m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 200\u001b[0m Variable\u001b[39m.\u001b[39;49m_execution_engine\u001b[39m.\u001b[39;49mrun_backward(  \u001b[39m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    201\u001b[0m     tensors, grad_tensors_, retain_graph, create_graph, inputs,\n\u001b[1;32m    202\u001b[0m     allow_unreachable\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, accumulate_grad\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "model = qpctorch.quantum.MLP(20, qdevice=\"lightning.gpu\")\n",
    "qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### With adjoint differentiation (using `default.qubit`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.\n",
      "  warnings.warn('Lazy modules are a new feature under heavy development '\n",
      "Epoch   1/50:   0%|          | 0/100 [00:00<?, ?batch/s]                                                                                                                                                                                  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"/tmp/ipykernel_1766250/3985056997.py\", line 3, in <module>\n",
      "    qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50)\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/torch_backend/training.py\", line 41, in train_and_evaluate\n",
      "    outputs = model(inputs)\n",
      "              ^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n",
      "    return forward_call(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/torch_backend/quantum/mlp.py\", line 20, in forward\n",
      "    x = self.fc2(x)\n",
      "        ^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n",
      "    return forward_call(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/container.py\", line 217, in forward\n",
      "    input = module(input)\n",
      "            ^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n",
      "    return forward_call(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/torch_backend/quantum/pennylane_backend.py\", line 20, in forward\n",
      "    return self.linear(inputs)\n",
      "           ^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1501, in _call_impl\n",
      "    return forward_call(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/qnn/torch.py\", line 408, in forward\n",
      "    results = self._evaluate_qnode(inputs)\n",
      "              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/qnn/torch.py\", line 429, in _evaluate_qnode\n",
      "    res = self.qnode(**kwargs)\n",
      "          ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/qnode.py\", line 950, in __call__\n",
      "    res = qml.execute(\n",
      "          ^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py\", line 642, in execute\n",
      "    res = _execute(\n",
      "          ^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/torch.py\", line 498, in execute\n",
      "    return ExecuteTapes.apply(kwargs, *parameters)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/torch.py\", line 262, in new_apply\n",
      "    flat_out = orig_apply(out_struct_holder, *inp)\n",
      "               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/autograd/function.py\", line 506, in apply\n",
      "    return super().apply(*args, **kwargs)  # type: ignore[misc]\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/torch.py\", line 266, in new_forward\n",
      "    out = orig_fw(ctx, *inp)\n",
      "          ^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/torch.py\", line 343, in forward\n",
      "    res, ctx.jacs = ctx.execute_fn(unwrapped_tapes, **ctx.gradient_kwargs)\n",
      "                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/contextlib.py\", line 81, in inner\n",
      "    return func(*args, **kwds)\n",
      "           ^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_device.py\", line 511, in execute_and_gradients\n",
      "    jacs.append(gradient_method(circuit, **kwargs))\n",
      "                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py\", line 1917, in adjoint_jacobian\n",
      "    raise qml.QuantumFunctionError(\n",
      "pennylane.QuantumFunctionError: Parameter broadcasting is not supported with adjoint differentiation\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    model = qpctorch.quantum.MLP(5, qdiff_method=\"adjoint\")\n",
    "    qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50)\n",
    "except qml.QuantumFunctionError as e:\n",
    "    print(traceback.format_exc())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum neural network with TensorCircuit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-08-14 06:37:59.047600: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
      "Please first ``pip install -U cirq`` to enable related functionality in translation module\n",
      "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.\n",
      "  warnings.warn('Lazy modules are a new feature under heavy development '\n",
      "Epoch   1/50: 100%|██████████| 100/100 [00:02<00:00, 42.90batch/s, Loss = 0.6891, AUC = 58.74%]                                                                                                                                           \n",
      "Epoch   2/50: 100%|██████████| 100/100 [00:00<00:00, 151.17batch/s, Loss = 0.6810, AUC = 65.32%]                                                                                                                                          \n",
      "Epoch   3/50: 100%|██████████| 100/100 [00:00<00:00, 155.32batch/s, Loss = 0.6779, AUC = 67.21%]                                                                                                                                          \n",
      "Epoch   4/50: 100%|██████████| 100/100 [00:00<00:00, 147.67batch/s, Loss = 0.6704, AUC = 71.42%]                                                                                                                                          \n",
      "Epoch   5/50: 100%|██████████| 100/100 [00:00<00:00, 142.86batch/s, Loss = 0.6527, AUC = 80.80%]                                                                                                                                          \n",
      "Epoch   6/50: 100%|██████████| 100/100 [00:00<00:00, 154.97batch/s, Loss = 0.6350, AUC = 82.29%]                                                                                                                                          \n",
      "Epoch   7/50: 100%|██████████| 100/100 [00:00<00:00, 147.00batch/s, Loss = 0.6243, AUC = 84.02%]                                                                                                                                          \n",
      "Epoch   8/50: 100%|██████████| 100/100 [00:00<00:00, 145.51batch/s, Loss = 0.6188, AUC = 85.79%]                                                                                                                                          \n",
      "Epoch   9/50: 100%|██████████| 100/100 [00:00<00:00, 148.45batch/s, Loss = 0.6081, AUC = 85.99%]                                                                                                                                          \n",
      "Epoch  10/50: 100%|██████████| 100/100 [00:00<00:00, 148.19batch/s, Loss = 0.6011, AUC = 86.55%]                                                                                                                                          \n",
      "Epoch  11/50: 100%|██████████| 100/100 [00:00<00:00, 151.42batch/s, Loss = 0.5925, AUC = 86.59%]                                                                                                                                          \n",
      "Epoch  12/50: 100%|██████████| 100/100 [00:00<00:00, 150.29batch/s, Loss = 0.5875, AUC = 86.11%]                                                                                                                                          \n",
      "Epoch  13/50: 100%|██████████| 100/100 [00:00<00:00, 150.92batch/s, Loss = 0.5836, AUC = 87.08%]                                                                                                                                          \n",
      "Epoch  14/50: 100%|██████████| 100/100 [00:00<00:00, 157.99batch/s, Loss = 0.5751, AUC = 87.52%]                                                                                                                                          \n",
      "Epoch  15/50: 100%|██████████| 100/100 [00:00<00:00, 160.59batch/s, Loss = 0.5713, AUC = 87.88%]                                                                                                                                          \n",
      "Epoch  16/50: 100%|██████████| 100/100 [00:00<00:00, 150.74batch/s, Loss = 0.5703, AUC = 87.60%]                                                                                                                                          \n",
      "Epoch  17/50: 100%|██████████| 100/100 [00:00<00:00, 157.43batch/s, Loss = 0.5646, AUC = 87.84%]                                                                                                                                          \n",
      "Epoch  18/50: 100%|██████████| 100/100 [00:00<00:00, 158.14batch/s, Loss = 0.5567, AUC = 87.84%]                                                                                                                                          \n",
      "Epoch  19/50: 100%|██████████| 100/100 [00:00<00:00, 150.19batch/s, Loss = 0.5519, AUC = 88.33%]                                                                                                                                          \n",
      "Epoch  20/50: 100%|██████████| 100/100 [00:00<00:00, 154.39batch/s, Loss = 0.5481, AUC = 88.29%]                                                                                                                                          \n",
      "Epoch  21/50: 100%|██████████| 100/100 [00:00<00:00, 148.94batch/s, Loss = 0.5473, AUC = 87.96%]                                                                                                                                          \n",
      "Epoch  22/50: 100%|██████████| 100/100 [00:00<00:00, 148.92batch/s, Loss = 0.5415, AUC = 87.84%]                                                                                                                                          \n",
      "Epoch  23/50: 100%|██████████| 100/100 [00:00<00:00, 154.60batch/s, Loss = 0.5355, AUC = 87.96%]                                                                                                                                          \n",
      "Epoch  24/50: 100%|██████████| 100/100 [00:00<00:00, 155.74batch/s, Loss = 0.5313, AUC = 87.88%]                                                                                                                                          \n",
      "Epoch  25/50: 100%|██████████| 100/100 [00:00<00:00, 159.66batch/s, Loss = 0.5289, AUC = 88.00%]                                                                                                                                          \n",
      "Epoch  26/50: 100%|██████████| 100/100 [00:00<00:00, 139.80batch/s, Loss = 0.5256, AUC = 88.12%]                                                                                                                                          \n",
      "Epoch  27/50: 100%|██████████| 100/100 [00:00<00:00, 148.83batch/s, Loss = 0.5174, AUC = 88.33%]                                                                                                                                          \n",
      "Epoch  28/50: 100%|██████████| 100/100 [00:00<00:00, 153.29batch/s, Loss = 0.5172, AUC = 87.96%]                                                                                                                                          \n",
      "Epoch  29/50: 100%|██████████| 100/100 [00:00<00:00, 152.14batch/s, Loss = 0.5112, AUC = 87.88%]                                                                                                                                          \n",
      "Epoch  30/50: 100%|██████████| 100/100 [00:00<00:00, 152.63batch/s, Loss = 0.5044, AUC = 88.77%]                                                                                                                                          \n",
      "Epoch  31/50: 100%|██████████| 100/100 [00:00<00:00, 147.58batch/s, Loss = 0.4999, AUC = 88.73%]                                                                                                                                          \n",
      "Epoch  32/50: 100%|██████████| 100/100 [00:00<00:00, 147.92batch/s, Loss = 0.4969, AUC = 89.05%]                                                                                                                                          \n",
      "Epoch  33/50: 100%|██████████| 100/100 [00:00<00:00, 151.72batch/s, Loss = 0.4958, AUC = 89.25%]                                                                                                                                          \n",
      "Epoch  34/50: 100%|██████████| 100/100 [00:00<00:00, 147.32batch/s, Loss = 0.4937, AUC = 89.05%]                                                                                                                                          \n",
      "Epoch  35/50: 100%|██████████| 100/100 [00:00<00:00, 145.02batch/s, Loss = 0.4833, AUC = 89.41%]                                                                                                                                          \n",
      "Epoch  36/50: 100%|██████████| 100/100 [00:00<00:00, 150.53batch/s, Loss = 0.4817, AUC = 89.81%]                                                                                                                                          \n",
      "Epoch  37/50: 100%|██████████| 100/100 [00:00<00:00, 144.57batch/s, Loss = 0.4726, AUC = 89.86%]                                                                                                                                          \n",
      "Epoch  38/50: 100%|██████████| 100/100 [00:00<00:00, 152.08batch/s, Loss = 0.4702, AUC = 89.73%]                                                                                                                                          \n",
      "Epoch  39/50: 100%|██████████| 100/100 [00:00<00:00, 140.29batch/s, Loss = 0.4687, AUC = 90.10%]                                                                                                                                          \n",
      "Epoch  40/50: 100%|██████████| 100/100 [00:00<00:00, 154.08batch/s, Loss = 0.4728, AUC = 90.66%]                                                                                                                                          \n",
      "Epoch  41/50: 100%|██████████| 100/100 [00:00<00:00, 147.76batch/s, Loss = 0.4639, AUC = 90.98%]                                                                                                                                          \n",
      "Epoch  42/50: 100%|██████████| 100/100 [00:00<00:00, 147.88batch/s, Loss = 0.4717, AUC = 90.90%]                                                                                                                                          \n",
      "Epoch  43/50: 100%|██████████| 100/100 [00:00<00:00, 143.26batch/s, Loss = 0.4621, AUC = 90.62%]                                                                                                                                          \n",
      "Epoch  44/50: 100%|██████████| 100/100 [00:00<00:00, 146.81batch/s, Loss = 0.4584, AUC = 91.06%]                                                                                                                                          \n",
      "Epoch  45/50: 100%|██████████| 100/100 [00:00<00:00, 156.51batch/s, Loss = 0.4517, AUC = 90.86%]                                                                                                                                          \n",
      "Epoch  46/50: 100%|██████████| 100/100 [00:00<00:00, 145.10batch/s, Loss = 0.4497, AUC = 91.06%]                                                                                                                                          \n",
      "Epoch  47/50: 100%|██████████| 100/100 [00:00<00:00, 142.77batch/s, Loss = 0.4459, AUC = 91.22%]                                                                                                                                          \n",
      "Epoch  48/50: 100%|██████████| 100/100 [00:00<00:00, 150.91batch/s, Loss = 0.4430, AUC = 91.47%]                                                                                                                                          \n",
      "Epoch  49/50: 100%|██████████| 100/100 [00:00<00:00, 143.53batch/s, Loss = 0.4431, AUC = 91.55%]                                                                                                                                          \n",
      "Epoch  50/50: 100%|██████████| 100/100 [00:00<00:00, 144.54batch/s, Loss = 0.4406, AUC = 91.32%]                                                                                                                                          "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 35.11s\n",
      "BEST AUC = 91.55% AT EPOCH 49\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpctorch.quantum.MLP(5, qml_backend=\"tensorcircuit\")\n",
    "qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.\n",
      "  warnings.warn('Lazy modules are a new feature under heavy development '\n",
      "Epoch   1/50: 100%|██████████| 100/100 [00:47<00:00,  2.09batch/s, Loss = 0.6735, AUC = 68.08%]                                                                                                                                           \n",
      "Epoch   2/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.6623, AUC = 73.87%]                                                                                                                                           \n",
      "Epoch   3/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.6499, AUC = 77.50%]                                                                                                                                           \n",
      "Epoch   4/50: 100%|██████████| 100/100 [00:42<00:00,  2.34batch/s, Loss = 0.6350, AUC = 81.28%]                                                                                                                                           \n",
      "Epoch   5/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.6156, AUC = 84.90%]                                                                                                                                           \n",
      "Epoch   6/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.6021, AUC = 86.67%]                                                                                                                                           \n",
      "Epoch   7/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.5897, AUC = 89.69%]                                                                                                                                           \n",
      "Epoch   8/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.5771, AUC = 88.89%]                                                                                                                                           \n",
      "Epoch   9/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.5573, AUC = 90.66%]                                                                                                                                           \n",
      "Epoch  10/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.5470, AUC = 91.02%]                                                                                                                                           \n",
      "Epoch  11/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.5343, AUC = 93.12%]                                                                                                                                           \n",
      "Epoch  12/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.5295, AUC = 92.55%]                                                                                                                                           \n",
      "Epoch  13/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.5172, AUC = 93.80%]                                                                                                                                           \n",
      "Epoch  14/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.5064, AUC = 94.08%]                                                                                                                                           \n",
      "Epoch  15/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.5031, AUC = 94.16%]                                                                                                                                           \n",
      "Epoch  16/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.4966, AUC = 94.20%]                                                                                                                                           \n",
      "Epoch  17/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.4872, AUC = 94.40%]                                                                                                                                           \n",
      "Epoch  18/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.4807, AUC = 94.77%]                                                                                                                                           \n",
      "Epoch  19/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.4731, AUC = 94.32%]                                                                                                                                           \n",
      "Epoch  20/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.4625, AUC = 95.17%]                                                                                                                                           \n",
      "Epoch  21/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.4572, AUC = 94.93%]                                                                                                                                           \n",
      "Epoch  22/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.4550, AUC = 94.97%]                                                                                                                                           \n",
      "Epoch  23/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.4512, AUC = 95.21%]                                                                                                                                           \n",
      "Epoch  24/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.4492, AUC = 95.13%]                                                                                                                                           \n",
      "Epoch  25/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.4433, AUC = 95.29%]                                                                                                                                           \n",
      "Epoch  26/50: 100%|██████████| 100/100 [00:42<00:00,  2.34batch/s, Loss = 0.4411, AUC = 95.01%]                                                                                                                                           \n",
      "Epoch  27/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.4329, AUC = 95.33%]                                                                                                                                           \n",
      "Epoch  28/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.4282, AUC = 95.25%]                                                                                                                                           \n",
      "Epoch  29/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.4272, AUC = 95.25%]                                                                                                                                           \n",
      "Epoch  30/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.4253, AUC = 95.65%]                                                                                                                                           \n",
      "Epoch  31/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.4202, AUC = 96.05%]                                                                                                                                           \n",
      "Epoch  32/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.4113, AUC = 96.66%]                                                                                                                                           \n",
      "Epoch  33/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.4064, AUC = 97.06%]                                                                                                                                           \n",
      "Epoch  34/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.3994, AUC = 96.46%]                                                                                                                                           \n",
      "Epoch  35/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.3982, AUC = 96.10%]                                                                                                                                           \n",
      "Epoch  36/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.3934, AUC = 96.78%]                                                                                                                                           \n",
      "Epoch  37/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.3932, AUC = 96.70%]                                                                                                                                           \n",
      "Epoch  38/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.3871, AUC = 96.34%]                                                                                                                                           \n",
      "Epoch  39/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.3834, AUC = 96.22%]                                                                                                                                           \n",
      "Epoch  40/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.3859, AUC = 96.05%]                                                                                                                                           \n",
      "Epoch  41/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.3804, AUC = 96.46%]                                                                                                                                           \n",
      "Epoch  42/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.3770, AUC = 97.34%]                                                                                                                                           \n",
      "Epoch  43/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.3715, AUC = 96.98%]                                                                                                                                           \n",
      "Epoch  44/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.3687, AUC = 96.50%]                                                                                                                                           \n",
      "Epoch  45/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.3664, AUC = 96.30%]                                                                                                                                           \n",
      "Epoch  46/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.3639, AUC = 96.98%]                                                                                                                                           \n",
      "Epoch  47/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.3587, AUC = 97.62%]                                                                                                                                           \n",
      "Epoch  48/50: 100%|██████████| 100/100 [00:42<00:00,  2.35batch/s, Loss = 0.3585, AUC = 97.71%]                                                                                                                                           \n",
      "Epoch  49/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.3517, AUC = 97.75%]                                                                                                                                           \n",
      "Epoch  50/50: 100%|██████████| 100/100 [00:42<00:00,  2.36batch/s, Loss = 0.3483, AUC = 98.03%]                                                                                                                                           "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 2128.16s\n",
      "BEST AUC = 98.03% AT EPOCH 50\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpctorch.quantum.MLP(20, qml_backend=\"tensorcircuit\")\n",
    "qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, device=device, num_classes=2, num_epochs=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## JAX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[gpu(id=0)]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'  # See https://github.com/google/jax/issues/12461#issuecomment-1256266598\n",
    "import jaxlib\n",
    "import jax\n",
    "from jax.config import config\n",
    "config.update(\"jax_enable_x64\", True)\n",
    "print(jax.devices())\n",
    "import quantum_transformers.qmlperfcomp.jax_backend as qpcjax\n",
    "train_dataloader, valid_dataloader = qpcjax.data.get_swiss_roll_dataloaders(batch_size=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classical neural network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/50: 100%|██████████| 100/100 [00:01<00:00, 69.44batch/s, Loss = 1.0297, AUC = 39.45%]                                                                                                                                           \n",
      "Epoch   2/50: 100%|██████████| 100/100 [00:00<00:00, 1453.08batch/s, Loss = 0.6693, AUC = 62.95%]                                                                                                                                         \n",
      "Epoch   3/50: 100%|██████████| 100/100 [00:00<00:00, 1442.46batch/s, Loss = 0.5931, AUC = 73.30%]                                                                                                                                         \n",
      "Epoch   4/50: 100%|██████████| 100/100 [00:00<00:00, 1443.16batch/s, Loss = 0.5594, AUC = 77.72%]                                                                                                                                         \n",
      "Epoch   5/50: 100%|██████████| 100/100 [00:00<00:00, 1448.97batch/s, Loss = 0.5358, AUC = 79.38%]                                                                                                                                         \n",
      "Epoch   6/50: 100%|██████████| 100/100 [00:00<00:00, 1458.03batch/s, Loss = 0.5198, AUC = 80.03%]                                                                                                                                         \n",
      "Epoch   7/50: 100%|██████████| 100/100 [00:00<00:00, 1403.54batch/s, Loss = 0.5063, AUC = 79.95%]                                                                                                                                         \n",
      "Epoch   8/50: 100%|██████████| 100/100 [00:00<00:00, 1392.64batch/s, Loss = 0.4941, AUC = 80.52%]                                                                                                                                         \n",
      "Epoch   9/50: 100%|██████████| 100/100 [00:00<00:00, 1405.03batch/s, Loss = 0.4810, AUC = 81.49%]                                                                                                                                         \n",
      "Epoch  10/50: 100%|██████████| 100/100 [00:00<00:00, 1429.67batch/s, Loss = 0.4677, AUC = 81.86%]                                                                                                                                         \n",
      "Epoch  11/50: 100%|██████████| 100/100 [00:00<00:00, 1488.03batch/s, Loss = 0.4598, AUC = 85.92%]                                                                                                                                         \n",
      "Epoch  12/50: 100%|██████████| 100/100 [00:00<00:00, 1488.83batch/s, Loss = 0.4515, AUC = 87.26%]                                                                                                                                         \n",
      "Epoch  13/50: 100%|██████████| 100/100 [00:00<00:00, 1481.58batch/s, Loss = 0.4356, AUC = 87.42%]                                                                                                                                         \n",
      "Epoch  14/50: 100%|██████████| 100/100 [00:00<00:00, 1487.22batch/s, Loss = 0.4239, AUC = 88.96%]                                                                                                                                         \n",
      "Epoch  15/50: 100%|██████████| 100/100 [00:00<00:00, 1488.08batch/s, Loss = 0.4163, AUC = 89.98%]                                                                                                                                         \n",
      "Epoch  16/50: 100%|██████████| 100/100 [00:00<00:00, 1481.24batch/s, Loss = 0.4036, AUC = 90.83%]                                                                                                                                         \n",
      "Epoch  17/50: 100%|██████████| 100/100 [00:00<00:00, 1486.38batch/s, Loss = 0.3919, AUC = 91.36%]                                                                                                                                         \n",
      "Epoch  18/50: 100%|██████████| 100/100 [00:00<00:00, 1481.41batch/s, Loss = 0.3807, AUC = 92.29%]                                                                                                                                         \n",
      "Epoch  19/50: 100%|██████████| 100/100 [00:00<00:00, 1482.16batch/s, Loss = 0.3761, AUC = 94.16%]                                                                                                                                         \n",
      "Epoch  20/50: 100%|██████████| 100/100 [00:00<00:00, 1483.34batch/s, Loss = 0.3676, AUC = 94.89%]                                                                                                                                         \n",
      "Epoch  21/50: 100%|██████████| 100/100 [00:00<00:00, 1474.49batch/s, Loss = 0.3585, AUC = 95.37%]                                                                                                                                         \n",
      "Epoch  22/50: 100%|██████████| 100/100 [00:00<00:00, 1485.16batch/s, Loss = 0.3452, AUC = 95.50%]                                                                                                                                         \n",
      "Epoch  23/50: 100%|██████████| 100/100 [00:00<00:00, 1485.69batch/s, Loss = 0.3415, AUC = 95.86%]                                                                                                                                         \n",
      "Epoch  24/50: 100%|██████████| 100/100 [00:00<00:00, 1410.84batch/s, Loss = 0.3362, AUC = 96.75%]                                                                                                                                         \n",
      "Epoch  25/50: 100%|██████████| 100/100 [00:00<00:00, 1389.47batch/s, Loss = 0.3224, AUC = 96.75%]                                                                                                                                         \n",
      "Epoch  26/50: 100%|██████████| 100/100 [00:00<00:00, 1397.65batch/s, Loss = 0.3244, AUC = 97.52%]                                                                                                                                         \n",
      "Epoch  27/50: 100%|██████████| 100/100 [00:00<00:00, 1413.58batch/s, Loss = 0.3074, AUC = 97.12%]                                                                                                                                         \n",
      "Epoch  28/50: 100%|██████████| 100/100 [00:00<00:00, 1404.26batch/s, Loss = 0.3011, AUC = 97.36%]                                                                                                                                         \n",
      "Epoch  29/50: 100%|██████████| 100/100 [00:00<00:00, 1399.88batch/s, Loss = 0.2985, AUC = 98.05%]                                                                                                                                         \n",
      "Epoch  30/50: 100%|██████████| 100/100 [00:00<00:00, 1394.42batch/s, Loss = 0.2966, AUC = 98.25%]                                                                                                                                         \n",
      "Epoch  31/50: 100%|██████████| 100/100 [00:00<00:00, 1396.67batch/s, Loss = 0.2849, AUC = 98.17%]                                                                                                                                         \n",
      "Epoch  32/50: 100%|██████████| 100/100 [00:00<00:00, 1400.62batch/s, Loss = 0.2762, AUC = 98.21%]                                                                                                                                         \n",
      "Epoch  33/50: 100%|██████████| 100/100 [00:00<00:00, 1404.48batch/s, Loss = 0.2711, AUC = 98.38%]                                                                                                                                         \n",
      "Epoch  34/50: 100%|██████████| 100/100 [00:00<00:00, 1408.08batch/s, Loss = 0.2665, AUC = 98.50%]                                                                                                                                         \n",
      "Epoch  35/50: 100%|██████████| 100/100 [00:00<00:00, 1405.93batch/s, Loss = 0.2637, AUC = 98.46%]                                                                                                                                         \n",
      "Epoch  36/50: 100%|██████████| 100/100 [00:00<00:00, 1378.59batch/s, Loss = 0.2542, AUC = 98.34%]                                                                                                                                         \n",
      "Epoch  37/50: 100%|██████████| 100/100 [00:00<00:00, 1395.69batch/s, Loss = 0.2522, AUC = 98.58%]                                                                                                                                         \n",
      "Epoch  38/50: 100%|██████████| 100/100 [00:00<00:00, 1404.48batch/s, Loss = 0.2433, AUC = 98.58%]                                                                                                                                         \n",
      "Epoch  39/50: 100%|██████████| 100/100 [00:00<00:00, 1418.36batch/s, Loss = 0.2369, AUC = 98.54%]                                                                                                                                         \n",
      "Epoch  40/50: 100%|██████████| 100/100 [00:00<00:00, 1422.35batch/s, Loss = 0.2366, AUC = 99.03%]                                                                                                                                         \n",
      "Epoch  41/50: 100%|██████████| 100/100 [00:00<00:00, 1443.05batch/s, Loss = 0.2325, AUC = 99.03%]                                                                                                                                         \n",
      "Epoch  42/50: 100%|██████████| 100/100 [00:00<00:00, 1435.31batch/s, Loss = 0.2274, AUC = 99.03%]                                                                                                                                         \n",
      "Epoch  43/50: 100%|██████████| 100/100 [00:00<00:00, 1436.42batch/s, Loss = 0.2253, AUC = 99.19%]                                                                                                                                         \n",
      "Epoch  44/50: 100%|██████████| 100/100 [00:00<00:00, 1434.25batch/s, Loss = 0.2147, AUC = 99.03%]                                                                                                                                         \n",
      "Epoch  45/50: 100%|██████████| 100/100 [00:00<00:00, 1417.62batch/s, Loss = 0.2075, AUC = 98.66%]                                                                                                                                         \n",
      "Epoch  46/50: 100%|██████████| 100/100 [00:00<00:00, 1413.89batch/s, Loss = 0.2040, AUC = 98.78%]                                                                                                                                         \n",
      "Epoch  47/50: 100%|██████████| 100/100 [00:00<00:00, 1439.29batch/s, Loss = 0.2033, AUC = 99.23%]                                                                                                                                         \n",
      "Epoch  48/50: 100%|██████████| 100/100 [00:00<00:00, 1445.87batch/s, Loss = 0.1945, AUC = 99.23%]                                                                                                                                         \n",
      "Epoch  49/50: 100%|██████████| 100/100 [00:00<00:00, 1420.89batch/s, Loss = 0.1896, AUC = 99.23%]                                                                                                                                         \n",
      "Epoch  50/50: 100%|██████████| 100/100 [00:00<00:00, 1474.60batch/s, Loss = 0.1835, AUC = 99.31%]                                                                                                                                         "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 4.92s\n",
      "BEST AUC = 99.31% AT EPOCH 50\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpcjax.classical.MLP(5)\n",
    "qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/50: 100%|██████████| 100/100 [00:00<00:00, 187.19batch/s, Loss = 0.5377, AUC = 79.30%]                                                                                                                                          \n",
      "Epoch   2/50: 100%|██████████| 100/100 [00:00<00:00, 1409.61batch/s, Loss = 0.4650, AUC = 91.64%]                                                                                                                                         \n",
      "Epoch   3/50: 100%|██████████| 100/100 [00:00<00:00, 1395.46batch/s, Loss = 0.4134, AUC = 90.75%]                                                                                                                                         \n",
      "Epoch   4/50: 100%|██████████| 100/100 [00:00<00:00, 1335.83batch/s, Loss = 0.3940, AUC = 93.99%]                                                                                                                                         \n",
      "Epoch   5/50: 100%|██████████| 100/100 [00:00<00:00, 1378.23batch/s, Loss = 0.3938, AUC = 96.55%]                                                                                                                                         \n",
      "Epoch   6/50: 100%|██████████| 100/100 [00:00<00:00, 1382.34batch/s, Loss = 0.3682, AUC = 95.01%]                                                                                                                                         \n",
      "Epoch   7/50: 100%|██████████| 100/100 [00:00<00:00, 1382.11batch/s, Loss = 0.3650, AUC = 95.33%]                                                                                                                                         \n",
      "Epoch   8/50: 100%|██████████| 100/100 [00:00<00:00, 1396.39batch/s, Loss = 0.3501, AUC = 96.79%]                                                                                                                                         \n",
      "Epoch   9/50: 100%|██████████| 100/100 [00:00<00:00, 1396.43batch/s, Loss = 0.3268, AUC = 94.81%]                                                                                                                                         \n",
      "Epoch  10/50: 100%|██████████| 100/100 [00:00<00:00, 1373.34batch/s, Loss = 0.3238, AUC = 96.83%]                                                                                                                                         \n",
      "Epoch  11/50: 100%|██████████| 100/100 [00:00<00:00, 1411.97batch/s, Loss = 0.3071, AUC = 96.06%]                                                                                                                                         \n",
      "Epoch  12/50: 100%|██████████| 100/100 [00:00<00:00, 1382.57batch/s, Loss = 0.2919, AUC = 98.05%]                                                                                                                                         \n",
      "Epoch  13/50: 100%|██████████| 100/100 [00:00<00:00, 1398.65batch/s, Loss = 0.2796, AUC = 97.73%]                                                                                                                                         \n",
      "Epoch  14/50: 100%|██████████| 100/100 [00:00<00:00, 1398.80batch/s, Loss = 0.2714, AUC = 97.61%]                                                                                                                                         \n",
      "Epoch  15/50: 100%|██████████| 100/100 [00:00<00:00, 1390.47batch/s, Loss = 0.2578, AUC = 98.90%]                                                                                                                                         \n",
      "Epoch  16/50: 100%|██████████| 100/100 [00:00<00:00, 1392.71batch/s, Loss = 0.2430, AUC = 98.78%]                                                                                                                                         \n",
      "Epoch  17/50: 100%|██████████| 100/100 [00:00<00:00, 1382.40batch/s, Loss = 0.2307, AUC = 98.78%]                                                                                                                                         \n",
      "Epoch  18/50: 100%|██████████| 100/100 [00:00<00:00, 1401.80batch/s, Loss = 0.2278, AUC = 98.58%]                                                                                                                                         \n",
      "Epoch  19/50: 100%|██████████| 100/100 [00:00<00:00, 1410.90batch/s, Loss = 0.2134, AUC = 99.15%]                                                                                                                                         \n",
      "Epoch  20/50: 100%|██████████| 100/100 [00:00<00:00, 1396.10batch/s, Loss = 0.2044, AUC = 99.11%]                                                                                                                                         \n",
      "Epoch  21/50: 100%|██████████| 100/100 [00:00<00:00, 1382.41batch/s, Loss = 0.1983, AUC = 99.15%]                                                                                                                                         \n",
      "Epoch  22/50: 100%|██████████| 100/100 [00:00<00:00, 1404.76batch/s, Loss = 0.1982, AUC = 98.99%]                                                                                                                                         \n",
      "Epoch  23/50: 100%|██████████| 100/100 [00:00<00:00, 1420.77batch/s, Loss = 0.1653, AUC = 99.27%]                                                                                                                                         \n",
      "Epoch  24/50: 100%|██████████| 100/100 [00:00<00:00, 1404.17batch/s, Loss = 0.1650, AUC = 99.55%]                                                                                                                                         \n",
      "Epoch  25/50: 100%|██████████| 100/100 [00:00<00:00, 1385.33batch/s, Loss = 0.1433, AUC = 99.76%]                                                                                                                                         \n",
      "Epoch  26/50: 100%|██████████| 100/100 [00:00<00:00, 1383.47batch/s, Loss = 0.1434, AUC = 99.63%]                                                                                                                                         \n",
      "Epoch  27/50: 100%|██████████| 100/100 [00:00<00:00, 1380.63batch/s, Loss = 0.1191, AUC = 99.84%]                                                                                                                                         \n",
      "Epoch  28/50: 100%|██████████| 100/100 [00:00<00:00, 1377.49batch/s, Loss = 0.1123, AUC = 99.88%]                                                                                                                                         \n",
      "Epoch  29/50: 100%|██████████| 100/100 [00:00<00:00, 1396.84batch/s, Loss = 0.1066, AUC = 99.80%]                                                                                                                                         \n",
      "Epoch  30/50: 100%|██████████| 100/100 [00:00<00:00, 1383.41batch/s, Loss = 0.1002, AUC = 99.92%]                                                                                                                                         \n",
      "Epoch  31/50: 100%|██████████| 100/100 [00:00<00:00, 1391.96batch/s, Loss = 0.0888, AUC = 99.88%]                                                                                                                                         \n",
      "Epoch  32/50: 100%|██████████| 100/100 [00:00<00:00, 1397.25batch/s, Loss = 0.0824, AUC = 99.88%]                                                                                                                                         \n",
      "Epoch  33/50: 100%|██████████| 100/100 [00:00<00:00, 1379.99batch/s, Loss = 0.0692, AUC = 99.96%]                                                                                                                                         \n",
      "Epoch  34/50: 100%|██████████| 100/100 [00:00<00:00, 1385.24batch/s, Loss = 0.0687, AUC = 99.96%]                                                                                                                                         \n",
      "Epoch  35/50: 100%|██████████| 100/100 [00:00<00:00, 1380.70batch/s, Loss = 0.0618, AUC = 99.96%]                                                                                                                                         \n",
      "Epoch  36/50: 100%|██████████| 100/100 [00:00<00:00, 1392.99batch/s, Loss = 0.0589, AUC = 99.96%]                                                                                                                                         \n",
      "Epoch  37/50: 100%|██████████| 100/100 [00:00<00:00, 1390.28batch/s, Loss = 0.0534, AUC = 99.96%]                                                                                                                                         \n",
      "Epoch  38/50: 100%|██████████| 100/100 [00:00<00:00, 1388.67batch/s, Loss = 0.0498, AUC = 99.96%]                                                                                                                                         \n",
      "Epoch  39/50: 100%|██████████| 100/100 [00:00<00:00, 1385.46batch/s, Loss = 0.0495, AUC = 99.96%]                                                                                                                                         \n",
      "Epoch  40/50: 100%|██████████| 100/100 [00:00<00:00, 1399.54batch/s, Loss = 0.0429, AUC = 99.96%]                                                                                                                                         \n",
      "Epoch  41/50: 100%|██████████| 100/100 [00:00<00:00, 1393.04batch/s, Loss = 0.0393, AUC = 99.96%]                                                                                                                                         \n",
      "Epoch  42/50: 100%|██████████| 100/100 [00:00<00:00, 1404.14batch/s, Loss = 0.0375, AUC = 99.96%]                                                                                                                                         \n",
      "Epoch  43/50: 100%|██████████| 100/100 [00:00<00:00, 1444.13batch/s, Loss = 0.0410, AUC = 99.96%]                                                                                                                                         \n",
      "Epoch  44/50: 100%|██████████| 100/100 [00:00<00:00, 1399.14batch/s, Loss = 0.0323, AUC = 100.00%]                                                                                                                                        \n",
      "Epoch  45/50: 100%|██████████| 100/100 [00:00<00:00, 1370.26batch/s, Loss = 0.0308, AUC = 100.00%]                                                                                                                                        \n",
      "Epoch  46/50: 100%|██████████| 100/100 [00:00<00:00, 1373.25batch/s, Loss = 0.0304, AUC = 99.96%]                                                                                                                                         \n",
      "Epoch  47/50: 100%|██████████| 100/100 [00:00<00:00, 1378.09batch/s, Loss = 0.0286, AUC = 100.00%]                                                                                                                                        \n",
      "Epoch  48/50: 100%|██████████| 100/100 [00:00<00:00, 1363.94batch/s, Loss = 0.0263, AUC = 100.00%]                                                                                                                                        \n",
      "Epoch  49/50: 100%|██████████| 100/100 [00:00<00:00, 1370.24batch/s, Loss = 0.0268, AUC = 100.00%]                                                                                                                                        \n",
      "Epoch  50/50: 100%|██████████| 100/100 [00:00<00:00, 1373.21batch/s, Loss = 0.0315, AUC = 99.96%]                                                                                                                                         "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 4.13s\n",
      "BEST AUC = 100.00% AT EPOCH 44\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpcjax.classical.MLP(20)\n",
    "qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum neural network with Pennylane"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### With `default.qubit.jax` quantum device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/50: 100%|██████████| 100/100 [00:02<00:00, 34.53batch/s, Loss = 0.6993, AUC = 47.40%]                                                                                                                                           \n",
      "Epoch   2/50: 100%|██████████| 100/100 [00:00<00:00, 964.68batch/s, Loss = 0.6986, AUC = 48.48%]                                                                                                                                          \n",
      "Epoch   3/50: 100%|██████████| 100/100 [00:00<00:00, 970.59batch/s, Loss = 0.6987, AUC = 48.21%]                                                                                                                                          \n",
      "Epoch   4/50: 100%|██████████| 100/100 [00:00<00:00, 969.47batch/s, Loss = 0.7005, AUC = 46.98%]                                                                                                                                          \n",
      "Epoch   5/50: 100%|██████████| 100/100 [00:00<00:00, 948.41batch/s, Loss = 0.7040, AUC = 46.37%]                                                                                                                                          \n",
      "Epoch   6/50: 100%|██████████| 100/100 [00:00<00:00, 964.89batch/s, Loss = 0.7031, AUC = 47.56%]                                                                                                                                          \n",
      "Epoch   7/50: 100%|██████████| 100/100 [00:00<00:00, 975.19batch/s, Loss = 0.7007, AUC = 47.12%]                                                                                                                                          \n",
      "Epoch   8/50: 100%|██████████| 100/100 [00:00<00:00, 971.95batch/s, Loss = 0.7041, AUC = 46.41%]                                                                                                                                          \n",
      "Epoch   9/50: 100%|██████████| 100/100 [00:00<00:00, 972.38batch/s, Loss = 0.7024, AUC = 47.79%]                                                                                                                                          \n",
      "Epoch  10/50: 100%|██████████| 100/100 [00:00<00:00, 970.71batch/s, Loss = 0.7021, AUC = 49.15%]                                                                                                                                          \n",
      "Epoch  11/50: 100%|██████████| 100/100 [00:00<00:00, 970.25batch/s, Loss = 0.7030, AUC = 50.26%]                                                                                                                                          \n",
      "Epoch  12/50: 100%|██████████| 100/100 [00:00<00:00, 964.10batch/s, Loss = 0.7014, AUC = 49.11%]                                                                                                                                          \n",
      "Epoch  13/50: 100%|██████████| 100/100 [00:00<00:00, 954.90batch/s, Loss = 0.7026, AUC = 48.99%]                                                                                                                                          \n",
      "Epoch  14/50: 100%|██████████| 100/100 [00:00<00:00, 959.28batch/s, Loss = 0.7018, AUC = 47.97%]                                                                                                                                          \n",
      "Epoch  15/50: 100%|██████████| 100/100 [00:00<00:00, 967.52batch/s, Loss = 0.7022, AUC = 47.16%]                                                                                                                                          \n",
      "Epoch  16/50: 100%|██████████| 100/100 [00:00<00:00, 971.12batch/s, Loss = 0.6997, AUC = 48.70%]                                                                                                                                          \n",
      "Epoch  17/50: 100%|██████████| 100/100 [00:00<00:00, 965.08batch/s, Loss = 0.7003, AUC = 47.00%]                                                                                                                                          \n",
      "Epoch  18/50: 100%|██████████| 100/100 [00:00<00:00, 972.71batch/s, Loss = 0.7016, AUC = 46.55%]                                                                                                                                          \n",
      "Epoch  19/50: 100%|██████████| 100/100 [00:00<00:00, 968.27batch/s, Loss = 0.7027, AUC = 47.10%]                                                                                                                                          \n",
      "Epoch  20/50: 100%|██████████| 100/100 [00:00<00:00, 957.80batch/s, Loss = 0.7075, AUC = 45.54%]                                                                                                                                          \n",
      "Epoch  21/50: 100%|██████████| 100/100 [00:00<00:00, 950.21batch/s, Loss = 0.7092, AUC = 45.90%]                                                                                                                                          \n",
      "Epoch  22/50: 100%|██████████| 100/100 [00:00<00:00, 969.35batch/s, Loss = 0.7099, AUC = 45.62%]                                                                                                                                          \n",
      "Epoch  23/50: 100%|██████████| 100/100 [00:00<00:00, 960.48batch/s, Loss = 0.7066, AUC = 47.12%]                                                                                                                                          \n",
      "Epoch  24/50: 100%|██████████| 100/100 [00:00<00:00, 963.48batch/s, Loss = 0.7047, AUC = 48.17%]                                                                                                                                          \n",
      "Epoch  25/50: 100%|██████████| 100/100 [00:00<00:00, 968.97batch/s, Loss = 0.7089, AUC = 47.73%]                                                                                                                                          \n",
      "Epoch  26/50: 100%|██████████| 100/100 [00:00<00:00, 959.74batch/s, Loss = 0.7049, AUC = 50.93%]                                                                                                                                          \n",
      "Epoch  27/50: 100%|██████████| 100/100 [00:00<00:00, 961.50batch/s, Loss = 0.7068, AUC = 50.85%]                                                                                                                                          \n",
      "Epoch  28/50: 100%|██████████| 100/100 [00:00<00:00, 969.16batch/s, Loss = 0.7050, AUC = 50.85%]                                                                                                                                          \n",
      "Epoch  29/50: 100%|██████████| 100/100 [00:00<00:00, 998.65batch/s, Loss = 0.7082, AUC = 50.16%]                                                                                                                                          \n",
      "Epoch  30/50: 100%|██████████| 100/100 [00:00<00:00, 1004.50batch/s, Loss = 0.7073, AUC = 51.14%]                                                                                                                                         \n",
      "Epoch  31/50: 100%|██████████| 100/100 [00:00<00:00, 983.91batch/s, Loss = 0.7082, AUC = 51.77%]                                                                                                                                          \n",
      "Epoch  32/50: 100%|██████████| 100/100 [00:00<00:00, 1002.34batch/s, Loss = 0.7032, AUC = 52.76%]                                                                                                                                         \n",
      "Epoch  33/50: 100%|██████████| 100/100 [00:00<00:00, 984.90batch/s, Loss = 0.7072, AUC = 51.83%]                                                                                                                                          \n",
      "Epoch  34/50: 100%|██████████| 100/100 [00:00<00:00, 994.66batch/s, Loss = 0.7064, AUC = 52.80%]                                                                                                                                          \n",
      "Epoch  35/50: 100%|██████████| 100/100 [00:00<00:00, 957.54batch/s, Loss = 0.7015, AUC = 54.06%]                                                                                                                                          \n",
      "Epoch  36/50: 100%|██████████| 100/100 [00:00<00:00, 955.90batch/s, Loss = 0.7051, AUC = 53.69%]                                                                                                                                          \n",
      "Epoch  37/50: 100%|██████████| 100/100 [00:00<00:00, 956.47batch/s, Loss = 0.7030, AUC = 53.53%]                                                                                                                                          \n",
      "Epoch  38/50: 100%|██████████| 100/100 [00:00<00:00, 976.13batch/s, Loss = 0.7049, AUC = 53.12%]                                                                                                                                          \n",
      "Epoch  39/50: 100%|██████████| 100/100 [00:00<00:00, 976.58batch/s, Loss = 0.7039, AUC = 54.48%]                                                                                                                                          \n",
      "Epoch  40/50: 100%|██████████| 100/100 [00:00<00:00, 1002.09batch/s, Loss = 0.7045, AUC = 53.59%]                                                                                                                                         \n",
      "Epoch  41/50: 100%|██████████| 100/100 [00:00<00:00, 979.21batch/s, Loss = 0.7027, AUC = 54.42%]                                                                                                                                          \n",
      "Epoch  42/50: 100%|██████████| 100/100 [00:00<00:00, 991.37batch/s, Loss = 0.7061, AUC = 52.29%]                                                                                                                                          \n",
      "Epoch  43/50: 100%|██████████| 100/100 [00:00<00:00, 982.59batch/s, Loss = 0.7047, AUC = 52.64%]                                                                                                                                          \n",
      "Epoch  44/50: 100%|██████████| 100/100 [00:00<00:00, 985.14batch/s, Loss = 0.7000, AUC = 52.68%]                                                                                                                                          \n",
      "Epoch  45/50: 100%|██████████| 100/100 [00:00<00:00, 966.13batch/s, Loss = 0.6999, AUC = 54.18%]                                                                                                                                          \n",
      "Epoch  46/50: 100%|██████████| 100/100 [00:00<00:00, 973.49batch/s, Loss = 0.6995, AUC = 54.63%]                                                                                                                                          \n",
      "Epoch  47/50: 100%|██████████| 100/100 [00:00<00:00, 969.90batch/s, Loss = 0.7008, AUC = 53.63%]                                                                                                                                          \n",
      "Epoch  48/50: 100%|██████████| 100/100 [00:00<00:00, 987.31batch/s, Loss = 0.6983, AUC = 55.80%]                                                                                                                                          \n",
      "Epoch  49/50: 100%|██████████| 100/100 [00:00<00:00, 960.69batch/s, Loss = 0.6993, AUC = 55.19%]                                                                                                                                          \n",
      "Epoch  50/50: 100%|██████████| 100/100 [00:00<00:00, 1012.18batch/s, Loss = 0.6998, AUC = 55.34%]                                                                                                                                         "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 8.00s\n",
      "BEST AUC = 55.80% AT EPOCH 48\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpcjax.quantum.MLP(5)\n",
    "qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/50: 100%|██████████| 100/100 [01:04<00:00,  1.56batch/s, Loss = 0.6971, AUC = 47.93%]                                                                                                                                           \n",
      "Epoch   2/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6967, AUC = 50.37%]                                                                                                                                           \n",
      "Epoch   3/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6992, AUC = 48.82%]                                                                                                                                           \n",
      "Epoch   4/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.7004, AUC = 47.56%]                                                                                                                                           \n",
      "Epoch   5/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6986, AUC = 49.84%]                                                                                                                                           \n",
      "Epoch   6/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6975, AUC = 52.68%]                                                                                                                                           \n",
      "Epoch   7/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6966, AUC = 53.29%]                                                                                                                                           \n",
      "Epoch   8/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6965, AUC = 49.84%]                                                                                                                                           \n",
      "Epoch   9/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6939, AUC = 52.60%]                                                                                                                                           \n",
      "Epoch  10/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6908, AUC = 55.97%]                                                                                                                                           \n",
      "Epoch  11/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6910, AUC = 56.45%]                                                                                                                                           \n",
      "Epoch  12/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6915, AUC = 53.81%]                                                                                                                                           \n",
      "Epoch  13/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6917, AUC = 53.37%]                                                                                                                                           \n",
      "Epoch  14/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6923, AUC = 52.23%]                                                                                                                                           \n",
      "Epoch  15/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6929, AUC = 51.83%]                                                                                                                                           \n",
      "Epoch  16/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6941, AUC = 50.77%]                                                                                                                                           \n",
      "Epoch  17/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6956, AUC = 48.62%]                                                                                                                                           \n",
      "Epoch  18/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6932, AUC = 50.24%]                                                                                                                                           \n",
      "Epoch  19/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6917, AUC = 51.66%]                                                                                                                                           \n",
      "Epoch  20/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6894, AUC = 52.48%]                                                                                                                                           \n",
      "Epoch  21/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6883, AUC = 56.57%]                                                                                                                                           \n",
      "Epoch  22/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6889, AUC = 58.89%]                                                                                                                                           \n",
      "Epoch  23/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6874, AUC = 56.49%]                                                                                                                                           \n",
      "Epoch  24/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6869, AUC = 56.21%]                                                                                                                                           \n",
      "Epoch  25/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6801, AUC = 62.38%]                                                                                                                                           \n",
      "Epoch  26/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6768, AUC = 63.76%]                                                                                                                                           \n",
      "Epoch  27/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6750, AUC = 65.22%]                                                                                                                                           \n",
      "Epoch  28/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6733, AUC = 67.37%]                                                                                                                                           \n",
      "Epoch  29/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6737, AUC = 62.58%]                                                                                                                                           \n",
      "Epoch  30/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6733, AUC = 63.27%]                                                                                                                                           \n",
      "Epoch  31/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6728, AUC = 60.71%]                                                                                                                                           \n",
      "Epoch  32/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6711, AUC = 62.46%]                                                                                                                                           \n",
      "Epoch  33/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6710, AUC = 63.11%]                                                                                                                                           \n",
      "Epoch  34/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6707, AUC = 62.62%]                                                                                                                                           \n",
      "Epoch  35/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6712, AUC = 62.58%]                                                                                                                                           \n",
      "Epoch  36/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6732, AUC = 61.77%]                                                                                                                                           \n",
      "Epoch  37/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6745, AUC = 59.82%]                                                                                                                                           \n",
      "Epoch  38/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6789, AUC = 61.04%]                                                                                                                                           \n",
      "Epoch  39/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6795, AUC = 57.02%]                                                                                                                                           \n",
      "Epoch  40/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6761, AUC = 63.51%]                                                                                                                                           \n",
      "Epoch  41/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6851, AUC = 53.25%]                                                                                                                                           \n",
      "Epoch  42/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6715, AUC = 61.53%]                                                                                                                                           \n",
      "Epoch  43/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6696, AUC = 66.72%]                                                                                                                                           \n",
      "Epoch  44/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6658, AUC = 67.45%]                                                                                                                                           \n",
      "Epoch  45/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6673, AUC = 66.44%]                                                                                                                                           \n",
      "Epoch  46/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6703, AUC = 59.21%]                                                                                                                                           \n",
      "Epoch  47/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6729, AUC = 58.89%]                                                                                                                                           \n",
      "Epoch  48/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6743, AUC = 55.56%]                                                                                                                                           \n",
      "Epoch  49/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6733, AUC = 56.29%]                                                                                                                                           \n",
      "Epoch  50/50: 100%|██████████| 100/100 [00:55<00:00,  1.80batch/s, Loss = 0.6805, AUC = 53.12%]                                                                                                                                           "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 2790.32s\n",
      "BEST AUC = 67.45% AT EPOCH 44\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpcjax.quantum.MLP(20)\n",
    "qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### With `lightning.gpu` quantum device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-08-14 08:01:19.040233: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:\n",
      "INTERNAL: CpuCallback error: TypeError: RX(): incompatible function arguments. The following argument types are supported:\n",
      "    1. (self: pennylane_lightning_gpu.lightning_gpu_qubit_ops.LightningGPU_C128, arg0: List[int], arg1: bool, arg2: List[float]) -> None\n",
      "\n",
      "Invoked with: <pennylane_lightning_gpu.lightning_gpu_qubit_ops.LightningGPU_C128 object at 0x7fe950edb330>, [0], False, [array([0., 0., 0., 0.])]\n",
      "\n",
      "At:\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning_gpu/lightning_gpu.py(551): apply_cq\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning_gpu/lightning_gpu.py(572): apply\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(320): execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(603): batch_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/contextlib.py(81): inner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(210): fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(287): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/jax_jit_tuple.py(191): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(185): _flat_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(45): pure_callback_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(107): _callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py(1917): _wrapped_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(1349): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/profiler.py(314): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1143): _pjit_call_impl_python\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1187): call_impl_cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1203): _pjit_call_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(815): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1415): _pjit_batcher\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/batching.py(398): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(163): _python_pjit_helper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(250): cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py(188): call_wrapped\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/api.py(1240): vmap_f\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/pennylane_backend.py(39): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/mlp.py(22): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(2370): scope_fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(998): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(1034): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1746): init_with_output\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1845): init\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py(73): train_and_evaluate\n",
      "  /tmp/ipykernel_1766250/2728498697.py(3): <module>\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3508): run_code\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3448): run_ast_nodes\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3269): run_cell_async\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3064): _run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3009): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/zmqshell.py(546): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/ipkernel.py(422): do_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(740): execute_request\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(412): dispatch_shell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(505): process_one\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(516): dispatch_queue\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/events.py(80): _run\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(1922): _run_once\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(607): run_forever\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/tornado/platform/asyncio.py(195): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelapp.py(736): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/traitlets/config/application.py(1043): launch_instance\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel_launcher.py(17): <module>\n",
      "  <frozen runpy>(88): _run_code\n",
      "  <frozen runpy>(198): _run_module_as_main\n",
      "\n",
      "2023-08-14 08:01:19.040267: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2461] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: TypeError: RX(): incompatible function arguments. The following argument types are supported:\n",
      "    1. (self: pennylane_lightning_gpu.lightning_gpu_qubit_ops.LightningGPU_C128, arg0: List[int], arg1: bool, arg2: List[float]) -> None\n",
      "\n",
      "Invoked with: <pennylane_lightning_gpu.lightning_gpu_qubit_ops.LightningGPU_C128 object at 0x7fe950edb330>, [0], False, [array([0., 0., 0., 0.])]\n",
      "\n",
      "At:\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning_gpu/lightning_gpu.py(551): apply_cq\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning_gpu/lightning_gpu.py(572): apply\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(320): execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(603): batch_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/contextlib.py(81): inner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(210): fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(287): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/jax_jit_tuple.py(191): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(185): _flat_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(45): pure_callback_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(107): _callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py(1917): _wrapped_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(1349): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/profiler.py(314): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1143): _pjit_call_impl_python\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1187): call_impl_cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1203): _pjit_call_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(815): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1415): _pjit_batcher\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/batching.py(398): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(163): _python_pjit_helper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(250): cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py(188): call_wrapped\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/api.py(1240): vmap_f\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/pennylane_backend.py(39): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/mlp.py(22): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(2370): scope_fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(998): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(1034): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1746): init_with_output\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1845): init\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py(73): train_and_evaluate\n",
      "  /tmp/ipykernel_1766250/2728498697.py(3): <module>\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3508): run_code\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3448): run_ast_nodes\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3269): run_cell_async\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3064): _run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3009): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/zmqshell.py(546): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/ipkernel.py(422): do_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(740): execute_request\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(412): dispatch_shell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(505): process_one\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(516): dispatch_queue\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/events.py(80): _run\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(1922): _run_once\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(607): run_forever\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/tornado/platform/asyncio.py(195): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelapp.py(736): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/traitlets/config/application.py(1043): launch_instance\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel_launcher.py(17): <module>\n",
      "  <frozen runpy>(88): _run_code\n",
      "  <frozen runpy>(198): _run_module_as_main\n",
      "; current tracing scope: custom-call.24; current profiling annotation: XlaModule:#hlo_module=jit_circuit,program_id=1055#.\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    model = qpcjax.quantum.MLP(5, qdevice=\"lightning.gpu\")\n",
    "    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50)\n",
    "except jaxlib.xla_extension.XlaRuntimeError as e:\n",
    "    pass  # The error is already printed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See https://discuss.pennylane.ai/t/incompatible-function-arguments-error-on-lightning-qubit-with-jax/2900."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum neural network with TensorCircuit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/50: 100%|██████████| 100/100 [00:01<00:00, 51.23batch/s, Loss = 0.6696, AUC = 70.19%]                                                                                                                                           \n",
      "Epoch   2/50: 100%|██████████| 100/100 [00:00<00:00, 899.24batch/s, Loss = 0.6668, AUC = 69.50%]                                                                                                                                          \n",
      "Epoch   3/50: 100%|██████████| 100/100 [00:00<00:00, 907.86batch/s, Loss = 0.6587, AUC = 70.58%]                                                                                                                                          \n",
      "Epoch   4/50: 100%|██████████| 100/100 [00:00<00:00, 899.56batch/s, Loss = 0.6448, AUC = 80.03%]                                                                                                                                          \n",
      "Epoch   5/50: 100%|██████████| 100/100 [00:00<00:00, 899.19batch/s, Loss = 0.6339, AUC = 81.84%]                                                                                                                                          \n",
      "Epoch   6/50: 100%|██████████| 100/100 [00:00<00:00, 899.46batch/s, Loss = 0.6277, AUC = 80.24%]                                                                                                                                          \n",
      "Epoch   7/50: 100%|██████████| 100/100 [00:00<00:00, 902.46batch/s, Loss = 0.6150, AUC = 81.35%]                                                                                                                                          \n",
      "Epoch   8/50: 100%|██████████| 100/100 [00:00<00:00, 894.34batch/s, Loss = 0.6079, AUC = 80.40%]                                                                                                                                          \n",
      "Epoch   9/50: 100%|██████████| 100/100 [00:00<00:00, 902.83batch/s, Loss = 0.6024, AUC = 80.72%]                                                                                                                                          \n",
      "Epoch  10/50: 100%|██████████| 100/100 [00:00<00:00, 925.01batch/s, Loss = 0.5947, AUC = 81.82%]                                                                                                                                          \n",
      "Epoch  11/50: 100%|██████████| 100/100 [00:00<00:00, 894.18batch/s, Loss = 0.5876, AUC = 81.59%]                                                                                                                                          \n",
      "Epoch  12/50: 100%|██████████| 100/100 [00:00<00:00, 902.02batch/s, Loss = 0.5791, AUC = 82.12%]                                                                                                                                          \n",
      "Epoch  13/50: 100%|██████████| 100/100 [00:00<00:00, 895.55batch/s, Loss = 0.5767, AUC = 80.24%]                                                                                                                                          \n",
      "Epoch  14/50: 100%|██████████| 100/100 [00:00<00:00, 893.22batch/s, Loss = 0.5673, AUC = 80.88%]                                                                                                                                          \n",
      "Epoch  15/50: 100%|██████████| 100/100 [00:00<00:00, 895.28batch/s, Loss = 0.5629, AUC = 80.52%]                                                                                                                                          \n",
      "Epoch  16/50: 100%|██████████| 100/100 [00:00<00:00, 893.84batch/s, Loss = 0.5567, AUC = 81.21%]                                                                                                                                          \n",
      "Epoch  17/50: 100%|██████████| 100/100 [00:00<00:00, 898.72batch/s, Loss = 0.5509, AUC = 81.25%]                                                                                                                                          \n",
      "Epoch  18/50: 100%|██████████| 100/100 [00:00<00:00, 905.44batch/s, Loss = 0.5438, AUC = 81.53%]                                                                                                                                          \n",
      "Epoch  19/50: 100%|██████████| 100/100 [00:00<00:00, 908.34batch/s, Loss = 0.5381, AUC = 82.22%]                                                                                                                                          \n",
      "Epoch  20/50: 100%|██████████| 100/100 [00:00<00:00, 894.27batch/s, Loss = 0.5325, AUC = 81.33%]                                                                                                                                          \n",
      "Epoch  21/50: 100%|██████████| 100/100 [00:00<00:00, 828.95batch/s, Loss = 0.5282, AUC = 81.98%]                                                                                                                                          \n",
      "Epoch  22/50: 100%|██████████| 100/100 [00:00<00:00, 882.36batch/s, Loss = 0.5230, AUC = 81.53%]                                                                                                                                          \n",
      "Epoch  23/50: 100%|██████████| 100/100 [00:00<00:00, 886.10batch/s, Loss = 0.5185, AUC = 81.86%]                                                                                                                                          \n",
      "Epoch  24/50: 100%|██████████| 100/100 [00:00<00:00, 883.71batch/s, Loss = 0.5139, AUC = 82.31%]                                                                                                                                          \n",
      "Epoch  25/50: 100%|██████████| 100/100 [00:00<00:00, 891.44batch/s, Loss = 0.5088, AUC = 82.75%]                                                                                                                                          \n",
      "Epoch  26/50: 100%|██████████| 100/100 [00:00<00:00, 887.94batch/s, Loss = 0.5018, AUC = 82.69%]                                                                                                                                          \n",
      "Epoch  27/50: 100%|██████████| 100/100 [00:00<00:00, 902.50batch/s, Loss = 0.4975, AUC = 82.41%]                                                                                                                                          \n",
      "Epoch  28/50: 100%|██████████| 100/100 [00:00<00:00, 898.06batch/s, Loss = 0.4933, AUC = 84.15%]                                                                                                                                          \n",
      "Epoch  29/50: 100%|██████████| 100/100 [00:00<00:00, 887.96batch/s, Loss = 0.4888, AUC = 82.69%]                                                                                                                                          \n",
      "Epoch  30/50: 100%|██████████| 100/100 [00:00<00:00, 892.61batch/s, Loss = 0.4847, AUC = 83.18%]                                                                                                                                          \n",
      "Epoch  31/50: 100%|██████████| 100/100 [00:00<00:00, 875.64batch/s, Loss = 0.4789, AUC = 84.46%]                                                                                                                                          \n",
      "Epoch  32/50: 100%|██████████| 100/100 [00:00<00:00, 883.51batch/s, Loss = 0.4762, AUC = 83.18%]                                                                                                                                          \n",
      "Epoch  33/50: 100%|██████████| 100/100 [00:00<00:00, 892.86batch/s, Loss = 0.4737, AUC = 83.56%]                                                                                                                                          \n",
      "Epoch  34/50: 100%|██████████| 100/100 [00:00<00:00, 893.67batch/s, Loss = 0.4672, AUC = 83.91%]                                                                                                                                          \n",
      "Epoch  35/50: 100%|██████████| 100/100 [00:00<00:00, 895.80batch/s, Loss = 0.4650, AUC = 83.56%]                                                                                                                                          \n",
      "Epoch  36/50: 100%|██████████| 100/100 [00:00<00:00, 896.60batch/s, Loss = 0.4603, AUC = 83.66%]                                                                                                                                          \n",
      "Epoch  37/50: 100%|██████████| 100/100 [00:00<00:00, 890.45batch/s, Loss = 0.4550, AUC = 83.69%]                                                                                                                                          \n",
      "Epoch  38/50: 100%|██████████| 100/100 [00:00<00:00, 892.29batch/s, Loss = 0.4518, AUC = 83.56%]                                                                                                                                          \n",
      "Epoch  39/50: 100%|██████████| 100/100 [00:00<00:00, 894.02batch/s, Loss = 0.4473, AUC = 85.96%]                                                                                                                                          \n",
      "Epoch  40/50: 100%|██████████| 100/100 [00:00<00:00, 908.68batch/s, Loss = 0.4432, AUC = 85.84%]                                                                                                                                          \n",
      "Epoch  41/50: 100%|██████████| 100/100 [00:00<00:00, 924.94batch/s, Loss = 0.4418, AUC = 86.10%]                                                                                                                                          \n",
      "Epoch  42/50: 100%|██████████| 100/100 [00:00<00:00, 903.73batch/s, Loss = 0.4360, AUC = 86.40%]                                                                                                                                          \n",
      "Epoch  43/50: 100%|██████████| 100/100 [00:00<00:00, 890.34batch/s, Loss = 0.4291, AUC = 88.03%]                                                                                                                                          \n",
      "Epoch  44/50: 100%|██████████| 100/100 [00:00<00:00, 890.46batch/s, Loss = 0.4254, AUC = 89.12%]                                                                                                                                          \n",
      "Epoch  45/50: 100%|██████████| 100/100 [00:00<00:00, 909.18batch/s, Loss = 0.4192, AUC = 89.89%]                                                                                                                                          \n",
      "Epoch  46/50: 100%|██████████| 100/100 [00:00<00:00, 918.56batch/s, Loss = 0.4159, AUC = 89.00%]                                                                                                                                          \n",
      "Epoch  47/50: 100%|██████████| 100/100 [00:00<00:00, 905.84batch/s, Loss = 0.4137, AUC = 89.69%]                                                                                                                                          \n",
      "Epoch  48/50: 100%|██████████| 100/100 [00:00<00:00, 905.60batch/s, Loss = 0.4109, AUC = 88.43%]                                                                                                                                          \n",
      "Epoch  49/50: 100%|██████████| 100/100 [00:00<00:00, 915.27batch/s, Loss = 0.4069, AUC = 90.18%]                                                                                                                                          \n",
      "Epoch  50/50: 100%|██████████| 100/100 [00:00<00:00, 910.96batch/s, Loss = 0.4002, AUC = 89.85%]                                                                                                                                          "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 7.48s\n",
      "BEST AUC = 90.18% AT EPOCH 49\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpcjax.quantum.MLP(5, qml_backend=\"tensorcircuit\")\n",
    "qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/50: 100%|██████████| 100/100 [00:30<00:00,  3.29batch/s, Loss = 0.6885, AUC = 57.06%]                                                                                                                                           \n",
      "Epoch   2/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.6812, AUC = 71.96%]                                                                                                                                           \n",
      "Epoch   3/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.6701, AUC = 82.67%]                                                                                                                                           \n",
      "Epoch   4/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.6618, AUC = 85.96%]                                                                                                                                           \n",
      "Epoch   5/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.6535, AUC = 86.40%]                                                                                                                                           \n",
      "Epoch   6/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.6479, AUC = 85.06%]                                                                                                                                           \n",
      "Epoch   7/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.6401, AUC = 85.75%]                                                                                                                                           \n",
      "Epoch   8/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.6312, AUC = 85.96%]                                                                                                                                           \n",
      "Epoch   9/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.6206, AUC = 86.65%]                                                                                                                                           \n",
      "Epoch  10/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.6072, AUC = 85.94%]                                                                                                                                           \n",
      "Epoch  11/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5968, AUC = 86.81%]                                                                                                                                           \n",
      "Epoch  12/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5860, AUC = 87.13%]                                                                                                                                           \n",
      "Epoch  13/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5775, AUC = 87.34%]                                                                                                                                           \n",
      "Epoch  14/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5738, AUC = 85.84%]                                                                                                                                           \n",
      "Epoch  15/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5687, AUC = 84.19%]                                                                                                                                           \n",
      "Epoch  16/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5642, AUC = 84.74%]                                                                                                                                           \n",
      "Epoch  17/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5605, AUC = 84.42%]                                                                                                                                           \n",
      "Epoch  18/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5527, AUC = 84.62%]                                                                                                                                           \n",
      "Epoch  19/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5518, AUC = 83.85%]                                                                                                                                           \n",
      "Epoch  20/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5434, AUC = 84.29%]                                                                                                                                           \n",
      "Epoch  21/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5448, AUC = 83.32%]                                                                                                                                           \n",
      "Epoch  22/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5339, AUC = 84.66%]                                                                                                                                           \n",
      "Epoch  23/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5308, AUC = 84.86%]                                                                                                                                           \n",
      "Epoch  24/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5279, AUC = 84.74%]                                                                                                                                           \n",
      "Epoch  25/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5257, AUC = 84.90%]                                                                                                                                           \n",
      "Epoch  26/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5217, AUC = 85.39%]                                                                                                                                           \n",
      "Epoch  27/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5197, AUC = 85.47%]                                                                                                                                           \n",
      "Epoch  28/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5186, AUC = 86.77%]                                                                                                                                           \n",
      "Epoch  29/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5079, AUC = 86.81%]                                                                                                                                           \n",
      "Epoch  30/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5090, AUC = 86.40%]                                                                                                                                           \n",
      "Epoch  31/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.5005, AUC = 87.01%]                                                                                                                                           \n",
      "Epoch  32/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4997, AUC = 86.97%]                                                                                                                                           \n",
      "Epoch  33/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4932, AUC = 87.46%]                                                                                                                                           \n",
      "Epoch  34/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4910, AUC = 87.18%]                                                                                                                                           \n",
      "Epoch  35/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4886, AUC = 87.42%]                                                                                                                                           \n",
      "Epoch  36/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4761, AUC = 88.56%]                                                                                                                                           \n",
      "Epoch  37/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4774, AUC = 87.95%]                                                                                                                                           \n",
      "Epoch  38/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4639, AUC = 88.76%]                                                                                                                                           \n",
      "Epoch  39/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4687, AUC = 87.95%]                                                                                                                                           \n",
      "Epoch  40/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4677, AUC = 87.46%]                                                                                                                                           \n",
      "Epoch  41/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4505, AUC = 89.69%]                                                                                                                                           \n",
      "Epoch  42/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4540, AUC = 88.76%]                                                                                                                                           \n",
      "Epoch  43/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4482, AUC = 90.50%]                                                                                                                                           \n",
      "Epoch  44/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4522, AUC = 89.20%]                                                                                                                                           \n",
      "Epoch  45/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4447, AUC = 89.98%]                                                                                                                                           \n",
      "Epoch  46/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4444, AUC = 89.33%]                                                                                                                                           \n",
      "Epoch  47/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4558, AUC = 88.68%]                                                                                                                                           \n",
      "Epoch  48/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4517, AUC = 88.51%]                                                                                                                                           \n",
      "Epoch  49/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4413, AUC = 88.76%]                                                                                                                                           \n",
      "Epoch  50/50: 100%|██████████| 100/100 [00:24<00:00,  4.05batch/s, Loss = 0.4359, AUC = 89.73%]                                                                                                                                           "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 1239.42s\n",
      "BEST AUC = 90.50% AT EPOCH 43\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpcjax.quantum.MLP(20, qml_backend=\"tensorcircuit\")\n",
    "qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=2, num_epochs=50)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gsoc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
