{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 167,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchsummary\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import Module, Sequential, Linear, ReLU, Tanh, MaxPool2d, Conv2d, ConvTranspose2d, BatchNorm2d\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(Module):\n",
    "    def __init__(self, dim):\n",
    "        super(CNN, self).__init__()\n",
    "\n",
    "        self.dim = dim\n",
    "        self.n = int(self.dim ** 0.5)\n",
    "\n",
    "        self.linear_in = Linear(self.dim, self.n**2)\n",
    "        self.linear_out = Linear(self.n*(self.n - 2)+1, self.dim)\n",
    "        nfe = 4\n",
    "        nfd = 4\n",
    "        nz = 128\n",
    "        kernel_size = 3\n",
    "#         self.encoder = Sequential(\n",
    "#                             Conv2d(1, 64, kernel_size=3,\n",
    "#                                       padding=1),\n",
    "#                             ReLU(True),\n",
    "#                             MaxPool2d(2, 2),\n",
    "#                             Conv2d(64, 100, kernel_size=3,\n",
    "#                                       padding=1),\n",
    "#                             ReLU(True))\n",
    "\n",
    "#         self.decoder = Sequential(\n",
    "#                             ConvTranspose2d(100, 64, kernel_size=3,\n",
    "#                                                stride=1, padding=1),\n",
    "#                             ReLU(True),\n",
    "#                             ConvTranspose2d(64, 32, kernel_size=3,\n",
    "#                                                stride=2, padding=1),\n",
    "#                             ReLU(True),\n",
    "#                             ConvTranspose2d(32, 1, kernel_size=4,\n",
    "#                                                padding=1),\n",
    "#                             Tanh())\n",
    "\n",
    "        n = self.n\n",
    "        self.encoder = Sequential(\n",
    "                            Conv2d(1, 4, kernel_size=3, padding=3),\n",
    "                            nn.BatchNorm2d(4),\n",
    "                            ReLU(True),\n",
    "                            MaxPool2d(2, 2),\n",
    "                            Conv2d(4, 32, kernel_size=3, padding=1),\n",
    "                            nn.BatchNorm2d(32),\n",
    "                            ReLU(True),\n",
    "                            MaxPool2d(2, 2),\n",
    "                            Conv2d(32, 128, kernel_size=3,padding=1),\n",
    "                            )\n",
    "\n",
    "        self.decoder = Sequential(\n",
    "                            ConvTranspose2d(128, 32, kernel_size=3, padding=1),\n",
    "                            nn.BatchNorm2d(32),\n",
    "                            ReLU(True),\n",
    "                            ConvTranspose2d(32, 4, kernel_size=3),\n",
    "                            nn.BatchNorm2d(4),\n",
    "                            ReLU(True),\n",
    "                            ConvTranspose2d(4, 1, kernel_size=3),\n",
    "                            Tanh())\n",
    "\n",
    "#         self.encoder = Sequential(\n",
    "#             # input 1 x 128 x 128\n",
    "#             nn.Conv2d(1, nfe, kernel_size, 2, 1, bias=False),\n",
    "#             nn.BatchNorm2d(nfe),\n",
    "#             nn.LeakyReLU(True),\n",
    "#             # input (nfe) x 64 x 64\n",
    "#             nn.Conv2d(nfe, nfe * 2, kernel_size, 2, 1, bias=False),\n",
    "#             nn.BatchNorm2d(nfe * 2),\n",
    "#             nn.LeakyReLU(True),\n",
    "#             # input (nfe*2) x 32 x 32\n",
    "#             nn.Conv2d(nfe * 2, nfe * 4, kernel_size, 2, 1, bias=False),\n",
    "#             nn.BatchNorm2d(nfe * 4),\n",
    "#             nn.LeakyReLU(True),\n",
    "#             # input (nfe*4) x 16 x 16\n",
    "#             nn.Conv2d(nfe * 4, nfe * 8, kernel_size, 2, 1, bias=False),\n",
    "#             nn.BatchNorm2d(nfe * 8),\n",
    "#             nn.LeakyReLU(True),\n",
    "#             # input (nfe*8) x 8 x 8\n",
    "#             nn.Conv2d(nfe * 8, nfe * 16, kernel_size, 2, 1, bias=False),\n",
    "#             nn.BatchNorm2d(nfe * 16),\n",
    "#             nn.LeakyReLU(True),\n",
    "#             # input (nfe*16) x 4 x 4\n",
    "#             nn.Conv2d(nfe * 16, nz, kernel_size, 1, 0, bias=False),\n",
    "#             nn.BatchNorm2d(nz),\n",
    "#             nn.LeakyReLU(True)\n",
    "#             # output (nz) x 1 x 1\n",
    "#         )\n",
    "\n",
    "#         self.decoder = Sequential(\n",
    "#             # input (nz) x 1 x 1\n",
    "#             nn.ConvTranspose2d(nz, nfd * 16, kernel_size, 1, 0, bias=False),\n",
    "#             nn.BatchNorm2d(nfd * 16),\n",
    "#             nn.ReLU(True),\n",
    "#             # input (nfd*16) x 4 x 4\n",
    "#             nn.ConvTranspose2d(nfd * 16, nfd * 8, kernel_size, 2, 1, bias=False),\n",
    "#             nn.BatchNorm2d(nfd * 8),\n",
    "#             nn.ReLU(True),\n",
    "#             # input (nfd*8) x 8 x 8\n",
    "#             nn.ConvTranspose2d(nfd * 8, nfd * 4, kernel_size, 2, 1, bias=False),\n",
    "#             nn.BatchNorm2d(nfd * 4),\n",
    "#             nn.ReLU(True),\n",
    "#             # input (nfd*4) x 16 x 16\n",
    "#             nn.ConvTranspose2d(nfd * 4, nfd * 2, kernel_size, 2, 1, bias=False),\n",
    "#             nn.BatchNorm2d(nfd * 2),\n",
    "#             nn.ReLU(True),\n",
    "#             # input (nfd*2) x 32 x 32\n",
    "#             nn.ConvTranspose2d(nfd * 2, nfd, kernel_size, 2, 1, bias=False),\n",
    "#             nn.BatchNorm2d(nfd),\n",
    "#             nn.ReLU(True),\n",
    "#             # input (nfd) x 64 x 64\n",
    "#             nn.ConvTranspose2d(nfd, 1, kernel_size, 2, 1, bias=False),\n",
    "#             nn.Tanh()\n",
    "#             # output 1 x 128 x 128\n",
    "#         )\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x: [batch_size, dim]\n",
    "        n = self.n\n",
    "        print(x.shape)\n",
    "        # Reduce dimensionality to form a grid: x -> [batch_size, :, n, n]\n",
    "        x = self.linear_in(x)\n",
    "        print(x.shape)\n",
    "        cnn_input = x.reshape(-1, n, n)[:, None, :]\n",
    "        print(cnn_input.shape)\n",
    "        # CNN\n",
    "        encoded = self.encoder(cnn_input)\n",
    "        decoded = self.decoder(encoded)\n",
    "        print(decoded.shape, decoded.shape[-1] * decoded.shape[-2], n*(n - 2)+1)\n",
    "        # Recover dimensionality: [batch_size, :, n, n] -> [batch_size, dim]\n",
    "        #y = decoded.squeeze().reshape(-1, n*(n - 2)+1)\n",
    "        #y = self.linear_out(y)\n",
    "        return decoded\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 2601])\n",
      "torch.Size([2, 2601])\n",
      "torch.Size([2, 1, 51, 51])\n",
      "torch.Size([2, 1, 17, 17]) 289 2500\n",
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Linear-1                 [-1, 2601]       6,767,802\n",
      "            Conv2d-2            [-1, 4, 55, 55]              40\n",
      "       BatchNorm2d-3            [-1, 4, 55, 55]               8\n",
      "              ReLU-4            [-1, 4, 55, 55]               0\n",
      "         MaxPool2d-5            [-1, 4, 27, 27]               0\n",
      "            Conv2d-6           [-1, 32, 27, 27]           1,184\n",
      "       BatchNorm2d-7           [-1, 32, 27, 27]              64\n",
      "              ReLU-8           [-1, 32, 27, 27]               0\n",
      "         MaxPool2d-9           [-1, 32, 13, 13]               0\n",
      "           Conv2d-10          [-1, 128, 13, 13]          36,992\n",
      "  ConvTranspose2d-11           [-1, 32, 13, 13]          36,896\n",
      "      BatchNorm2d-12           [-1, 32, 13, 13]              64\n",
      "             ReLU-13           [-1, 32, 13, 13]               0\n",
      "  ConvTranspose2d-14            [-1, 4, 15, 15]           1,156\n",
      "      BatchNorm2d-15            [-1, 4, 15, 15]               8\n",
      "             ReLU-16            [-1, 4, 15, 15]               0\n",
      "  ConvTranspose2d-17            [-1, 1, 17, 17]              37\n",
      "             Tanh-18            [-1, 1, 17, 17]               0\n",
      "================================================================\n",
      "Total params: 6,844,251\n",
      "Trainable params: 6,844,251\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.01\n",
      "Forward/backward pass size (MB): 1.21\n",
      "Params size (MB): 26.11\n",
      "Estimated Total Size (MB): 27.33\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "# %debug\n",
    "\n",
    "cnn = CNN(2601)\n",
    "\n",
    "torchsummary.summary(cnn, input_size=(2601,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 29, 29])"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Conv2d(1, 3, 3, padding=3)(torch.rand((1, 25, 25))).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(Module):\n",
    "    def __init__(self, dim):\n",
    "        super(CNN, self).__init__()\n",
    "\n",
    "        self.dim = dim\n",
    "        self.n = int(self.dim ** 0.5)\n",
    "\n",
    "        self.linear_in = Linear(self.dim, self.n**2)\n",
    "        self.linear_out = Linear(self.n*(self.n - 2)+1, self.dim)\n",
    "\n",
    "        self.encoder = Sequential(\n",
    "                            Conv2d(1, 32, kernel_size=3,\n",
    "                                      padding=1),\n",
    "                            ReLU(True),\n",
    "                            MaxPool2d(2, 2),\n",
    "                            Conv2d(32, 64, kernel_size=3,\n",
    "                                      padding=1),\n",
    "                            ReLU(True),\n",
    "                            MaxPool2d(2, 2),\n",
    "                            Conv2d(64, 128, kernel_size=3,\n",
    "                                      padding=1),\n",
    "                            ReLU(True))\n",
    "\n",
    "        self.decoder = Sequential(\n",
    "                            ConvTranspose2d(128, 64, kernel_size=3,\n",
    "                                               stride=1, padding=1),\n",
    "                            ReLU(True),\n",
    "                            ConvTranspose2d(64, 32, kernel_size=3,\n",
    "                                               stride=2, padding=1),\n",
    "                            ReLU(True),\n",
    "                            ConvTranspose2d(32, 1, kernel_size=4,\n",
    "                                               padding=1),\n",
    "                            Tanh())\n",
    "\n",
    "#         self.encoder = Sequential(\n",
    "#                             Conv2d(1, 64, kernel_size=3,\n",
    "#                                       padding=1),\n",
    "#                             BatchNorm2d(64),\n",
    "#                             ReLU(True),\n",
    "#                             MaxPool2d(2, 2),\n",
    "#                             Conv2d(64, 128, kernel_size=3,\n",
    "#                                       padding=1),\n",
    "#                             ReLU(True))\n",
    "\n",
    "#         self.decoder = Sequential(\n",
    "#                             ConvTranspose2d(128, 64, kernel_size=3,\n",
    "#                                                stride=1, padding=1),\n",
    "#                             BatchNorm2d(64),\n",
    "#                             ReLU(True),\n",
    "#                             ConvTranspose2d(64, 32, kernel_size=3,\n",
    "#                                                stride=2, padding=1),\n",
    "#                             ReLU(True),\n",
    "#                             ConvTranspose2d(32, 1, kernel_size=4,\n",
    "#                                                padding=1),\n",
    "#                             Tanh())\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x: [batch_size, dim]\n",
    "        n = self.n\n",
    "        # Reduce dimensionality to form a grid: x -> [batch_size, :, n, n]\n",
    "        x = self.linear_in(x)\n",
    "        cnn_input = x.reshape(-1, n, n)[:, None, :]\n",
    "        # CNN\n",
    "        encoded = self.encoder(cnn_input)\n",
    "        decoded = self.decoder(encoded)\n",
    "        # Recover dimensionality: [batch_size, :, n, n] -> [batch_size, dim]\n",
    "        y = decoded.squeeze().reshape(-1, n*(n - 2)+1)\n",
    "        y = self.linear_out(y)\n",
    "        return y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Linear-1                 [-1, 2601]       6,767,802\n",
      "            Conv2d-2           [-1, 64, 51, 51]             640\n",
      "              ReLU-3           [-1, 64, 51, 51]               0\n",
      "         MaxPool2d-4           [-1, 64, 25, 25]               0\n",
      "            Conv2d-5          [-1, 128, 25, 25]          73,856\n",
      "              ReLU-6          [-1, 128, 25, 25]               0\n",
      "   ConvTranspose2d-7           [-1, 64, 25, 25]          73,792\n",
      "              ReLU-8           [-1, 64, 25, 25]               0\n",
      "   ConvTranspose2d-9           [-1, 32, 49, 49]          18,464\n",
      "             ReLU-10           [-1, 32, 49, 49]               0\n",
      "  ConvTranspose2d-11            [-1, 1, 50, 50]             513\n",
      "             Tanh-12            [-1, 1, 50, 50]               0\n",
      "           Linear-13                 [-1, 2601]       6,505,101\n",
      "================================================================\n",
      "Total params: 13,440,168\n",
      "Trainable params: 13,440,168\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.01\n",
      "Forward/backward pass size (MB): 5.93\n",
      "Params size (MB): 51.27\n",
      "Estimated Total Size (MB): 57.21\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "#%debug\n",
    "\n",
    "cnn = CNN(2601)\n",
    "\n",
    "torchsummary.summary(cnn, input_size=(2601,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "2601"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "firedrake",
   "language": "python",
   "name": "firedrake"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
