{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "67cafed8-df12-464c-b7b9-dc568050b470",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from collections import defaultdict\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import manifold\n",
    "from torch.utils.data import DataLoader\n",
    "device='cpu'\n",
    "import os\n",
    "os.chdir('.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c876ceab-10e5-42f0-b1e9-ccce28ed06bf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Define a Convolutional Autoencoder model\n",
    "class ConvAutoencoder(nn.Module):\n",
    "    def __init__(self, embedding_dim=8):\n",
    "        super(ConvAutoencoder, self).__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),\n",
    "        )\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=0, output_padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1),\n",
    "              nn.ReLU(),\n",
    "            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=0, output_padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=0),\n",
    "            nn.Sigmoid(),\n",
    "        )\n",
    "        self.embedding_encoder = nn.Sequential(\n",
    "            nn.Linear(128 * 2 * 2, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, embedding_dim)\n",
    "            )\n",
    "        self.embedding_decoder = nn.Sequential(\n",
    "            nn.Linear(embedding_dim, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 128 * 2 * 2)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        z = self.encoder(x)\n",
    "        z = z.view(z.size(0), -1)  # Flatten        \n",
    "        embedding = self.embedding_encoder(z)\n",
    "        #embedding = embedding/torch.norm(embedding,p=2,dim=-1,keepdim=True)\n",
    "        xhat = self.decoder(self.embedding_decoder(embedding).view(x.size(0), 128, 2, 2))\n",
    "        return xhat, embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b9e2f8be-2c8b-4f9f-bf13-cb2a8b4f230a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "batch_size=256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "20f3979c-d20f-4ba9-859c-90c6bb175f6c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ConvAutoencoder(\n",
       "  (encoder): Sequential(\n",
       "    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (3): ReLU()\n",
       "    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (5): ReLU()\n",
       "    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (7): ReLU()\n",
       "    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (9): ReLU()\n",
       "    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n",
       "  )\n",
       "  (decoder): Sequential(\n",
       "    (0): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), output_padding=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (3): ReLU()\n",
       "    (4): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))\n",
       "    (5): ReLU()\n",
       "    (6): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (7): ReLU()\n",
       "    (8): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(2, 2), output_padding=(1, 1))\n",
       "    (9): ReLU()\n",
       "    (10): ConvTranspose2d(32, 1, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (11): Sigmoid()\n",
       "  )\n",
       "  (embedding_encoder): Sequential(\n",
       "    (0): Linear(in_features=512, out_features=512, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Linear(in_features=512, out_features=4, bias=True)\n",
       "  )\n",
       "  (embedding_decoder): Sequential(\n",
       "    (0): Linear(in_features=4, out_features=512, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Linear(in_features=512, out_features=512, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transform = transforms.ToTensor()\n",
    "mnist_data = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
    "mnist_test_loader = DataLoader(mnist_data, batch_size=batch_size, shuffle=False)\n",
    "# Load your trained autoencoder\n",
    "autoencoder = ConvAutoencoder(embedding_dim=4).to(device)\n",
    "autoencoder.load_state_dict(torch.load('AE_MNIST_0.pt'))  # Load your model\n",
    "autoencoder.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d293da14-36c8-42d6-8035-c629ad28c771",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embeddings are calculated\n"
     ]
    }
   ],
   "source": [
    "autoencoder.eval()\n",
    "test_encode, test_targets = [], []\n",
    "for x_val, y_val in mnist_test_loader:\n",
    "    #print(x_val.shape)\n",
    "    x_val = x_val.to(device)\n",
    "    xhat,zhat = autoencoder(x_val)\n",
    "    # yhat = model.decoder(zhat)\n",
    "    test_encode.append(zhat.detach().numpy())\n",
    "    test_targets.append(y_val.detach().numpy())\n",
    "X_list=np.vstack(test_encode)\n",
    "label_list=np.concatenate(test_targets)\n",
    "MNIST_data=(X_list,label_list)\n",
    "torch.save(MNIST,'../data/MNIST.pt')\n",
    "\n",
    "\n",
    "print('Embeddings are calculated')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a2a2968f-de06-4c8b-af93-9be2cf81188d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(60000, 4)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MNIST[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "5400ea3f-4ce1-4dea-8bb7-548bd13f3141",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ -2.0414,   0.8055, -48.8841,  16.4660],\n",
       "        [-14.5873,  25.7937, -25.8517,  63.5655],\n",
       "        [  8.2878,  22.7776,  39.0892,   4.5647],\n",
       "        ...,\n",
       "        [-54.7844, -29.8298, -16.5093,  30.3908],\n",
       "        [ 12.1600, -24.6894, -34.8477,  25.3909],\n",
       "        [ -8.7653, -20.8837, -67.9662,  12.9058]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A=torch.vstack()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "987c4279-def6-48a6-8b6b-9890b12fd5d8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([96, 1, 28, 28])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "11560bdd-249f-4516-8689-5d52b92bac9e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.utils.data.dataloader.DataLoader at 0x7f5ad8051a50>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mnist_test_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec2136c2-f025-4778-a94a-4d9396d57fc8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
