{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "33eecf77",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Importing the libraries \n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn import metrics\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1d2e3d3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import *\n",
    "from torchvision import transforms\n",
    "import torchvision\n",
    "from tqdm import tqdm\n",
    "from torchvision.utils import save_image\n",
    "to_pil_image = transforms.ToPILImage()\n",
    "def image_to_vid(images):\n",
    "    imgs = [np.array(to_pil_image(img)) for img in images]\n",
    "    imageio.mimsave('../outputs/generated_images.gif', imgs)\n",
    "def save_reconstructed_imagesVAE(recon_images, epoch):\n",
    "    save_image(recon_images.cpu(), f\"VAE{epoch}.jpg\")\n",
    "def save_reconstructed_imagesEVAE(recon_images, epoch):\n",
    "    save_image(recon_images.cpu(), f\"EVAE{epoch}.jpg\")\n",
    "def save_reconstructed_imagesREAL(images, epoch):\n",
    "    save_image(images.cpu(), f\"REAL{epoch}.jpg\")\n",
    "def save_train_loss_plot(train_loss, valid_loss):\n",
    "    # loss plots\n",
    "    plt.figure(figsize=(10, 7))\n",
    "    plt.plot(train_loss, color='orange', label='train loss VAE')\n",
    "    plt.plot(valid_loss, color='red', label='train loss EVAE')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Total loss (Reconstruction+B*I(K))/M')\n",
    "    plt.legend()\n",
    "    plt.savefig('lossMNIST.jpg')\n",
    "    plt.show()\n",
    "def save_valid_loss_plot(valid_lossVAE, valid_lossEVAE):\n",
    "    # loss plots\n",
    "    plt.figure(figsize=(10, 7))\n",
    "    plt.plot(valid_lossVAE, color='orange', label='validation loss VAE')\n",
    "    plt.plot(valid_lossEVAE, color='red', label='validation loss EVAE')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Reconstruction loss')\n",
    "    plt.legend()\n",
    "    plt.savefig('lossMNIST_valid.jpg')\n",
    "    plt.show()\n",
    "\n",
    "from torchvision.utils import make_grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d7c481",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "9e816c1d",
   "metadata": {},
   "source": [
    "##EVAE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "6f379ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class EVAE(nn.Module):\n",
    "    def __init__(self,image_channel=3,kernel_size=5,latent_dim=64,init_channel=32,m=100,B=1):\n",
    "        super(EVAE, self).__init__()\n",
    "\n",
    "        self.bm=np.power(m,-2/9)#nn.Parameter(torch.ones(1),requires_grad=True)#np.power(m,-2/9)\n",
    "        self.B=B\n",
    "        self.latent_dim=latent_dim\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(image_channel,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=2),            \n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel*2,out_channels=init_channel*4,kernel_size=kernel_size,stride=2,padding=2),\n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel*4,out_channels=init_channel*8,kernel_size=kernel_size,stride=2,padding=2),\n",
    "            nn.BatchNorm2d(init_channel*8),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            # nn.Conv2d(init_channel*8,out_channels=init_channel*16,kernel_size=kernel_size,stride=2,padding=2) ,            \n",
    "            # nn.BatchNorm2d(init_channel*16),\n",
    "            # nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "            #nn.AdaptiveAvgPool2d((100,init_channel*16)),\n",
    "            nn.Flatten( start_dim=1)\n",
    "            \n",
    "            #nn.Linear(init_channel*8,init_channel*8) #hidden\n",
    "            \n",
    "        )\n",
    "        \n",
    "        \n",
    "\n",
    "        ## fully connected laeyer for learning representation\n",
    "        self.fully_connected_layer_a=nn.Linear(init_channel*8*8*8,latent_dim)\n",
    "        self.fully_connected_layer_b=nn.Linear(init_channel*8*8*8,latent_dim)\n",
    "\n",
    "\n",
    "        #self.fully_connected_layer=nn.Linear(latent_dim,init_channel*16*8*8)\n",
    "        ## decoder\n",
    "        \n",
    "                # encoder\n",
    "        self.decoder = nn.Sequential(\n",
    "            \n",
    "            \n",
    "            nn.Linear(latent_dim, init_channel*8 *8*8),\n",
    "            nn.Unflatten(dim=-1, unflattened_size=(init_channel*8, 8,8)),     # 1 x 1     \n",
    "            \n",
    "            nn.ConvTranspose2d(init_channel*8,out_channels=init_channel*8,kernel_size=kernel_size,stride=1,padding=2),            \n",
    "            nn.BatchNorm2d(init_channel*8),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*8,out_channels=init_channel*4,kernel_size=kernel_size,stride=2,padding=2,\n",
    "                                              output_padding=1),\n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*4,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=2,\n",
    "                                              output_padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2) ,\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*2,out_channels=image_channel,kernel_size=kernel_size,stride=2,padding=2,\n",
    "                                              output_padding=1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "        \n",
    "        \n",
    " \n",
    "\n",
    "    def reparametrize(self,mu,log_r):\n",
    "        \n",
    "        coef= self.bm #b(m)\n",
    "        n=log_r.size(dim=0)\n",
    "        d=log_r.size(dim=1)\n",
    "        B=self.B\n",
    "        #beta=torch.exp(0.5*log_beta)\n",
    "        uniform_k=torch.rand(n*d,3).to(device)\n",
    "        U_k=2*uniform_k-1\n",
    "        \n",
    "        uniform_p=torch.rand(n,d).to(device)\n",
    "        U_p=2*uniform_p-1  #uniform prior\n",
    "        #z=self.prior(U_p)\n",
    "        median=torch.median(U_k,1)[0].reshape(n,d) #sample from E kernel\n",
    "        sample=coef*(torch.exp(0.5*log_r)*median+mu)+1*U_p*torch.exp(0.5*log_r)\n",
    "        return sample#,U_p/2\n",
    "    def forward(self,x):\n",
    "        #encoding\n",
    "\n",
    "        hidden =self.encoder(x)\n",
    "        #print(hidden.size())\n",
    "        #hidden=torch.flatten(x, start_dim=1)\n",
    "        #hidden=self.seq_l(hidden)\n",
    "        #batch,_,_,_=x.shape\n",
    "        #hidden=F.adaptive_avg_pool2d(x,1).reshape(batch,-1)     \n",
    "        \n",
    "        mean=self.fully_connected_layer_a(hidden)\n",
    "        log_r=self.fully_connected_layer_b(hidden)\n",
    "        #log_beta=self.fully_connected_layer_beta(hidden)\n",
    "        #rand_ind=torch.randint(hidden.size(0), (hidden.size(0),))\n",
    "        #mean=mean[rand_ind]\n",
    "\n",
    "        \n",
    "        zs=self.reparametrize(mean,log_r) #uniform distribution\n",
    "\n",
    "        #z=torch.concat((K,z),dim=1)\n",
    "        \n",
    "        #zs=torch.tanh(zs)#+zs#self.fully_connected_layer(zs)\n",
    "      \n",
    "        # z=z.view(-1,self.latent_dim,4,4)\n",
    "        #decoding\n",
    "        reconstruction=self.decoder(zs)\n",
    "\n",
    "        return reconstruction,mean,log_r,zs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "4c602b50-8be1-4940-9aa7-836a3ab633ab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8192"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "32*16*4*4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "79c70590",
   "metadata": {},
   "outputs": [],
   "source": [
    "def final_lossEVAE(bce_loss,log_r):\n",
    "\n",
    "\n",
    "    return bce_loss\n",
    "\n",
    "def model_trainEVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.train()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    for i, data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "        counter+=1\n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        reconstruction,mean,log_r,z=model(data)\n",
    "\n",
    "        bce_loss= criterion(reconstruction,data)\n",
    "        \n",
    "        loss=final_lossEVAE(bce_loss,log_r)\n",
    "        #print(loss)\n",
    "        loss.backward()\n",
    "        running_loss+=loss.item()\n",
    "        optimizer.step()\n",
    "    train_loss=running_loss/counter\n",
    "    #print(log_var)\n",
    "    return train_loss,z\n",
    "\n",
    "def model_validateEVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.eval()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    with torch.no_grad():\n",
    "        for i,data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "            counter+=1\n",
    "            data=data[0]\n",
    "            data=data.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            reconstruction,*_=model(data)\n",
    "            bce_loss= criterion(reconstruction,data)\n",
    "            loss=bce_loss#final_loss(bce_loss,mu,log_var)\n",
    "            running_loss+=loss.item()\n",
    "            if i==int(len(dataset)/dataloader.batch_size)-1:\n",
    "                recon_images=reconstruction\n",
    "        valid_loss=running_loss/counter\n",
    "        return valid_loss,recon_images"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a445ce36",
   "metadata": {},
   "source": [
    "## dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "0f8a070e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset MNIST\n",
       "    Number of datapoints: 60000\n",
       "    Root location: ./\n",
       "    Split: Train"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torchvision.datasets.MNIST('./',download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "a8de03d4-0658-4a30-8158-f51211cc002a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2048"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "32*64"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "788b980d",
   "metadata": {},
   "source": [
    "# training EVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "b72a8856",
   "metadata": {},
   "outputs": [],
   "source": [
    "device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "latent_dim=64\n",
    "#VAEmodel=VAE(latent_dim=latent_dim).to(device)\n",
    "EVAEmodel=EVAE(latent_dim=latent_dim,m=100).to(device)\n",
    "lr=0.0003\n",
    "epochs=50\n",
    "batch_size=100\n",
    "image_size=64\n",
    "transform=transforms.Compose([\n",
    "                                  transforms.CenterCrop(140),\n",
    "                                  transforms.Resize((image_size,image_size)),\n",
    "                                  transforms.ToTensor()\n",
    "\n",
    "                              ])\n",
    "\n",
    "#training set\n",
    "train_set=torchvision.datasets.CelebA(root='./',split='train',download=False,transform=transform)\n",
    "train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=True)\n",
    "\n",
    "#test set\n",
    "test_set=torchvision.datasets.CelebA(root='./',split='valid',download=False,transform=transform)\n",
    "test_loader=torch.utils.data.DataLoader(test_set,batch_size=batch_size,shuffle=True)\n",
    "optimizerEVAE=optim.Adam(EVAEmodel.parameters(),lr=lr)\n",
    "\n",
    "criterion=torch.nn.BCELoss(reduction=\"sum\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "576ebcb9-e05d-4e83-bb4f-d4bb879ae972",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### grid_imagesVAE=[]\n",
    "train_lossVAE=[]\n",
    "valid_lossVAE=[]\n",
    "\n",
    "grid_imagesEVAE=[]\n",
    "train_lossEVAE=[]\n",
    "valid_lossEVAE=[]\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    print(f\"Epoch{epoch+1} of {epochs}\")\n",
    "    # train_epoch_lossVAE=model_trainVAE(VAEmodel,train_loader,train_set,device,optimizerVAE,criterion)\n",
    "    # valid_epoch_lossVAE=model_validateVAE(VAEmodel,test_loader,test_set,device,optimizerVAE,criterion)\n",
    "    # train_lossVAE.append(train_epoch_lossVAE)\n",
    "    # valid_lossVAE.append(valid_epoch_lossVAE)\n",
    "    \n",
    "    train_epoch_lossEVAE,z=model_trainEVAE(EVAEmodel,train_loader,train_set,device,optimizerEVAE,criterion)\n",
    "    valid_epoch_lossEVAE,recon_imagesEVAE=model_validateEVAE(EVAEmodel,test_loader,test_set,device,optimizerEVAE,criterion)\n",
    "    train_lossEVAE.append(train_epoch_lossEVAE)\n",
    "    valid_lossEVAE.append(valid_epoch_lossEVAE)\n",
    "  #  mus.append(mu)\n",
    "   # varss.append(logvar.exp())\n",
    "#     #save images recon\n",
    "#     save_reconstructed_imagesVAE(recon_imagesVAE,epoch+1)\n",
    "    save_reconstructed_imagesEVAE(recon_imagesEVAE,epoch+1)\n",
    "#     save_reconstructed_imagesREAL(real_images,epoch+1)\n",
    "    #grid_images.append(image_grid)\n",
    "    # print(f\"train loss:{train_epoch_lossVAE:.4f}\")\n",
    "    # print(f\"valid loss:{valid_epoch_lossVAE:.4f}\")\n",
    "    print(f\"train loss:{train_epoch_lossEVAE:.4f}\")\n",
    "    print(f\"valid loss:{valid_epoch_lossEVAE:.4f}\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d848ec-b89e-4217-b316-be70176fafbe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
