{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "33eecf77",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Importing the libraries \n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn import metrics\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1d2e3d3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import *\n",
    "from torchvision import transforms\n",
    "import torchvision\n",
    "from tqdm import tqdm\n",
    "from torchvision.utils import save_image\n",
    "to_pil_image = transforms.ToPILImage()\n",
    "def image_to_vid(images):\n",
    "    imgs = [np.array(to_pil_image(img)) for img in images]\n",
    "    imageio.mimsave('../outputs/generated_images.gif', imgs)\n",
    "def save_reconstructed_imagesVAE(recon_images, epoch):\n",
    "    save_image(recon_images.cpu(), f\"VAE{epoch}.jpg\")\n",
    "def save_reconstructed_imagesEVAE(recon_images, epoch):\n",
    "    save_image(recon_images.cpu(), f\"EVAE{epoch}.jpg\")\n",
    "def save_reconstructed_imagesREAL(images, epoch):\n",
    "    save_image(images.cpu(), f\"REAL{epoch}.jpg\")\n",
    "def save_train_loss_plot(train_loss, valid_loss):\n",
    "    # loss plots\n",
    "    plt.figure(figsize=(10, 7))\n",
    "    plt.plot(train_loss, color='orange', label='train loss VAE')\n",
    "    plt.plot(valid_loss, color='red', label='train loss EVAE')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Total loss (Reconstruction+B*I(K))/M')\n",
    "    plt.legend()\n",
    "    plt.savefig('lossMNIST.jpg')\n",
    "    plt.show()\n",
    "def save_valid_loss_plot(valid_lossVAE, valid_lossEVAE):\n",
    "    # loss plots\n",
    "    plt.figure(figsize=(10, 7))\n",
    "    plt.plot(valid_lossVAE, color='orange', label='validation loss VAE')\n",
    "    plt.plot(valid_lossEVAE, color='red', label='validation loss EVAE')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Reconstruction loss')\n",
    "    plt.legend()\n",
    "    plt.savefig('lossMNIST_valid.jpg')\n",
    "    plt.show()\n",
    "\n",
    "from torchvision.utils import make_grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d32eaae4-6198-4285-bf37-d258bcc63f99",
   "metadata": {},
   "outputs": [],
   "source": [
    "class VAE(nn.Module):\n",
    "    def __init__(self,image_channel=1,kernel_size=4,latent_dim=64,init_channel=32,m=100):\n",
    "        super(VAE, self).__init__()\n",
    "\n",
    "        self.bm=np.power(m,-2/9)\n",
    "        self.latent_dim=latent_dim\n",
    "        # encoder\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(image_channel,out_channels=init_channel,kernel_size=kernel_size,stride=2,padding=1),            \n",
    "            nn.BatchNorm2d(init_channel),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel*2,out_channels=init_channel*4,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.Conv2d(init_channel*4,out_channels=init_channel*8,kernel_size=kernel_size,stride=2,padding=0)   ,            \n",
    "            nn.BatchNorm2d(init_channel*8),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "            #nn.AdaptiveAvgPool2d((100,init_channel*8)),\n",
    "            nn.Flatten( start_dim=1)\n",
    "            \n",
    "            #nn.Linear(init_channel*8,init_channel*8) #hidden\n",
    "            \n",
    "        )\n",
    "        \n",
    "        \n",
    "\n",
    "        ## fully connected laeyer for learning representation\n",
    "        self.fully_connected_layer_a=nn.Linear(init_channel*8,latent_dim)\n",
    "        self.fully_connected_layer_b=nn.Linear(init_channel*8,latent_dim)\n",
    "        #self.fully_connected_layer_beta=nn.Linear(init_channel*8,latent_dim)\n",
    "\n",
    "        #self.fully_connected_layer=nn.Linear(latent_dim,init_channel*8)\n",
    "        ## decoder\n",
    "        \n",
    "                # encoder\n",
    "        self.decoder = nn.Sequential(\n",
    "            \n",
    "            \n",
    "            \n",
    "            nn.ConvTranspose2d(latent_dim,out_channels=init_channel*4,kernel_size=kernel_size,stride=1,padding=0),            \n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*4,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*2,out_channels=init_channel*1,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*1) ,\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*1,out_channels=image_channel,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "        \n",
    "    def reparametrize(self,mu,log_var):\n",
    "        \n",
    "        std=torch.exp(0.5*log_var)\n",
    "        epsilon=torch.randn_like(std)\n",
    "        sample=mu+epsilon*std  # sampling\n",
    "        return sample        \n",
    " \n",
    "\n",
    "\n",
    "    def forward(self,x):\n",
    "        #encoding\n",
    "\n",
    "        hidden =self.encoder(x)\n",
    "        \n",
    "        #hidden=torch.flatten(x, start_dim=1)\n",
    "     \n",
    "        mean=self.fully_connected_layer_a(hidden)\n",
    "        log_var=self.fully_connected_layer_b(hidden)\n",
    "\n",
    "        z=self.reparametrize(mean,log_var) #uniform distribution\n",
    "\n",
    "        #z=z.to(device)\n",
    "        \n",
    "        #z=self.fully_connected_layer(z)\n",
    "      \n",
    "        z=z.view(-1,self.latent_dim,1,1)\n",
    "        #decoding\n",
    "        reconstruction=self.decoder(z)\n",
    "\n",
    "        return reconstruction,mean,log_var "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "id": "bec7e1cf-d5d5-4120-ad67-d5bb4a690343",
   "metadata": {},
   "outputs": [],
   "source": [
    "class EVAE(nn.Module):\n",
    "    def __init__(self,image_channel=1,kernel_size=4,latent_dim=64,init_channel=32,m=100,B=1):\n",
    "        super(EVAE, self).__init__()\n",
    "\n",
    "        self.bm=np.power(m,-2/9)#nn.Parameter(torch.ones(1),requires_grad=True)#np.power(m,-2/9)\n",
    "        self.B=B\n",
    "        self.latent_dim=latent_dim\n",
    "        # encoder\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(image_channel,out_channels=init_channel,kernel_size=kernel_size,stride=2,padding=1),            \n",
    "            nn.BatchNorm2d(init_channel),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel*2,out_channels=init_channel*4,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.Conv2d(init_channel*4,out_channels=init_channel*8,kernel_size=kernel_size,stride=2,padding=0)   ,            \n",
    "            nn.BatchNorm2d(init_channel*8),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "            #nn.AdaptiveAvgPool2d((100,init_channel*8)),\n",
    "            nn.Flatten( start_dim=1)\n",
    "            \n",
    "            #nn.Linear(init_channel*8,init_channel*8) #hidden\n",
    "            \n",
    "        )\n",
    "        \n",
    "        \n",
    "\n",
    "        ## fully connected laeyer for learning representation\n",
    "        self.fully_connected_layer_a=nn.Linear(init_channel*8,latent_dim)\n",
    "        self.fully_connected_layer_b=nn.Linear(init_channel*8,latent_dim)\n",
    "        #self.fully_connected_layer_beta=nn.Linear(init_channel*8,latent_dim)\n",
    "        # self.prior=nn.Sequential(\n",
    "        # nn.Linear(latent_dim,latent_dim),\n",
    "        # #nn.BatchNorm2d(latent_dim,affine=False),\n",
    "        # nn.ReLU(inplace=True),\n",
    "        # nn.Linear(latent_dim,latent_dim),\n",
    "        # #nn.BatchNorm2d(latent_dim,affine=False),\n",
    "        # #nn.Tanh()\n",
    "        # )\n",
    "\n",
    "        #self.fully_connected_layer=nn.Linear(latent_dim,init_channel*8)      \n",
    "        ## decoder\n",
    "        \n",
    "                # encoder\n",
    "        self.decoder = nn.Sequential(\n",
    "            \n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(latent_dim,out_channels=init_channel*4,kernel_size=kernel_size,stride=1,padding=0),            \n",
    "            nn.BatchNorm2d(init_channel*4,affine=False),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*4,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2,affine=False),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*2,out_channels=init_channel*1,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*1,affine=False) ,\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*1,out_channels=image_channel,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "        \n",
    "        \n",
    " \n",
    "\n",
    "    def reparametrize(self,mu,log_r):\n",
    "        \n",
    "        coef= self.bm #b(m)\n",
    "        n=log_r.size(dim=0)\n",
    "        d=log_r.size(dim=1)\n",
    "        B=self.B\n",
    "        #beta=torch.exp(0.5*log_beta)\n",
    "        uniform_k=torch.rand(n*d,3).to(device)\n",
    "        U_k=2*uniform_k-1\n",
    "        \n",
    "        uniform_p=torch.rand(n,d).to(device)\n",
    "        U_p=2*uniform_p-1  #uniform prior\n",
    "        #z=self.prior(U_p)\n",
    "        median=torch.median(U_k,1)[0].reshape(n,d) #sample from E kernel\n",
    "        sample=coef*(torch.exp(0.5*log_r)*median+mu)+coef*U_p*torch.exp(0.5*log_r)\n",
    "        return sample#,U_p/2\n",
    "    def forward(self,x):\n",
    "        #encoding\n",
    "\n",
    "        hidden =self.encoder(x)\n",
    "        \n",
    "        #hidden=torch.flatten(x, start_dim=1)\n",
    "     \n",
    "        mean=self.fully_connected_layer_a(hidden)\n",
    "        log_r=self.fully_connected_layer_b(hidden)\n",
    "        #log_beta=self.fully_connected_layer_beta(hidden)\n",
    "        #rand_ind=torch.randint(hidden.size(0), (hidden.size(0),))\n",
    "        #mean=mean[rand_ind]\n",
    "        zs=self.reparametrize(mean,log_r) #uniform distribution\n",
    "\n",
    "        #z=torch.concat((K,z),dim=1)\n",
    "        \n",
    "        #z=self.fully_connected_layer(zs)\n",
    "      \n",
    "        z=zs.view(-1,self.latent_dim,1,1)\n",
    "        #decoding\n",
    "        reconstruction=self.decoder(z)\n",
    "\n",
    "        return reconstruction,mean,log_r,zs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "id": "46fe3b8b-efee-4ec4-9257-8a4d86b33f28",
   "metadata": {},
   "outputs": [],
   "source": [
    "device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "latent_dim=16\n",
    "#VAEmodel=VAE(latent_dim=latent_dim).to(device)\n",
    "EVAEmodel=EVAE(latent_dim=latent_dim,m=100).to(device)\n",
    "VAEmodel=VAE(latent_dim=latent_dim).to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d7c481",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "PATH = \"EVAEmodelMNIST_state_dict_dz8.pth\"\n",
    "PATHvae = \"VAEmodelMNIST_state_dict_dz16.pth\"\n",
    "PATHae = \"AEmodelMNIST_state_dict_dz8.pth\"\n",
    "# EVAEmodel.load_state_dict(torch.load(PATH))\n",
    "# EVAEmodel.eval() # Set to evaluation mode for inference\n",
    "VAEmodel.load_state_dict(torch.load(PATHvae))\n",
    "VAEmodel.eval() # Set to evaluation mode for inference\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "id": "206a59cf-bccf-4ecd-8eb3-6c16524c1213",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generator Code\n",
    "class Classifier(nn.Module):\n",
    "    def __init__(self, ngpu):\n",
    "        super(Classifier, self).__init__()\n",
    "        self.ngpu = ngpu\n",
    "\n",
    "        self.main = nn.Sequential(\n",
    "            # input is Z, going into a convolution\n",
    "            nn.Linear( nz, ngf * 2),\n",
    "            nn.BatchNorm1d(ngf * 2),\n",
    "            nn.GELU(),\n",
    "            # state size. (ngf*8) x 4 x 4\n",
    "            nn.Linear(ngf * 2, ngf * 4),\n",
    "            nn.BatchNorm1d(ngf * 4),\n",
    "            nn.GELU(),\n",
    "            \n",
    "            # state size. (ngf*4) x 8 x 8\n",
    "            nn.Linear( ngf * 4, ngf * 8),\n",
    "            nn.BatchNorm1d(ngf * 8),\n",
    "            nn.GELU(),\n",
    "            # state size. (ngf*2) x 16 x 16\n",
    "            nn.Linear( ngf * 8, 10),\n",
    "            #nn.Tanh()\n",
    "            # state size. (nc) x 32 x 32\n",
    "        )\n",
    "    \n",
    "    def forward(self, input):\n",
    "        x=self.main(input)\n",
    "        #print(x.size())\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 183,
   "id": "446e9f3e-e302-4e19-b215-2e969a86a450",
   "metadata": {},
   "outputs": [],
   "source": [
    "nz = 16\n",
    "\n",
    "# Size of feature maps in generator\n",
    "ngf = 64\n",
    "\n",
    "# Size of feature maps in discriminator\n",
    "ndf = 64\n",
    "\n",
    "ngpu=1\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c84cbd41-53dc-4986-a414-80a60850766d",
   "metadata": {},
   "source": [
    "## Downstream classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 184,
   "id": "efad5c6b-4a9a-4f71-91d8-b29f4e23cfc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def final_loss(bce_loss):\n",
    "\n",
    "    return bce_loss\n",
    "\n",
    "def model_train_cls(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.train()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    for i, data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "        counter+=1\n",
    "        image=data[0]\n",
    "        label=data[1]\n",
    "        image=image.to(device)\n",
    "        label=label.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        hidden=VAEmodel.encoder(image)\n",
    "        # mean=EVAEmodel.fully_connected_layer_a(hidden)\n",
    "        # log_r=EVAEmodel.fully_connected_layer_b(hidden)\n",
    "\n",
    "        # latent_mu=EVAEmodel.reparametrize(mean,log_r) #uniform distribution\n",
    "        latent_mu=VAEmodel.fully_connected_layer_a(hidden)\n",
    "        pred=model(latent_mu)\n",
    "        \n",
    "        loss= criterion(pred,label)\n",
    "\n",
    "        #print(loss)\n",
    "        loss.backward()\n",
    "        running_loss+=loss.item()\n",
    "        optimizer.step()\n",
    "    train_loss=running_loss/counter\n",
    "    #print(log_var)\n",
    "    return train_loss\n",
    "\n",
    "def model_validate_cls(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.eval()\n",
    "    correct=0\n",
    "    counter=0\n",
    "    with torch.no_grad():\n",
    "        for i,data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "            counter+=1\n",
    "            image=data[0]\n",
    "            label=data[1]\n",
    "            image=image.to(device)\n",
    "            label=label.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            #latent_mu=EVAEmodel.fully_connected_layer_a(EVAEmodel.encoder(image))\n",
    "            hidden=VAEmodel.encoder(image)\n",
    "            # mean=EVAEmodel.fully_connected_layer_a(hidden)\n",
    "            # log_r=EVAEmodel.fully_connected_layer_b(hidden)\n",
    "    \n",
    "            latent_mu=VAEmodel.fully_connected_layer_a(hidden) #uniform distribution\n",
    "            pred=model(latent_mu)\n",
    "            probabilities = F.softmax(pred, dim=1)\n",
    "            _, predicted_class = torch.max(probabilities, dim=1)\n",
    "            correct += (predicted_class == label).sum().item()\n",
    "        accuracy = 100 * correct / len(dataset)\n",
    "        return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "id": "39fe9819-8411-4e50-9ce6-02756d1b3096",
   "metadata": {},
   "outputs": [],
   "source": [
    "device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "epochs=50\n",
    "batch_size=100\n",
    "lr=0.0003\n",
    "transform=transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor()])\n",
    "#training set transforms.Resize((32,32)),\n",
    "train_set=torchvision.datasets.MNIST(root='./',train=True,download=False,transform=transform)\n",
    "train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=True)\n",
    "\n",
    "#test set\n",
    "test_set=torchvision.datasets.MNIST(root='./',train=False,download=False,transform=transform)\n",
    "test_loader=torch.utils.data.DataLoader(test_set,batch_size=batch_size,shuffle=True)\n",
    "#optimizerAE=optim.Adam(AEmodel.parameters(),lr=lr)\n",
    "#optimizerVAE=optim.Adam(VAEmodel.parameters(),lr=lr)\n",
    "\n",
    "criterion=nn.CrossEntropyLoss()\n",
    "netG = Classifier(ngpu).to(device)\n",
    "optimizer=optim.Adam(netG.parameters(),lr=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f91b7c66-c6a1-4dee-b2cf-f2dc0f05d04e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    print(f\"Epoch{epoch+1} of {epochs}\")\n",
    "\n",
    "    # train_epoch_lossVAE=model_trainVAE(VAEmodel,train_loader,train_set,device,optimizerVAE,criterion)\n",
    "    # valid_epoch_lossVAE=model_validateVAE(VAEmodel,test_loader,test_set,device,optimizerVAE,criterion)\n",
    "    # train_lossVAE.append(train_epoch_lossVAE)\n",
    "    # valid_lossVAE.append(valid_epoch_lossVAE)\n",
    "    \n",
    "    train_epoch_lossEVAE=model_train_cls(netG,train_loader,train_set,device,optimizer,criterion)\n",
    "    valid_epoch_lossEVAE=model_validate_cls(netG,test_loader,test_set,device,optimizer,criterion)\n",
    "\n",
    "    print(f\"train loss:{train_epoch_lossEVAE:.4f}\")\n",
    "    print(f\"valid loss:{valid_epoch_lossEVAE:.4f}\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac0931fd-a7b9-43b4-afa9-b1a15d95778b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
