{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "33eecf77",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Importing the libraries \n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn import metrics\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1d2e3d3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import *\n",
    "from torchvision import transforms\n",
    "import torchvision\n",
    "from tqdm import tqdm\n",
    "from torchvision.utils import save_image\n",
    "to_pil_image = transforms.ToPILImage()\n",
    "def image_to_vid(images):\n",
    "    imgs = [np.array(to_pil_image(img)) for img in images]\n",
    "    imageio.mimsave('../outputs/generated_images.gif', imgs)\n",
    "def save_reconstructed_imagesVAE(recon_images, epoch):\n",
    "    save_image(recon_images.cpu(), f\"VAE{epoch}.jpg\")\n",
    "def save_reconstructed_imagesEVAE(recon_images, epoch):\n",
    "    save_image(recon_images.cpu(), f\"EVAE{epoch}.jpg\")\n",
    "def save_reconstructed_imagesREAL(images, epoch):\n",
    "    save_image(images.cpu(), f\"REAL{epoch}.jpg\")\n",
    "def save_train_loss_plot(train_loss, valid_loss):\n",
    "    # loss plots\n",
    "    plt.figure(figsize=(10, 7))\n",
    "    plt.plot(train_loss, color='orange', label='train loss VAE')\n",
    "    plt.plot(valid_loss, color='red', label='train loss EVAE')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Total loss (Reconstruction+B*I(K))/M')\n",
    "    plt.legend()\n",
    "    plt.savefig('lossMNIST.jpg')\n",
    "    plt.show()\n",
    "def save_valid_loss_plot(valid_lossVAE, valid_lossEVAE):\n",
    "    # loss plots\n",
    "    plt.figure(figsize=(10, 7))\n",
    "    plt.plot(valid_lossVAE, color='orange', label='validation loss VAE')\n",
    "    plt.plot(valid_lossEVAE, color='red', label='validation loss EVAE')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Reconstruction loss')\n",
    "    plt.legend()\n",
    "    plt.savefig('lossMNIST_valid.jpg')\n",
    "    plt.show()\n",
    "\n",
    "from torchvision.utils import make_grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d7c481",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "83637226",
   "metadata": {},
   "outputs": [],
   "source": [
    "class VAE(nn.Module):\n",
    "    def __init__(self,image_channel=1,kernel_size=4,latent_dim=64,init_channel=32,m=100):\n",
    "        super(VAE, self).__init__()\n",
    "\n",
    "        self.bm=np.power(m,-2/9)\n",
    "\n",
    "        # encoder\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(image_channel,out_channels=init_channel,kernel_size=kernel_size,stride=2,padding=1),            \n",
    "            nn.BatchNorm2d(init_channel),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel*2,out_channels=init_channel*4,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.Conv2d(init_channel*4,out_channels=init_channel*8,kernel_size=kernel_size,stride=2,padding=0)   ,            \n",
    "            nn.BatchNorm2d(init_channel*8),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "            #nn.AdaptiveAvgPool2d((100,init_channel*8)),\n",
    "            nn.Flatten( start_dim=1)\n",
    "            \n",
    "            #nn.Linear(init_channel*8,init_channel*8) #hidden\n",
    "            \n",
    "        )\n",
    "        \n",
    "        \n",
    "\n",
    "        ## fully connected laeyer for learning representation\n",
    "        self.fully_connected_layer_a=nn.Linear(init_channel*8,latent_dim)\n",
    "        self.fully_connected_layer_b=nn.Linear(init_channel*8,latent_dim)\n",
    "        self.fully_connected_layer_beta=nn.Linear(init_channel*8,latent_dim)\n",
    "\n",
    "        self.fully_connected_layer=nn.Linear(latent_dim,init_channel*8)\n",
    "        ## decoder\n",
    "        \n",
    "                # encoder\n",
    "        self.decoder = nn.Sequential(\n",
    "            \n",
    "            \n",
    "            \n",
    "            nn.ConvTranspose2d(init_channel*8,out_channels=init_channel*4,kernel_size=kernel_size,stride=1,padding=0),            \n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*4,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*2,out_channels=init_channel*1,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*1) ,\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*1,out_channels=image_channel,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "        \n",
    "    def reparametrize(self,mu,log_var):\n",
    "        \n",
    "        std=torch.exp(0.5*log_var)\n",
    "        epsilon=torch.randn_like(std)\n",
    "        sample=mu+epsilon*std  # sampling\n",
    "        return sample        \n",
    " \n",
    "\n",
    "\n",
    "    def forward(self,x):\n",
    "        #encoding\n",
    "\n",
    "        hidden =self.encoder(x)\n",
    "        \n",
    "        #hidden=torch.flatten(x, start_dim=1)\n",
    "     \n",
    "        mean=self.fully_connected_layer_a(hidden)\n",
    "        log_var=self.fully_connected_layer_b(hidden)\n",
    "\n",
    "        z=self.reparametrize(mean,log_var) #uniform distribution\n",
    "\n",
    "        z=z.to(device)\n",
    "        \n",
    "        z=self.fully_connected_layer(z)\n",
    "      \n",
    "        z=z.view(-1,256,1,1)\n",
    "        #decoding\n",
    "        reconstruction=self.decoder(z)\n",
    "\n",
    "        return reconstruction,mean,log_var "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a44dcb36",
   "metadata": {},
   "outputs": [],
   "source": [
    "def final_lossVAE(bce_loss, mu, logvar):\n",
    "    \"\"\"\n",
    "    This function will add the reconstruction loss (BCELoss) and the \n",
    "    KL-Divergence.\n",
    "    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n",
    "    :param bce_loss: recontruction loss\n",
    "    :param mu: the mean from the latent vector\n",
    "    :param logvar: log variance from the latent vector\n",
    "    \"\"\"\n",
    "\n",
    "    BCE = bce_loss \n",
    "    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
    "    return BCE + KLD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bc27c73f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def final_loss(bce_loss,mu,log_var):\n",
    "#     #print(torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0))\n",
    "#     return bce_loss+torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n",
    "mse_cost_function = torch.nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6203b932",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_trainVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.train()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    for i, data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "        counter+=1\n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        reconstruction,mean,log_var  =model(data)\n",
    "        \n",
    "\n",
    "        reconstruction_error=criterion(reconstruction,data)\n",
    "        # Combining the loss functions\n",
    "        #print(reconstruction_error)\n",
    "        loss =final_lossVAE( reconstruction_error,mean,log_var  )#mse_u + mse_f+reconstruction_error+mse_fv#+mse_F_f#+reconstruction_manifold#+mse_ulb+mse_uub\n",
    "        #loss=final_loss(bce_loss,mu,log_var)\n",
    "        #print(loss)\n",
    "        loss.backward()\n",
    "        running_loss+=loss.item()\n",
    "        optimizer.step()\n",
    "    #print(mean)\n",
    "    #print(log_var)\n",
    "    train_loss=running_loss/counter\n",
    "    return train_loss\n",
    "\n",
    "def model_validateVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.eval()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    with torch.no_grad():\n",
    "        for i,data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "            counter+=1\n",
    "            data=data[0]\n",
    "            data=data.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            #reconstruction,mu,log_var=model(data)\n",
    "            \n",
    "            reconstruction,mean,log_var =model(data)\n",
    "            #bce_loss= criterion(reconstruction,data)\n",
    "            \n",
    "\n",
    "            reconstruction_error=criterion(reconstruction,data)\n",
    "\n",
    "            #values.append(fid.compute())\n",
    "            \n",
    "\n",
    "            loss=reconstruction_error#final_loss(reconstruction_error,mean,log_var)\n",
    "            running_loss+=loss.item()\n",
    "            \n",
    "\n",
    "        valid_loss=running_loss/counter\n",
    "        #print(counter)\n",
    "        return valid_loss\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e816c1d",
   "metadata": {},
   "source": [
    "##EVAE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6f379ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class EVAE(nn.Module):\n",
    "    def __init__(self,image_channel=1,kernel_size=4,latent_dim=64,init_channel=32,m=100,B=1):\n",
    "        super(EVAE, self).__init__()\n",
    "\n",
    "        self.bm=np.power(m,-2/9)\n",
    "        self.B=B\n",
    "        # encoder\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(image_channel,out_channels=init_channel,kernel_size=kernel_size,stride=2,padding=1),            \n",
    "            nn.BatchNorm2d(init_channel),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel*2,out_channels=init_channel*4,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.Conv2d(init_channel*4,out_channels=init_channel*8,kernel_size=kernel_size,stride=2,padding=0)   ,            \n",
    "            nn.BatchNorm2d(init_channel*8),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "            #nn.AdaptiveAvgPool2d((100,init_channel*8)),\n",
    "            nn.Flatten( start_dim=1)\n",
    "            \n",
    "            #nn.Linear(init_channel*8,init_channel*8) #hidden\n",
    "            \n",
    "        )\n",
    "        \n",
    "        \n",
    "\n",
    "        ## fully connected laeyer for learning representation\n",
    "        self.fully_connected_layer_a=nn.Linear(init_channel*8,latent_dim)\n",
    "        self.fully_connected_layer_b=nn.Linear(init_channel*8,latent_dim)\n",
    "        \n",
    "        ## decoder\n",
    "        \n",
    "                # encoder\n",
    "        self.decoder = nn.Sequential(\n",
    "            \n",
    "\n",
    "            nn.Linear(latent_dim, init_channel*8 *1*1),\n",
    "            nn.Unflatten(dim=-1, unflattened_size=(init_channel*8, 1,1)),     # 1 x 1    \n",
    "            \n",
    "            nn.ConvTranspose2d(init_channel*8,out_channels=init_channel*4,kernel_size=kernel_size,stride=1,padding=0),            \n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*4,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*2,out_channels=init_channel*1,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*1) ,\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*1,out_channels=image_channel,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "        \n",
    "        \n",
    " \n",
    "\n",
    "    def reparametrize(self,mu,log_r):\n",
    "        \n",
    "        coef= self.bm #b(m)\n",
    "        n=log_r.size(dim=0)\n",
    "        d=log_r.size(dim=1)\n",
    "        B=self.B\n",
    "        #beta=torch.exp(0.5*log_beta)\n",
    "        uniform_k=torch.rand(n*d,3).to(device)\n",
    "        U_k=2*uniform_k-1\n",
    "        \n",
    "        gaussian_p=torch.randn(n,d).to(device)\n",
    "        #U_p=2*uniform_p-1  #gaussian prior\n",
    "        \n",
    "        median=torch.median(U_k,1)[0].reshape(n,d) #sample from E kernel\n",
    "        sample=coef*(torch.exp(0.5*log_r)*median+mu)+torch.exp(0.5*log_r)*gaussian_p\n",
    "        return sample\n",
    "    def forward(self,x):\n",
    "        #encoding\n",
    "\n",
    "        hidden =self.encoder(x)\n",
    "        \n",
    "        #hidden=torch.flatten(x, start_dim=1)\n",
    "        #hidden=self.seq_l(hidden)\n",
    "     \n",
    "        mean=self.fully_connected_layer_a(hidden)\n",
    "        log_r=self.fully_connected_layer_b(hidden)\n",
    "\n",
    "\n",
    "        \n",
    "        zs=self.reparametrize(mean,log_r)\n",
    "\n",
    "        #decoding\n",
    "        reconstruction=self.decoder(zs)\n",
    "\n",
    "        return reconstruction,mean,log_r,zs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "79c70590",
   "metadata": {},
   "outputs": [],
   "source": [
    "def final_lossEVAE(bce_loss,log_r):\n",
    "    M=log_r.size(dim=1)\n",
    "    N=log_r.size(dim=0)\n",
    "    return bce_loss\n",
    "\n",
    "def model_trainEVAE(model,dataloader,dataset,device,optimizer,criterion,B):\n",
    "    model.train()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    for i, data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "        counter+=1\n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        reconstruction,mean,log_r,z=model(data)\n",
    "\n",
    "        bce_loss= criterion(reconstruction,data)\n",
    "        \n",
    "        loss=final_lossEVAE(bce_loss,log_r)\n",
    "        #print(loss)\n",
    "        loss.backward()\n",
    "        running_loss+=loss.item()\n",
    "        optimizer.step()\n",
    "    train_loss=running_loss/counter\n",
    "    #print(log_var)\n",
    "    return train_loss\n",
    "\n",
    "def model_validateEVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.eval()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    with torch.no_grad():\n",
    "        for i,data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "            counter+=1\n",
    "            data=data[0]\n",
    "            data=data.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            reconstruction,*_=model(data)\n",
    "            bce_loss= criterion(reconstruction,data)\n",
    "            loss=bce_loss#final_loss(bce_loss,mu,log_var)\n",
    "            running_loss+=loss.item()\n",
    "\n",
    "        valid_loss=running_loss/counter\n",
    "        return valid_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a445ce36",
   "metadata": {},
   "source": [
    "## dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "0f8a070e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset MNIST\n",
       "    Number of datapoints: 60000\n",
       "    Root location: ./\n",
       "    Split: Train"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torchvision.datasets.MNIST('./',download=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "788b980d",
   "metadata": {},
   "source": [
    "# training VAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b72a8856",
   "metadata": {},
   "outputs": [],
   "source": [
    "device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "latent_dim=64\n",
    "VAEmodel=VAE(latent_dim=latent_dim).to(device)\n",
    "EVAEmodel=EVAE(latent_dim=latent_dim).to(device)\n",
    "lr=0.0003\n",
    "epochs=50\n",
    "batch_size=100\n",
    "\n",
    "transform=transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor()])\n",
    "\n",
    "#training set transforms.Resize((32,32)),\n",
    "train_set=torchvision.datasets.MNIST(root='./',train=True,download=False,transform=transform)\n",
    "train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=True)\n",
    "\n",
    "#test set\n",
    "test_set=torchvision.datasets.MNIST(root='./',train=False,download=False,transform=transform)\n",
    "test_loader=torch.utils.data.DataLoader(test_set,batch_size=batch_size,shuffle=True)\n",
    "optimizerVAE=optim.Adam(VAEmodel.parameters(),lr=lr)\n",
    "optimizerEVAE=optim.Adam(EVAEmodel.parameters(),lr=lr)\n",
    "criterion=nn.BCELoss(reduction=\"sum\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2994894",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_imagesVAE=[]\n",
    "train_lossVAE=[]\n",
    "valid_lossVAE=[]\n",
    "\n",
    "grid_imagesEVAE=[]\n",
    "train_lossEVAE=[]\n",
    "valid_lossEVAE=[]\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    print(f\"Epoch{epoch+1} of {epochs}\")\n",
    "    train_epoch_lossVAE=model_trainVAE(VAEmodel,train_loader,train_set,device,optimizerVAE,criterion)\n",
    "    valid_epoch_lossVAE=model_validateVAE(VAEmodel,test_loader,test_set,device,optimizerVAE,criterion)\n",
    "    train_lossVAE.append(train_epoch_lossVAE)\n",
    "    valid_lossVAE.append(valid_epoch_lossVAE)\n",
    "    \n",
    "    train_epoch_lossEVAE=model_trainEVAE(EVAEmodel,train_loader,train_set,device,optimizerEVAE,criterion,B=0.1)\n",
    "    valid_epoch_lossEVAE=model_validateEVAE(EVAEmodel,test_loader,test_set,device,optimizerEVAE,criterion)\n",
    "    train_lossEVAE.append(train_epoch_lossEVAE)\n",
    "    valid_lossEVAE.append(valid_epoch_lossEVAE)\n",
    "  #  mus.append(mu)\n",
    "   # varss.append(logvar.exp())\n",
    "#     #save images recon\n",
    "#     save_reconstructed_imagesVAE(recon_imagesVAE,epoch+1)\n",
    "#     save_reconstructed_imagesEVAE(recon_imagesEVAE,epoch+1)\n",
    "#     save_reconstructed_imagesREAL(real_images,epoch+1)\n",
    "    #grid_images.append(image_grid)\n",
    "    print(f\"train loss:{train_epoch_lossVAE:.4f}\")\n",
    "    print(f\"valid loss:{valid_epoch_lossVAE:.4f}\")\n",
    "    print(f\"train loss:{train_epoch_lossEVAE:.4f}\")\n",
    "    print(f\"valid loss:{valid_epoch_lossEVAE:.4f}\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b764d9e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
