{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "33eecf77",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Importing the libraries \n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn import metrics\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1d2e3d3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import *\n",
    "from torchvision import transforms\n",
    "import torchvision\n",
    "from tqdm import tqdm\n",
    "from torchvision.utils import save_image\n",
    "to_pil_image = transforms.ToPILImage()\n",
    "def image_to_vid(images):\n",
    "    imgs = [np.array(to_pil_image(img)) for img in images]\n",
    "    imageio.mimsave('../outputs/generated_images.gif', imgs)\n",
    "def save_reconstructed_imagesVAE(recon_images, epoch):\n",
    "    save_image(recon_images.cpu(), f\"VAE{epoch}.jpg\")\n",
    "def save_reconstructed_imagesEVAE(recon_images, epoch):\n",
    "    save_image(recon_images.cpu(), f\"EVAE{epoch}.jpg\")\n",
    "def save_reconstructed_imagesREAL(images, epoch):\n",
    "    save_image(images.cpu(), f\"REAL{epoch}.jpg\")\n",
    "def save_train_loss_plot(train_loss, valid_loss):\n",
    "    # loss plots\n",
    "    plt.figure(figsize=(10, 7))\n",
    "    plt.plot(train_loss, color='orange', label='train loss VAE')\n",
    "    plt.plot(valid_loss, color='red', label='train loss EVAE')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Total loss (Reconstruction+B*I(K))/M')\n",
    "    plt.legend()\n",
    "    plt.savefig('lossMNIST.jpg')\n",
    "    plt.show()\n",
    "def save_valid_loss_plot(valid_lossVAE, valid_lossEVAE):\n",
    "    # loss plots\n",
    "    plt.figure(figsize=(10, 7))\n",
    "    plt.plot(valid_lossVAE, color='orange', label='validation loss VAE')\n",
    "    plt.plot(valid_lossEVAE, color='red', label='validation loss EVAE')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Reconstruction loss')\n",
    "    plt.legend()\n",
    "    plt.savefig('lossMNIST_valid.jpg')\n",
    "    plt.show()\n",
    "\n",
    "from torchvision.utils import make_grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "83637226",
   "metadata": {},
   "outputs": [],
   "source": [
    "class VAE(nn.Module):\n",
    "    def __init__(self,image_channel=1,kernel_size=4,latent_dim=64,init_channel=32,m=100):\n",
    "        super(VAE, self).__init__()\n",
    "\n",
    "        self.bm=np.power(m,-2/9)\n",
    "        self.latent_dim=latent_dim\n",
    "        # encoder\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(image_channel,out_channels=init_channel,kernel_size=kernel_size,stride=2,padding=1),            \n",
    "            nn.BatchNorm2d(init_channel),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel*2,out_channels=init_channel*4,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.Conv2d(init_channel*4,out_channels=init_channel*8,kernel_size=kernel_size,stride=2,padding=0)   ,            \n",
    "            nn.BatchNorm2d(init_channel*8),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "            #nn.AdaptiveAvgPool2d((100,init_channel*8)),\n",
    "            nn.Flatten( start_dim=1)\n",
    "            \n",
    "            #nn.Linear(init_channel*8,init_channel*8) #hidden\n",
    "            \n",
    "        )\n",
    "        \n",
    "        \n",
    "\n",
    "        ## fully connected laeyer for learning representation\n",
    "        self.fully_connected_layer_a=nn.Linear(init_channel*8,latent_dim)\n",
    "        self.fully_connected_layer_b=nn.Linear(init_channel*8,latent_dim)\n",
    "        #self.fully_connected_layer_beta=nn.Linear(init_channel*8,latent_dim)\n",
    "\n",
    "        #self.fully_connected_layer=nn.Linear(latent_dim,init_channel*8)\n",
    "        ## decoder\n",
    "        \n",
    "                # encoder\n",
    "        self.decoder = nn.Sequential(\n",
    "            \n",
    "            \n",
    "            \n",
    "            nn.ConvTranspose2d(latent_dim,out_channels=init_channel*4,kernel_size=kernel_size,stride=1,padding=0),            \n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*4,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*2,out_channels=init_channel*1,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*1) ,\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*1,out_channels=image_channel,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "        \n",
    "    def reparametrize(self,mu,log_var):\n",
    "        \n",
    "        std=torch.exp(0.5*log_var)\n",
    "        epsilon=torch.randn_like(std)\n",
    "        sample=mu+epsilon*std  # sampling\n",
    "        return sample        \n",
    " \n",
    "\n",
    "\n",
    "    def forward(self,x):\n",
    "        #encoding\n",
    "\n",
    "        hidden =self.encoder(x)\n",
    "        \n",
    "        #hidden=torch.flatten(x, start_dim=1)\n",
    "     \n",
    "        mean=self.fully_connected_layer_a(hidden)\n",
    "        log_var=self.fully_connected_layer_b(hidden)\n",
    "\n",
    "        z=self.reparametrize(mean,log_var) #uniform distribution\n",
    "\n",
    "        #z=z.to(device)\n",
    "        \n",
    "        #z=self.fully_connected_layer(z)\n",
    "      \n",
    "        z=z.view(-1,self.latent_dim,1,1)\n",
    "        #decoding\n",
    "        reconstruction=self.decoder(z)\n",
    "\n",
    "        return reconstruction,mean,log_var "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f4ca378-dd1a-449b-9dc9-f365529c690e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a44dcb36",
   "metadata": {},
   "outputs": [],
   "source": [
    "def final_lossVAE(bce_loss, mu, logvar):\n",
    "    \"\"\"\n",
    "    This function will add the reconstruction loss (BCELoss) and the \n",
    "    KL-Divergence.\n",
    "    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n",
    "    :param bce_loss: recontruction loss\n",
    "    :param mu: the mean from the latent vector\n",
    "    :param logvar: log variance from the latent vector\n",
    "    \"\"\"\n",
    "\n",
    "    BCE = bce_loss \n",
    "    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
    "    return BCE + KLD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bc27c73f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def final_loss(bce_loss,mu,log_var):\n",
    "#     #print(torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0))\n",
    "#     return bce_loss+torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n",
    "mse_cost_function = torch.nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6203b932",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_trainVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.train()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    for i, data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "        counter+=1\n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        reconstruction,mean,log_var  =model(data)\n",
    "        \n",
    "\n",
    "        reconstruction_error=criterion(reconstruction,data)\n",
    "        # Combining the loss functions\n",
    "        #print(reconstruction_error)\n",
    "        loss =final_lossVAE( reconstruction_error,mean,log_var  )#mse_u + mse_f+reconstruction_error+mse_fv#+mse_F_f#+reconstruction_manifold#+mse_ulb+mse_uub\n",
    "        #loss=final_loss(bce_loss,mu,log_var)\n",
    "        #print(loss)\n",
    "        loss.backward()\n",
    "        running_loss+=loss.item()\n",
    "        optimizer.step()\n",
    "    #print(mean)\n",
    "    #print(log_var)\n",
    "    train_loss=running_loss/counter\n",
    "    return train_loss\n",
    "\n",
    "def model_validateVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.eval()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    with torch.no_grad():\n",
    "        for i,data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "            counter+=1\n",
    "            data=data[0]\n",
    "            data=data.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            #reconstruction,mu,log_var=model(data)\n",
    "            \n",
    "            reconstruction,mean,log_var =model(data)\n",
    "            #bce_loss= criterion(reconstruction,data)\n",
    "            \n",
    "\n",
    "            reconstruction_error=criterion(reconstruction,data)\n",
    "\n",
    "            #values.append(fid.compute())\n",
    "            \n",
    "\n",
    "            loss=reconstruction_error#final_loss(reconstruction_error,mean,log_var)\n",
    "            running_loss+=loss.item()\n",
    "            \n",
    "\n",
    "        valid_loss=running_loss/counter\n",
    "        #print(counter)\n",
    "        return valid_loss\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e816c1d",
   "metadata": {},
   "source": [
    "##EVAE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6f379ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class EVAE(nn.Module):\n",
    "    def __init__(self,image_channel=1,kernel_size=4,latent_dim=64,init_channel=32,m=100,B=1):\n",
    "        super(EVAE, self).__init__()\n",
    "\n",
    "        self.bm=np.power(m,-2/9)#nn.Parameter(torch.ones(1),requires_grad=True)#np.power(m,-2/9)\n",
    "        self.B=B\n",
    "        self.latent_dim=latent_dim\n",
    "        # encoder\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(image_channel,out_channels=init_channel,kernel_size=kernel_size,stride=2,padding=1),            \n",
    "            nn.BatchNorm2d(init_channel),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel*2,out_channels=init_channel*4,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.Conv2d(init_channel*4,out_channels=init_channel*8,kernel_size=kernel_size,stride=2,padding=0)   ,            \n",
    "            nn.BatchNorm2d(init_channel*8),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "            #nn.AdaptiveAvgPool2d((100,init_channel*8)),\n",
    "            nn.Flatten( start_dim=1)\n",
    "            \n",
    "            #nn.Linear(init_channel*8,init_channel*8) #hidden\n",
    "            \n",
    "        )\n",
    "        \n",
    "        \n",
    "\n",
    "        ## fully connected laeyer for learning representation\n",
    "        self.fully_connected_layer_a=nn.Linear(init_channel*8,latent_dim)\n",
    "        self.fully_connected_layer_b=nn.Linear(init_channel*8,latent_dim)\n",
    "\n",
    "\n",
    "        #self.fully_connected_layer=nn.Linear(latent_dim,init_channel*8)      \n",
    "        ## decoder\n",
    "        \n",
    "                # encoder\n",
    "        self.decoder = nn.Sequential(\n",
    "            \n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(latent_dim,out_channels=init_channel*4,kernel_size=kernel_size,stride=1,padding=0),            \n",
    "            nn.BatchNorm2d(init_channel*4,affine=False),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*4,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2,affine=False),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*2,out_channels=init_channel*1,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*1,affine=False) ,\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*1,out_channels=image_channel,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "        \n",
    "        \n",
    " \n",
    "\n",
    "    def reparametrize(self,mu,log_r):\n",
    "        \n",
    "        coef= self.bm #b(m)\n",
    "        n=log_r.size(dim=0)\n",
    "        d=log_r.size(dim=1)\n",
    "        B=self.B\n",
    "        #beta=torch.exp(0.5*log_beta)\n",
    "        uniform_k=torch.rand(n*d,3).to(device)\n",
    "        U_k=2*uniform_k-1\n",
    "        \n",
    "        uniform_p=torch.rand(n,d).to(device)\n",
    "        U_p=2*uniform_p-1  #uniform prior\n",
    "        #z=self.prior(U_p)\n",
    "        median=torch.median(U_k,1)[0].reshape(n,d) #sample from E kernel\n",
    "        sample=coef*(torch.exp(0.5*log_r)*median+mu)+1*U_p*torch.exp(0.5*log_r)\n",
    "        return sample#,U_p/2\n",
    "    def forward(self,x):\n",
    "        #encoding\n",
    "\n",
    "        hidden =self.encoder(x)\n",
    "        \n",
    "        #hidden=torch.flatten(x, start_dim=1)\n",
    "     \n",
    "        mean=self.fully_connected_layer_a(hidden)\n",
    "        log_r=self.fully_connected_layer_b(hidden)\n",
    "        #log_beta=self.fully_connected_layer_beta(hidden)\n",
    "        #rand_ind=torch.randint(hidden.size(0), (hidden.size(0),))\n",
    "        #mean=mean[rand_ind]\n",
    "        zs=self.reparametrize(mean,log_r) #uniform distribution\n",
    "\n",
    "        #z=torch.concat((K,z),dim=1)\n",
    "        \n",
    "        #z=self.fully_connected_layer(zs)\n",
    "      \n",
    "        z=zs.view(-1,self.latent_dim,1,1)\n",
    "        #decoding\n",
    "        reconstruction=self.decoder(z)\n",
    "\n",
    "        return reconstruction,mean,log_r,zs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "79c70590",
   "metadata": {},
   "outputs": [],
   "source": [
    "def final_lossEVAE(bce_loss,log_r):\n",
    "    M=log_r.size(dim=1)\n",
    "    N=log_r.size(dim=0)\n",
    "   # m=torch.exp(0.5*log_beta)\n",
    "    return bce_loss#+2*3/5* torch.exp(torch.logsumexp(torch.logsumexp(-0.5*log_r , dim = 1), dim = 0))#/N\n",
    "\n",
    "def model_trainEVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.train()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    for i, data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "        counter+=1\n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        reconstruction,mean,log_r,z=model(data)\n",
    "\n",
    "        bce_loss= criterion(reconstruction,data)\n",
    "        \n",
    "        loss=final_lossEVAE(bce_loss,log_r)\n",
    "        #print(loss)\n",
    "        loss.backward()\n",
    "        running_loss+=loss.item()\n",
    "        optimizer.step()\n",
    "    train_loss=running_loss/counter\n",
    "    #print(log_var)\n",
    "    return train_loss,z\n",
    "\n",
    "def model_validateEVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.eval()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    with torch.no_grad():\n",
    "        for i,data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "            counter+=1\n",
    "            data=data[0]\n",
    "            data=data.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            reconstruction,*_=model(data)\n",
    "            bce_loss= criterion(reconstruction,data)\n",
    "            loss=bce_loss#final_loss(bce_loss,mu,log_var)\n",
    "            running_loss+=loss.item()\n",
    "            if i==int(len(dataset)/dataloader.batch_size)-1:\n",
    "                recon_images=reconstruction\n",
    "        valid_loss=running_loss/counter\n",
    "        return valid_loss,recon_images"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a445ce36",
   "metadata": {},
   "source": [
    "## dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "0f8a070e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset MNIST\n",
       "    Number of datapoints: 60000\n",
       "    Root location: ./\n",
       "    Split: Train"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torchvision.datasets.MNIST('./',download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a8de03d4-0658-4a30-8158-f51211cc002a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2048"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "32*64"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "788b980d",
   "metadata": {},
   "source": [
    "# training EVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "b72a8856",
   "metadata": {},
   "outputs": [],
   "source": [
    "device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "latent_dim=32\n",
    "#AEmodel=AE(latent_dim=latent_dim).to(device)\n",
    "#VAEmodel=VAE(latent_dim=latent_dim).to(device)\n",
    "EVAEmodel=EVAE(latent_dim=latent_dim,m=100).to(device)\n",
    "lr=0.0003\n",
    "epochs=50\n",
    "batch_size=100\n",
    "\n",
    "transform=transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor()])\n",
    "\n",
    "#training set transforms.Resize((32,32)),\n",
    "train_set=torchvision.datasets.MNIST(root='./',train=True,download=False,transform=transform)\n",
    "train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=True)\n",
    "\n",
    "#test set\n",
    "test_set=torchvision.datasets.MNIST(root='./',train=False,download=False,transform=transform)\n",
    "test_loader=torch.utils.data.DataLoader(test_set,batch_size=batch_size,shuffle=True)\n",
    "#optimizerAE=optim.Adam(AEmodel.parameters(),lr=lr)\n",
    "#optimizerVAE=optim.Adam(VAEmodel.parameters(),lr=lr)\n",
    "optimizerEVAE=optim.Adam(EVAEmodel.parameters(),lr=lr)\n",
    "criterion=nn.BCELoss(reduction=\"sum\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "576ebcb9-e05d-4e83-bb4f-d4bb879ae972",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### grid_imagesVAE=[]\n",
    "train_lossVAE=[]\n",
    "valid_lossVAE=[]\n",
    "\n",
    "grid_imagesEVAE=[]\n",
    "train_lossEVAE=[]\n",
    "valid_lossEVAE=[]\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    print(f\"Epoch{epoch+1} of {epochs}\")\n",
    "\n",
    "    #train_epoch_lossVAE=model_trainVAE(VAEmodel,train_loader,train_set,device,optimizerVAE,criterion)\n",
    "    #valid_epoch_lossVAE=model_validateVAE(VAEmodel,test_loader,test_set,device,optimizerVAE,criterion)\n",
    "    # train_lossVAE.append(train_epoch_lossVAE)\n",
    "    # valid_lossVAE.append(valid_epoch_lossVAE)\n",
    "    \n",
    "    train_epoch_lossEVAE,z=model_trainEVAE(EVAEmodel,train_loader,train_set,device,optimizerEVAE,criterion)\n",
    "    valid_epoch_lossEVAE,recon_imagesEVAE=model_validateEVAE(EVAEmodel,test_loader,test_set,device,optimizerEVAE,criterion)\n",
    "    # train_lossEVAE.append(train_epoch_lossEVAE)\n",
    "    # valid_lossEVAE.append(valid_epoch_lossEVAE)\n",
    "  #  mus.append(mu)\n",
    "   # varss.append(logvar.exp())\n",
    "#     #save images recon\n",
    "#     save_reconstructed_imagesVAE(recon_imagesVAE,epoch+1)\n",
    "    #save_reconstructed_imagesEVAE(recon_imagesEVAE,epoch+1)\n",
    "#     save_reconstructed_imagesREAL(real_images,epoch+1)\n",
    "    #grid_images.append(image_grid)\n",
    "    # print(f\"train loss:{train_epoch_lossVAE:.4f}\")\n",
    "    # print(f\"valid loss:{valid_epoch_lossVAE:.4f}\")\n",
    "    print(f\"train loss:{train_epoch_lossEVAE:.4f}\")\n",
    "    print(f\"valid loss:{valid_epoch_lossEVAE:.4f}\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "c460f332-689d-413e-8df1-5c9ffb58b491",
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH = \"VAEmodelMNIST_state_dict_dz16.pth\"\n",
    "torch.save(VAEmodel.state_dict(), PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "11c0a6d9-4116-4c06-8877-e1762b9cd3eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Root directory for dataset\n",
    "dataroot = \"data/celeba\"\n",
    "device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "# Number of workers for dataloader\n",
    "workers = 2\n",
    "\n",
    "# Batch size during training\n",
    "batch_size = 100\n",
    "\n",
    "# Spatial size of training images. All images will be resized to this\n",
    "#   size using a transformer.\n",
    "image_size = 64\n",
    "\n",
    "# Number of channels in the training images. For color images this is 3\n",
    "nc =4\n",
    "\n",
    "# Size of z latent vector (i.e. size of generator input)\n",
    "nz = 32\n",
    "\n",
    "# Size of feature maps in generator\n",
    "ngf = 64\n",
    "\n",
    "# Size of feature maps in discriminator\n",
    "ndf = 64\n",
    "\n",
    "# Number of training epochs\n",
    "num_epochs = 5\n",
    "\n",
    "# Learning rate for optimizers\n",
    "lr = 0.0002\n",
    "\n",
    "# Beta1 hyperparameter for Adam optimizers\n",
    "beta1 = 0.5\n",
    "\n",
    "# Number of GPUs available. Use 0 for CPU mode.\n",
    "ngpu = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "50a0e6cb-87ba-4d80-9421-095c9524afee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# custom weights initialization called on ``netG`` and ``netD``\n",
    "def weights_init(m):\n",
    "    classname = m.__class__.__name__\n",
    "    if classname.find('Conv') != -1:\n",
    "        nn.init.normal_(m.weight.data, 0.0, 0.02)\n",
    "    elif classname.find('BatchNorm') != -1:\n",
    "        nn.init.normal_(m.weight.data, 1.0, 0.02)\n",
    "        nn.init.constant_(m.bias.data, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "30669492-9c8b-41f7-afce-9cafea6fbbe5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generator Code\n",
    "class Generator(nn.Module):\n",
    "    def __init__(self, ngpu):\n",
    "        super(Generator, self).__init__()\n",
    "        self.ngpu = ngpu\n",
    "\n",
    "        self.main = nn.Sequential(\n",
    "            # input is Z, going into a convolution\n",
    "            nn.Linear( nz, ngf * 2,bias=True),\n",
    "            nn.BatchNorm1d(ngf * 2),\n",
    "            nn.GELU(),\n",
    "            # state size. (ngf*8) x 4 x 4\n",
    "            nn.Linear(ngf * 2, ngf * 4, bias=True),\n",
    "            nn.BatchNorm1d(ngf * 4),\n",
    "            nn.GELU(),\n",
    "            \n",
    "            # state size. (ngf*4) x 8 x 8\n",
    "            nn.Linear( ngf * 4, ngf * 8, bias=True),\n",
    "            nn.BatchNorm1d(ngf * 8),\n",
    "            nn.GELU(),\n",
    "            # state size. (ngf*2) x 16 x 16\n",
    "            nn.Linear( ngf * 8, 32, bias=True),\n",
    "            #nn.Tanh()\n",
    "            # state size. (nc) x 32 x 32\n",
    "        )\n",
    "    \n",
    "    def forward(self, input):\n",
    "        x=self.main(input)\n",
    "        #print(x.size())\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "f3311c95-4b4f-414b-8a22-4f1ff376c731",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([32, 32])"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Create the generator\n",
    "netG = Generator(ngpu).to(device)\n",
    "\n",
    "# Handle multi-GPU if desired\n",
    "if (device.type == 'cuda') and (ngpu > 1):\n",
    "    netG = nn.DataParallel(netG, list(range(ngpu)))\n",
    "\n",
    "# Apply the ``weights_init`` function to randomly initialize all weights\n",
    "#  to ``mean=0``, ``stdev=0.02``.\n",
    "#netG.apply(weights_init)\n",
    "\n",
    "# Print the model\n",
    "temp=netG(torch.randn(32,32).to(device))\n",
    "temp.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "ba26c7d1-c114-4cdb-a4bc-aa0d6c016ae7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Discriminator(nn.Module):\n",
    "    def __init__(self, ngpu):\n",
    "        super(Discriminator, self).__init__()\n",
    "        self.ngpu = ngpu\n",
    "\n",
    "        self.main = nn.Sequential(\n",
    "            # input is (nc) x 32 x 32\n",
    "            nn.Linear(32, ndf, bias=True),\n",
    "            nn.GELU(),\n",
    "            # state size. (ndf) x 16 x 16\n",
    "            nn.Linear(ndf, ndf * 2,bias=True),\n",
    "            nn.BatchNorm1d(ndf * 2),\n",
    "            nn.GELU(),\n",
    "            \n",
    "            # state size. (ndf*2) x 8 x 8\n",
    "            nn.Linear(ndf * 2, ndf * 4, bias=True),\n",
    "            nn.BatchNorm1d(ndf * 4),\n",
    "            nn.GELU(),\n",
    "            nn.Linear(ndf * 4, ndf * 8, bias=True),\n",
    "            nn.BatchNorm1d(ndf * 8),\n",
    "            nn.GELU(),\n",
    "            #state size. (ndf*4) x 4 x 4\n",
    "            nn.Linear(ndf *8, 1, bias=True),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "    def forward(self, input):\n",
    "        return self.main(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "5c0b692e-ff08-4a46-a4f5-cfd1afdbcbd9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 1])\n"
     ]
    }
   ],
   "source": [
    "# Create the Discriminator\n",
    "netD = Discriminator(ngpu).to(device)\n",
    "\n",
    "# Handle multi-GPU if desired\n",
    "if (device.type == 'cuda') and (ngpu > 1):\n",
    "    netD = nn.DataParallel(netD, list(range(ngpu)))\n",
    "    \n",
    "# Apply the ``weights_init`` function to randomly initialize all weights\n",
    "# like this: ``to mean=0, stdev=0.2``.\n",
    "netD.apply(weights_init)\n",
    "\n",
    "# Print the model\n",
    "t2=netD(temp)\n",
    "print(t2.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "ca52323d-3dc4-49ab-8c59-a375d897a4b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "netG = Generator(ngpu).to(device)\n",
    "\n",
    "# Handle multi-GPU if desired\n",
    "if (device.type == 'cuda') and (ngpu > 1):\n",
    "    netG = nn.DataParallel(netG, list(range(ngpu)))\n",
    "\n",
    "# Apply the ``weights_init`` function to randomly initialize all weights\n",
    "#  to ``mean=0``, ``stdev=0.02``.\n",
    "#netG.apply(weights_init)\n",
    "\n",
    "# Create the Discriminator\n",
    "netD = Discriminator(ngpu).to(device)\n",
    "\n",
    "# Handle multi-GPU if desired\n",
    "if (device.type == 'cuda') and (ngpu > 1):\n",
    "    netD = nn.DataParallel(netD, list(range(ngpu)))\n",
    "    \n",
    "# Apply the ``weights_init`` function to randomly initialize all weights\n",
    "# like this: ``to mean=0, stdev=0.2``.\n",
    "netD.apply(weights_init)\n",
    "# Initialize the ``BCELoss`` function\n",
    "criterion = nn.BCELoss()\n",
    "\n",
    "# Create batch of latent vectors that we will use to visualize\n",
    "#  the progression of the generator\n",
    "fixed_noise = torch.randn(100, 48, device=device)\n",
    "\n",
    "# Establish convention for real and fake labels during training\n",
    "real_label = 1.\n",
    "fake_label = 0.\n",
    "\n",
    "# Setup Adam optimizers for both G and D\n",
    "optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))\n",
    "optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45464f07-2ace-43d6-a66b-a3439178b92d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training Loop\n",
    "\n",
    "# Lists to keep track of progress\n",
    "img_list = []\n",
    "G_losses = []\n",
    "D_losses = []\n",
    "iters = 0\n",
    "#fid1 = FrechetInceptionDistance(device=device)\n",
    "print(\"Starting Training Loop...\")\n",
    "# For each epoch\n",
    "EVAEmodel.eval()\n",
    "for epoch in range(50):\n",
    "    # For each batch in the dataloader\n",
    "    for i, data in tqdm(enumerate(train_loader),total=int(len(train_set)/train_loader.batch_size)):\n",
    "        \n",
    "        ############################\n",
    "        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))\n",
    "        ###########################\n",
    "        ## Train with all-real batch\n",
    "        netD.zero_grad()\n",
    "        # Format batch\n",
    "        images = data[0].to(device)\n",
    "        *_,real_latent= EVAEmodel(images)#DKGMmodel_generate.encoder_initial(images)\n",
    "        real_latent=real_latent.view(-1,32)\n",
    "        b_size = real_latent.size(0)\n",
    "        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)\n",
    "        # Forward pass real batch through D\n",
    "        output = netD(real_latent).view(-1)\n",
    "        # Calculate loss on all-real batch\n",
    "        errD_real = criterion(output, label)\n",
    "        # Calculate gradients for D in backward pass\n",
    "        errD_real.backward()\n",
    "        D_x = output.mean().item()\n",
    "\n",
    "        ## Train with all-fake batch\n",
    "        # Generate batch of latent vectors\n",
    "        noise = torch.randn(b_size, nz, device=device)\n",
    "        # Generate fake image batch with G\n",
    "        fake = netG(noise)\n",
    "        label.fill_(fake_label)\n",
    "        # Classify all fake batch with D\n",
    "        output = netD(fake.detach()).view(-1)\n",
    "        # Calculate D's loss on the all-fake batch\n",
    "        errD_fake = criterion(output, label)\n",
    "        # Calculate the gradients for this batch, accumulated (summed) with previous gradients\n",
    "        errD_fake.backward()\n",
    "        D_G_z1 = output.mean().item()\n",
    "        # Compute error of D as sum over the fake and the real batches\n",
    "        errD = errD_real + errD_fake\n",
    "        # Update D\n",
    "        optimizerD.step()\n",
    "\n",
    "        ############################\n",
    "        # (2) Update G network: maximize log(D(G(z)))\n",
    "        ###########################\n",
    "        netG.zero_grad()\n",
    "        label.fill_(real_label)  # fake labels are real for generator cost\n",
    "        # Since we just updated D, perform another forward pass of all-fake batch through D\n",
    "        output = netD(fake).view(-1)\n",
    "        # Calculate G's loss based on this output\n",
    "        errG = criterion(output, label)\n",
    "        # Calculate gradients for G\n",
    "        errG.backward()\n",
    "        D_G_z2 = output.mean().item()\n",
    "        # Update G\n",
    "        optimizerG.step()\n",
    "        \n",
    "\n",
    "        # if i==int(len(train_set)/train_loader.batch_size)-1:\n",
    "        #         recon_images=fake\n",
    "        #         save_reconstructed_images(recon_images,epoch+1)\n",
    "        #         save_reconstructed_imagesori(data[0],epoch+1)\n",
    "        # Check how the generator is doing by saving G's output on fixed_noise\n",
    "#         if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(train_loader)-1)):\n",
    "#             with torch.no_grad():\n",
    "#                 fake = netG(fixed_noise).detach().cpu()\n",
    "#             img_list.append(make_grid(fake, padding=2, normalize=True))\n",
    "            \n",
    "        iters += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "8b43da78-d5aa-46e1-af9a-99a65e78117e",
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH = \"EVAEmodel_netG_MNIST_state_dict.pth\"\n",
    "torch.save(netG.state_dict(), PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "1bf9769f-04f3-4644-9e40-2265a638b7bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 7/600 [00:00<00:13, 45.22it/s]\n"
     ]
    }
   ],
   "source": [
    "#DKGMmodel_generate.train()\n",
    "EVAEmodel.eval()\n",
    "netG.eval()\n",
    "#DKGMmodel_debiasing.eval()\n",
    "#diffusion = Diffusion(img_size=16, device=device)\n",
    "index_batch=torch.randperm(train_loader.batch_size)[0]\n",
    "index_sample=torch.randperm(train_loader.batch_size)[:64]\n",
    "# test_set=torchvision.datasets.CIFAR10(root='./',train=False,download=False,transform=transform)\n",
    "# test_loader=torch.utils.data.DataLoader(test_set,batch_size=batch_size,shuffle=False)\n",
    "train_set=torchvision.datasets.MNIST(root='./',train=True,download=False,transform=transform)\n",
    "train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=False)\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i,data in tqdm(enumerate(train_loader),total=int(len(train_set)/train_loader.batch_size)):\n",
    "        \n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        #optimizerVAE.zero_grad()\n",
    "        #optimizerEVAE.zero_grad()\n",
    "        if i==7:\n",
    "            #reconstruction,mean,log_r,z=EVAEmodel(data)\n",
    "\n",
    "     \n",
    "           \n",
    "            noise= torch.randn(64, 8, device=device)#torch.median(U_k,1)[0].reshape(point.size(0),3,32,32)\n",
    "\n",
    "            state1= netG(noise)\n",
    "\n",
    "            #z=EVAEmodel.fully_connected_layer(state1.view(100,-1))\n",
    "      \n",
    "            z=state1.view(-1,8,1,1)\n",
    "        #decoding\n",
    "            #decoding\n",
    "            reconstruction=EVAEmodel.decoder(z)#K_z.view(-1,4,16,16))\n",
    "\n",
    "\n",
    "            \n",
    "            #debiased_state=DKGMmodel_debiasing(state2)\n",
    "            #reconstruction=EVAEmodel(state1,noise)\n",
    "            save_reconstructed_imagesVAE(data,0)\n",
    "            save_reconstructed_imagesEVAE(reconstruction,0)\n",
    "            #save_reconstructed_imagesEVAE((debiased_state+0),1)\n",
    "            #save_reconstructed_imagesREAL(real_images,0)\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "0a346abb-c6e5-4e16-aaf7-16848bdee53a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 600/600 [03:41<00:00,  2.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FIDEVAE: 7.855438232421875\n"
     ]
    }
   ],
   "source": [
    "#VAEmodel.eval()\n",
    "EVAEmodel.eval()\n",
    "netG.eval()\n",
    "#unet.train()\n",
    "#DKGMmodel_debiasing.eval()\n",
    "#diffusion = Diffusion(img_size=16, device=device)\n",
    "running_loss=0.0\n",
    "counter=0\n",
    "#test set\n",
    "tota_sharpVAE=0.0\n",
    "tota_sharpEVAE=0.0\n",
    "from scipy import signal\n",
    "#laplace\n",
    "kernel=np.array([[0 ,1, 0],[1, -4,1],[0, 1 ,0]])\n",
    "\n",
    "\n",
    "from torcheval.metrics import FrechetInceptionDistance\n",
    "#fidVAE = FrechetInceptionDistance(device=device)\n",
    "fidEVAE = FrechetInceptionDistance(device=device)            \n",
    "transform_grayscale=transforms.Grayscale(num_output_channels=1)\n",
    "#from torchmetrics.image.inception import InceptionScore\n",
    "#inception = InceptionScore(normalize=True)   \n",
    "with torch.no_grad():\n",
    "    for i,data in tqdm(enumerate(train_loader),total=int(len(train_set)/train_loader.batch_size)):\n",
    "        counter+=1\n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        #generation\n",
    "        state1=netG( torch.randn(data.size(0), 32, device=device)) #diffusion.sample(unet,data.size(0))\n",
    "\n",
    "        state2=EVAEmodel.decoder(state1.view(-1,32,1,1))    \n",
    "\n",
    "\n",
    "        fidEVAE.update(data.repeat(1, 3, 1,1), is_real=True)\n",
    "        fidEVAE.update(state2.repeat(1, 3, 1,1), is_real=False)\n",
    "\n",
    "lossEVAE=fidEVAE.compute()\n",
    "\n",
    "print(f\"FIDEVAE: {float(lossEVAE)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d848ec-b89e-4217-b316-be70176fafbe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
