{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import importlib\n",
    "import pc_e\n",
    "\n",
    "from datamodules import EMNIST as MyDataModule\n",
    "from get_arch import get_architecture\n",
    "from custom_callbacks import ErrorConvergenceCallback\n",
    "\n",
    "from lightning import Trainer\n",
    "from lightning.pytorch.loggers import WandbLogger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 518,
     "referenced_widgets": [
      "edb0d3e75d4d42d9b95b3ed163cc9f9b",
      "4e159612454242df9cacca54cba58e9f",
      "402fd128704c48b3904c9f494dda0896",
      "fa3ae80d5a4f41d48ff16b111c8f5c55",
      "a6abbf039aa544d9a51b8eb06dec6de2",
      "bef11c875d4d4939a57ac45b8ec3a1cd",
      "74e980cb3e094ab6a12a6869b2abfccc",
      "772711bab2854b64b378becfbdab8ccb",
      "8faa48effcf24fec9441dd5deb8d43bf",
      "e0212cf5f87a4f03b0939b3ca6850bce",
      "d5dd9d32d94340f6ad58b9745010864c",
      "1ac733579b6c4d8699533fea53c5f153",
      "f036496c33f847b7b9fe663bd8ff9570",
      "8d8ef98cdbcc4977a7d60364e932e5d8",
      "3557f1a0affd4fc7bd5268adfa32af44",
      "f6114d15557a4aef93fbf779cdb1997e",
      "03d3af91c6ef4e04b2d169dd3a0cb7d4",
      "83dcedb024f64a9d85fcaed775ba1d49",
      "311abec69664416cbfb6d1545fc3cfcf",
      "a53079022f0d4f4091ad84cd3ac3dd90",
      "df82df3832e347d5b7c8a030e0977d44",
      "3a7f48d25ad74459a5120a5a9cf7709c",
      "f086af1c26cc4fba9027763bba332f77",
      "7a9bc10b5b4c464fba521c029a1e6da6",
      "b011e5dc68d74cb49ca4e24e0db9b51e",
      "9ffc68c487844c3fb94b542e03de3bd8",
      "4694098d432847158fb337ec76564641",
      "86eb49972f254bc0a1875e850a131759",
      "c243d99cae0c4d76ae4c41875f09a53f",
      "4121557f07a04cd19778f3dd37514438",
      "188805d523a84a1c83ad325a52b27d9f",
      "1c6abb592d554be4a52817437a6a00db",
      "10c2312b240c4a1f98b6c77bfed8ce6e",
      "6086a5a7dc4d4432bfdc949e413fa84b",
      "e284b6d633bd4d63a569030700e1f99a",
      "85542ba39d13415b89e6af9fbe214c2f",
      "991df7ec1a274974b8e866b887883297",
      "dfefc8388c254320bf57ca7e6b9e4b7b",
      "31546413008d45b8855d56ce12b62e99",
      "14c7406558b04b70926fd61d0773e097",
      "93bd49fee19045278eadc81678233e59",
      "64401f010e374afb95757a27df0de60e",
      "13822cea6b1745e2b5b2f769868def73",
      "c0a4fd26d68a4ec7b82072e85ea3929a",
      "2b7d0df9631c4e7dacbb97190aa447c0",
      "e7ad62c762754452af9debf6d991eca4",
      "6f194edb74a64a878ae4f6b7f56d8e96",
      "74bcbe689e9d4136a3f078f0905f894e",
      "535edc2e49424ffaab0b5f6c2d23caef",
      "7204bcbab9d543fdad92caab57bff964",
      "604675ed8e964509bc5eedc32b75f409",
      "5cf835f387f341fdae187940c354142b",
      "0c26259762fa40d5ac55d28e32ab61d7",
      "fb01801bcff1400da252469ce23202e0",
      "3c9b41ee5e30405ba91f0bada742304d",
      "54d3388364e7428d9a2c786501c9ee3d",
      "2941723e1089478eb35f4418ee54e892",
      "10248902688040e596ed8a651dfbed1b",
      "7a4232fe1274444f86c5b4ae534ad9d1",
      "dad46dd2c81d4b57bf098895624373ad",
      "dd35fc3a9e1a45bd8d82dbbe548ddaf3",
      "d87047437436469a9b9683047780754e",
      "07f7afc1352f40d9892f5bbf0015543c",
      "2202be93b0d54964becc18ef0a70769f",
      "c747efed23aa4f41ba4bc7e35ef7e643",
      "832a6e746b0b4cbd89fb4a225e4d9ca2"
     ]
    },
    "id": "oYEs9RZQMu_D",
    "outputId": "2566d589-18bb-460a-e748-56c0053297d9"
   },
   "outputs": [],
   "source": [
    "importlib.reload(pc_e)\n",
    "from pc_e import PCE\n",
    "\n",
    "\n",
    "# 0: load dataset as Lightning DataModule\n",
    "\n",
    "batch_size = 64\n",
    "datamodule = MyDataModule(batch_size)\n",
    "print(\"Training on\", datamodule.dataset_name)\n",
    "\n",
    "# 1: Set up Lightning trainer\n",
    "# logger = False\n",
    "logger = WandbLogger(project=\"XXX\", entity=\"XXX\", mode=\"online\")\n",
    "\n",
    "trainer = Trainer(\n",
    "    accelerator=\"gpu\",\n",
    "    devices=1,\n",
    "    logger=logger,\n",
    "    callbacks=[ErrorConvergenceCallback()],\n",
    "    max_epochs=8,\n",
    "    inference_mode=False,  # inference_mode would interfere with the state backward pass\n",
    "    limit_predict_batches=1,  # enable 1-batch prediction\n",
    ")\n",
    "\n",
    "# 2: Get architecture that belongs to this dataset\n",
    "architecture = get_architecture(dataset=datamodule.dataset_name)\n",
    "\n",
    "# 3: Initiate model and train it\n",
    "pc = PCE(architecture, iters=16, e_lr=0.1, w_lr=0.001)\n",
    "trainer.fit(pc, datamodule=datamodule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 220,
     "referenced_widgets": [
      "e486251f9ee5432185e05546c8078e17",
      "cf35bbc27f4843de94b387c24107b906",
      "7088147f9e4b43b290e0245afc2d7b64",
      "dc960bb695f440fb882626a868b6e0aa",
      "64fa667e5237403b8ed54e3102bdf61b",
      "bfe8fd585db94dac8c2318def102f25f",
      "e87eb11e10304ee696c4f11e1b854222",
      "4a8412e4828c49e1848993881adcfce4",
      "c9b20badde544365a29f3e6ee0ebbc17",
      "c8a1012fcc1c4feda5ea5adc81e225ce",
      "a93ee649652f4e90aabc8b4dbadd87a5"
     ]
    },
    "id": "3Qrjmi4sZ-bK",
    "outputId": "a0e04c82-cb1e-40cc-c0cd-f9d388f3d2dd"
   },
   "outputs": [],
   "source": [
    "trainer.test(pc, datamodule=datamodule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 499
    },
    "id": "t7JauwTwPKwZ",
    "outputId": "224862b7-b544-4b30-c5a6-9d138cc898e0"
   },
   "outputs": [],
   "source": [
    "# pc.predict_targets = {\"color\", \"shape\", \"angle\"}\n",
    "# pc.predict_targets = {\"img\"}\n",
    "pc.predict_targets = {\"y\"}\n",
    "trainer.callbacks.append(datamodule.prediction_callback())\n",
    "trainer.predict(pc, datamodule=datamodule)\n",
    "trainer.callbacks = trainer.callbacks[:-1]"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "torchenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
