{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdf347c1-6090-46fe-bfae-098eea52af1d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from collections import defaultdict\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import manifold\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "device='cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49eada7d-9a8b-4046-865c-6676c22a3a8e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Define a Convolutional Autoencoder model\n",
    "class ConvAutoencoder(nn.Module):\n",
    "    def __init__(self, embedding_dim=8):\n",
    "        super(ConvAutoencoder, self).__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),\n",
    "        )\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=0, output_padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1),\n",
    "              nn.ReLU(),\n",
    "            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=0, output_padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=0),\n",
    "            nn.Sigmoid(),\n",
    "        )\n",
    "        self.embedding_encoder = nn.Sequential(\n",
    "            nn.Linear(128 * 2 * 2, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, embedding_dim)\n",
    "            )\n",
    "        self.embedding_decoder = nn.Sequential(\n",
    "            nn.Linear(embedding_dim, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 128 * 2 * 2)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.encoder(x)\n",
    "        x = x.view(x.size(0), -1)  # Flatten        \n",
    "        embedding = self.embedding_encoder(x)\n",
    "        #embedding = embedding/torch.norm(embedding,p=2,dim=-1,keepdim=True)\n",
    "        xhat = self.decoder(self.embedding_decoder(embedding).view(x.size(0), 128, 2, 2))\n",
    "        return xhat, embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ccc4e3a-dc8a-4b2a-b43e-e0937905c4ca",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "batch_size=128\n",
    "transform = transforms.ToTensor()\n",
    "emnist_data = torchvision.datasets.EMNIST(root='./data', train=True,split='byclass', download=True, transform=transform)\n",
    "emnist_test_loader = DataLoader(emnist_data, batch_size=batch_size, shuffle=False)\n",
    "# Load your trained autoencoder\n",
    "autoencoder = ConvAutoencoder(embedding_dim=6).to(device)\n",
    "autoencoder.load_state_dict(torch.load('AE_EMNIST_0.pt'))  # Load your model\n",
    "autoencoder.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "102b4863-01c1-41e1-b68d-5b92540af8e0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "autoencoder.eval()\n",
    "test_encode, test_targets = [], []\n",
    "for x_val, y_val in emnist_test_loader:\n",
    "    x_val = x_val.to(device)\n",
    "\n",
    "    xhat,zhat = autoencoder(x_val)\n",
    "    # yhat = model.decoder(zhat)\n",
    "    test_encode.append(zhat.detach().numpy())\n",
    "    test_targets.append(y_val.detach().numpy())\n",
    "X_list=np.vstack(test_encode)\n",
    "label_list=np.concatenate(test_targets)\n",
    "EMNIST=(X_list,label_list)\n",
    "torch.save(EMNIST,'../data/EMNIST.pt')\n",
    "print('Embeddings are calculated')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b90472e-9db8-44f1-8c60-dca53eb09b04",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "test_encode = torch.cat(test_encode).cpu().numpy()\n",
    "test_targets = torch.cat(test_targets).cpu().numpy()\n",
    "\n",
    "# Select a subset of classes\n",
    "selected_classes = np.random.randint(0,20,10)  # Replace with your chosen class indices\n",
    "mask = np.isin(test_targets, selected_classes)\n",
    "\n",
    "# Filter the data\n",
    "z_subset = test_encode[mask]\n",
    "Y_subset = test_targets[mask]\n",
    "\n",
    "# Apply t-SNE to the subset\n",
    "tsne = manifold.TSNE(n_components=2, init=\"pca\", random_state=0)\n",
    "X_2d_subset = tsne.fit_transform(z_subset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b4c82f5-cbe6-4484-8fa9-0f3f8979efe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cdc02d3-2200-4967-b0bf-47dcf3ae792d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 10))\n",
    "\n",
    "# Iterate over each class in the selected_classes and plot them separately\n",
    "for class_index in np.unique(Y_subset):\n",
    "    # Select data points that belong to the current class\n",
    "    indices = Y_subset == class_index\n",
    "    plt.scatter(X_2d_subset[indices, 0], X_2d_subset[indices, 1], label=f'Class {class_index}', s=1)\n",
    "\n",
    "\n",
    "# Adding legend\n",
    "plt.legend()\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "777acaa1-ab41-4149-be2b-9265a8adfb9c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
