{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "cdf347c1-6090-46fe-bfae-098eea52af1d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from collections import defaultdict\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import manifold\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "device='cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "49eada7d-9a8b-4046-865c-6676c22a3a8e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Define a Convolutional Autoencoder model\n",
    "class ConvAutoencoder(nn.Module):\n",
    "    def __init__(self, embedding_dim=8):\n",
    "        super(ConvAutoencoder, self).__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),\n",
    "        )\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=0, output_padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1),\n",
    "              nn.ReLU(),\n",
    "            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=0, output_padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=0),\n",
    "            nn.Sigmoid(),\n",
    "        )\n",
    "        self.embedding_encoder = nn.Sequential(\n",
    "            nn.Linear(128 * 2 * 2, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, embedding_dim)\n",
    "            )\n",
    "        self.embedding_decoder = nn.Sequential(\n",
    "            nn.Linear(embedding_dim, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 128 * 2 * 2)\n",
    "            )\n",
    "        self.embedding_classifier = nn.Sequential(\n",
    "            nn.Linear(embedding_dim,64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64,64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64,62))\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.encoder(x)\n",
    "        x = x.view(x.size(0), -1)  # Flatten        \n",
    "        embedding = self.embedding_encoder(x)\n",
    "        logits = self.embedding_classifier(embedding)\n",
    "        #embedding = embedding/torch.norm(embedding,p=2,dim=-1,keepdim=True)\n",
    "        xhat = self.decoder(self.embedding_decoder(embedding).view(x.size(0), 128, 2, 2))\n",
    "        return xhat, embedding,logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3ffe03ba-a6bb-433e-8bca-1bbc7269bcef",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "batch_size=128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "5ccc4e3a-dc8a-4b2a-b43e-e0937905c4ca",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ConvAutoencoder(\n",
       "  (encoder): Sequential(\n",
       "    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (3): ReLU()\n",
       "    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (5): ReLU()\n",
       "    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (7): ReLU()\n",
       "    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (9): ReLU()\n",
       "    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n",
       "  )\n",
       "  (decoder): Sequential(\n",
       "    (0): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), output_padding=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (3): ReLU()\n",
       "    (4): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))\n",
       "    (5): ReLU()\n",
       "    (6): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (7): ReLU()\n",
       "    (8): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(2, 2), output_padding=(1, 1))\n",
       "    (9): ReLU()\n",
       "    (10): ConvTranspose2d(32, 1, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (11): Sigmoid()\n",
       "  )\n",
       "  (embedding_encoder): Sequential(\n",
       "    (0): Linear(in_features=512, out_features=512, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Linear(in_features=512, out_features=6, bias=True)\n",
       "  )\n",
       "  (embedding_decoder): Sequential(\n",
       "    (0): Linear(in_features=6, out_features=512, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Linear(in_features=512, out_features=512, bias=True)\n",
       "  )\n",
       "  (embedding_classifier): Sequential(\n",
       "    (0): Linear(in_features=6, out_features=64, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Linear(in_features=64, out_features=64, bias=True)\n",
       "    (3): ReLU()\n",
       "    (4): Linear(in_features=64, out_features=62, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transform = transforms.ToTensor()\n",
    "emnist_data = torchvision.datasets.EMNIST(root='./data', train=True,split='byclass', download=True, transform=transform)\n",
    "emnist_test_loader = DataLoader(emnist_data, batch_size=batch_size, shuffle=False)\n",
    "# Load your trained autoencoder\n",
    "autoencoder = ConvAutoencoder(embedding_dim=6).to(device)\n",
    "autoencoder.load_state_dict(torch.load('AE_EMNIST_1.pt'))  # Load your model\n",
    "autoencoder.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "102b4863-01c1-41e1-b68d-5b92540af8e0",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embeddings are calculated\n"
     ]
    }
   ],
   "source": [
    "autoencoder.eval()\n",
    "test_encode, test_targets,X_list,label_list = [], [],[],[]\n",
    "for x_val, y_val in emnist_test_loader:\n",
    "    x_val = x_val.to(device)\n",
    "\n",
    "    xhat,zhat,_ = autoencoder(x_val)\n",
    "    # yhat = model.decoder(zhat)\n",
    "    test_encode.append(zhat.detach())\n",
    "    test_targets.append(y_val.detach())\n",
    "    X_list.append(zhat.detach().numpy())\n",
    "    label_list.append(y_val.detach().numpy())\n",
    "X_list=np.vstack(X_list)\n",
    "label_list=np.concatenate(label_list)\n",
    "EMNIST=(X_list,label_list)\n",
    "torch.save(EMNIST,'../data/EMNIST.pt')\n",
    "\n",
    "print('Embeddings are calculated')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "3b90472e-9db8-44f1-8c60-dca53eb09b04",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "1cdc02d3-2200-4967-b0bf-47dcf3ae792d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "cat(): argument 'tensors' (position 1) must be tuple of Tensors, not numpy.ndarray",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[23], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m selected_labels\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mrandint(\u001b[38;5;241m0\u001b[39m,\u001b[38;5;241m20\u001b[39m,\u001b[38;5;241m10\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m test_encode \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat(test_encode)\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[1;32m      3\u001b[0m test_targets \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat(test_targets)\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[1;32m      5\u001b[0m \u001b[38;5;66;03m# Select a subset of classes\u001b[39;00m\n",
      "\u001b[0;31mTypeError\u001b[0m: cat(): argument 'tensors' (position 1) must be tuple of Tensors, not numpy.ndarray"
     ]
    }
   ],
   "source": [
    "selected_labels=np.random.randint(0,20,10)\n",
    "test_encode = torch.cat(test_encode).cpu().numpy()\n",
    "test_targets = torch.cat(test_targets).cpu().numpy()\n",
    "\n",
    "# Select a subset of classes\n",
    "selected_classes = np.random.randint(0,20,10)  # Replace with your chosen class indices\n",
    "mask = np.isin(test_targets, selected_classes)\n",
    "\n",
    "# Filter the data\n",
    "z_subset = test_encode[mask]\n",
    "Y_subset = test_targets[mask]\n",
    "\n",
    "# Apply t-SNE to the subset\n",
    "tsne = manifold.TSNE(n_components=2, init=\"pca\", random_state=0)\n",
    "X_2d_subset = tsne.fit_transform(z_subset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b1a864f-1382-4e68-b40f-de1bfdd03f82",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe9dbb1a-f8bb-44b6-948c-6a6272779a1c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 10))\n",
    "\n",
    "LABELS = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', \n",
    "          'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',\n",
    "          'a', 'b', 'd', 'e', 'f', 'g', 'h', 'n', 'q', 'r', 't']\n",
    "\n",
    "# Iterate over each class in the selected_classes and plot them separately\n",
    "for class_index in np.unique(Y_subset):\n",
    "    # Select data points that belong to the current class\n",
    "    indices = Y_subset == class_index\n",
    "    plt.scatter(X_2d_subset[indices, 0], X_2d_subset[indices, 1], label=f' Labels[class_index]', s=1)\n",
    "plt.legend(bbox_to_anchor=(0.63, 0.6), loc=\"upper left\")\n",
    "\n",
    "# Adding legend\n",
    "#plt.legend()\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40aa5ab8-ecaa-4971-b77d-77bbe8e2daef",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
