{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "589fb211-66fe-4eb7-ba69-6485afc83f8e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building table: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:00<00:00, 10742.97it/s]\n",
      "Building table: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60000/60000 [00:05<00:00, 11510.08it/s]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import tqdm\n",
    "\n",
    "import unittest, torch\n",
    "\n",
    "from random import seed, sample\n",
    "\n",
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor\n",
    "\n",
    "from tql import Database, Query, Table\n",
    "from tql.sqrl import SQRL\n",
    "\n",
    "import time\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "\n",
    "test_data = datasets.MNIST(\n",
    "        root = 'data',\n",
    "        train = False,\n",
    "        transform = ToTensor(),\n",
    "        download = True,\n",
    "    )\n",
    "train_data = datasets.MNIST(\n",
    "        root = 'data',\n",
    "        train = True,\n",
    "        transform = ToTensor(),\n",
    "        download = True,\n",
    "    )\n",
    "\n",
    "db = Database(\"mnist\")\n",
    "db.register_dataset(test_data, \"test\")\n",
    "db.register_dataset(train_data, \"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f627237e-9607-494b-91d7-c2f927c2ee41",
   "metadata": {},
   "outputs": [],
   "source": [
    "class row_sum_batch(torch.nn.Module):\n",
    "    def __init__(self, row):\n",
    "        super().__init__()\n",
    "        self.row = row\n",
    "        \n",
    "    def forward(self, image_tensor, label):\n",
    "        return image_tensor[:,0,self.row].sum(dim=1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a1e2c8ea-5bf6-4b85-803b-325e619e0443",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10000])\n"
     ]
    }
   ],
   "source": [
    "print(test_data.targets.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d93d4a24-f462-4f04-8d55-a15fd21d3726",
   "metadata": {},
   "outputs": [],
   "source": [
    "q = Query('satisfies_rule_sum10', base='train')\\\n",
    "        .project(lambda imgs, labels: zip(imgs, labels, row_sum_batch(10)(imgs, labels)), batch_size=256)\\\n",
    "        .project(lambda imgs, label, sum_row_10: zip(imgs, label, sum_row_10, torch.logical_and(label == 5.0, torch.logical_and(0.7984313845634461 <= sum_row_10, sum_row_10 <= 15.972551059722871))), batch_size=256)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "aee73776-cd1e-4608-9e3f-a2b65f340aa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "res=q(db, disable=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "564d1bf0-6f6c-4276-b724-e428a304d1e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import argparse\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "from torch.optim.lr_scheduler import StepLR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fc3a88e5-21a6-496c-8ec0-d15ce396d218",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
    "        self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
    "        self.dropout1 = nn.Dropout(0.25)\n",
    "        self.dropout2 = nn.Dropout(0.5)\n",
    "        self.fc1 = nn.Linear(9216, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dropout1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dropout2(x)\n",
    "        x = self.fc2(x)\n",
    "        output = F.log_softmax(x, dim=1)\n",
    "        return output\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c950132f-98b3-4406-8dc7-28cf5a57605d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Is CUDA enabled? True\n"
     ]
    }
   ],
   "source": [
    "print(\"Is CUDA enabled?\",torch.cuda.is_available())\n",
    "\n",
    "model = Net().to('cuda')\n",
    "# model = nn.DataParallel(model)\n",
    "optimizer = optim.Adadelta(model.parameters(), lr=1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "965ee2dc-4891-4cdb-a1fd-d3f0961f2de9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#model.load_state_dict(torch.load(\"mnist_cnn.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a8a479ce-cb5e-4ce5-86eb-811c40ad2efa",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def train(args, model, device, train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if batch_idx % 100 == 0:\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * len(data), len(train_loader.dataset),\n",
    "                100. * batch_idx / len(train_loader), loss.item()))\n",
    "            # if args.dry_run:\n",
    "            #     break\n",
    "\n",
    "\n",
    "def test(model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    i=0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss\n",
    "            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "            i+=1\n",
    "            if i%10 == 0:\n",
    "                print(i, test_loss)\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ff4b0aee-33e5-49a0-a652-27ce4f6bd2a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "device='cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a009196e-f86c-45e5-a9f7-18f2c65b96e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.1307,), (0.3081,))\n",
    "    ])\n",
    "dataset1 = datasets.MNIST('../data', train=True, download=True,\n",
    "                       transform=transform)\n",
    "dataset2 = datasets.MNIST('../data', train=False,\n",
    "                       transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "49ecdf67-8481-4558-ab88-3067ed9926fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size=256\n",
    "train_kwargs = {'batch_size': batch_size}\n",
    "test_kwargs = {'batch_size': batch_size}\n",
    "cuda_kwargs = {'num_workers': 1,\n",
    "               'pin_memory': True,\n",
    "               'shuffle': True}\n",
    "train_kwargs.update(cuda_kwargs)\n",
    "test_kwargs.update(cuda_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6fcdac64-18f3-4792-9690-3bcfe3c4e715",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([10000, 28, 28])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset2.data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2c3d0c20-67a9-4226-adc6-26307faa72d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)\n",
    "test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "6a957312-9149-444d-ad82-2c2849e46d5c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10 5880.6593017578125\n",
      "20 11753.51318359375\n",
      "30 17637.39697265625\n",
      "40 22961.92437362671\n",
      "\n",
      "Test set: Average loss: 2.2962, Accuracy: 989/10000 (10%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "test(model, device, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d0f5bb2d-25e6-4078-8229-ee876bf5d768",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n",
       "  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
       "  (dropout1): Dropout(p=0.25, inplace=False)\n",
       "  (dropout2): Dropout(p=0.5, inplace=False)\n",
       "  (fc1): Linear(in_features=9216, out_features=128, bias=True)\n",
       "  (fc2): Linear(in_features=128, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "907f0697-5948-435d-a72a-a2ef300c7ead",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "class row_sum_batch(torch.nn.Module):\n",
    "    def __init__(self, row):\n",
    "        super().__init__()\n",
    "        self.row = row\n",
    "        \n",
    "    def forward(self, image_tensor, label):\n",
    "        #print(image_tensor.shape, label.shape)\n",
    "        return image_tensor[:,self.row].sum(dim=1)\n",
    "    \n",
    "class Rule_Row_10(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        \n",
    "    def forward(self, label, sum_row_10):\n",
    "        return torch.logical_or(\n",
    "        torch.logical_or(\n",
    "            torch.logical_or(\n",
    "                torch.logical_or(\n",
    "                    torch.logical_or(\n",
    "                        torch.logical_or(\n",
    "                            torch.logical_or(\n",
    "                                torch.logical_or(\n",
    "                                    torch.logical_or(\n",
    "                                        torch.logical_and(\n",
    "                                            label == 5.0,\n",
    "                                            torch.logical_and(0.7984313845634461 <= sum_row_10,\n",
    "                                                              sum_row_10 <= 15.972551059722871)),\n",
    "                                        torch.logical_and(\n",
    "                                            label == 0.0,\n",
    "                                            torch.logical_and(2.7396470713615417 <= sum_row_10,\n",
    "                                                              sum_row_10 <= 16.41815811157229))),\n",
    "                                    torch.logical_and(\n",
    "                                        label == 4.0,\n",
    "                                        torch.logical_and(1.8400196647644043 <= sum_row_10,\n",
    "                                                          sum_row_10 <= 10.57637233734131))),\n",
    "                                torch.logical_and(\n",
    "                                    label == 1.0,\n",
    "                                    torch.logical_and(1.2509803771972656 <= sum_row_10,\n",
    "                                                      sum_row_10 <= 7.072000136375429))),\n",
    "                            torch.logical_and(\n",
    "                                label == 9.0,\n",
    "                                torch.logical_and(1.9174902319908143 <= sum_row_10,\n",
    "                                                  sum_row_10 <= 14.0))),\n",
    "                        torch.logical_and(label == 2.0,\n",
    "                                          torch.logical_and(1.1560196113586425 <= sum_row_10,\n",
    "                                                            sum_row_10 <= 11.986333203315747))),\n",
    "                    torch.logical_and(label == 3.0,\n",
    "                                      torch.logical_and(1.3907843112945557 <= sum_row_10,\n",
    "                                                        sum_row_10 <= 11.763921070098917))),\n",
    "                torch.logical_and(\n",
    "                    label == 6.0,\n",
    "                    torch.logical_and(1.2137647527456283 <= sum_row_10,\n",
    "                                      sum_row_10 <= 10.93266674041748))),\n",
    "            torch.logical_and(\n",
    "                label == 7.0,\n",
    "                torch.logical_and(1.4729411888122559 <= sum_row_10,\n",
    "                                  sum_row_10 <= 19.431844177246106))),\n",
    "        torch.logical_and(\n",
    "            label == 8.0,\n",
    "            torch.logical_and(1.8323529064655304 <= sum_row_10,\n",
    "                              sum_row_10 <= 14.988234996795654))).reshape(len(sum_row_10), 1)\n",
    "    \n",
    "rule = Rule_Row_10()\n",
    "\n",
    "def train_with_rules(model, device, train_loader, optimizer, epoch, rule_lambda=0.01, rule_only=False):\n",
    "    model.train()\n",
    "    i=0\n",
    "    \n",
    "    (1, 1)\n",
    "    ((1, 1), 1)\n",
    "    (((1, 1), 1), 1)\n",
    "    \n",
    "   \n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        # print(target)\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        ########################\n",
    "        pred = output.argmax(dim=1, keepdim=True)\n",
    "        db = Database(\"mnist_{}\".format(i))\n",
    "        db.register_dataset(data, \"mnist_data_{}\".format(i), disable=True)\n",
    "        db.register_dataset(pred, \"mnist_preds_{}\".format(i), disable=True)\n",
    "        db.register_dataset(output.reshape((len(pred), 1, 10)), \"mnist_ops_{}\".format(i), disable=True)\n",
    "        \n",
    "        q = Query('satisfies_rule_sum10', \"mnist_data_{}\".format(i)).join(\"mnist_preds_{}\".format(i))\\\n",
    "                .join(\"mnist_ops_{}\".format(i), key = lambda idx, *row: idx[-1])\n",
    "        \n",
    "        def foo(imgs, label, outputs, sum_row_10):\n",
    "            return zip(imgs, label, sum_row_10, rule(label, sum_row_10) * outputs.gather(1, label.reshape((len(label), 1))))\n",
    "\n",
    "        q = q.project(lambda imgs, labels, outputs: zip(imgs, labels, outputs, row_sum_batch(10)(imgs, labels)), batch_size=256)\n",
    "        q = q.project(foo, batch_size=256)\n",
    "        res_tab_0 = q(db, disable=True)\n",
    "        res_tab = res_tab_0.project(lambda *row: row[3], disable=True)\n",
    "        \n",
    "        res=torch.tensor(res_tab.rows, dtype=torch.float, requires_grad=True)\n",
    "        rule_loss= torch.sum(res)\n",
    "        ########################\n",
    "        if rule_only:\n",
    "            # print(res_tab_0.rows[0])\n",
    "            # print(res_tab.rows[0])\n",
    "            # print(res)\n",
    "            # print(rule_loss)\n",
    "            loss = rule_loss\n",
    "        else:\n",
    "            loss = F.nll_loss(output, target) # + rule_lambda*rule_loss\n",
    "\n",
    "        # print(rule_loss.requires_grad, loss.requires_grad)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if batch_idx % 100 == 0:\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\tRule Loss: {:.3f}'.format(\n",
    "                epoch, batch_idx * len(data), len(train_loader.dataset),\n",
    "                100. * batch_idx / len(train_loader), loss.item(), rule_loss.item()))\n",
    "        \n",
    "\n",
    "def test_with_rules(model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    i=0\n",
    "    rule_loss = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            # register\n",
    "            \n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss\n",
    "            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "            \n",
    "            #print(data.shape, output.shape, pred.shape)\n",
    "            # print(i, \"registering\")\n",
    "            db = Database(\"mnist_{}\".format(i))\n",
    "            db.register_dataset(data, \"mnist_data_{}\".format(i), disable=True)\n",
    "            db.register_dataset(pred, \"mnist_preds_{}\".format(i), disable=True)\n",
    "            db.register_dataset(output.reshape((len(pred), 1, 10)), \"mnist_ops_{}\".format(i), disable=True)\n",
    "            \n",
    "            # q = Query('satisfies_rule_sum10', \"mnist_data_{}\".format(i)).join(\"mnist_preds_{}\".format(i))\\\n",
    "            #     .project(lambda imgs, labels: zip(imgs, labels, row_sum_batch(10)(imgs, labels)), batch_size=256)\\\n",
    "            #     .project(lambda imgs, label, sum_row_10: zip(imgs, label, sum_row_10, torch.logical_and(label == 5.0, torch.logical_and(0.7984313845634461 <= sum_row_10, sum_row_10 <= 15.972551059722871))), batch_size=256)\n",
    "\n",
    "            q = Query('satisfies_rule_sum10', \"mnist_data_{}\".format(i)).join(\"mnist_preds_{}\".format(i))\\\n",
    "                .join(\"mnist_ops_{}\".format(i), key = lambda idx, *row: idx[-1])\n",
    "        \n",
    "            def foo(imgs, label, outputs, sum_row_10):\n",
    "                # return zip(imgs, label, sum_row_10,\n",
    "                #         (torch.logical_not(torch.logical_and(label == 5.0,\n",
    "                #         torch.logical_and(0.7984313845634461 <= sum_row_10,\n",
    "                #                         sum_row_10 <= 15.972551059722871))).reshape(len(sum_row_10), 1) * F.softmax(outputs, dim=0).gather(1, label.reshape((len(label), 1)))))\n",
    "                return zip(imgs, label, sum_row_10, rule(label, sum_row_10) * outputs.gather(1, label.reshape((len(label), 1))))\n",
    "                \n",
    "            q = q.project(lambda imgs, labels, outputs: zip(imgs, labels, outputs, row_sum_batch(10)(imgs, labels)), batch_size=256)\n",
    "            q = q.project(foo, batch_size=256)\n",
    "            res = q(db, disable=True).project(lambda *row: row[3], disable=True)\n",
    "\n",
    "            res=q(db, disable=True)\n",
    "            s=sum(res.project(lambda *row: row[3], disable=True))\n",
    "            #print(\"rules loss\", s) \n",
    "            rule_loss+=s.item()\n",
    "            i+=1            \n",
    "            #print(i, test_loss, rule_loss)\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    rule_loss /= len(test_loader.dataset)\n",
    "    print('\\nTest set: Average loss: {:.4f}, Average rule loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, rule_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "f4d8d67f-164b-44cc-9972-fc8df6a1227a",
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "index 10 is out of bounds for dimension 1 with size 1",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_371703/3976546880.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtest_with_rules\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/tmp/ipykernel_371703/3178372673.py\u001b[0m in \u001b[0;36mtest_with_rules\u001b[0;34m(model, device, test_loader)\u001b[0m\n\u001b[1;32m    153\u001b[0m             \u001b[0mq\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproject\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mimgs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimgs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow_sum_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimgs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    154\u001b[0m             \u001b[0mq\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproject\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfoo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 155\u001b[0;31m             \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mq\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdisable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproject\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdisable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    156\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    157\u001b[0m             \u001b[0mres\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mq\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdisable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1192\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1195\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1196\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/model-debugging/tql/query.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, database, **kwargs)\u001b[0m\n\u001b[1;32m    368\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcached\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdatabase\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtables\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    369\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mdatabase\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtables\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 370\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdatabase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    371\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    372\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/model-debugging/tql/query.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, database, **kwargs)\u001b[0m\n\u001b[1;32m    350\u001b[0m         \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpipeline\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"No operations registered. Register operations using the register, join, filter, cols, order_by, and group_by functions\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    351\u001b[0m         \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"At least one table should be registered as the base when running the query\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 352\u001b[0;31m         \u001b[0mt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdatabase\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute_pipeline\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpipeline\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    353\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcached\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    354\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/model-debugging/tql/database.py\u001b[0m in \u001b[0;36mexecute_pipeline\u001b[0;34m(self, pipeline, name, **kwargs)\u001b[0m\n\u001b[1;32m    153\u001b[0m                     \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mop_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    154\u001b[0m                 \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 155\u001b[0;31m                     \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mop_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    156\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    157\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtables\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/model-debugging/tql/table.py\u001b[0m in \u001b[0;36mproject\u001b[0;34m(self, cols, batch_size, disable, shuffle)\u001b[0m\n\u001b[1;32m    301\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    302\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mrow_batch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDataLoader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdesc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Projecting\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdisable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdisable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 303\u001b[0;31m                 \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcols\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mrow_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    304\u001b[0m                 \u001b[0mprojected_rows\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    305\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_371703/3178372673.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(imgs, labels, outputs)\u001b[0m\n\u001b[1;32m    151\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimgs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msum_row_10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlabel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msum_row_10\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgather\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlabel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    152\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 153\u001b[0;31m             \u001b[0mq\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproject\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mimgs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimgs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow_sum_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimgs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    154\u001b[0m             \u001b[0mq\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproject\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfoo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    155\u001b[0m             \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mq\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdisable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproject\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdisable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1192\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1195\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1196\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_371703/3178372673.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, image_tensor, label)\u001b[0m\n\u001b[1;32m      6\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimage_tensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m         \u001b[0;31m#print(image_tensor.shape, label.shape)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mimage_tensor\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mRule_Row_10\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mIndexError\u001b[0m: index 10 is out of bounds for dimension 1 with size 1"
     ]
    }
   ],
   "source": [
    "test_with_rules(model, device, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83355eeb-4521-468b-953e-7d6115cd87d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# rules + original loss\n",
    "model = Net().to('cuda')\n",
    "model = nn.DataParallel(model)\n",
    "optimizer = optim.Adadelta(model.parameters(), lr=1.0)\n",
    "scheduler = StepLR(optimizer, step_size=1, gamma=0.7)\n",
    "\n",
    "def forward_train_query(model, device, table, optimizer, db, batch_size=256):\n",
    "    model.train()\n",
    "    \n",
    "    global closs\n",
    "    closs = 0\n",
    "    \n",
    "    def get_loss(imgs, tgts):\n",
    "        optimizer.zero_grad()\n",
    "        imgs, tgts = imgs.to(device), tgts.to(device)\n",
    "        ops = model(imgs)\n",
    "        loss = F.nll_loss(ops, tgts)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        global closs\n",
    "        closs += loss\n",
    "        \n",
    "        return zip(imgs, torch.argmax(ops, dim=1), tgts)\n",
    "    \n",
    "    forward_pass = Query(\"forward\", base=table).project(get_loss)(db, batch_size=batch_size)\n",
    "    return closs, (len(forward_pass.filter(lambda imgs, pred, tgt: pred == tgt)) / len(forward_pass))\n",
    "\n",
    "for epoch in range(1, 5):\n",
    "    # train_with_rules(model, device, train_loader, optimizer, epoch, rule_lambda=1.0)\n",
    "    # test_with_rules(model, device, test_loader)\n",
    "    closs, tr_acc = forward_train_query(model, device, \"train\", optimizer, db, batch_size=256)\n",
    "    print(closs, tr_acc)\n",
    "    test_with_rules(model, device, test_loader)\n",
    "    scheduler.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "670a5391-1ed0-42e7-9cfc-57c0aaee97bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net().to('cuda')\n",
    "model = nn.DataParallel(model)\n",
    "optimizer = optim.Adadelta(model.parameters(), lr=0.0001)\n",
    "# using only rules\n",
    "scheduler = StepLR(optimizer, step_size=1, gamma=0.7)\n",
    "for epoch in range(1, 10):\n",
    "    train_with_rules(model, device, train_loader, optimizer, epoch, rule_lambda=1.0, rule_only=True)\n",
    "    test_with_rules(model, device, test_loader)\n",
    "    scheduler.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03dab4ca-25d5-4ea9-89f1-b042fe31fa41",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
