{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6cc5f23e-d04b-4e23-afdf-47f3b1c3efcb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building table: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:01<00:00, 6695.76it/s]\n",
      "Building table: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60000/60000 [00:08<00:00, 6809.02it/s]\n"
     ]
    }
   ],
   "source": [
    "import os, time\n",
    "from random import seed, sample\n",
    "\n",
    "import tqdm\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import StepLR\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "from tql import Database, Query, Table\n",
    "from tql.sqrl import SQRL\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.1307,), (0.3081,))\n",
    "    ])\n",
    "\n",
    "\n",
    "test_data = datasets.MNIST(\n",
    "        root = 'data',\n",
    "        train = False,\n",
    "        transform = transform,\n",
    "        download = True,\n",
    "    )\n",
    "\n",
    "train_data = datasets.MNIST(\n",
    "        root = 'data',\n",
    "        train = True,\n",
    "        transform = transform,\n",
    "        download = True,\n",
    "    )\n",
    "\n",
    "\n",
    "db = Database(\"mnist\")\n",
    "db.register_dataset(test_data, \"test\")\n",
    "db.register_dataset(train_data, \"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "47ef78e1-246e-4cb6-9172-39dc7f2222e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# defining the base model\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
    "        self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
    "        self.dropout1 = nn.Dropout(0.25)\n",
    "        self.dropout2 = nn.Dropout(0.5)\n",
    "        self.fc1 = nn.Linear(9216, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dropout1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dropout2(x)\n",
    "        x = self.fc2(x)\n",
    "        output = F.log_softmax(x, dim=1)\n",
    "        return output\n",
    "\n",
    "class MNISTNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MNISTNet, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
    "        # torch.nn.init.kaiming_normal_(self.conv1.weight)\n",
    "        self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
    "        # torch.nn.init.kaiming_normal_(self.conv2.weight)\n",
    "        self.dropout1 = nn.Dropout(0.25)\n",
    "        self.dropout2 = nn.Dropout(0.5)\n",
    "        self.fc1 = nn.Linear(9216, 128)\n",
    "        # torch.nn.init.kaiming_normal_(self.fc1.weight)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "        # torch.nn.init.kaiming_normal_(self.fc2.weight)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dropout1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dropout2(x)\n",
    "        x = self.fc2(x)\n",
    "        output = x\n",
    "        return output\n",
    "    \n",
    "class SumNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(SumNet, self).__init__()\n",
    "        self.base = MNISTNet()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.base(x)\n",
    "        output = F.log_softmax(x, dim=1)\n",
    "        # output = torch.sum(x, dim=1)\n",
    "        return output\n",
    "    \n",
    "    \n",
    "class QuerySumNet(nn.Module):\n",
    "    def __init__(self, db: Database):\n",
    "        super(QuerySumNet, self).__init__()\n",
    "        self.base = MNISTNet()\n",
    "        self.db = db\n",
    "        self.query = Query(\"op_sum\", base=\"op\").project(lambda x: F.log_softmax(x, dim=1))\n",
    "    \n",
    "    def forward(self, ip):\n",
    "        x = self.base(ip)\n",
    "        \n",
    "        x = self.query(self.db, tensors={\"op\": x.view(-1, 10)}, batch_size=len(x), disable=True).rows\n",
    "        \n",
    "        output = torch.stack(x)\n",
    "        \n",
    "        return output\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "17b2bcac-a593-413b-9ff5-469f7ba64668",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Projecting: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 235/235 [00:06<00:00, 37.06it/s]\n",
      "Projecting: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60000/60000 [00:00<00:00, 1249159.84it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[(tensor(2),),\n",
       " (tensor(8),),\n",
       " (tensor(8),),\n",
       " (tensor(7),),\n",
       " (tensor(3),),\n",
       " (tensor(6),),\n",
       " (tensor(4),),\n",
       " (tensor(4),),\n",
       " (tensor(7),),\n",
       " (tensor(2),)]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(100)\n",
    "\n",
    "gt_model = QuerySumNet(db)\n",
    "\n",
    "gt = Query(\"random_gt\", base=\"train\").project(lambda img, label: zip(img, torch.argmax(gt_model(img), dim=1)))(db, batch_size=256)\n",
    "gt.project(lambda img, label: label).head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8ff95455-7cd9-4c33-b042-4b1e17af31af",
   "metadata": {},
   "outputs": [],
   "source": [
    "# training loop for query\n",
    "\n",
    "device='cuda'\n",
    "\n",
    "def train_model(model, db, gt_table_name, device, optimizer, batch_size):\n",
    "    model.train()\n",
    "    \n",
    "    global closs\n",
    "    closs = 0\n",
    "    \n",
    "    def get_loss(imgs, tgts):\n",
    "        optimizer.zero_grad()\n",
    "        imgs, tgts = imgs.to(device), tgts.to(device)\n",
    "        ops = model(imgs)\n",
    "        # print(imgs.shape)\n",
    "        # print(ops.shape)\n",
    "        loss = F.nll_loss(ops, tgts)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        global closs, csum\n",
    "        closs += loss\n",
    "        \n",
    "        return zip(imgs, torch.argmax(ops, dim=1), tgts)\n",
    "    \n",
    "    forward_pass = Query(\"forward\", base=gt_table_name).project(get_loss)(db, batch_size=batch_size)\n",
    "    return closs, (len(forward_pass.filter(lambda imgs, pred, tgt: pred == tgt, disable=True)) / len(forward_pass))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d9c6269e-e978-4610-a8e6-9d82fc3b2bee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training for epoch 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Projecting: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:09<00:00, 12.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: Accuracy: 0.85395, Loss: 55.29526901245117\n",
      "Training for epoch 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Projecting: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:09<00:00, 12.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: Accuracy: 0.9698166666666667, Loss: 12.407113075256348\n",
      "Training for epoch 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Projecting: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:09<00:00, 12.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: Accuracy: 0.97705, Loss: 9.259622573852539\n",
      "Training for epoch 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Projecting: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:09<00:00, 12.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: Accuracy: 0.9811833333333333, Loss: 7.708920478820801\n",
      "Training for epoch 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Projecting: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:08<00:00, 14.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: Accuracy: 0.9830333333333333, Loss: 6.863566875457764\n",
      "Training for epoch 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Projecting: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:08<00:00, 14.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: Accuracy: 0.9837666666666667, Loss: 6.510781764984131\n",
      "Training for epoch 7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Projecting: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:08<00:00, 14.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: Accuracy: 0.9849166666666667, Loss: 6.2134599685668945\n",
      "Training for epoch 8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Projecting: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:08<00:00, 14.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: Accuracy: 0.9851, Loss: 5.977859020233154\n",
      "Training for epoch 9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Projecting: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:08<00:00, 14.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: Accuracy: 0.9851833333333333, Loss: 6.101978778839111\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(42)\n",
    "torch.autograd.set_detect_anomaly(True)\n",
    "\n",
    "query_model = QuerySumNet(db).to('cuda')\n",
    "optimizer_q = optim.Adadelta(query_model.parameters(), lr=1.0)\n",
    "scheduler_q = StepLR(optimizer_q, step_size=1, gamma=0.6)\n",
    "\n",
    "for epoch in range(1, 10):\n",
    "    print(f\"Training for epoch {epoch}\")\n",
    "\n",
    "    closs_q, acc_q = train_model(query_model, db, \"train\", \"cuda\", optimizer_q, 512)\n",
    "    print(f\"Query: Accuracy: {acc_q}, Loss: {closs_q}\")\n",
    "\n",
    "    scheduler_q.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a820be09-e915-4365-9859-ccb06c42ecb2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3a73176-0f42-489b-ae28-4f75d171f6b4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
