{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Comparison with MNIST dataset and Vision Transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: cuda\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import quantum_transformers.qmlperfcomp.torch_backend as qpctorch\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using device: {device}\")\n",
    "train_dataloader, valid_dataloader = qpctorch.data.get_mnist_dataloaders(batch_size=64, num_workers=4, pin_memory=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classical"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/10: 100%|██████████| 938/938 [00:12<00:00, 73.64batch/s, Loss = 1.3147, AUC = 92.17%]                                                                                                                                           \n",
      "Epoch   2/10: 100%|██████████| 938/938 [00:08<00:00, 104.42batch/s, Loss = 0.9392, AUC = 95.47%]                                                                                                                                          \n",
      "Epoch   3/10: 100%|██████████| 938/938 [00:09<00:00, 100.25batch/s, Loss = 0.7299, AUC = 97.09%]                                                                                                                                          \n",
      "Epoch   4/10: 100%|██████████| 938/938 [00:08<00:00, 104.74batch/s, Loss = 0.6336, AUC = 97.56%]                                                                                                                                          \n",
      "Epoch   5/10: 100%|██████████| 938/938 [00:08<00:00, 110.71batch/s, Loss = 0.5654, AUC = 97.99%]                                                                                                                                          \n",
      "Epoch   6/10: 100%|██████████| 938/938 [00:08<00:00, 106.40batch/s, Loss = 0.5303, AUC = 98.17%]                                                                                                                                          \n",
      "Epoch   7/10: 100%|██████████| 938/938 [00:08<00:00, 105.37batch/s, Loss = 0.5146, AUC = 98.27%]                                                                                                                                          \n",
      "Epoch   8/10: 100%|██████████| 938/938 [00:09<00:00, 101.45batch/s, Loss = 0.4847, AUC = 98.44%]                                                                                                                                          \n",
      "Epoch   9/10: 100%|██████████| 938/938 [00:08<00:00, 107.93batch/s, Loss = 0.4677, AUC = 98.53%]                                                                                                                                          \n",
      "Epoch  10/10: 100%|██████████| 938/938 [00:09<00:00, 102.43batch/s, Loss = 0.4590, AUC = 98.57%]                                                                                                                                          "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 93.33s\n",
      "BEST AUC = 98.57% AT EPOCH 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpctorch.classical.VisionTransformer(img_size=28, num_channels=1, num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3)\n",
    "qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum with PennyLane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/10: 100%|██████████| 938/938 [06:39<00:00,  2.35batch/s, Loss = 1.7853, AUC = 85.27%]                                                                                                                                           \n",
      "Epoch   2/10: 100%|██████████| 938/938 [06:45<00:00,  2.32batch/s, Loss = 1.4544, AUC = 90.11%]                                                                                                                                           \n",
      "Epoch   3/10: 100%|██████████| 938/938 [06:30<00:00,  2.40batch/s, Loss = 1.2237, AUC = 93.31%]                                                                                                                                           \n",
      "Epoch   4/10: 100%|██████████| 938/938 [06:37<00:00,  2.36batch/s, Loss = 1.0775, AUC = 94.67%]                                                                                                                                           \n",
      "Epoch   5/10: 100%|██████████| 938/938 [06:35<00:00,  2.37batch/s, Loss = 1.0005, AUC = 95.30%]                                                                                                                                           \n",
      "Epoch   6/10: 100%|██████████| 938/938 [06:37<00:00,  2.36batch/s, Loss = 0.9508, AUC = 95.51%]                                                                                                                                           \n",
      "Epoch   7/10: 100%|██████████| 938/938 [06:41<00:00,  2.33batch/s, Loss = 0.9105, AUC = 95.68%]                                                                                                                                           \n",
      "Epoch   8/10: 100%|██████████| 938/938 [06:41<00:00,  2.34batch/s, Loss = 0.8817, AUC = 95.81%]                                                                                                                                           \n",
      "Epoch   9/10: 100%|██████████| 938/938 [06:37<00:00,  2.36batch/s, Loss = 0.8573, AUC = 95.98%]                                                                                                                                           \n",
      "Epoch  10/10: 100%|██████████| 938/938 [06:36<00:00,  2.36batch/s, Loss = 0.8303, AUC = 96.13%]                                                                                                                                           "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 3982.81s\n",
      "BEST AUC = 96.13% AT EPOCH 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpctorch.quantum.VisionTransformer(img_size=28, num_channels=1, num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3)\n",
    "qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum with PennyLane with Lightning-GPU device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/10:   0%|          | 0/938 [00:00<?, ?batch/s]                                                                                                                                                                                  "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/10:   0%|          | 2/938 [08:41<67:48:54, 260.83s/batch]                                                                                                                                                                      "
     ]
    }
   ],
   "source": [
    "model = qpctorch.quantum.VisionTransformer(img_size=28, num_channels=1, num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice=\"lightning.gpu\")\n",
    "qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The execution is very slow, so I stopped it."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum with TensorCircuit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-08-14 06:25:53.883071: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
      "Please first ``pip install -U cirq`` to enable related functionality in translation module\n",
      "Epoch   1/10: 100%|██████████| 938/938 [01:48<00:00,  8.63batch/s, Loss = 1.8936, AUC = 85.94%]                                                                                                                                           \n",
      "Epoch   2/10: 100%|██████████| 938/938 [00:43<00:00, 21.46batch/s, Loss = 1.5321, AUC = 87.37%]                                                                                                                                           \n",
      "Epoch   3/10: 100%|██████████| 938/938 [00:44<00:00, 21.26batch/s, Loss = 1.3957, AUC = 88.96%]                                                                                                                                           \n",
      "Epoch   4/10: 100%|██████████| 938/938 [00:44<00:00, 21.11batch/s, Loss = 1.3201, AUC = 90.09%]                                                                                                                                           \n",
      "Epoch   5/10: 100%|██████████| 938/938 [00:44<00:00, 21.01batch/s, Loss = 1.2480, AUC = 91.38%]                                                                                                                                           \n",
      "Epoch   6/10: 100%|██████████| 938/938 [00:44<00:00, 21.03batch/s, Loss = 1.1897, AUC = 91.92%]                                                                                                                                           \n",
      "Epoch   7/10: 100%|██████████| 938/938 [00:44<00:00, 21.26batch/s, Loss = 1.1504, AUC = 92.19%]                                                                                                                                           \n",
      "Epoch   8/10: 100%|██████████| 938/938 [00:43<00:00, 21.66batch/s, Loss = 1.1145, AUC = 92.58%]                                                                                                                                           \n",
      "Epoch   9/10: 100%|██████████| 938/938 [00:42<00:00, 22.12batch/s, Loss = 1.0758, AUC = 92.98%]                                                                                                                                           \n",
      "Epoch  10/10: 100%|██████████| 938/938 [00:44<00:00, 21.07batch/s, Loss = 1.0440, AUC = 93.51%]                                                                                                                                           "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 504.57s\n",
      "BEST AUC = 93.51% AT EPOCH 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpctorch.quantum.VisionTransformer(img_size=28, num_channels=1, num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qml_backend=\"tensorcircuit\")\n",
    "qpctorch.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## JAX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import traceback\n",
    "import os\n",
    "os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'  # See https://github.com/google/jax/issues/12461#issuecomment-1256266598\n",
    "import jaxlib\n",
    "from jax.config import config\n",
    "config.update(\"jax_enable_x64\", True)\n",
    "import catalyst\n",
    "import quantum_transformers.qmlperfcomp.jax_backend as qpcjax\n",
    "train_dataloader, valid_dataloader = qpcjax.data.get_mnist_dataloaders(batch_size=64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classical"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/10: 100%|██████████| 937/937 [00:12<00:00, 74.94batch/s, Loss = 1.7705, AUC = 80.97%]                                                                                                                                           \n",
      "Epoch   2/10: 100%|██████████| 937/937 [00:05<00:00, 175.54batch/s, Loss = 1.4029, AUC = 89.91%]                                                                                                                                          \n",
      "Epoch   3/10: 100%|██████████| 937/937 [00:05<00:00, 176.86batch/s, Loss = 1.1311, AUC = 93.58%]                                                                                                                                          \n",
      "Epoch   4/10: 100%|██████████| 937/937 [00:05<00:00, 172.44batch/s, Loss = 0.9299, AUC = 94.91%]                                                                                                                                          \n",
      "Epoch   5/10: 100%|██████████| 937/937 [00:05<00:00, 174.38batch/s, Loss = 0.8137, AUC = 95.90%]                                                                                                                                          \n",
      "Epoch   6/10: 100%|██████████| 937/937 [00:05<00:00, 176.41batch/s, Loss = 0.7418, AUC = 96.52%]                                                                                                                                          \n",
      "Epoch   7/10: 100%|██████████| 937/937 [00:05<00:00, 175.02batch/s, Loss = 0.6925, AUC = 96.96%]                                                                                                                                          \n",
      "Epoch   8/10: 100%|██████████| 937/937 [00:05<00:00, 179.89batch/s, Loss = 0.6509, AUC = 97.31%]                                                                                                                                          \n",
      "Epoch   9/10: 100%|██████████| 937/937 [00:05<00:00, 181.20batch/s, Loss = 0.6116, AUC = 97.62%]                                                                                                                                          \n",
      "Epoch  10/10: 100%|██████████| 937/937 [00:05<00:00, 170.18batch/s, Loss = 0.5837, AUC = 97.83%]                                                                                                                                          "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 60.51s\n",
      "BEST AUC = 97.83% AT EPOCH 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpcjax.classical.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3)\n",
    "qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum with PennyLane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/10: 100%|██████████| 937/937 [01:14<00:00, 12.65batch/s, Loss = 2.3199, AUC = 50.00%]                                                                                                                                           \n",
      "Epoch   2/10: 100%|██████████| 937/937 [00:48<00:00, 19.31batch/s, Loss = 2.3050, AUC = 50.10%]                                                                                                                                           \n",
      "Epoch   3/10: 100%|██████████| 937/937 [00:48<00:00, 19.35batch/s, Loss = 2.3025, AUC = 50.01%]                                                                                                                                           \n",
      "Epoch   4/10: 100%|██████████| 937/937 [00:48<00:00, 19.32batch/s, Loss = 2.3020, AUC = 50.07%]                                                                                                                                           \n",
      "Epoch   5/10: 100%|██████████| 937/937 [00:48<00:00, 19.22batch/s, Loss = 2.3013, AUC = 50.18%]                                                                                                                                           \n",
      "Epoch   6/10: 100%|██████████| 937/937 [00:48<00:00, 19.30batch/s, Loss = 2.3012, AUC = 50.19%]                                                                                                                                           \n",
      "Epoch   7/10: 100%|██████████| 937/937 [00:48<00:00, 19.42batch/s, Loss = 2.3009, AUC = 50.24%]                                                                                                                                           \n",
      "Epoch   8/10: 100%|██████████| 937/937 [00:48<00:00, 19.41batch/s, Loss = 2.3012, AUC = 50.09%]                                                                                                                                           \n",
      "Epoch   9/10: 100%|██████████| 937/937 [00:48<00:00, 19.31batch/s, Loss = 2.3009, AUC = 50.40%]                                                                                                                                           \n",
      "Epoch  10/10: 100%|██████████| 937/937 [00:48<00:00, 19.21batch/s, Loss = 2.3011, AUC = 50.38%]                                                                                                                                           "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 510.67s\n",
      "BEST AUC = 50.40% AT EPOCH 9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3)\n",
    "qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum with PennyLane with Lightning-GPU device\n",
    "\n",
    "Not working. See: https://discuss.pennylane.ai/t/incompatible-function-arguments-error-on-lightning-qubit-with-jax/2900."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"/tmp/ipykernel_443038/3515301543.py\", line 3, in <module>\n",
      "    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py\", line 73, in train_and_evaluate\n",
      "    variables = model.init(params_key, x, train=False)\n",
      "                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 1845, in init\n",
      "    _, v_out = self.init_with_output(\n",
      "               ^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 1746, in init_with_output\n",
      "    return init_with_output(\n",
      "           ^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py\", line 1034, in wrapper\n",
      "    return apply(fn, mutable=mutable, flags=init_flags)(\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py\", line 998, in wrapper\n",
      "    y = fn(root, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 2370, in scope_fn\n",
      "    return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 153, in __call__\n",
      "    x = TransformerBlock(\n",
      "        ^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 93, in __call__\n",
      "    attn_output = MultiHeadSelfAttention(embed_dim=self.hidden_size, num_heads=self.num_heads,\n",
      "                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 44, in __call__\n",
      "    q, k, v = [\n",
      "              ^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 45, in <listcomp>\n",
      "    proj(x).reshape(batch_size, seq_len, self.num_heads, head_dim).swapaxes(1, 2)\n",
      "    ^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/pennylane_backend.py\", line 39, in __call__\n",
      "    x = jax.vmap(self.circuit, in_axes=(0, None))(x, weights)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/api.py\", line 1240, in vmap_f\n",
      "    out_flat = batching.batch(\n",
      "               ^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py\", line 188, in call_wrapped\n",
      "    ans = self.f(*args, **dict(self.params, **kwargs))\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py\", line 250, in cache_miss\n",
      "    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(\n",
      "                                                 ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py\", line 163, in _python_pjit_helper\n",
      "    out_flat = pjit_p.bind(*args_flat, **params)\n",
      "               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 2677, in bind\n",
      "    return self.bind_with_trace(top_trace, args, params)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 383, in bind_with_trace\n",
      "    out = trace.process_primitive(self, map(trace.full_raise, args), params)\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/batching.py\", line 398, in process_primitive\n",
      "    val_out, dim_out = batched_primitive(vals_in, dims_in, **params)\n",
      "                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py\", line 1415, in _pjit_batcher\n",
      "    vals_out = pjit_p.bind(\n",
      "               ^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 2677, in bind\n",
      "    return self.bind_with_trace(top_trace, args, params)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 383, in bind_with_trace\n",
      "    out = trace.process_primitive(self, map(trace.full_raise, args), params)\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 815, in process_primitive\n",
      "    return primitive.impl(*tracers, **params)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py\", line 1203, in _pjit_call_impl\n",
      "    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py\", line 1187, in call_impl_cache_miss\n",
      "    out_flat, compiled = _pjit_call_impl_python(\n",
      "                         ^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py\", line 1143, in _pjit_call_impl_python\n",
      "    return compiled.unsafe_call(*args), compiled\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/profiler.py\", line 314, in wrapper\n",
      "    return func(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py\", line 1349, in __call__\n",
      "    results = self.xla_executable.execute_sharded(input_bufs)\n",
      "              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: TypeError: RX(): incompatible function arguments. The following argument types are supported:\n",
      "    1. (self: pennylane_lightning_gpu.lightning_gpu_qubit_ops.LightningGPU_C128, arg0: List[int], arg1: bool, arg2: List[float]) -> None\n",
      "\n",
      "Invoked with: <pennylane_lightning_gpu.lightning_gpu_qubit_ops.LightningGPU_C128 object at 0x7fcd033c9270>, [0], False, [array([-1.99065936, -1.39067658, -1.40838223, -1.28091254, -1.236017  ,\n",
      "       -1.99065936, -0.79695004, -0.97037919, -0.83079853, -0.9329554 ,\n",
      "       -1.99065936, -1.00111164, -1.10162356, -1.36933823, -0.90581727,\n",
      "       -1.99065936, -0.77338261, -1.07736925, -1.27714229, -1.22361589,\n",
      "       -1.99065936, -1.41910524, -1.18701463, -0.96778634, -1.39457697,\n",
      "       -1.99065936, -1.48105851, -1.51591239, -0.46449003, -0.78824321,\n",
      "       -1.99065936, -0.71558203, -1.24144874, -1.09839354, -1.3799354 ,\n",
      "       -1.99065936, -0.95286273, -1.00063692, -1.26076869, -1.05675874,\n",
      "       -1.99065936, -1.26840019, -0.94005721, -1.38595578, -0.61030866,\n",
      "       -1.99065936, -0.71439892, -0.95327059, -1.08793455, -1.3481988 ,\n",
      "       -1.99065936, -0.16415302, -1.13727631, -1.14171319, -1.3928046 ,\n",
      "       -1.99065936, -1.33280401, -1.27190475, -0.27698251, -0.59783505,\n",
      "       -1.99065936, -1.07410815, -1.08089054, -1.12720024, -1.4855256 ,\n",
      "       -1.99065936, -1.62930775, -0.61359425, -0.72750374, -0.81329352,\n",
      "       -1.99065936, -1.70892522, -1.39563787, -1.09084622, -0.99745661,\n",
      "       -1.99065936, -1.39579129, -0.93821183, -1.11267586, -1.3618317 ,\n",
      "       -1.99065936, -1.19673384, -1.36525872, -0.85516461, -1.0932371 ,\n",
      "       -1.99065936, -1.19721424, -1.04066598, -0.64450208, -1.55273813,\n",
      "       -1.99065936, -0.96791736, -1.80360259, -1.30149484, -0.2839262 ,\n",
      "       -1.99065936, -1.46805732, -1.39138411, -1.27190827, -1.23702557,\n",
      "       -1.99065936, -0.08671108, -1.62621405, -1.57943044, -1.04211977,\n",
      "       -1.99065936, -1.15063848, -0.56190361, -0.70644403, -1.11980448,\n",
      "       -1.99065936, -1.38225553, -0.725611  , -1.45496257, -0.76781471,\n",
      "       -1.99065936, -1.25031726, -0.74027637, -1.02105564, -1.20400562,\n",
      "       -1.99065936, -1.59105694, -1.61286693, -1.15849376, -1.42042045,\n",
      "       -1.99065936, -0.74705437, -1.07704607, -1.05218371, -0.8873684 ,\n",
      "       -1.99065936, -0.73331575, -0.69374305, -1.12760715, -1.54806115,\n",
      "       -1.99065936, -1.2807564 , -1.3624274 , -1.00578041, -0.87068549,\n",
      "       -1.99065936, -1.19865234, -0.8911431 , -0.77395507, -0.94454815,\n",
      "       -1.99065936, -0.69528866, -0.91436118, -1.2872507 , -1.58517108,\n",
      "       -1.99065936, -1.36052242, -1.39846098, -1.23202642, -1.22379228,\n",
      "       -1.99065936, -0.69681703, -1.39031723, -1.2262722 , -0.76031492,\n",
      "       -1.99065936, -1.35668371, -0.63182182, -0.6709817 , -1.35896601,\n",
      "       -1.99065936, -1.28068756, -1.55500923, -1.16366544, -1.91953158,\n",
      "       -1.99065936, -1.46342654, -0.30489002, -1.38779702, -0.98844208,\n",
      "       -1.99065936, -0.18323726, -1.19045485, -1.05977132, -1.09098738,\n",
      "       -1.99065936, -0.68175041, -1.21468183, -0.13046826, -0.42463526,\n",
      "       -1.99065936, -1.69748725, -0.80771108, -1.10917939, -0.76152976,\n",
      "       -1.99065936, -1.46977833, -1.08981589, -1.06376841, -1.09351361,\n",
      "       -1.99065936, -1.39495866, -1.18848993, -0.82710038, -1.19726168,\n",
      "       -1.99065936,  0.13395401, -0.06145497, -1.45255525, -1.30299591,\n",
      "       -1.99065936, -0.45357115, -1.71149023, -0.6245115 , -0.91892305,\n",
      "       -1.99065936, -1.05042194, -1.43260254, -1.44138882, -0.97532864,\n",
      "       -1.99065936, -1.37557435, -1.54062235, -0.11966081, -1.32451361,\n",
      "       -1.99065936, -0.99472941, -0.97100577, -0.98086822, -1.27070776,\n",
      "       -1.99065936, -1.1962407 , -1.14998185, -0.63131049, -1.34136496,\n",
      "       -1.99065936, -0.99383569, -1.28917555, -1.52504055, -0.90961462,\n",
      "       -1.99065936, -0.66839026, -0.63105238, -1.54710814, -1.17531734,\n",
      "       -1.99065936, -1.1795461 , -1.44927538, -1.53803568, -1.58135178,\n",
      "       -1.99065936, -1.12465192, -1.00732128, -0.86009437, -1.34203859,\n",
      "       -1.99065936, -1.20963163, -0.9615587 , -0.41452135, -1.46465335,\n",
      "       -1.99065936, -1.55746538, -1.56816282, -1.2954955 , -1.1642513 ,\n",
      "       -1.99065936, -1.2685908 , -1.17232309, -0.81652236, -0.87647916,\n",
      "       -1.99065936, -0.80280256, -1.08518401, -1.42118504, -0.99281612,\n",
      "       -1.99065936, -1.42223473, -1.31208892, -1.16823907, -1.08258662,\n",
      "       -1.99065936, -0.1048175 , -1.7909315 , -1.39731142, -1.61698453,\n",
      "       -1.99065936, -0.82725419, -0.9678525 , -1.48618196, -1.66092401,\n",
      "       -1.99065936, -0.73012181, -1.12669918, -1.34689929, -0.67030933,\n",
      "       -1.99065936, -0.51910207, -0.9698971 , -1.58338126, -0.91350581,\n",
      "       -1.99065936, -1.18769819, -1.40629907, -0.93678502, -0.92327303,\n",
      "       -1.99065936, -1.47733292, -1.22804599, -1.21117743, -1.42201438,\n",
      "       -1.99065936, -0.95033965, -1.04628571, -0.39094974, -1.70580013,\n",
      "       -1.99065936, -1.27439642, -0.9638992 , -0.97455712, -0.56043667,\n",
      "       -1.99065936, -1.36208643, -0.82787886, -0.40731838, -1.49553269])]\n",
      "\n",
      "At:\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning_gpu/lightning_gpu.py(551): apply_cq\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning_gpu/lightning_gpu.py(572): apply\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(320): execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(603): batch_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/contextlib.py(81): inner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(210): fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(287): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/jax_jit_tuple.py(191): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(185): _flat_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(45): pure_callback_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(107): _callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py(1917): _wrapped_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(1349): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/profiler.py(314): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1143): _pjit_call_impl_python\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1187): call_impl_cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1203): _pjit_call_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(815): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1415): _pjit_batcher\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/batching.py(398): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(163): _python_pjit_helper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(250): cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py(188): call_wrapped\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/api.py(1240): vmap_f\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/pennylane_backend.py(39): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(45): <listcomp>\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(44): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(93): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(153): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(2370): scope_fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(998): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(1034): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1746): init_with_output\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1845): init\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py(73): train_and_evaluate\n",
      "  /tmp/ipykernel_443038/3515301543.py(3): <module>\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3508): run_code\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3448): run_ast_nodes\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3269): run_cell_async\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3064): _run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3009): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/zmqshell.py(546): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/ipkernel.py(422): do_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(740): execute_request\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(412): dispatch_shell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(505): process_one\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(516): dispatch_queue\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/events.py(80): _run\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(1922): _run_once\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(607): run_forever\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/tornado/platform/asyncio.py(195): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelapp.py(736): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/traitlets/config/application.py(1043): launch_instance\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel_launcher.py(17): <module>\n",
      "  <frozen runpy>(88): _run_code\n",
      "  <frozen runpy>(198): _run_module_as_main\n",
      "; current tracing scope: custom-call.28; current profiling annotation: XlaModule:#hlo_module=jit_circuit,program_id=4813#.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-08-14 06:44:18.259223: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:\n",
      "INTERNAL: CpuCallback error: TypeError: RX(): incompatible function arguments. The following argument types are supported:\n",
      "    1. (self: pennylane_lightning_gpu.lightning_gpu_qubit_ops.LightningGPU_C128, arg0: List[int], arg1: bool, arg2: List[float]) -> None\n",
      "\n",
      "Invoked with: <pennylane_lightning_gpu.lightning_gpu_qubit_ops.LightningGPU_C128 object at 0x7fcd033c9270>, [0], False, [array([-1.99065936, -1.39067658, -1.40838223, -1.28091254, -1.236017  ,\n",
      "       -1.99065936, -0.79695004, -0.97037919, -0.83079853, -0.9329554 ,\n",
      "       -1.99065936, -1.00111164, -1.10162356, -1.36933823, -0.90581727,\n",
      "       -1.99065936, -0.77338261, -1.07736925, -1.27714229, -1.22361589,\n",
      "       -1.99065936, -1.41910524, -1.18701463, -0.96778634, -1.39457697,\n",
      "       -1.99065936, -1.48105851, -1.51591239, -0.46449003, -0.78824321,\n",
      "       -1.99065936, -0.71558203, -1.24144874, -1.09839354, -1.3799354 ,\n",
      "       -1.99065936, -0.95286273, -1.00063692, -1.26076869, -1.05675874,\n",
      "       -1.99065936, -1.26840019, -0.94005721, -1.38595578, -0.61030866,\n",
      "       -1.99065936, -0.71439892, -0.95327059, -1.08793455, -1.3481988 ,\n",
      "       -1.99065936, -0.16415302, -1.13727631, -1.14171319, -1.3928046 ,\n",
      "       -1.99065936, -1.33280401, -1.27190475, -0.27698251, -0.59783505,\n",
      "       -1.99065936, -1.07410815, -1.08089054, -1.12720024, -1.4855256 ,\n",
      "       -1.99065936, -1.62930775, -0.61359425, -0.72750374, -0.81329352,\n",
      "       -1.99065936, -1.70892522, -1.39563787, -1.09084622, -0.99745661,\n",
      "       -1.99065936, -1.39579129, -0.93821183, -1.11267586, -1.3618317 ,\n",
      "       -1.99065936, -1.19673384, -1.36525872, -0.85516461, -1.0932371 ,\n",
      "       -1.99065936, -1.19721424, -1.04066598, -0.64450208, -1.55273813,\n",
      "       -1.99065936, -0.96791736, -1.80360259, -1.30149484, -0.2839262 ,\n",
      "       -1.99065936, -1.46805732, -1.39138411, -1.27190827, -1.23702557,\n",
      "       -1.99065936, -0.08671108, -1.62621405, -1.57943044, -1.04211977,\n",
      "       -1.99065936, -1.15063848, -0.56190361, -0.70644403, -1.11980448,\n",
      "       -1.99065936, -1.38225553, -0.725611  , -1.45496257, -0.76781471,\n",
      "       -1.99065936, -1.25031726, -0.74027637, -1.02105564, -1.20400562,\n",
      "       -1.99065936, -1.59105694, -1.61286693, -1.15849376, -1.42042045,\n",
      "       -1.99065936, -0.74705437, -1.07704607, -1.05218371, -0.8873684 ,\n",
      "       -1.99065936, -0.73331575, -0.69374305, -1.12760715, -1.54806115,\n",
      "       -1.99065936, -1.2807564 , -1.3624274 , -1.00578041, -0.87068549,\n",
      "       -1.99065936, -1.19865234, -0.8911431 , -0.77395507, -0.94454815,\n",
      "       -1.99065936, -0.69528866, -0.91436118, -1.2872507 , -1.58517108,\n",
      "       -1.99065936, -1.36052242, -1.39846098, -1.23202642, -1.22379228,\n",
      "       -1.99065936, -0.69681703, -1.39031723, -1.2262722 , -0.76031492,\n",
      "       -1.99065936, -1.35668371, -0.63182182, -0.6709817 , -1.35896601,\n",
      "       -1.99065936, -1.28068756, -1.55500923, -1.16366544, -1.91953158,\n",
      "       -1.99065936, -1.46342654, -0.30489002, -1.38779702, -0.98844208,\n",
      "       -1.99065936, -0.18323726, -1.19045485, -1.05977132, -1.09098738,\n",
      "       -1.99065936, -0.68175041, -1.21468183, -0.13046826, -0.42463526,\n",
      "       -1.99065936, -1.69748725, -0.80771108, -1.10917939, -0.76152976,\n",
      "       -1.99065936, -1.46977833, -1.08981589, -1.06376841, -1.09351361,\n",
      "       -1.99065936, -1.39495866, -1.18848993, -0.82710038, -1.19726168,\n",
      "       -1.99065936,  0.13395401, -0.06145497, -1.45255525, -1.30299591,\n",
      "       -1.99065936, -0.45357115, -1.71149023, -0.6245115 , -0.91892305,\n",
      "       -1.99065936, -1.05042194, -1.43260254, -1.44138882, -0.97532864,\n",
      "       -1.99065936, -1.37557435, -1.54062235, -0.11966081, -1.32451361,\n",
      "       -1.99065936, -0.99472941, -0.97100577, -0.98086822, -1.27070776,\n",
      "       -1.99065936, -1.1962407 , -1.14998185, -0.63131049, -1.34136496,\n",
      "       -1.99065936, -0.99383569, -1.28917555, -1.52504055, -0.90961462,\n",
      "       -1.99065936, -0.66839026, -0.63105238, -1.54710814, -1.17531734,\n",
      "       -1.99065936, -1.1795461 , -1.44927538, -1.53803568, -1.58135178,\n",
      "       -1.99065936, -1.12465192, -1.00732128, -0.86009437, -1.34203859,\n",
      "       -1.99065936, -1.20963163, -0.9615587 , -0.41452135, -1.46465335,\n",
      "       -1.99065936, -1.55746538, -1.56816282, -1.2954955 , -1.1642513 ,\n",
      "       -1.99065936, -1.2685908 , -1.17232309, -0.81652236, -0.87647916,\n",
      "       -1.99065936, -0.80280256, -1.08518401, -1.42118504, -0.99281612,\n",
      "       -1.99065936, -1.42223473, -1.31208892, -1.16823907, -1.08258662,\n",
      "       -1.99065936, -0.1048175 , -1.7909315 , -1.39731142, -1.61698453,\n",
      "       -1.99065936, -0.82725419, -0.9678525 , -1.48618196, -1.66092401,\n",
      "       -1.99065936, -0.73012181, -1.12669918, -1.34689929, -0.67030933,\n",
      "       -1.99065936, -0.51910207, -0.9698971 , -1.58338126, -0.91350581,\n",
      "       -1.99065936, -1.18769819, -1.40629907, -0.93678502, -0.92327303,\n",
      "       -1.99065936, -1.47733292, -1.22804599, -1.21117743, -1.42201438,\n",
      "       -1.99065936, -0.95033965, -1.04628571, -0.39094974, -1.70580013,\n",
      "       -1.99065936, -1.27439642, -0.9638992 , -0.97455712, -0.56043667,\n",
      "       -1.99065936, -1.36208643, -0.82787886, -0.40731838, -1.49553269])]\n",
      "\n",
      "At:\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning_gpu/lightning_gpu.py(551): apply_cq\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning_gpu/lightning_gpu.py(572): apply\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(320): execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(603): batch_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/contextlib.py(81): inner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(210): fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(287): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/jax_jit_tuple.py(191): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(185): _flat_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(45): pure_callback_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(107): _callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py(1917): _wrapped_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(1349): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/profiler.py(314): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1143): _pjit_call_impl_python\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1187): call_impl_cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1203): _pjit_call_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(815): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1415): _pjit_batcher\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/batching.py(398): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(163): _python_pjit_helper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(250): cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py(188): call_wrapped\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/api.py(1240): vmap_f\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/pennylane_backend.py(39): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(45): <listcomp>\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(44): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(93): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(153): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(2370): scope_fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(998): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(1034): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1746): init_with_output\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1845): init\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py(73): train_and_evaluate\n",
      "  /tmp/ipykernel_443038/3515301543.py(3): <module>\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3508): run_code\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3448): run_ast_nodes\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3269): run_cell_async\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3064): _run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3009): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/zmqshell.py(546): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/ipkernel.py(422): do_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(740): execute_request\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(412): dispatch_shell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(505): process_one\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(516): dispatch_queue\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/events.py(80): _run\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(1922): _run_once\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(607): run_forever\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/tornado/platform/asyncio.py(195): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelapp.py(736): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/traitlets/config/application.py(1043): launch_instance\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel_launcher.py(17): <module>\n",
      "  <frozen runpy>(88): _run_code\n",
      "  <frozen runpy>(198): _run_module_as_main\n",
      "\n",
      "2023-08-14 06:44:18.259263: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2461] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: TypeError: RX(): incompatible function arguments. The following argument types are supported:\n",
      "    1. (self: pennylane_lightning_gpu.lightning_gpu_qubit_ops.LightningGPU_C128, arg0: List[int], arg1: bool, arg2: List[float]) -> None\n",
      "\n",
      "Invoked with: <pennylane_lightning_gpu.lightning_gpu_qubit_ops.LightningGPU_C128 object at 0x7fcd033c9270>, [0], False, [array([-1.99065936, -1.39067658, -1.40838223, -1.28091254, -1.236017  ,\n",
      "       -1.99065936, -0.79695004, -0.97037919, -0.83079853, -0.9329554 ,\n",
      "       -1.99065936, -1.00111164, -1.10162356, -1.36933823, -0.90581727,\n",
      "       -1.99065936, -0.77338261, -1.07736925, -1.27714229, -1.22361589,\n",
      "       -1.99065936, -1.41910524, -1.18701463, -0.96778634, -1.39457697,\n",
      "       -1.99065936, -1.48105851, -1.51591239, -0.46449003, -0.78824321,\n",
      "       -1.99065936, -0.71558203, -1.24144874, -1.09839354, -1.3799354 ,\n",
      "       -1.99065936, -0.95286273, -1.00063692, -1.26076869, -1.05675874,\n",
      "       -1.99065936, -1.26840019, -0.94005721, -1.38595578, -0.61030866,\n",
      "       -1.99065936, -0.71439892, -0.95327059, -1.08793455, -1.3481988 ,\n",
      "       -1.99065936, -0.16415302, -1.13727631, -1.14171319, -1.3928046 ,\n",
      "       -1.99065936, -1.33280401, -1.27190475, -0.27698251, -0.59783505,\n",
      "       -1.99065936, -1.07410815, -1.08089054, -1.12720024, -1.4855256 ,\n",
      "       -1.99065936, -1.62930775, -0.61359425, -0.72750374, -0.81329352,\n",
      "       -1.99065936, -1.70892522, -1.39563787, -1.09084622, -0.99745661,\n",
      "       -1.99065936, -1.39579129, -0.93821183, -1.11267586, -1.3618317 ,\n",
      "       -1.99065936, -1.19673384, -1.36525872, -0.85516461, -1.0932371 ,\n",
      "       -1.99065936, -1.19721424, -1.04066598, -0.64450208, -1.55273813,\n",
      "       -1.99065936, -0.96791736, -1.80360259, -1.30149484, -0.2839262 ,\n",
      "       -1.99065936, -1.46805732, -1.39138411, -1.27190827, -1.23702557,\n",
      "       -1.99065936, -0.08671108, -1.62621405, -1.57943044, -1.04211977,\n",
      "       -1.99065936, -1.15063848, -0.56190361, -0.70644403, -1.11980448,\n",
      "       -1.99065936, -1.38225553, -0.725611  , -1.45496257, -0.76781471,\n",
      "       -1.99065936, -1.25031726, -0.74027637, -1.02105564, -1.20400562,\n",
      "       -1.99065936, -1.59105694, -1.61286693, -1.15849376, -1.42042045,\n",
      "       -1.99065936, -0.74705437, -1.07704607, -1.05218371, -0.8873684 ,\n",
      "       -1.99065936, -0.73331575, -0.69374305, -1.12760715, -1.54806115,\n",
      "       -1.99065936, -1.2807564 , -1.3624274 , -1.00578041, -0.87068549,\n",
      "       -1.99065936, -1.19865234, -0.8911431 , -0.77395507, -0.94454815,\n",
      "       -1.99065936, -0.69528866, -0.91436118, -1.2872507 , -1.58517108,\n",
      "       -1.99065936, -1.36052242, -1.39846098, -1.23202642, -1.22379228,\n",
      "       -1.99065936, -0.69681703, -1.39031723, -1.2262722 , -0.76031492,\n",
      "       -1.99065936, -1.35668371, -0.63182182, -0.6709817 , -1.35896601,\n",
      "       -1.99065936, -1.28068756, -1.55500923, -1.16366544, -1.91953158,\n",
      "       -1.99065936, -1.46342654, -0.30489002, -1.38779702, -0.98844208,\n",
      "       -1.99065936, -0.18323726, -1.19045485, -1.05977132, -1.09098738,\n",
      "       -1.99065936, -0.68175041, -1.21468183, -0.13046826, -0.42463526,\n",
      "       -1.99065936, -1.69748725, -0.80771108, -1.10917939, -0.76152976,\n",
      "       -1.99065936, -1.46977833, -1.08981589, -1.06376841, -1.09351361,\n",
      "       -1.99065936, -1.39495866, -1.18848993, -0.82710038, -1.19726168,\n",
      "       -1.99065936,  0.13395401, -0.06145497, -1.45255525, -1.30299591,\n",
      "       -1.99065936, -0.45357115, -1.71149023, -0.6245115 , -0.91892305,\n",
      "       -1.99065936, -1.05042194, -1.43260254, -1.44138882, -0.97532864,\n",
      "       -1.99065936, -1.37557435, -1.54062235, -0.11966081, -1.32451361,\n",
      "       -1.99065936, -0.99472941, -0.97100577, -0.98086822, -1.27070776,\n",
      "       -1.99065936, -1.1962407 , -1.14998185, -0.63131049, -1.34136496,\n",
      "       -1.99065936, -0.99383569, -1.28917555, -1.52504055, -0.90961462,\n",
      "       -1.99065936, -0.66839026, -0.63105238, -1.54710814, -1.17531734,\n",
      "       -1.99065936, -1.1795461 , -1.44927538, -1.53803568, -1.58135178,\n",
      "       -1.99065936, -1.12465192, -1.00732128, -0.86009437, -1.34203859,\n",
      "       -1.99065936, -1.20963163, -0.9615587 , -0.41452135, -1.46465335,\n",
      "       -1.99065936, -1.55746538, -1.56816282, -1.2954955 , -1.1642513 ,\n",
      "       -1.99065936, -1.2685908 , -1.17232309, -0.81652236, -0.87647916,\n",
      "       -1.99065936, -0.80280256, -1.08518401, -1.42118504, -0.99281612,\n",
      "       -1.99065936, -1.42223473, -1.31208892, -1.16823907, -1.08258662,\n",
      "       -1.99065936, -0.1048175 , -1.7909315 , -1.39731142, -1.61698453,\n",
      "       -1.99065936, -0.82725419, -0.9678525 , -1.48618196, -1.66092401,\n",
      "       -1.99065936, -0.73012181, -1.12669918, -1.34689929, -0.67030933,\n",
      "       -1.99065936, -0.51910207, -0.9698971 , -1.58338126, -0.91350581,\n",
      "       -1.99065936, -1.18769819, -1.40629907, -0.93678502, -0.92327303,\n",
      "       -1.99065936, -1.47733292, -1.22804599, -1.21117743, -1.42201438,\n",
      "       -1.99065936, -0.95033965, -1.04628571, -0.39094974, -1.70580013,\n",
      "       -1.99065936, -1.27439642, -0.9638992 , -0.97455712, -0.56043667,\n",
      "       -1.99065936, -1.36208643, -0.82787886, -0.40731838, -1.49553269])]\n",
      "\n",
      "At:\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning_gpu/lightning_gpu.py(551): apply_cq\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning_gpu/lightning_gpu.py(572): apply\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(320): execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(603): batch_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/contextlib.py(81): inner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(210): fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(287): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/jax_jit_tuple.py(191): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(185): _flat_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(45): pure_callback_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(107): _callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py(1917): _wrapped_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(1349): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/profiler.py(314): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1143): _pjit_call_impl_python\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1187): call_impl_cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1203): _pjit_call_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(815): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1415): _pjit_batcher\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/batching.py(398): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(163): _python_pjit_helper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(250): cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py(188): call_wrapped\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/api.py(1240): vmap_f\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/pennylane_backend.py(39): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(45): <listcomp>\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(44): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(93): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(153): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(2370): scope_fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(998): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(1034): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1746): init_with_output\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1845): init\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py(73): train_and_evaluate\n",
      "  /tmp/ipykernel_443038/3515301543.py(3): <module>\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3508): run_code\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3448): run_ast_nodes\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3269): run_cell_async\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3064): _run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3009): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/zmqshell.py(546): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/ipkernel.py(422): do_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(740): execute_request\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(412): dispatch_shell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(505): process_one\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(516): dispatch_queue\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/events.py(80): _run\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(1922): _run_once\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(607): run_forever\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/tornado/platform/asyncio.py(195): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelapp.py(736): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/traitlets/config/application.py(1043): launch_instance\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel_launcher.py(17): <module>\n",
      "  <frozen runpy>(88): _run_code\n",
      "  <frozen runpy>(198): _run_module_as_main\n",
      "; current tracing scope: custom-call.28; current profiling annotation: XlaModule:#hlo_module=jit_circuit,program_id=4813#.\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice=\"lightning.gpu\")\n",
    "    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)\n",
    "except jaxlib.xla_extension.XlaRuntimeError as e:\n",
    "    print(traceback.format_exc())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum with PennyLane with Lightning-GPU device and catalyst\n",
    "\n",
    "Not supported."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"/tmp/ipykernel_443038/3502656055.py\", line 3, in <module>\n",
      "    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py\", line 73, in train_and_evaluate\n",
      "    variables = model.init(params_key, x, train=False)\n",
      "                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 1845, in init\n",
      "    _, v_out = self.init_with_output(\n",
      "               ^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 1746, in init_with_output\n",
      "    return init_with_output(\n",
      "           ^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py\", line 1034, in wrapper\n",
      "    return apply(fn, mutable=mutable, flags=init_flags)(\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py\", line 998, in wrapper\n",
      "    y = fn(root, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 2370, in scope_fn\n",
      "    return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 153, in __call__\n",
      "    x = TransformerBlock(\n",
      "        ^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 93, in __call__\n",
      "    attn_output = MultiHeadSelfAttention(embed_dim=self.hidden_size, num_heads=self.num_heads,\n",
      "                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 44, in __call__\n",
      "    q, k, v = [\n",
      "              ^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 45, in <listcomp>\n",
      "    proj(x).reshape(batch_size, seq_len, self.num_heads, head_dim).swapaxes(1, 2)\n",
      "    ^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/pennylane_backend.py\", line 39, in __call__\n",
      "    x = jax.vmap(self.circuit, in_axes=(0, None))(x, weights)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/api.py\", line 1240, in vmap_f\n",
      "    out_flat = batching.batch(\n",
      "               ^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py\", line 188, in call_wrapped\n",
      "    ans = self.f(*args, **dict(self.params, **kwargs))\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/catalyst/compilation_pipelines.py\", line 560, in __call__\n",
      "    function, args = self._maybe_promote(self.compiled_function, *args)\n",
      "                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/catalyst/compilation_pipelines.py\", line 535, in _maybe_promote\n",
      "    self.mlir_module = self.get_mlir(*r_sig)\n",
      "                       ^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/catalyst/compilation_pipelines.py\", line 480, in get_mlir\n",
      "    mlir_module, ctx, jaxpr = tracer.get_mlir(self.qfunc, *self.c_sig)\n",
      "                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/catalyst/jax_tracer.py\", line 65, in get_mlir\n",
      "    jaxpr = jax.make_jaxpr(func)(*args, **kwargs)\n",
      "            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/api.py\", line 2432, in make_jaxpr_f\n",
      "    jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(\n",
      "                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/profiler.py\", line 314, in wrapper\n",
      "    return func(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py\", line 2191, in trace_to_jaxpr_dynamic2\n",
      "    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)\n",
      "                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py\", line 2206, in trace_to_subjaxpr_dynamic2\n",
      "    ans = fun.call_wrapped(*in_tracers_)\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py\", line 188, in call_wrapped\n",
      "    ans = self.f(*args, **dict(self.params, **kwargs))\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/catalyst/pennylane_extensions.py\", line 78, in __call__\n",
      "    raise CompileError(\n",
      "catalyst.utils.exceptions.CompileError: The lightning.gpu device is not supported for compilation at the moment.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice=\"lightning.gpu\", use_catalyst=True)\n",
    "    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)\n",
    "except catalyst.CompileError as e:\n",
    "    print(traceback.format_exc())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum with PennyLane with Lightning\n",
    "\n",
    "Same error as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"/tmp/ipykernel_443038/3447823501.py\", line 3, in <module>\n",
      "    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py\", line 73, in train_and_evaluate\n",
      "    variables = model.init(params_key, x, train=False)\n",
      "                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 1845, in init\n",
      "    _, v_out = self.init_with_output(\n",
      "               ^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 1746, in init_with_output\n",
      "    return init_with_output(\n",
      "           ^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py\", line 1034, in wrapper\n",
      "    return apply(fn, mutable=mutable, flags=init_flags)(\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py\", line 998, in wrapper\n",
      "    y = fn(root, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 2370, in scope_fn\n",
      "    return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 153, in __call__\n",
      "    x = TransformerBlock(\n",
      "        ^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 93, in __call__\n",
      "    attn_output = MultiHeadSelfAttention(embed_dim=self.hidden_size, num_heads=self.num_heads,\n",
      "                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 44, in __call__\n",
      "    q, k, v = [\n",
      "              ^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 45, in <listcomp>\n",
      "    proj(x).reshape(batch_size, seq_len, self.num_heads, head_dim).swapaxes(1, 2)\n",
      "    ^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/pennylane_backend.py\", line 39, in __call__\n",
      "    x = jax.vmap(self.circuit, in_axes=(0, None))(x, weights)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/api.py\", line 1240, in vmap_f\n",
      "    out_flat = batching.batch(\n",
      "               ^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py\", line 188, in call_wrapped\n",
      "    ans = self.f(*args, **dict(self.params, **kwargs))\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py\", line 250, in cache_miss\n",
      "    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(\n",
      "                                                 ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py\", line 163, in _python_pjit_helper\n",
      "    out_flat = pjit_p.bind(*args_flat, **params)\n",
      "               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 2677, in bind\n",
      "    return self.bind_with_trace(top_trace, args, params)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 383, in bind_with_trace\n",
      "    out = trace.process_primitive(self, map(trace.full_raise, args), params)\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/batching.py\", line 398, in process_primitive\n",
      "    val_out, dim_out = batched_primitive(vals_in, dims_in, **params)\n",
      "                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py\", line 1415, in _pjit_batcher\n",
      "    vals_out = pjit_p.bind(\n",
      "               ^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 2677, in bind\n",
      "    return self.bind_with_trace(top_trace, args, params)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 383, in bind_with_trace\n",
      "    out = trace.process_primitive(self, map(trace.full_raise, args), params)\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 815, in process_primitive\n",
      "    return primitive.impl(*tracers, **params)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py\", line 1203, in _pjit_call_impl\n",
      "    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py\", line 1187, in call_impl_cache_miss\n",
      "    out_flat, compiled = _pjit_call_impl_python(\n",
      "                         ^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py\", line 1143, in _pjit_call_impl_python\n",
      "    return compiled.unsafe_call(*args), compiled\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/profiler.py\", line 314, in wrapper\n",
      "    return func(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py\", line 1349, in __call__\n",
      "    results = self.xla_executable.execute_sharded(input_bufs)\n",
      "              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: TypeError: RX(): incompatible function arguments. The following argument types are supported:\n",
      "    1. (self: pennylane_lightning.lightning_qubit_ops.StateVectorC128, arg0: List[int], arg1: bool, arg2: List[float]) -> None\n",
      "\n",
      "Invoked with: <pennylane_lightning.lightning_qubit_ops.StateVectorC128 object at 0x7fcc7838ccb0>, [0], False, [array([-1.99065936, -1.39067658, -1.40838223, -1.28091254, -1.236017  ,\n",
      "       -1.99065936, -0.79695004, -0.97037919, -0.83079853, -0.9329554 ,\n",
      "       -1.99065936, -1.00111164, -1.10162356, -1.36933823, -0.90581727,\n",
      "       -1.99065936, -0.77338261, -1.07736925, -1.27714229, -1.22361589,\n",
      "       -1.99065936, -1.41910524, -1.18701463, -0.96778634, -1.39457697,\n",
      "       -1.99065936, -1.48105851, -1.51591239, -0.46449003, -0.78824321,\n",
      "       -1.99065936, -0.71558203, -1.24144874, -1.09839354, -1.3799354 ,\n",
      "       -1.99065936, -0.95286273, -1.00063692, -1.26076869, -1.05675874,\n",
      "       -1.99065936, -1.26840019, -0.94005721, -1.38595578, -0.61030866,\n",
      "       -1.99065936, -0.71439892, -0.95327059, -1.08793455, -1.3481988 ,\n",
      "       -1.99065936, -0.16415302, -1.13727631, -1.14171319, -1.3928046 ,\n",
      "       -1.99065936, -1.33280401, -1.27190475, -0.27698251, -0.59783505,\n",
      "       -1.99065936, -1.07410815, -1.08089054, -1.12720024, -1.4855256 ,\n",
      "       -1.99065936, -1.62930775, -0.61359425, -0.72750374, -0.81329352,\n",
      "       -1.99065936, -1.70892522, -1.39563787, -1.09084622, -0.99745661,\n",
      "       -1.99065936, -1.39579129, -0.93821183, -1.11267586, -1.3618317 ,\n",
      "       -1.99065936, -1.19673384, -1.36525872, -0.85516461, -1.0932371 ,\n",
      "       -1.99065936, -1.19721424, -1.04066598, -0.64450208, -1.55273813,\n",
      "       -1.99065936, -0.96791736, -1.80360259, -1.30149484, -0.2839262 ,\n",
      "       -1.99065936, -1.46805732, -1.39138411, -1.27190827, -1.23702557,\n",
      "       -1.99065936, -0.08671108, -1.62621405, -1.57943044, -1.04211977,\n",
      "       -1.99065936, -1.15063848, -0.56190361, -0.70644403, -1.11980448,\n",
      "       -1.99065936, -1.38225553, -0.725611  , -1.45496257, -0.76781471,\n",
      "       -1.99065936, -1.25031726, -0.74027637, -1.02105564, -1.20400562,\n",
      "       -1.99065936, -1.59105694, -1.61286693, -1.15849376, -1.42042045,\n",
      "       -1.99065936, -0.74705437, -1.07704607, -1.05218371, -0.8873684 ,\n",
      "       -1.99065936, -0.73331575, -0.69374305, -1.12760715, -1.54806115,\n",
      "       -1.99065936, -1.2807564 , -1.3624274 , -1.00578041, -0.87068549,\n",
      "       -1.99065936, -1.19865234, -0.8911431 , -0.77395507, -0.94454815,\n",
      "       -1.99065936, -0.69528866, -0.91436118, -1.2872507 , -1.58517108,\n",
      "       -1.99065936, -1.36052242, -1.39846098, -1.23202642, -1.22379228,\n",
      "       -1.99065936, -0.69681703, -1.39031723, -1.2262722 , -0.76031492,\n",
      "       -1.99065936, -1.35668371, -0.63182182, -0.6709817 , -1.35896601,\n",
      "       -1.99065936, -1.28068756, -1.55500923, -1.16366544, -1.91953158,\n",
      "       -1.99065936, -1.46342654, -0.30489002, -1.38779702, -0.98844208,\n",
      "       -1.99065936, -0.18323726, -1.19045485, -1.05977132, -1.09098738,\n",
      "       -1.99065936, -0.68175041, -1.21468183, -0.13046826, -0.42463526,\n",
      "       -1.99065936, -1.69748725, -0.80771108, -1.10917939, -0.76152976,\n",
      "       -1.99065936, -1.46977833, -1.08981589, -1.06376841, -1.09351361,\n",
      "       -1.99065936, -1.39495866, -1.18848993, -0.82710038, -1.19726168,\n",
      "       -1.99065936,  0.13395401, -0.06145497, -1.45255525, -1.30299591,\n",
      "       -1.99065936, -0.45357115, -1.71149023, -0.6245115 , -0.91892305,\n",
      "       -1.99065936, -1.05042194, -1.43260254, -1.44138882, -0.97532864,\n",
      "       -1.99065936, -1.37557435, -1.54062235, -0.11966081, -1.32451361,\n",
      "       -1.99065936, -0.99472941, -0.97100577, -0.98086822, -1.27070776,\n",
      "       -1.99065936, -1.1962407 , -1.14998185, -0.63131049, -1.34136496,\n",
      "       -1.99065936, -0.99383569, -1.28917555, -1.52504055, -0.90961462,\n",
      "       -1.99065936, -0.66839026, -0.63105238, -1.54710814, -1.17531734,\n",
      "       -1.99065936, -1.1795461 , -1.44927538, -1.53803568, -1.58135178,\n",
      "       -1.99065936, -1.12465192, -1.00732128, -0.86009437, -1.34203859,\n",
      "       -1.99065936, -1.20963163, -0.9615587 , -0.41452135, -1.46465335,\n",
      "       -1.99065936, -1.55746538, -1.56816282, -1.2954955 , -1.1642513 ,\n",
      "       -1.99065936, -1.2685908 , -1.17232309, -0.81652236, -0.87647916,\n",
      "       -1.99065936, -0.80280256, -1.08518401, -1.42118504, -0.99281612,\n",
      "       -1.99065936, -1.42223473, -1.31208892, -1.16823907, -1.08258662,\n",
      "       -1.99065936, -0.1048175 , -1.7909315 , -1.39731142, -1.61698453,\n",
      "       -1.99065936, -0.82725419, -0.9678525 , -1.48618196, -1.66092401,\n",
      "       -1.99065936, -0.73012181, -1.12669918, -1.34689929, -0.67030933,\n",
      "       -1.99065936, -0.51910207, -0.9698971 , -1.58338126, -0.91350581,\n",
      "       -1.99065936, -1.18769819, -1.40629907, -0.93678502, -0.92327303,\n",
      "       -1.99065936, -1.47733292, -1.22804599, -1.21117743, -1.42201438,\n",
      "       -1.99065936, -0.95033965, -1.04628571, -0.39094974, -1.70580013,\n",
      "       -1.99065936, -1.27439642, -0.9638992 , -0.97455712, -0.56043667,\n",
      "       -1.99065936, -1.36208643, -0.82787886, -0.40731838, -1.49553269])]\n",
      "\n",
      "At:\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning/lightning_qubit.py(413): apply_lightning\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning/lightning_qubit.py(435): apply\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(320): execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(603): batch_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/contextlib.py(81): inner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(210): fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(287): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/jax_jit_tuple.py(191): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(185): _flat_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(45): pure_callback_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(107): _callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py(1917): _wrapped_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(1349): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/profiler.py(314): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1143): _pjit_call_impl_python\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1187): call_impl_cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1203): _pjit_call_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(815): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1415): _pjit_batcher\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/batching.py(398): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(163): _python_pjit_helper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(250): cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py(188): call_wrapped\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/api.py(1240): vmap_f\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/pennylane_backend.py(39): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(45): <listcomp>\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(44): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(93): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(153): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(2370): scope_fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(998): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(1034): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1746): init_with_output\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1845): init\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py(73): train_and_evaluate\n",
      "  /tmp/ipykernel_443038/3447823501.py(3): <module>\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3508): run_code\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3448): run_ast_nodes\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3269): run_cell_async\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3064): _run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3009): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/zmqshell.py(546): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/ipkernel.py(422): do_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(740): execute_request\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(412): dispatch_shell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(505): process_one\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(516): dispatch_queue\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/events.py(80): _run\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(1922): _run_once\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(607): run_forever\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/tornado/platform/asyncio.py(195): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelapp.py(736): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/traitlets/config/application.py(1043): launch_instance\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel_launcher.py(17): <module>\n",
      "  <frozen runpy>(88): _run_code\n",
      "  <frozen runpy>(198): _run_module_as_main\n",
      "; current tracing scope: custom-call.28; current profiling annotation: XlaModule:#hlo_module=jit_circuit,program_id=4819#.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-08-14 06:44:19.334928: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:\n",
      "INTERNAL: CpuCallback error: TypeError: RX(): incompatible function arguments. The following argument types are supported:\n",
      "    1. (self: pennylane_lightning.lightning_qubit_ops.StateVectorC128, arg0: List[int], arg1: bool, arg2: List[float]) -> None\n",
      "\n",
      "Invoked with: <pennylane_lightning.lightning_qubit_ops.StateVectorC128 object at 0x7fcc7838ccb0>, [0], False, [array([-1.99065936, -1.39067658, -1.40838223, -1.28091254, -1.236017  ,\n",
      "       -1.99065936, -0.79695004, -0.97037919, -0.83079853, -0.9329554 ,\n",
      "       -1.99065936, -1.00111164, -1.10162356, -1.36933823, -0.90581727,\n",
      "       -1.99065936, -0.77338261, -1.07736925, -1.27714229, -1.22361589,\n",
      "       -1.99065936, -1.41910524, -1.18701463, -0.96778634, -1.39457697,\n",
      "       -1.99065936, -1.48105851, -1.51591239, -0.46449003, -0.78824321,\n",
      "       -1.99065936, -0.71558203, -1.24144874, -1.09839354, -1.3799354 ,\n",
      "       -1.99065936, -0.95286273, -1.00063692, -1.26076869, -1.05675874,\n",
      "       -1.99065936, -1.26840019, -0.94005721, -1.38595578, -0.61030866,\n",
      "       -1.99065936, -0.71439892, -0.95327059, -1.08793455, -1.3481988 ,\n",
      "       -1.99065936, -0.16415302, -1.13727631, -1.14171319, -1.3928046 ,\n",
      "       -1.99065936, -1.33280401, -1.27190475, -0.27698251, -0.59783505,\n",
      "       -1.99065936, -1.07410815, -1.08089054, -1.12720024, -1.4855256 ,\n",
      "       -1.99065936, -1.62930775, -0.61359425, -0.72750374, -0.81329352,\n",
      "       -1.99065936, -1.70892522, -1.39563787, -1.09084622, -0.99745661,\n",
      "       -1.99065936, -1.39579129, -0.93821183, -1.11267586, -1.3618317 ,\n",
      "       -1.99065936, -1.19673384, -1.36525872, -0.85516461, -1.0932371 ,\n",
      "       -1.99065936, -1.19721424, -1.04066598, -0.64450208, -1.55273813,\n",
      "       -1.99065936, -0.96791736, -1.80360259, -1.30149484, -0.2839262 ,\n",
      "       -1.99065936, -1.46805732, -1.39138411, -1.27190827, -1.23702557,\n",
      "       -1.99065936, -0.08671108, -1.62621405, -1.57943044, -1.04211977,\n",
      "       -1.99065936, -1.15063848, -0.56190361, -0.70644403, -1.11980448,\n",
      "       -1.99065936, -1.38225553, -0.725611  , -1.45496257, -0.76781471,\n",
      "       -1.99065936, -1.25031726, -0.74027637, -1.02105564, -1.20400562,\n",
      "       -1.99065936, -1.59105694, -1.61286693, -1.15849376, -1.42042045,\n",
      "       -1.99065936, -0.74705437, -1.07704607, -1.05218371, -0.8873684 ,\n",
      "       -1.99065936, -0.73331575, -0.69374305, -1.12760715, -1.54806115,\n",
      "       -1.99065936, -1.2807564 , -1.3624274 , -1.00578041, -0.87068549,\n",
      "       -1.99065936, -1.19865234, -0.8911431 , -0.77395507, -0.94454815,\n",
      "       -1.99065936, -0.69528866, -0.91436118, -1.2872507 , -1.58517108,\n",
      "       -1.99065936, -1.36052242, -1.39846098, -1.23202642, -1.22379228,\n",
      "       -1.99065936, -0.69681703, -1.39031723, -1.2262722 , -0.76031492,\n",
      "       -1.99065936, -1.35668371, -0.63182182, -0.6709817 , -1.35896601,\n",
      "       -1.99065936, -1.28068756, -1.55500923, -1.16366544, -1.91953158,\n",
      "       -1.99065936, -1.46342654, -0.30489002, -1.38779702, -0.98844208,\n",
      "       -1.99065936, -0.18323726, -1.19045485, -1.05977132, -1.09098738,\n",
      "       -1.99065936, -0.68175041, -1.21468183, -0.13046826, -0.42463526,\n",
      "       -1.99065936, -1.69748725, -0.80771108, -1.10917939, -0.76152976,\n",
      "       -1.99065936, -1.46977833, -1.08981589, -1.06376841, -1.09351361,\n",
      "       -1.99065936, -1.39495866, -1.18848993, -0.82710038, -1.19726168,\n",
      "       -1.99065936,  0.13395401, -0.06145497, -1.45255525, -1.30299591,\n",
      "       -1.99065936, -0.45357115, -1.71149023, -0.6245115 , -0.91892305,\n",
      "       -1.99065936, -1.05042194, -1.43260254, -1.44138882, -0.97532864,\n",
      "       -1.99065936, -1.37557435, -1.54062235, -0.11966081, -1.32451361,\n",
      "       -1.99065936, -0.99472941, -0.97100577, -0.98086822, -1.27070776,\n",
      "       -1.99065936, -1.1962407 , -1.14998185, -0.63131049, -1.34136496,\n",
      "       -1.99065936, -0.99383569, -1.28917555, -1.52504055, -0.90961462,\n",
      "       -1.99065936, -0.66839026, -0.63105238, -1.54710814, -1.17531734,\n",
      "       -1.99065936, -1.1795461 , -1.44927538, -1.53803568, -1.58135178,\n",
      "       -1.99065936, -1.12465192, -1.00732128, -0.86009437, -1.34203859,\n",
      "       -1.99065936, -1.20963163, -0.9615587 , -0.41452135, -1.46465335,\n",
      "       -1.99065936, -1.55746538, -1.56816282, -1.2954955 , -1.1642513 ,\n",
      "       -1.99065936, -1.2685908 , -1.17232309, -0.81652236, -0.87647916,\n",
      "       -1.99065936, -0.80280256, -1.08518401, -1.42118504, -0.99281612,\n",
      "       -1.99065936, -1.42223473, -1.31208892, -1.16823907, -1.08258662,\n",
      "       -1.99065936, -0.1048175 , -1.7909315 , -1.39731142, -1.61698453,\n",
      "       -1.99065936, -0.82725419, -0.9678525 , -1.48618196, -1.66092401,\n",
      "       -1.99065936, -0.73012181, -1.12669918, -1.34689929, -0.67030933,\n",
      "       -1.99065936, -0.51910207, -0.9698971 , -1.58338126, -0.91350581,\n",
      "       -1.99065936, -1.18769819, -1.40629907, -0.93678502, -0.92327303,\n",
      "       -1.99065936, -1.47733292, -1.22804599, -1.21117743, -1.42201438,\n",
      "       -1.99065936, -0.95033965, -1.04628571, -0.39094974, -1.70580013,\n",
      "       -1.99065936, -1.27439642, -0.9638992 , -0.97455712, -0.56043667,\n",
      "       -1.99065936, -1.36208643, -0.82787886, -0.40731838, -1.49553269])]\n",
      "\n",
      "At:\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning/lightning_qubit.py(413): apply_lightning\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning/lightning_qubit.py(435): apply\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(320): execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(603): batch_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/contextlib.py(81): inner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(210): fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(287): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/jax_jit_tuple.py(191): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(185): _flat_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(45): pure_callback_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(107): _callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py(1917): _wrapped_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(1349): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/profiler.py(314): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1143): _pjit_call_impl_python\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1187): call_impl_cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1203): _pjit_call_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(815): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1415): _pjit_batcher\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/batching.py(398): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(163): _python_pjit_helper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(250): cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py(188): call_wrapped\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/api.py(1240): vmap_f\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/pennylane_backend.py(39): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(45): <listcomp>\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(44): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(93): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(153): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(2370): scope_fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(998): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(1034): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1746): init_with_output\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1845): init\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py(73): train_and_evaluate\n",
      "  /tmp/ipykernel_443038/3447823501.py(3): <module>\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3508): run_code\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3448): run_ast_nodes\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3269): run_cell_async\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3064): _run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3009): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/zmqshell.py(546): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/ipkernel.py(422): do_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(740): execute_request\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(412): dispatch_shell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(505): process_one\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(516): dispatch_queue\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/events.py(80): _run\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(1922): _run_once\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(607): run_forever\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/tornado/platform/asyncio.py(195): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelapp.py(736): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/traitlets/config/application.py(1043): launch_instance\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel_launcher.py(17): <module>\n",
      "  <frozen runpy>(88): _run_code\n",
      "  <frozen runpy>(198): _run_module_as_main\n",
      "\n",
      "2023-08-14 06:44:19.334966: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2461] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: TypeError: RX(): incompatible function arguments. The following argument types are supported:\n",
      "    1. (self: pennylane_lightning.lightning_qubit_ops.StateVectorC128, arg0: List[int], arg1: bool, arg2: List[float]) -> None\n",
      "\n",
      "Invoked with: <pennylane_lightning.lightning_qubit_ops.StateVectorC128 object at 0x7fcc7838ccb0>, [0], False, [array([-1.99065936, -1.39067658, -1.40838223, -1.28091254, -1.236017  ,\n",
      "       -1.99065936, -0.79695004, -0.97037919, -0.83079853, -0.9329554 ,\n",
      "       -1.99065936, -1.00111164, -1.10162356, -1.36933823, -0.90581727,\n",
      "       -1.99065936, -0.77338261, -1.07736925, -1.27714229, -1.22361589,\n",
      "       -1.99065936, -1.41910524, -1.18701463, -0.96778634, -1.39457697,\n",
      "       -1.99065936, -1.48105851, -1.51591239, -0.46449003, -0.78824321,\n",
      "       -1.99065936, -0.71558203, -1.24144874, -1.09839354, -1.3799354 ,\n",
      "       -1.99065936, -0.95286273, -1.00063692, -1.26076869, -1.05675874,\n",
      "       -1.99065936, -1.26840019, -0.94005721, -1.38595578, -0.61030866,\n",
      "       -1.99065936, -0.71439892, -0.95327059, -1.08793455, -1.3481988 ,\n",
      "       -1.99065936, -0.16415302, -1.13727631, -1.14171319, -1.3928046 ,\n",
      "       -1.99065936, -1.33280401, -1.27190475, -0.27698251, -0.59783505,\n",
      "       -1.99065936, -1.07410815, -1.08089054, -1.12720024, -1.4855256 ,\n",
      "       -1.99065936, -1.62930775, -0.61359425, -0.72750374, -0.81329352,\n",
      "       -1.99065936, -1.70892522, -1.39563787, -1.09084622, -0.99745661,\n",
      "       -1.99065936, -1.39579129, -0.93821183, -1.11267586, -1.3618317 ,\n",
      "       -1.99065936, -1.19673384, -1.36525872, -0.85516461, -1.0932371 ,\n",
      "       -1.99065936, -1.19721424, -1.04066598, -0.64450208, -1.55273813,\n",
      "       -1.99065936, -0.96791736, -1.80360259, -1.30149484, -0.2839262 ,\n",
      "       -1.99065936, -1.46805732, -1.39138411, -1.27190827, -1.23702557,\n",
      "       -1.99065936, -0.08671108, -1.62621405, -1.57943044, -1.04211977,\n",
      "       -1.99065936, -1.15063848, -0.56190361, -0.70644403, -1.11980448,\n",
      "       -1.99065936, -1.38225553, -0.725611  , -1.45496257, -0.76781471,\n",
      "       -1.99065936, -1.25031726, -0.74027637, -1.02105564, -1.20400562,\n",
      "       -1.99065936, -1.59105694, -1.61286693, -1.15849376, -1.42042045,\n",
      "       -1.99065936, -0.74705437, -1.07704607, -1.05218371, -0.8873684 ,\n",
      "       -1.99065936, -0.73331575, -0.69374305, -1.12760715, -1.54806115,\n",
      "       -1.99065936, -1.2807564 , -1.3624274 , -1.00578041, -0.87068549,\n",
      "       -1.99065936, -1.19865234, -0.8911431 , -0.77395507, -0.94454815,\n",
      "       -1.99065936, -0.69528866, -0.91436118, -1.2872507 , -1.58517108,\n",
      "       -1.99065936, -1.36052242, -1.39846098, -1.23202642, -1.22379228,\n",
      "       -1.99065936, -0.69681703, -1.39031723, -1.2262722 , -0.76031492,\n",
      "       -1.99065936, -1.35668371, -0.63182182, -0.6709817 , -1.35896601,\n",
      "       -1.99065936, -1.28068756, -1.55500923, -1.16366544, -1.91953158,\n",
      "       -1.99065936, -1.46342654, -0.30489002, -1.38779702, -0.98844208,\n",
      "       -1.99065936, -0.18323726, -1.19045485, -1.05977132, -1.09098738,\n",
      "       -1.99065936, -0.68175041, -1.21468183, -0.13046826, -0.42463526,\n",
      "       -1.99065936, -1.69748725, -0.80771108, -1.10917939, -0.76152976,\n",
      "       -1.99065936, -1.46977833, -1.08981589, -1.06376841, -1.09351361,\n",
      "       -1.99065936, -1.39495866, -1.18848993, -0.82710038, -1.19726168,\n",
      "       -1.99065936,  0.13395401, -0.06145497, -1.45255525, -1.30299591,\n",
      "       -1.99065936, -0.45357115, -1.71149023, -0.6245115 , -0.91892305,\n",
      "       -1.99065936, -1.05042194, -1.43260254, -1.44138882, -0.97532864,\n",
      "       -1.99065936, -1.37557435, -1.54062235, -0.11966081, -1.32451361,\n",
      "       -1.99065936, -0.99472941, -0.97100577, -0.98086822, -1.27070776,\n",
      "       -1.99065936, -1.1962407 , -1.14998185, -0.63131049, -1.34136496,\n",
      "       -1.99065936, -0.99383569, -1.28917555, -1.52504055, -0.90961462,\n",
      "       -1.99065936, -0.66839026, -0.63105238, -1.54710814, -1.17531734,\n",
      "       -1.99065936, -1.1795461 , -1.44927538, -1.53803568, -1.58135178,\n",
      "       -1.99065936, -1.12465192, -1.00732128, -0.86009437, -1.34203859,\n",
      "       -1.99065936, -1.20963163, -0.9615587 , -0.41452135, -1.46465335,\n",
      "       -1.99065936, -1.55746538, -1.56816282, -1.2954955 , -1.1642513 ,\n",
      "       -1.99065936, -1.2685908 , -1.17232309, -0.81652236, -0.87647916,\n",
      "       -1.99065936, -0.80280256, -1.08518401, -1.42118504, -0.99281612,\n",
      "       -1.99065936, -1.42223473, -1.31208892, -1.16823907, -1.08258662,\n",
      "       -1.99065936, -0.1048175 , -1.7909315 , -1.39731142, -1.61698453,\n",
      "       -1.99065936, -0.82725419, -0.9678525 , -1.48618196, -1.66092401,\n",
      "       -1.99065936, -0.73012181, -1.12669918, -1.34689929, -0.67030933,\n",
      "       -1.99065936, -0.51910207, -0.9698971 , -1.58338126, -0.91350581,\n",
      "       -1.99065936, -1.18769819, -1.40629907, -0.93678502, -0.92327303,\n",
      "       -1.99065936, -1.47733292, -1.22804599, -1.21117743, -1.42201438,\n",
      "       -1.99065936, -0.95033965, -1.04628571, -0.39094974, -1.70580013,\n",
      "       -1.99065936, -1.27439642, -0.9638992 , -0.97455712, -0.56043667,\n",
      "       -1.99065936, -1.36208643, -0.82787886, -0.40731838, -1.49553269])]\n",
      "\n",
      "At:\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning/lightning_qubit.py(413): apply_lightning\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane_lightning/lightning_qubit.py(435): apply\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(320): execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/_qubit_device.py(603): batch_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/contextlib.py(81): inner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(210): fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/execution.py(287): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/pennylane/interfaces/jax_jit_tuple.py(191): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(185): _flat_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(45): pure_callback_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py(107): _callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py(1917): _wrapped_callback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(1349): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/profiler.py(314): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1143): _pjit_call_impl_python\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1187): call_impl_cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1203): _pjit_call_impl\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(815): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(1415): _pjit_batcher\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/batching.py(398): process_primitive\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py(2677): bind\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(163): _python_pjit_helper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/pjit.py(250): cache_miss\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py(188): call_wrapped\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/api.py(1240): vmap_f\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/pennylane_backend.py(39): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(45): <listcomp>\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(44): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(93): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py(153): __call__\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(966): _call_wrapped_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(467): wrapped_module_method\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(2370): scope_fn\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(998): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py(1034): wrapper\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1746): init_with_output\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py(1845): init\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback\n",
      "  /global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py(73): train_and_evaluate\n",
      "  /tmp/ipykernel_443038/3447823501.py(3): <module>\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3508): run_code\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3448): run_ast_nodes\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3269): run_cell_async\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3064): _run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3009): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/zmqshell.py(546): run_cell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/ipkernel.py(422): do_execute\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(740): execute_request\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(412): dispatch_shell\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(505): process_one\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelbase.py(516): dispatch_queue\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/events.py(80): _run\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(1922): _run_once\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/asyncio/base_events.py(607): run_forever\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/tornado/platform/asyncio.py(195): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel/kernelapp.py(736): start\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/traitlets/config/application.py(1043): launch_instance\n",
      "  /global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/ipykernel_launcher.py(17): <module>\n",
      "  <frozen runpy>(88): _run_code\n",
      "  <frozen runpy>(198): _run_module_as_main\n",
      "; current tracing scope: custom-call.28; current profiling annotation: XlaModule:#hlo_module=jit_circuit,program_id=4819#.\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice=\"lightning.qubit\")\n",
    "    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)\n",
    "except jaxlib.xla_extension.XlaRuntimeError as e:\n",
    "    print(traceback.format_exc())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum with PennyLane with Lightning and catalyst\n",
    "\n",
    "Also results in error."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"/tmp/ipykernel_443038/2603097982.py\", line 3, in <module>\n",
      "    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/training.py\", line 73, in train_and_evaluate\n",
      "    variables = model.init(params_key, x, train=False)\n",
      "                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 1845, in init\n",
      "    _, v_out = self.init_with_output(\n",
      "               ^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 1746, in init_with_output\n",
      "    return init_with_output(\n",
      "           ^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py\", line 1034, in wrapper\n",
      "    return apply(fn, mutable=mutable, flags=init_flags)(\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/core/scope.py\", line 998, in wrapper\n",
      "    y = fn(root, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 2370, in scope_fn\n",
      "    return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 153, in __call__\n",
      "    x = TransformerBlock(\n",
      "        ^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 93, in __call__\n",
      "    attn_output = MultiHeadSelfAttention(embed_dim=self.hidden_size, num_heads=self.num_heads,\n",
      "                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 44, in __call__\n",
      "    q, k, v = [\n",
      "              ^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/vit.py\", line 45, in <listcomp>\n",
      "    proj(x).reshape(batch_size, seq_len, self.num_heads, head_dim).swapaxes(1, 2)\n",
      "    ^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 467, in wrapped_module_method\n",
      "    return self._call_wrapped_method(fun, args, kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/flax/linen/module.py\", line 966, in _call_wrapped_method\n",
      "    y = fun(self, *args, **kwargs)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/u1/s/salcc/QuantumTransformers/quantum_transformers/qmlperfcomp/jax_backend/quantum/pennylane_backend.py\", line 39, in __call__\n",
      "    x = jax.vmap(self.circuit, in_axes=(0, None))(x, weights)\n",
      "        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/api.py\", line 1240, in vmap_f\n",
      "    out_flat = batching.batch(\n",
      "               ^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py\", line 188, in call_wrapped\n",
      "    ans = self.f(*args, **dict(self.params, **kwargs))\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/catalyst/compilation_pipelines.py\", line 569, in __call__\n",
      "    return self.jaxed_qfunc(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/catalyst/compilation_pipelines.py\", line 663, in __call__\n",
      "    return self.jaxed_qfunc(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/custom_derivatives.py\", line 259, in __call__\n",
      "    out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat,\n",
      "               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/custom_derivatives.py\", line 361, in bind\n",
      "    outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers,  # type: ignore\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/batching.py\", line 491, in process_custom_jvp_call\n",
      "    out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros)\n",
      "               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/custom_derivatives.py\", line 361, in bind\n",
      "    outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers,  # type: ignore\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 829, in process_custom_jvp_call\n",
      "    return fun.call_wrapped(*tracers)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/linear_util.py\", line 188, in call_wrapped\n",
      "    ans = self.f(*args, **dict(self.params, **kwargs))\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/catalyst/compilation_pipelines.py\", line 590, in jaxed_qfunc\n",
      "    results = self.wrap_callback(qfunc, *args, **kwargs)\n",
      "              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/catalyst/compilation_pipelines.py\", line 603, in wrap_callback\n",
      "    return jax.pure_callback(qfunc, qfunc.jaxpr.out_avals, *args, vectorized=False, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py\", line 250, in pure_callback_api\n",
      "    return pure_callback(callback, result_shape_dtypes, *args,\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py\", line 192, in pure_callback\n",
      "    out_flat = pure_callback_p.bind(\n",
      "               ^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 380, in bind\n",
      "    return self.bind_with_trace(find_top_trace(args), args, params)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 383, in bind_with_trace\n",
      "    out = trace.process_primitive(self, map(trace.full_raise, args), params)\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/batching.py\", line 398, in process_primitive\n",
      "    val_out, dim_out = batched_primitive(vals_in, dims_in, **params)\n",
      "                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/callback.py\", line 97, in pure_callback_batching_rule\n",
      "    outvals = lax_map(_batch_fun, batched_args)\n",
      "              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py\", line 1818, in map\n",
      "    _, ys = scan(g, (), xs)\n",
      "            ^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/traceback_util.py\", line 166, in reraise_with_filtered_traceback\n",
      "    return fun(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py\", line 263, in scan\n",
      "    out = scan_p.bind(*consts, *in_flat,\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py\", line 1038, in scan_bind\n",
      "    return core.AxisPrimitive.bind(scan_p, *args, **params)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 2677, in bind\n",
      "    return self.bind_with_trace(top_trace, args, params)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 383, in bind_with_trace\n",
      "    out = trace.process_primitive(self, map(trace.full_raise, args), params)\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/core.py\", line 815, in process_primitive\n",
      "    return primitive.impl(*tracers, **params)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/dispatch.py\", line 144, in apply_primitive\n",
      "    return compiled_fun(*args)\n",
      "           ^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/profiler.py\", line 314, in wrapper\n",
      "    return func(*args, **kwargs)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/global/common/software/m4392/conda/gsoc/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py\", line 1349, in __call__\n",
      "    results = self.xla_executable.execute_sharded(input_bufs)\n",
      "              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: xla_python_gpu_callback XLA extension have thrown an exception: [/__w/catalyst/catalyst/runtime-build/_deps/pennylane_lightning-src/pennylane_lightning/src/simulator/KernelMap.hpp][Line:270][Method:assignKernelForOp]: Error in PennyLane Lightning: The given interval conflicts with existing intervals.; current profiling annotation: XlaModule:#hlo_module=jit_scan,program_id=4821#.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-08-14 06:44:20.747541: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:\n",
      "UNKNOWN: xla_python_gpu_callback XLA extension have thrown an exception: [/__w/catalyst/catalyst/runtime-build/_deps/pennylane_lightning-src/pennylane_lightning/src/simulator/KernelMap.hpp][Line:270][Method:assignKernelForOp]: Error in PennyLane Lightning: The given interval conflicts with existing intervals.\n",
      "2023-08-14 06:44:20.747573: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2461] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: xla_python_gpu_callback XLA extension have thrown an exception: [/__w/catalyst/catalyst/runtime-build/_deps/pennylane_lightning-src/pennylane_lightning/src/simulator/KernelMap.hpp][Line:270][Method:assignKernelForOp]: Error in PennyLane Lightning: The given interval conflicts with existing intervals.; current profiling annotation: XlaModule:#hlo_module=jit_scan,program_id=4821#.\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qdevice=\"lightning.qubit\", use_catalyst=True)\n",
    "    qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)\n",
    "except Exception as e:\n",
    "    print(traceback.format_exc())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum with TensorCircuit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch   1/10: 100%|██████████| 937/937 [00:44<00:00, 20.97batch/s, Loss = 2.0976, AUC = 68.97%]                                                                                                                                           \n",
      "Epoch   2/10: 100%|██████████| 937/937 [00:12<00:00, 76.27batch/s, Loss = 1.9541, AUC = 76.78%]                                                                                                                                           \n",
      "Epoch   3/10: 100%|██████████| 937/937 [00:12<00:00, 76.29batch/s, Loss = 1.7557, AUC = 80.60%]                                                                                                                                           \n",
      "Epoch   4/10: 100%|██████████| 937/937 [00:12<00:00, 73.74batch/s, Loss = 1.6315, AUC = 84.99%]                                                                                                                                           \n",
      "Epoch   5/10: 100%|██████████| 937/937 [00:12<00:00, 75.55batch/s, Loss = 1.5194, AUC = 87.82%]                                                                                                                                           \n",
      "Epoch   6/10: 100%|██████████| 937/937 [00:12<00:00, 74.77batch/s, Loss = 1.4198, AUC = 89.33%]                                                                                                                                           \n",
      "Epoch   7/10: 100%|██████████| 937/937 [00:11<00:00, 79.45batch/s, Loss = 1.2961, AUC = 91.21%]                                                                                                                                           \n",
      "Epoch   8/10: 100%|██████████| 937/937 [00:11<00:00, 80.37batch/s, Loss = 1.2002, AUC = 92.41%]                                                                                                                                           \n",
      "Epoch   9/10: 100%|██████████| 937/937 [00:12<00:00, 76.22batch/s, Loss = 1.1331, AUC = 93.22%]                                                                                                                                           \n",
      "Epoch  10/10: 100%|██████████| 937/937 [00:11<00:00, 78.23batch/s, Loss = 1.0649, AUC = 93.98%]                                                                                                                                           "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TOTAL TIME = 154.64s\n",
      "BEST AUC = 93.98% AT EPOCH 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = qpcjax.quantum.VisionTransformer(num_classes=10, patch_size=14, hidden_size=6, num_heads=2, num_transformer_blocks=4, mlp_hidden_size=3, qml_backend=\"tensorcircuit\")\n",
    "qpcjax.training.train_and_evaluate(model, train_dataloader, valid_dataloader, num_classes=10, learning_rate=0.0003, num_epochs=10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gsoc",
   "language": "python",
   "name": "gsoc"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
