{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "33eecf77",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Importing the libraries \n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn import metrics\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1d2e3d3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import *\n",
    "from torchvision import transforms\n",
    "import torchvision\n",
    "from tqdm import tqdm\n",
    "from torchvision.utils import save_image\n",
    "to_pil_image = transforms.ToPILImage()\n",
    "def image_to_vid(images):\n",
    "    imgs = [np.array(to_pil_image(img)) for img in images]\n",
    "    imageio.mimsave('../outputs/generated_images.gif', imgs)\n",
    "def save_reconstructed_imagesVAE(recon_images, epoch):\n",
    "    save_image(recon_images.cpu(), f\"VAE{epoch}.jpg\")\n",
    "def save_reconstructed_imagesEVAE(recon_images, epoch):\n",
    "    save_image(recon_images.cpu(), f\"EVAE{epoch}.jpg\")\n",
    "def save_reconstructed_imagesREAL(images, epoch):\n",
    "    save_image(images.cpu(), f\"REAL{epoch}.jpg\")\n",
    "def save_train_loss_plot(train_loss, valid_loss):\n",
    "    # loss plots\n",
    "    plt.figure(figsize=(10, 7))\n",
    "    plt.plot(train_loss, color='orange', label='train loss VAE')\n",
    "    plt.plot(valid_loss, color='red', label='train loss EVAE')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Total loss (Reconstruction+B*I(K))/M')\n",
    "    plt.legend()\n",
    "    plt.savefig('lossMNIST.jpg')\n",
    "    plt.show()\n",
    "def save_valid_loss_plot(valid_lossVAE, valid_lossEVAE):\n",
    "    # loss plots\n",
    "    plt.figure(figsize=(10, 7))\n",
    "    plt.plot(valid_lossVAE, color='orange', label='validation loss VAE')\n",
    "    plt.plot(valid_lossEVAE, color='red', label='validation loss EVAE')\n",
    "    plt.xlabel('Epochs')\n",
    "    plt.ylabel('Reconstruction loss')\n",
    "    plt.legend()\n",
    "    plt.savefig('lossMNIST_valid.jpg')\n",
    "    plt.show()\n",
    "\n",
    "from torchvision.utils import make_grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d7c481",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "83637226",
   "metadata": {},
   "outputs": [],
   "source": [
    "class VAE(nn.Module):\n",
    "    def __init__(self,image_channel=3,kernel_size=4,latent_dim=64,init_channel=32,m=100):\n",
    "        super(VAE, self).__init__()\n",
    "\n",
    "        self.bm=np.power(m,-2/9)\n",
    "        self.latent_dim=latent_dim\n",
    "        # encoder\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(image_channel,out_channels=init_channel,kernel_size=kernel_size,stride=2,padding=1),            \n",
    "            nn.BatchNorm2d(init_channel),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel*2,out_channels=init_channel*4,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.Conv2d(init_channel*4,out_channels=init_channel*8,kernel_size=kernel_size,stride=2,padding=0)   ,            \n",
    "            nn.BatchNorm2d(init_channel*8),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "            #nn.AdaptiveAvgPool2d((100,init_channel*8)),\n",
    "            nn.Flatten( start_dim=1)\n",
    "            \n",
    "            #nn.Linear(init_channel*8,init_channel*8) #hidden\n",
    "            \n",
    "        )\n",
    "        \n",
    "        \n",
    "\n",
    "        ## fully connected laeyer for learning representation\n",
    "        self.fully_connected_layer_a=nn.Linear(init_channel*8,latent_dim)\n",
    "        self.fully_connected_layer_b=nn.Linear(init_channel*8,latent_dim)\n",
    "\n",
    "        ## decoder\n",
    "        \n",
    "                # encoder\n",
    "        self.decoder = nn.Sequential(\n",
    "            \n",
    "            \n",
    "            \n",
    "            nn.ConvTranspose2d(self.latent_dim,out_channels=init_channel*4,kernel_size=kernel_size,stride=1,padding=0),            \n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*4,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*2,out_channels=init_channel*1,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*1) ,\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            \n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*1,out_channels=image_channel,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "        \n",
    "    def reparametrize(self,mu,log_var):\n",
    "        \n",
    "        std=torch.exp(0.5*log_var)\n",
    "        epsilon=torch.randn_like(std)\n",
    "        sample=mu+epsilon*std  # sampling\n",
    "        return sample        \n",
    " \n",
    "\n",
    "\n",
    "    def forward(self,x):\n",
    "        #encoding\n",
    "\n",
    "        hidden =self.encoder(x)\n",
    "        \n",
    "        #hidden=torch.flatten(x, start_dim=1)\n",
    "     \n",
    "        mean=self.fully_connected_layer_a(hidden)\n",
    "        log_var=self.fully_connected_layer_b(hidden)\n",
    "\n",
    "        z=self.reparametrize(mean,log_var) #uniform distribution\n",
    "        z=z.view(-1,self.latent_dim,1,1)\n",
    "        #decoding\n",
    "        reconstruction=self.decoder(z)\n",
    "\n",
    "        return reconstruction,mean,log_var "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a44dcb36",
   "metadata": {},
   "outputs": [],
   "source": [
    "def final_lossVAE(bce_loss, mu, logvar):\n",
    "    \"\"\"\n",
    "    This function will add the reconstruction loss (BCELoss) and the \n",
    "    KL-Divergence.\n",
    "    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n",
    "    :param bce_loss: recontruction loss\n",
    "    :param mu: the mean from the latent vector\n",
    "    :param logvar: log variance from the latent vector\n",
    "    \"\"\"\n",
    "\n",
    "    BCE = bce_loss \n",
    "    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
    "    return BCE + KLD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bc27c73f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def final_loss(bce_loss,mu,log_var):\n",
    "#     #print(torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0))\n",
    "#     return bce_loss+torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n",
    "mse_cost_function = torch.nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6203b932",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_trainVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.train()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    for i, data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "        counter+=1\n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        reconstruction,mean,log_var  =model(data)\n",
    "        \n",
    "\n",
    "        reconstruction_error=criterion(reconstruction,data)\n",
    "        # Combining the loss functions\n",
    "        #print(reconstruction_error)\n",
    "        loss =final_lossVAE( reconstruction_error,mean,log_var  )#mse_u + mse_f+reconstruction_error+mse_fv#+mse_F_f#+reconstruction_manifold#+mse_ulb+mse_uub\n",
    "        #loss=final_loss(bce_loss,mu,log_var)\n",
    "        #print(loss)\n",
    "        loss.backward()\n",
    "        running_loss+=loss.item()\n",
    "        optimizer.step()\n",
    "    #print(mean)\n",
    "    #print(log_var)\n",
    "    train_loss=running_loss/counter\n",
    "    return train_loss\n",
    "\n",
    "def model_validateVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.eval()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    with torch.no_grad():\n",
    "        for i,data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "            counter+=1\n",
    "            data=data[0]\n",
    "            data=data.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            #reconstruction,mu,log_var=model(data)\n",
    "            \n",
    "            reconstruction,mean,log_var =model(data)\n",
    "            #bce_loss= criterion(reconstruction,data)\n",
    "            \n",
    "\n",
    "            reconstruction_error=criterion(reconstruction,data)\n",
    "\n",
    "            #values.append(fid.compute())\n",
    "            \n",
    "\n",
    "            loss=reconstruction_error#final_loss(reconstruction_error,mean,log_var)\n",
    "            running_loss+=loss.item()\n",
    "            \n",
    "\n",
    "        valid_loss=running_loss/counter\n",
    "        #print(counter)\n",
    "        return valid_loss\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7aa3fc3d-fd36-4719-954c-bc98eac0f8cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResidualLayer(nn.Module):\n",
    "\n",
    "    def __init__(self,\n",
    "                 in_channels: int,\n",
    "                 out_channels: int):\n",
    "        super(ResidualLayer, self).__init__()\n",
    "        self.resblock = nn.Sequential(nn.Conv2d(in_channels, out_channels,\n",
    "                                                kernel_size=3, padding=1, bias=False),\n",
    "                                      nn.ReLU(True),\n",
    "                                      nn.Conv2d(out_channels, out_channels,\n",
    "                                                kernel_size=1, bias=False))\n",
    "\n",
    "    def forward(self, input) :\n",
    "        return input + self.resblock(input)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e816c1d",
   "metadata": {},
   "source": [
    "##EVAE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6f379ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class EVAE(nn.Module):\n",
    "    def __init__(self,image_channel=3,kernel_size=4,latent_dim=64,init_channel=32,m=100,B=1):\n",
    "        super(EVAE, self).__init__()\n",
    "\n",
    "        self.bm=np.power(m,-2/9)#nn.Parameter(torch.ones(1),requires_grad=True)#np.power(m,-2/9)\n",
    "        self.B=B\n",
    "        self.latent_dim=latent_dim\n",
    "\n",
    "        # encoder\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(image_channel,out_channels=init_channel,kernel_size=kernel_size,stride=2,padding=1),            \n",
    "            nn.BatchNorm2d(init_channel),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.Conv2d(init_channel*2,out_channels=init_channel*4,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "            nn.Conv2d(init_channel*4,out_channels=init_channel*8,kernel_size=kernel_size,stride=2,padding=0)   ,            \n",
    "            nn.BatchNorm2d(init_channel*8),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            #nn.AdaptiveAvgPool2d((100,init_channel*8)),\n",
    "            nn.Flatten( start_dim=1)\n",
    "            \n",
    "            #nn.Linear(init_channel*8,init_channel*8) #hidden\n",
    "            \n",
    "        )\n",
    "        \n",
    "        \n",
    "\n",
    "        ## fully connected laeyer for learning representation\n",
    "        self.fully_connected_layer_a=nn.Linear(init_channel*8,latent_dim)\n",
    "        self.fully_connected_layer_b=nn.Linear(init_channel*8,latent_dim)\n",
    "        #self.fully_connected_layer_band=nn.Linear(init_channel*8,latent_dim)\n",
    "        ## decoder\n",
    "        \n",
    "                # encoder\n",
    "        self.decoder = nn.Sequential(\n",
    "            \n",
    "            nn.Linear(latent_dim, init_channel*8 *1*1),\n",
    "            nn.Unflatten(dim=-1, unflattened_size=(init_channel*8, 1,1)),     # 1 x 1     \n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*8,out_channels=init_channel*4,kernel_size=kernel_size,stride=1,padding=0,bias=False),            \n",
    "            nn.BatchNorm2d(init_channel*4),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*4,out_channels=init_channel*2,kernel_size=kernel_size,stride=2,padding=1,bias=False),\n",
    "            nn.BatchNorm2d(init_channel*2),\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*2,out_channels=init_channel*1,kernel_size=kernel_size,stride=2,padding=1,bias=False),\n",
    "            nn.BatchNorm2d(init_channel*1) ,\n",
    "            nn.ReLU(inplace=True),\n",
    "\n",
    "\n",
    "            nn.ConvTranspose2d(init_channel*1,out_channels=image_channel,kernel_size=kernel_size,stride=2,padding=1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "        \n",
    "        \n",
    " \n",
    "\n",
    "    def reparametrize(self,mu,log_r):\n",
    "        \n",
    "        coef= self.bm #b(m)\n",
    "        n=log_r.size(dim=0)\n",
    "        d=log_r.size(dim=1)\n",
    "        B=self.B\n",
    "        #beta=torch.exp(0.5*log_beta)\n",
    "        uniform_k=torch.rand(n*d,3).to(device)\n",
    "        U_k=2*uniform_k-1\n",
    "        \n",
    "        uniform_p=torch.rand(n,d).to(device)\n",
    "        U_p=2*uniform_p-1  #uniform prior\n",
    "        #z=self.prior(U_p)\n",
    "        median=torch.median(U_k,1)[0].reshape(n,d) #sample from E kernel\n",
    "        sample=coef*(torch.exp(0.5*log_r)*median+mu)+U_p*torch.exp(0.5*log_r)\n",
    "        return sample#,U_p/2\n",
    "    def forward(self,x):\n",
    "        #encoding\n",
    "\n",
    "        hidden =self.encoder(x)\n",
    "        \n",
    "        #hidden=torch.flatten(x, start_dim=1)\n",
    "        #hidden=self.seq_l(hidden)\n",
    "     \n",
    "        mean=self.fully_connected_layer_a(hidden)\n",
    "        log_r=self.fully_connected_layer_b(hidden)\n",
    "        #bm=self.fully_connected_layer_band(hidden)\n",
    "        #log_beta=self.fully_connected_layer_beta(hidden)\n",
    "        #rand_ind=torch.randint(hidden.size(0), (hidden.size(0),))\n",
    "        #mean=mean[rand_ind]\n",
    "\n",
    "        \n",
    "        zs=self.reparametrize(mean,log_r) #uniform distribution\n",
    "\n",
    "        #z=torch.concat((K,z),dim=1)\n",
    "        \n",
    "        #zs=torch.tanh(zs)#+zs#self.fully_connected_layer(zs)\n",
    "      \n",
    "        # z=z.view(-1,self.latent_dim,4,4)\n",
    "        #decoding\n",
    "        reconstruction=self.decoder(zs)\n",
    "\n",
    "        return reconstruction,mean,log_r,zs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4c602b50-8be1-4940-9aa7-836a3ab633ab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([44, 58, 37, 21, 60, 31, 76, 67, 43, 88, 66, 27, 23, 30, 14,  3, 19, 97,\n",
       "         7, 73,  3, 19, 94,  2, 73, 21, 64, 79, 28, 41, 70, 16, 59, 67, 78, 30,\n",
       "        66, 77, 53, 90, 31, 54, 28, 99, 13, 77, 18, 59, 71, 46, 91, 39,  6,  8,\n",
       "        57, 25, 74, 69, 80, 79, 40, 93, 78, 67, 66, 19, 92, 34, 72, 45, 77, 17,\n",
       "        98, 14, 78, 63,  4, 87, 40, 53, 29, 63, 62,  4, 51, 56, 28, 80, 68,  3,\n",
       "        82, 79,  6, 70, 58, 13, 51, 26, 74, 94])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.randint(100, (100,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "79c70590",
   "metadata": {},
   "outputs": [],
   "source": [
    "def final_lossEVAE(bce_loss,log_r):\n",
    "\n",
    "\n",
    "    return bce_loss\n",
    "\n",
    "def model_trainEVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.train()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    for i, data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "        counter+=1\n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        reconstruction,mean,log_r,z=model(data)\n",
    "\n",
    "        bce_loss= criterion(reconstruction,data)\n",
    "        \n",
    "        loss=final_lossEVAE(bce_loss,log_r)\n",
    "        #print(loss)\n",
    "        loss.backward()\n",
    "        running_loss+=loss.item()\n",
    "        optimizer.step()\n",
    "    train_loss=running_loss/counter\n",
    "    #print(log_var)\n",
    "    return train_loss,z\n",
    "\n",
    "def model_validateEVAE(model,dataloader,dataset,device,optimizer,criterion):\n",
    "    model.eval()\n",
    "    running_loss=0.0\n",
    "    counter=0\n",
    "    with torch.no_grad():\n",
    "        for i,data in tqdm(enumerate(dataloader),total=int(len(dataset)/dataloader.batch_size)):\n",
    "            counter+=1\n",
    "            data=data[0]\n",
    "            data=data.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            reconstruction,*_=model(data)\n",
    "            bce_loss= criterion(reconstruction,data)\n",
    "            loss=bce_loss#final_loss(bce_loss,mu,log_var)\n",
    "            running_loss+=loss.item()\n",
    "            if i==int(len(dataset)/dataloader.batch_size)-1:\n",
    "                recon_images=reconstruction\n",
    "        valid_loss=running_loss/counter\n",
    "        return valid_loss,recon_images"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a445ce36",
   "metadata": {},
   "source": [
    "## dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0f8a070e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset MNIST\n",
       "    Number of datapoints: 60000\n",
       "    Root location: ./\n",
       "    Split: Train"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torchvision.datasets.MNIST('./',download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a8de03d4-0658-4a30-8158-f51211cc002a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2048"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "32*64"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "788b980d",
   "metadata": {},
   "source": [
    "# training EVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "b72a8856",
   "metadata": {},
   "outputs": [],
   "source": [
    "device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "latent_dim=64\n",
    "VAEmodel=VAE(latent_dim=latent_dim).to(device)\n",
    "EVAEmodel=EVAE(latent_dim=latent_dim,m=100).to(device)\n",
    "lr=0.0003\n",
    "epochs=100\n",
    "batch_size=100\n",
    "\n",
    "transform=transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor()])\n",
    "\n",
    "#training set transforms.Resize((32,32)),\n",
    "train_set=torchvision.datasets.CIFAR10(root='./',train=True,download=False,transform=transform)\n",
    "train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=True)\n",
    "\n",
    "#test set\n",
    "test_set=torchvision.datasets.CIFAR10(root='./',train=False,download=False,transform=transform)\n",
    "test_loader=torch.utils.data.DataLoader(test_set,batch_size=batch_size,shuffle=True)\n",
    "optimizerVAE=optim.Adam(VAEmodel.parameters(),lr=lr)\n",
    "optimizerEVAE=optim.Adam(EVAEmodel.parameters(),lr=lr)\n",
    "criterion=nn.BCELoss(reduction=\"sum\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "576ebcb9-e05d-4e83-bb4f-d4bb879ae972",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch1 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 73.45it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 104.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:188455.0131\n",
      "valid loss:181314.3730\n",
      "Epoch2 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.46it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 92.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:179446.5490\n",
      "valid loss:178234.6622\n",
      "Epoch3 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 69.66it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 111.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:177815.2917\n",
      "valid loss:177158.9847\n",
      "Epoch4 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 71.73it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 110.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:176987.5057\n",
      "valid loss:176518.9950\n",
      "Epoch5 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 73.35it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:176438.1384\n",
      "valid loss:176080.4492\n",
      "Epoch6 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 71.92it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 110.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:176099.0556\n",
      "valid loss:175793.2373\n",
      "Epoch7 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 73.46it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:175779.8094\n",
      "valid loss:175497.1245\n",
      "Epoch8 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 71.94it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:175560.4718\n",
      "valid loss:175383.0811\n",
      "Epoch9 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 73.68it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 108.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:175392.7302\n",
      "valid loss:175122.7870\n",
      "Epoch10 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.93it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 110.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:175239.6610\n",
      "valid loss:175025.5650\n",
      "Epoch11 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 73.14it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 108.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:175144.9624\n",
      "valid loss:175286.7436\n",
      "Epoch12 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 68.96it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 107.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:175013.1028\n",
      "valid loss:174888.9116\n",
      "Epoch13 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 69.30it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174950.2812\n",
      "valid loss:174871.0650\n",
      "Epoch14 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 68.33it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 103.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174844.9848\n",
      "valid loss:174830.8973\n",
      "Epoch15 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.03it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 110.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174808.8834\n",
      "valid loss:174678.0725\n",
      "Epoch16 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 71.09it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 112.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174751.7188\n",
      "valid loss:174769.1539\n",
      "Epoch17 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 73.00it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 111.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174695.3190\n",
      "valid loss:174639.9897\n",
      "Epoch18 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 71.64it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 110.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174636.2098\n",
      "valid loss:174641.5517\n",
      "Epoch19 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.85it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 108.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174646.4702\n",
      "valid loss:174560.1177\n",
      "Epoch20 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 68.67it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 106.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174582.5947\n",
      "valid loss:174685.7728\n",
      "Epoch21 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 69.14it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 106.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174541.5506\n",
      "valid loss:174577.8178\n",
      "Epoch22 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 68.40it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174554.6698\n",
      "valid loss:174557.5142\n",
      "Epoch23 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 74.02it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174495.8539\n",
      "valid loss:174522.1761\n",
      "Epoch24 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 72.17it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174490.3932\n",
      "valid loss:174487.6522\n",
      "Epoch25 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 73.57it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 111.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174477.1673\n",
      "valid loss:174478.2516\n",
      "Epoch26 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 72.34it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 108.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174458.4321\n",
      "valid loss:174451.6514\n",
      "Epoch27 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 71.58it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174442.7664\n",
      "valid loss:174452.0578\n",
      "Epoch28 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.95it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 105.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174420.6896\n",
      "valid loss:174507.9244\n",
      "Epoch29 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 69.24it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 110.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174406.6846\n",
      "valid loss:174386.8898\n",
      "Epoch30 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.94it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 111.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174393.4546\n",
      "valid loss:174390.5298\n",
      "Epoch31 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 73.32it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 110.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174365.0909\n",
      "valid loss:174379.3644\n",
      "Epoch32 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 72.73it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 103.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174329.1880\n",
      "valid loss:174393.1827\n",
      "Epoch33 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 73.53it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 110.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174334.8940\n",
      "valid loss:174371.7450\n",
      "Epoch34 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 68.31it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 97.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174335.3559\n",
      "valid loss:174354.0523\n",
      "Epoch35 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.96it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 98.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174321.2847\n",
      "valid loss:174420.3664\n",
      "Epoch36 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 74.59it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 107.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174288.0684\n",
      "valid loss:174362.3386\n",
      "Epoch37 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 76.25it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 94.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174270.6497\n",
      "valid loss:174401.5316\n",
      "Epoch38 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 75.78it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 108.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174274.9150\n",
      "valid loss:174401.2938\n",
      "Epoch39 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 75.96it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 117.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174255.6521\n",
      "valid loss:174345.0761\n",
      "Epoch40 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 72.80it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 104.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174243.1688\n",
      "valid loss:174345.4117\n",
      "Epoch41 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 69.83it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 114.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174226.0179\n",
      "valid loss:174327.7959\n",
      "Epoch42 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 74.85it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 112.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174221.0728\n",
      "valid loss:174329.7487\n",
      "Epoch43 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.96it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 114.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174227.3796\n",
      "valid loss:174391.5306\n",
      "Epoch44 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 72.32it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 114.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174205.1733\n",
      "valid loss:174300.9239\n",
      "Epoch45 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 71.72it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 110.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174209.8746\n",
      "valid loss:174298.9892\n",
      "Epoch46 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 71.59it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 102.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174189.4818\n",
      "valid loss:174329.3256\n",
      "Epoch47 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 74.01it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 112.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174186.6598\n",
      "valid loss:174331.7967\n",
      "Epoch48 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 76.77it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 120.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174156.1663\n",
      "valid loss:174301.2119\n",
      "Epoch49 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 74.00it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 116.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174146.6947\n",
      "valid loss:174330.8509\n",
      "Epoch50 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 73.71it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 112.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174140.3708\n",
      "valid loss:174279.0108\n",
      "Epoch51 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 71.43it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 114.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174143.6925\n",
      "valid loss:174254.1841\n",
      "Epoch52 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.66it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 104.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174161.8693\n",
      "valid loss:174315.2956\n",
      "Epoch53 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 68.94it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 111.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174136.3694\n",
      "valid loss:174275.3570\n",
      "Epoch54 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 69.88it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 98.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174126.6018\n",
      "valid loss:174262.9075\n",
      "Epoch55 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 68.44it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174112.9940\n",
      "valid loss:174385.6688\n",
      "Epoch56 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 71.37it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 106.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174099.8108\n",
      "valid loss:174286.0098\n",
      "Epoch57 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.14it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 107.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174106.7857\n",
      "valid loss:174261.0872\n",
      "Epoch58 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 68.48it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 107.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174084.6867\n",
      "valid loss:174250.1920\n",
      "Epoch59 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 67.08it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 107.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174087.8121\n",
      "valid loss:174266.6820\n",
      "Epoch60 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 68.06it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 97.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174073.8626\n",
      "valid loss:174284.1519\n",
      "Epoch61 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 69.79it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 107.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174081.5637\n",
      "valid loss:174238.6802\n",
      "Epoch62 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 73.66it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174087.8608\n",
      "valid loss:174263.9286\n",
      "Epoch63 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 71.05it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 108.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174041.3238\n",
      "valid loss:174334.8428\n",
      "Epoch64 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 67.68it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 101.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174062.4528\n",
      "valid loss:174257.0795\n",
      "Epoch65 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 67.28it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 105.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174055.4136\n",
      "valid loss:174267.3197\n",
      "Epoch66 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 68.00it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 94.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174039.8664\n",
      "valid loss:174262.2633\n",
      "Epoch67 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 67.91it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 104.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174037.3165\n",
      "valid loss:174258.6908\n",
      "Epoch68 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 68.90it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 100.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174032.3724\n",
      "valid loss:174248.7409\n",
      "Epoch69 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 69.24it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 112.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174046.5757\n",
      "valid loss:174262.5478\n",
      "Epoch70 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 69.80it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 98.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174020.0746\n",
      "valid loss:174285.7395\n",
      "Epoch71 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 69.11it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174025.9421\n",
      "valid loss:174222.8684\n",
      "Epoch72 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 68.77it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 97.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174033.6541\n",
      "valid loss:174252.9833\n",
      "Epoch73 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 65.09it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 96.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174025.0917\n",
      "valid loss:174254.9916\n",
      "Epoch74 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 67.78it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 102.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174006.8483\n",
      "valid loss:174227.6006\n",
      "Epoch75 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 66.93it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 105.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174006.3233\n",
      "valid loss:174238.7289\n",
      "Epoch76 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.12it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 105.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173981.6122\n",
      "valid loss:174334.2306\n",
      "Epoch77 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 69.88it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 91.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:174003.7254\n",
      "valid loss:174252.2981\n",
      "Epoch78 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 68.15it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 99.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173988.1699\n",
      "valid loss:174282.6994\n",
      "Epoch79 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 66.42it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 101.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173987.1657\n",
      "valid loss:174245.9647\n",
      "Epoch80 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 65.86it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 100.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173983.1955\n",
      "valid loss:174397.3011\n",
      "Epoch81 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 66.46it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 104.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173978.8248\n",
      "valid loss:174276.9877\n",
      "Epoch82 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 65.93it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173983.8967\n",
      "valid loss:174258.7502\n",
      "Epoch83 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 69.01it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 101.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173972.8204\n",
      "valid loss:174254.9505\n",
      "Epoch84 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 66.94it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173961.4702\n",
      "valid loss:174228.7022\n",
      "Epoch85 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 72.61it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 110.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173968.5227\n",
      "valid loss:174226.0567\n",
      "Epoch86 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.98it/s]\n",
      "100%|██████████| 100/100 [00:01<00:00, 91.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173963.2603\n",
      "valid loss:174258.7427\n",
      "Epoch87 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.52it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173956.1121\n",
      "valid loss:174228.1398\n",
      "Epoch88 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.11it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 100.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173941.5230\n",
      "valid loss:174247.8734\n",
      "Epoch89 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 69.39it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 108.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173937.4407\n",
      "valid loss:174322.2678\n",
      "Epoch90 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.69it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 108.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173942.5482\n",
      "valid loss:174261.8722\n",
      "Epoch91 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 71.20it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 110.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173929.0306\n",
      "valid loss:174220.5952\n",
      "Epoch92 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 71.34it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173943.2765\n",
      "valid loss:174230.8106\n",
      "Epoch93 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 72.94it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 104.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173934.6421\n",
      "valid loss:174401.2134\n",
      "Epoch94 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 71.59it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 108.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173927.8533\n",
      "valid loss:174249.1427\n",
      "Epoch95 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 72.38it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173930.2808\n",
      "valid loss:174304.7605\n",
      "Epoch96 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.81it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 110.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173924.4072\n",
      "valid loss:174248.7078\n",
      "Epoch97 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 72.95it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173922.9222\n",
      "valid loss:174248.6572\n",
      "Epoch98 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.32it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 105.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173914.9089\n",
      "valid loss:174237.7161\n",
      "Epoch99 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:06<00:00, 71.86it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 111.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173912.7696\n",
      "valid loss:174242.8811\n",
      "Epoch100 of 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:07<00:00, 70.58it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 109.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss:173908.6643\n",
      "valid loss:174243.8378\n"
     ]
    }
   ],
   "source": [
    "##### grid_imagesVAE=[]\n",
    "train_lossVAE=[]\n",
    "valid_lossVAE=[]\n",
    "\n",
    "grid_imagesEVAE=[]\n",
    "train_lossEVAE=[]\n",
    "valid_lossEVAE=[]\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    print(f\"Epoch{epoch+1} of {epochs}\")\n",
    "    # train_epoch_lossVAE=model_trainVAE(VAEmodel,train_loader,train_set,device,optimizerVAE,criterion)\n",
    "    # valid_epoch_lossVAE=model_validateVAE(VAEmodel,test_loader,test_set,device,optimizerVAE,criterion)\n",
    "    # train_lossVAE.append(train_epoch_lossVAE)\n",
    "    # valid_lossVAE.append(valid_epoch_lossVAE)\n",
    "    \n",
    "    train_epoch_lossEVAE,z=model_trainEVAE(EVAEmodel,train_loader,train_set,device,optimizerEVAE,criterion)\n",
    "    valid_epoch_lossEVAE,recon_imagesEVAE=model_validateEVAE(EVAEmodel,test_loader,test_set,device,optimizerEVAE,criterion)\n",
    "    train_lossEVAE.append(train_epoch_lossEVAE)\n",
    "    valid_lossEVAE.append(valid_epoch_lossEVAE)\n",
    "  #  mus.append(mu)\n",
    "   # varss.append(logvar.exp())\n",
    "#     #save images recon\n",
    "#     save_reconstructed_imagesVAE(recon_imagesVAE,epoch+1)\n",
    "#     save_reconstructed_imagesEVAE(recon_imagesEVAE,epoch+1)\n",
    "# #     save_reconstructed_imagesREAL(real_images,epoch+1)\n",
    "#     #grid_images.append(image_grid)\n",
    "#     # print(f\"train loss:{train_epoch_lossVAE:.4f}\")\n",
    "#     # print(f\"valid loss:{valid_epoch_lossVAE:.4f}\")\n",
    "    print(f\"train loss:{train_epoch_lossEVAE:.4f}\")\n",
    "    print(f\"valid loss:{valid_epoch_lossEVAE:.4f}\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "11c0a6d9-4116-4c06-8877-e1762b9cd3eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Root directory for dataset\n",
    "dataroot = \"data/celeba\"\n",
    "device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "# Number of workers for dataloader\n",
    "workers = 2\n",
    "\n",
    "# Batch size during training\n",
    "batch_size = 100\n",
    "\n",
    "# Spatial size of training images. All images will be resized to this\n",
    "#   size using a transformer.\n",
    "image_size = 64\n",
    "\n",
    "# Number of channels in the training images. For color images this is 3\n",
    "nc =4\n",
    "\n",
    "# Size of z latent vector (i.e. size of generator input)\n",
    "nz = 64\n",
    "\n",
    "# Size of feature maps in generator\n",
    "ngf = 64\n",
    "\n",
    "# Size of feature maps in discriminator\n",
    "ndf =64\n",
    "\n",
    "# Number of training epochs\n",
    "num_epochs = 5\n",
    "\n",
    "# Learning rate for optimizers\n",
    "lr = 0.0001\n",
    "\n",
    "# Beta1 hyperparameter for Adam optimizers\n",
    "beta1 = 0.5\n",
    "\n",
    "# Number of GPUs available. Use 0 for CPU mode.\n",
    "ngpu = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "50a0e6cb-87ba-4d80-9421-095c9524afee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# custom weights initialization called on ``netG`` and ``netD``\n",
    "def weights_init(m):\n",
    "    classname = m.__class__.__name__\n",
    "    if classname.find('Conv') != -1:\n",
    "        nn.init.normal_(m.weight.data, 0.0, 0.02)\n",
    "    elif classname.find('BatchNorm') != -1:\n",
    "        nn.init.normal_(m.weight.data, 1.0, 0.02)\n",
    "        nn.init.constant_(m.bias.data, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "51c5cc94-9aa2-491c-b2a7-01de8948075a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResidualLayer(nn.Module):\n",
    "\n",
    "    def __init__(self,\n",
    "                 in_channels: int,\n",
    "                 out_channels: int):\n",
    "        super(ResidualLayer, self).__init__()\n",
    "        self.resblock = nn.Sequential(nn.Linear(in_channels, out_channels,bias=True),\n",
    "                                      nn.GELU(),\n",
    "                                      nn.Linear(out_channels, out_channels,bias=True))\n",
    "\n",
    "    def forward(self, input) :\n",
    "        return input + self.resblock(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "470f4829-c764-4813-9208-4ab14618e0a1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "30669492-9c8b-41f7-afce-9cafea6fbbe5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generator Code\n",
    "class Generator(nn.Module):\n",
    "    def __init__(self, ngpu):\n",
    "        super(Generator, self).__init__()\n",
    "        self.ngpu = ngpu\n",
    "        # self.main = nn.Sequential(\n",
    "        #     # input is Z, going into a convolution\n",
    "        #     nn.ConvTranspose2d( n_latent, ngf * 8, 4, 1, 0, bias=False),\n",
    "        #     nn.BatchNorm2d(ngf * 8),\n",
    "        #     nn.ReLU(True),\n",
    "        #     # state size. (ngf*8) x 4 x 4\n",
    "        #     nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),\n",
    "        #     nn.BatchNorm2d(ngf * 4),\n",
    "        #     nn.ReLU(True),\n",
    "            \n",
    "        #     # state size. (ngf*4) x 8 x 8\n",
    "        #     nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),\n",
    "        #     nn.BatchNorm2d(ngf * 2),\n",
    "        #     nn.ReLU(True),\n",
    "        #     # state size. (ngf*2) x 16 x 16\n",
    "        #     nn.ConvTranspose2d( ngf * 2, ngf * 1, 4, 2, 1, bias=False),\n",
    "        #     nn.BatchNorm2d(ngf * 1),\n",
    "        #     nn.ReLU(True),\n",
    "        #     nn.ConvTranspose2d( ngf * 1, 3, 3, 1, 1, bias=False),\n",
    "        #     #nn.Tanh()\n",
    "        #     # state size. (nc) x 32 x 32\n",
    "        # )\n",
    "        \n",
    "        self.latent = nn.Sequential(\n",
    "            # input is Z, going into a convolution\n",
    "            nn.Linear( nz, ngf * 2,bias=False),\n",
    "            nn.BatchNorm1d(ngf * 2),\n",
    "            nn.GELU(),\n",
    "            # state size. (ngf*8) x 4 x 4\n",
    "            nn.Linear(ngf * 2, ngf * 4, bias=False),\n",
    "            nn.BatchNorm1d(ngf * 4),\n",
    "            nn.GELU(),\n",
    "            \n",
    "            # state size. (ngf*4) x 8 x 8\n",
    "            nn.Linear( ngf * 4, ngf * 8, bias=False),\n",
    "            nn.BatchNorm1d(ngf * 8),\n",
    "            nn.GELU(),\n",
    "            # state size. (ngf*2) x 16 x 16\n",
    "            \n",
    "            nn.Linear( ngf * 8, n_latent , bias=False),\n",
    "           \n",
    "            \n",
    "        )\n",
    "    \n",
    "    def forward(self, input):\n",
    "        latent=self.latent(input)\n",
    "        #x=EVAEmodel.decoder(latent.view(-1,n_latent))#self.main(latent.view(-1,n_latent,1,1))\n",
    "        #print(x.size())\n",
    "        return latent#x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "f3311c95-4b4f-414b-8a22-4f1ff376c731",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([32, 64])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_latent=64\n",
    "nz=64\n",
    "# Create the generator\n",
    "netG = Generator(ngpu).to(device)\n",
    "\n",
    "# Handle multi-GPU if desired\n",
    "if (device.type == 'cuda') and (ngpu > 1):\n",
    "    netG = nn.DataParallel(netG, list(range(ngpu)))\n",
    "\n",
    "# Apply the ``weights_init`` function to randomly initialize all weights\n",
    "#  to ``mean=0``, ``stdev=0.02``.\n",
    "#netG.apply(weights_init)\n",
    "\n",
    "# Print the model\n",
    "temp=netG(torch.randn(32,64).to(device))\n",
    "temp.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "ba26c7d1-c114-4cdb-a4bc-aa0d6c016ae7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Discriminator(nn.Module):\n",
    "    def __init__(self, ngpu):\n",
    "        super(Discriminator, self).__init__()\n",
    "        self.ngpu = ngpu\n",
    "        # self.main = nn.Sequential(\n",
    "        #     # input is (nc) x 32 x 32\n",
    "        #     nn.Conv2d(3, ndf, 4, 2, 2, bias=False),\n",
    "        #     nn.LeakyReLU(0.2, inplace=True),\n",
    "        #     # state size. (ndf) x 16 x 16\n",
    "        #     nn.Conv2d(ndf, ndf * 2, 4, 2, 2, bias=False),\n",
    "        #     nn.BatchNorm2d(ndf * 2),\n",
    "        #     nn.LeakyReLU(0.2, inplace=True),\n",
    "            \n",
    "        #     # state size. (ndf*2) x 8 x 8\n",
    "        #     nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 2, bias=False),\n",
    "        #     nn.BatchNorm2d(ndf * 4),\n",
    "        #     nn.LeakyReLU(0.2, inplace=True),\n",
    "        #     #state size. (ndf*4) x 4 x 4\n",
    "        #     nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 2, bias=False),\n",
    "        #     nn.BatchNorm2d(ndf * 8),\n",
    "        #     nn.LeakyReLU(0.2, inplace=True),\n",
    "            \n",
    "        #     nn.Conv2d(ndf *8, 1, 3, 1,0, bias=False),\n",
    "        #     nn.Sigmoid()\n",
    "        # )\n",
    "        #self.att=MultiHeadAttention(d_in=ndf *8,d_out=ndf *8,dropout=0.2,num_heads=4)\n",
    "        self.main = nn.Sequential(\n",
    "            # input is (nc) x 32 x 32\n",
    "            nn.Linear(64, ndf, bias=True),\n",
    "            nn.GELU(),\n",
    "            # state size. (ndf) x 16 x 16\n",
    "            nn.Linear(ndf, ndf * 2,bias=True),\n",
    "            nn.BatchNorm1d(ndf * 2),\n",
    "            nn.GELU(),\n",
    "            \n",
    "            # state size. (ndf*2) x 8 x 8\n",
    "            nn.Linear(ndf * 2, ndf * 4, bias=True),\n",
    "            nn.BatchNorm1d(ndf * 4),\n",
    "            nn.GELU(),\n",
    "            nn.Linear(ndf * 4, ndf * 8, bias=True),\n",
    "            nn.BatchNorm1d(ndf * 8),\n",
    "            nn.GELU(),\n",
    "            #state size. (ndf*4) x 4 x 4\n",
    "            #MultiHeadAttention(d_in=ndf *8,d_out=ndf *8,dropout=0.2,num_heads=4),\n",
    "            nn.Linear(ndf *8, 1, bias=True),\n",
    "            \n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "    def forward(self, input):\n",
    "\n",
    "        \n",
    "        return self.main(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "5c0b692e-ff08-4a46-a4f5-cfd1afdbcbd9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 1])\n"
     ]
    }
   ],
   "source": [
    "# Create the Discriminator\n",
    "netD = Discriminator(ngpu).to(device)\n",
    "\n",
    "# Handle multi-GPU if desired\n",
    "if (device.type == 'cuda') and (ngpu > 1):\n",
    "    netD = nn.DataParallel(netD, list(range(ngpu)))\n",
    "    \n",
    "# Apply the ``weights_init`` function to randomly initialize all weights\n",
    "# like this: ``to mean=0, stdev=0.2``.\n",
    "netD.apply(weights_init)\n",
    "\n",
    "# Print the model\n",
    "t2=netD(temp)\n",
    "print(t2.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "ca52323d-3dc4-49ab-8c59-a375d897a4b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "netG = Generator(ngpu).to(device)\n",
    "\n",
    "# Handle multi-GPU if desired\n",
    "if (device.type == 'cuda') and (ngpu > 1):\n",
    "    netG = nn.DataParallel(netG, list(range(ngpu)))\n",
    "\n",
    "# Apply the ``weights_init`` function to randomly initialize all weights\n",
    "#  to ``mean=0``, ``stdev=0.02``.\n",
    "#netG.apply(weights_init)\n",
    "\n",
    "# Create the Discriminator\n",
    "netD = Discriminator(ngpu).to(device)\n",
    "\n",
    "# Handle multi-GPU if desired\n",
    "if (device.type == 'cuda') and (ngpu > 1):\n",
    "    netD = nn.DataParallel(netD, list(range(ngpu)))\n",
    "    \n",
    "# Apply the ``weights_init`` function to randomly initialize all weights\n",
    "# like this: ``to mean=0, stdev=0.2``.\n",
    "netD.apply(weights_init)\n",
    "# Initialize the ``BCELoss`` function\n",
    "criterion = nn.BCELoss()\n",
    "\n",
    "# Create batch of latent vectors that we will use to visualize\n",
    "#  the progression of the generator\n",
    "fixed_noise = torch.randn(100, 48, device=device)\n",
    "\n",
    "# Establish convention for real and fake labels during training\n",
    "real_label = 1.\n",
    "fake_label = 0.\n",
    "\n",
    "# Setup Adam optimizers for both G and D\n",
    "optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))\n",
    "optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "1c346bc1-d582-42a7-a87d-8c0a5c0ef133",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gradient_penalty(D, xr, xf):\n",
    "    \"\"\"\n",
    "\n",
    "    :param D:\n",
    "    :param xr: [b, 2]\n",
    "    :param xf: [b, 2]\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    # [b, 1]\n",
    "    t = torch.rand(batch_size, 1).cuda()\n",
    "    # [b, 1] => [b, 2]  broadcasting so t is the same for x1 and x2\n",
    "    t = t.expand_as(xr)\n",
    "    # interpolation\n",
    "    mid = t * xr + (1 - t) * xf\n",
    "    # set it to require grad info\n",
    "    mid.requires_grad_()\n",
    "\n",
    "    pred = D(mid)\n",
    "    grads = torch.autograd.grad(outputs=pred, inputs=mid,\n",
    "                          grad_outputs=torch.ones_like(pred),\n",
    "                          create_graph=True, retain_graph=True, only_inputs=True)[0]\n",
    "\n",
    "    gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()\n",
    "\n",
    "    return gp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "45464f07-2ace-43d6-a66b-a3439178b92d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting Training Loop...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:09<00:00, 55.07it/s]\n",
      "100%|██████████| 500/500 [00:10<00:00, 48.88it/s]\n",
      "100%|██████████| 500/500 [00:10<00:00, 48.23it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.34it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.65it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.74it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.95it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.09it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.26it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 51.94it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 54.41it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 54.81it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 53.35it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 51.79it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 51.30it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 51.21it/s]\n",
      "100%|██████████| 500/500 [00:08<00:00, 55.76it/s]\n",
      "100%|██████████| 500/500 [00:08<00:00, 56.09it/s]\n",
      "100%|██████████| 500/500 [00:08<00:00, 55.93it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 55.37it/s]\n",
      "100%|██████████| 500/500 [00:08<00:00, 56.54it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 55.46it/s]\n",
      "100%|██████████| 500/500 [00:08<00:00, 56.09it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 54.92it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 54.95it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 54.34it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.13it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.28it/s]\n",
      "100%|██████████| 500/500 [00:10<00:00, 49.23it/s]\n",
      "100%|██████████| 500/500 [00:10<00:00, 49.63it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.04it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.75it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.62it/s]\n",
      "100%|██████████| 500/500 [00:10<00:00, 49.87it/s]\n",
      "100%|██████████| 500/500 [00:10<00:00, 49.68it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 54.89it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 53.65it/s]\n",
      "100%|██████████| 500/500 [00:08<00:00, 55.87it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 51.48it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 51.78it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 51.43it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 54.36it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 54.97it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 53.34it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 54.94it/s]\n",
      "100%|██████████| 500/500 [00:08<00:00, 55.83it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 55.38it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 53.32it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 53.17it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 55.19it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 52.50it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.35it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.39it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.76it/s]\n",
      "100%|██████████| 500/500 [00:10<00:00, 49.03it/s]\n",
      "100%|██████████| 500/500 [00:10<00:00, 49.94it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 51.27it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.69it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.94it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.34it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 53.07it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 53.63it/s]\n",
      "100%|██████████| 500/500 [00:08<00:00, 55.75it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 52.40it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 51.85it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.92it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 52.14it/s]\n",
      "100%|██████████| 500/500 [00:08<00:00, 55.67it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 54.65it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 52.76it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 55.31it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 53.88it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 55.40it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.99it/s]\n",
      "100%|██████████| 500/500 [00:10<00:00, 49.72it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 52.91it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.75it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.79it/s]\n",
      "100%|██████████| 500/500 [00:10<00:00, 49.60it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.67it/s]\n",
      "100%|██████████| 500/500 [00:10<00:00, 49.47it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.36it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.46it/s]\n",
      "100%|██████████| 500/500 [00:10<00:00, 49.21it/s]\n",
      "100%|██████████| 500/500 [00:10<00:00, 49.71it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 55.37it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 54.44it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 55.47it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.17it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 50.79it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 51.51it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 51.70it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 53.68it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 52.76it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 52.95it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 54.40it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 53.00it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 54.77it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 51.25it/s]\n",
      "100%|██████████| 500/500 [00:09<00:00, 53.18it/s]\n"
     ]
    }
   ],
   "source": [
    "# Training Loop\n",
    "\n",
    "# Lists to keep track of progress\n",
    "img_list = []\n",
    "G_losses = []\n",
    "D_losses = []\n",
    "iters = 0\n",
    "#fid1 = FrechetInceptionDistance(device=device)\n",
    "print(\"Starting Training Loop...\")\n",
    "# For each epoch\n",
    "EVAEmodel.eval()\n",
    "for epoch in range(100):\n",
    "    # For each batch in the dataloader\n",
    "    for i, data in tqdm(enumerate(train_loader),total=int(len(train_set)/train_loader.batch_size)):\n",
    "        \n",
    "        ############################\n",
    "        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))\n",
    "        ###########################\n",
    "        ## Train with all-real batch\n",
    "        netD.zero_grad()\n",
    "        # Format batch\n",
    "        images = data[0].to(device)\n",
    "        with torch.no_grad():\n",
    "            #real_latent,*_= EVAEmodel(images) #*_,real_latent= EVAEmodel(images)#DKGMmodel_generate.encoder_initial(images)\n",
    "            *_,real_latent= EVAEmodel(images)\n",
    "        real_latent=real_latent\n",
    "        b_size = real_latent.size(0)\n",
    "        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)\n",
    "        # Forward pass real batch through D\n",
    "        output = netD(real_latent).view(-1)\n",
    "        # Calculate loss on all-real batch\n",
    "        errD_real = criterion(output, label)\n",
    "        # Calculate gradients for D in backward pass\n",
    "        errD_real.backward()\n",
    "        D_x = output.mean().item()\n",
    "\n",
    "        ## Train with all-fake batch\n",
    "        # Generate batch of latent vectors\n",
    "        noise = torch.randn(b_size, nz, device=device)\n",
    "        # Generate fake image batch with G\n",
    "        fake = netG(noise)\n",
    "        label.fill_(fake_label)\n",
    "        # Classify all fake batch with D\n",
    "        output = netD(fake.detach()).view(-1)\n",
    "        # Calculate D's loss on the all-fake batch\n",
    "        errD_fake = criterion(output, label)\n",
    "        # Calculate the gradients for this batch, accumulated (summed) with previous gradients\n",
    "        errD_fake.backward()\n",
    "        D_G_z1 = output.mean().item()\n",
    "        # Compute error of D as sum over the fake and the real batches\n",
    "        #gp = gradient_penalty(netD, real_latent, fake.detach())\n",
    "        errD = errD_real + errD_fake#+ 1 * gp\n",
    "        # Update D\n",
    "        #torch.nn.utils.clip_grad_norm_(netD.parameters(), max_norm=1.0) \n",
    "        optimizerD.step()\n",
    "\n",
    "        ############################\n",
    "        # (2) Update G network: maximize log(D(G(z)))\n",
    "        ###########################\n",
    "        netG.zero_grad()\n",
    "        label.fill_(real_label)  # fake labels are real for generator cost\n",
    "        # Since we just updated D, perform another forward pass of all-fake batch through D\n",
    "        output = netD(fake).view(-1)\n",
    "        # Calculate G's loss based on this output\n",
    "        errG = criterion(output, label)\n",
    "        # Calculate gradients for G\n",
    "        errG.backward()\n",
    "        D_G_z2 = output.mean().item()\n",
    "        # Update G\n",
    "        #torch.nn.utils.clip_grad_norm_(netG.parameters(), max_norm=1.0) \n",
    "        optimizerG.step()\n",
    "\n",
    "            \n",
    "        iters += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00dfd02b-f687-47ef-a8b9-d6e9297a5f92",
   "metadata": {},
   "outputs": [],
   "source": [
    "errD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "9561350d-d9d7-44ac-9198-a5bc6e12e46c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.8387, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward0>)"
      ]
     },
     "execution_count": 77,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "errG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 293,
   "id": "1dd02599-89a4-46ba-a7c7-933f12f269b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjMAAAHFCAYAAAAHcXhbAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAN/RJREFUeJzt3Xl8VNX9//H3AMkkwRAgISSRJKBAQEFkqSxVCFCQ9cvShVaBoFBR1IpAeQhoCdayFsSKin6LAUtR6kattlRUtgpYQHDBgKghEzZxIhAgCyE5vz/4Zb6GJJCZLJMTXs/HYx4wZ+6593PuGTJv7tyb6zDGGAEAAFiqjr8LAAAAqAjCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMrmorV66Uw+HQrl27Sn19yJAhat68ebG25s2ba9y4cV5tZ9u2bUpOTtapU6d8K/QqtHbtWt14440KDg6Ww+HQ3r17S11u06ZNcjgcnkfdunXVpEkTDR06tMx5rQ5F761Dhw75rYYfOnToULH9FBAQoPDwcP3oRz/Sww8/rH379vm87uzsbCUnJ2vTpk2VVzDgBcIM4KU333xTjz32mFd9tm3bpjlz5hBmyum7777TmDFjdP3112v9+vXavn27Wrdufdk+c+fO1fbt27Vp0yY99thj2rZtm3r16qWDBw9WU9V2ePDBB7V9+3Zt3rxZf/nLXzR8+HC99dZb6tChgxYtWuTTOrOzszVnzhzCDPymnr8LAGzTsWNHf5fgtfz8fDkcDtWrZ8c/+S+//FL5+fkaPXq0evXqVa4+rVq1Urdu3SRJt912mxo2bKikpCStXr1ac+bMqcpyrRIXF+fZT5I0aNAgTZkyRSNHjtT06dPVrl07DRw40I8VAt7jyAzgpUu/ZiosLNQTTzyhhIQEBQcHq2HDhrrpppv01FNPSZKSk5P129/+VpLUokULz2H+ov/FFhYWauHChWrTpo2cTqciIyM1duxYHT58uNh2jTGaO3eu4uPjFRQUpC5dumjDhg1KTExUYmKiZ7mir13+8pe/aOrUqbr22mvldDr11Vdf6bvvvtOkSZN0ww036JprrlFkZKT69OmjrVu3FttW0VcSixYt0oIFC9S8eXMFBwcrMTHREzQeeeQRxcTEKCwsTCNGjNCJEyfKtf/eeustde/eXSEhIQoNDVW/fv20fft2z+vjxo3TrbfeKkkaNWqUHA5HsfGVV5cuXSRJ3377bbH2OXPmqGvXrmrcuLEaNGigTp06acWKFbr0nrvNmzfXkCFDtH79enXq1EnBwcFq06aNXnzxxRLb2rFjh3784x8rKChIMTExmjFjhvLz80ssV965TkxMVLt27bR9+3b16NFDwcHBat68uVJSUiRJ77zzjjp16qSQkBC1b99e69ev93r//FBwcLBWrFihgICAYkdnyvN+OXTokJo0aSLp4r4ten8X/Rv56quvdNddd6lVq1YKCQnRtddeq6FDh+qzzz6rUM3AD9nx3zSgihUUFOjChQsl2stzU/mFCxcqOTlZjz76qHr27Kn8/Hzt37/f85XShAkT9P333+vpp5/WG2+8oejoaEnSDTfcIEm677779MILL+iBBx7QkCFDdOjQIT322GPatGmTPv74Y0VEREiSZs2apXnz5umee+7RyJEjlZGRoQkTJig/P7/Ur2BmzJih7t27a/ny5apTp44iIyP13XffSZJmz56tqKgonT17Vm+++aYSExP1/vvvlwgNzzzzjG666SY988wzOnXqlKZOnaqhQ4eqa9euCggI0Isvvqj09HRNmzZNEyZM0FtvvXXZfbVmzRrdeeed6t+/v15++WXl5eVp4cKFnu3feuuteuyxx3TLLbfo/vvv19y5c9W7d281aNDgivNwqbS0NEkqsW8OHTqkiRMnKi4uTtLFIPLggw/qyJEj+t3vflds2U8++URTp07VI488oqZNm+rPf/6zxo8fr5YtW6pnz56SpC+++EJ9+/ZV8+bNtXLlSoWEhOjZZ5/VmjVrStRU3rmWpOPHj+uuu+7S9OnT1axZMz399NO6++67lZGRoddee00zZ85UWFiYHn/8cQ0fPlzffPONYmJivN5PRWJiYtS5c2dt27ZNFy5cUL169fT9999Luvz7JTo6WuvXr9eAAQM0fvx4TZgwQZI8Aefo0aMKDw/X/Pnz1aRJE33//fdatWqVunbtqj179ighIcHnmgEPA1zFUlJSjKTLPuLj44v1iY+PN0lJSZ7nQ4YMMTfffPNlt7No0SIjyaSlpRVrT01NNZLMpEmTirV/9NFHRpKZOXOmMcaY77//3jidTjNq1Khiy23fvt1IMr169fK0bdy40UgyPXv2vOL4L1y4YPLz803fvn3NiBEjPO1paWlGkunQoYMpKCjwtC9dutRIMv/zP/9TbD2TJ082kszp06fL3FZBQYGJiYkx7du3L7bOM2fOmMjISNOjR48SY3j11VevOIaiZdeuXWvy8/NNdna2+fDDD01CQoK54YYbzMmTJy9bU35+vnn88cdNeHi4KSws9LwWHx9vgoKCTHp6uqctJyfHNG7c2EycONHTNmrUKBMcHGyOHz/uabtw4YJp06ZNsTkv71wbY0yvXr2MJLNr1y5PW2Zmpqlbt64JDg42R44c8bTv3bvXSDJ/+tOfLrufiuZ00aJFZS4zatQoI8l8++23pb5e1vvlu+++M5LM7NmzL1tD0TrOnz9vWrVqZR5++OErLg+UB18zAZJeeukl7dy5s8Sj6OuOy7nlllv0ySefaNKkSfr3v/+trKyscm9348aNklTi6qhbbrlFbdu21fvvvy/p4tGDvLw8/eIXvyi2XLdu3UpcbVXkpz/9aanty5cvV6dOnRQUFKR69eopICBA77//vlJTU0ssO2jQINWp838/Jtq2bStJGjx4cLHlitpdLlcZI5UOHDigo0ePasyYMcXWec011+inP/2pduzYoezs7DL7X8moUaMUEBCgkJAQ/fjHP1ZWVpbeeecdNWzYsNhyH3zwgX7yk58oLCxMdevWVUBAgH73u98pMzOzxFdlN998s+cIjiQFBQWpdevWSk9P97Rt3LhRffv2VdOmTT1tdevW1ahRo4qtq7xzXSQ6OlqdO3f2PG/cuLEiIyN18803FzsCU7Tvf1iTr0wpRyK9eb+U5sKFC5o7d65uuOEGBQYGql69egoMDNTBgwfLvQ7gSggzgC5+IHTp0qXEIyws7Ip9Z8yYoT/+8Y/asWOHBg4cqPDwcPXt27dclwVnZmZKkuerpx+KiYnxvF705w8/MIuU1lbWOpcsWaL77rtPXbt21euvv64dO3Zo586dGjBggHJyckos37hx42LPAwMDL9uem5tbai0/HENZYy0sLNTJkyfL7H8lCxYs0M6dO7V582bNmjVL3377rYYPH668vDzPMv/973/Vv39/SdL//u//6sMPP9TOnTs1a9YsSSqxD8LDw0tsx+l0FlsuMzNTUVFRJZa7tK28c13k0n0sXdzPvuz78kpPT5fT6fRsw9v3S2mmTJmixx57TMOHD9c//vEPffTRR9q5c6c6dOhQ7nUAV8I5M0AF1atXT1OmTNGUKVN06tQpvffee5o5c6Zuv/12ZWRkKCQkpMy+RR+Wx44dU7NmzYq9dvToUc85FEXLXXoyq3Tx3IrSjs44HI4SbatXr1ZiYqKee+65Yu1nzpy5/CArwQ/HeqmjR4+qTp06atSokc/rv+666zwn/fbs2VPBwcF69NFH9fTTT2vatGmSpFdeeUUBAQF6++23FRQU5Om7bt06n7cbHh6u48ePl2i/tK28c+0vR44c0e7du9WrVy/PVW+V8X5ZvXq1xo4dq7lz5xZrd7vdJY6aAb7iyAxQiRo2bKif/exnuv/++/X99997fmGa0+mUVPJ//n369JF08Qf+D+3cuVOpqanq27evJKlr165yOp1au3ZtseV27Njh1dcLDofDU0uRTz/9tNjVRFUlISFB1157rdasWVPs64xz587p9ddf91zhVFmmT5+uli1bav78+Z4P36LL0+vWretZLicnR3/5y1983k7v3r31/vvvFwuaBQUFJeaqvHPtDzk5OZowYYIuXLig6dOne9rL+34p6/1d1jreeecdHTlypLLKBzgyA1TU0KFD1a5dO3Xp0kVNmjRRenq6li5dqvj4eLVq1UqS1L59e0nSU089paSkJAUEBCghIUEJCQm655579PTTT6tOnToaOHCg5wqX2NhYPfzww5IufuUwZcoUzZs3T40aNdKIESN0+PBhzZkzR9HR0cXOQbmcIUOG6Pe//71mz56tXr166cCBA3r88cfVokWLUq/mqkx16tTRwoULdeedd2rIkCGaOHGi8vLytGjRIp06dUrz58+v1O0FBARo7ty5+sUvfqGnnnpKjz76qAYPHqwlS5bojjvu0D333KPMzEz98Y9/LPFh641HH31Ub731lvr06aPf/e53CgkJ0TPPPKNz584VW668c13VXC6XduzYocLCQp0+fVp79uzxXJW2ePFiz9dwUvnfL6GhoYqPj9ff//539e3bV40bN1ZERITn8vaVK1eqTZs2uummm7R7924tWrSoxNEpoEL8fQYy4E9FVzPt3Lmz1NcHDx58xauZFi9ebHr06GEiIiJMYGCgiYuLM+PHjzeHDh0q1m/GjBkmJibG1KlTx0gyGzduNMZcvKJmwYIFpnXr1iYgIMBERESY0aNHm4yMjGL9CwsLzRNPPGGaNWtmAgMDzU033WTefvtt06FDh2JXllzuSqC8vDwzbdo0c+2115qgoCDTqVMns27dOpOUlFRsnGVd+VLWuq+0H39o3bp1pmvXriYoKMjUr1/f9O3b13z44Yfl2k5prrRs165dTaNGjcypU6eMMca8+OKLJiEhwTidTnPdddeZefPmmRUrVpS42iw+Pt4MHjy4xPp69epV7OoxY4z58MMPTbdu3YzT6TRRUVHmt7/9rXnhhRdKrLO8c92rVy9z4403lth2WTVJMvfff39Zu8gY839zWvSoW7euadSokencubOZPHmy2bdvX4k+5X2/GGPMe++9Zzp27GicTqeR5Pk3cvLkSTN+/HgTGRlpQkJCzK233mq2bt1a6n4EfOUwphy/SANAjZSWlqY2bdpo9uzZmjlzpr/LAQC/IMwAlvjkk0/08ssvq0ePHmrQoIEOHDighQsXKisrS59//nmZVzUBQG3HOTOAJerXr69du3ZpxYoVOnXqlMLCwpSYmKg//OEPBBkAVzWOzAAAAKtxaTYAALAaYQYAAFiNMAMAAKxW608ALiws1NGjRxUaGlrqr3cHAAA1jzFGZ86cUUxMzBV/MWitDzNHjx5VbGysv8sAAAA+yMjIuOJvjK71YSY0NFTSxZ3RoEEDP1cDAADKIysrS7GxsZ7P8cup9WGm6KulBg0aEGYAALBMeU4R4QRgAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKvV83cBAOByueR2u33qm5eXJ6fT6VPfiIgIxcXF+dQXQM1BmAHgVy6XSwlt2io3J9u3FTjqSKbQp65BwSE6sD+VQANYjjADwK/cbrdyc7IVPmSqAsJjveqb880und662qe++ZkZynx7sdxuN2EGsBxhBkCNEBAeK2dUS6/65Gdm+NwXQO3BCcAAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGC1ev4uAACuNi6XS26326e+ERERiouLq+SKALsRZgCgGrlcLiW0aavcnGyf+gcFh+jA/lQCDfADhBkAqEZut1u5OdkKHzJVAeGxXvXNz8xQ5tuL5Xa7CTPADxBmAMAPAsJj5Yxq6e8ygFqBE4ABAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKv5NczMmzdPP/rRjxQaGqrIyEgNHz5cBw4cKLaMMUbJycmKiYlRcHCwEhMTtW/fPj9VDAAAahq/hpnNmzfr/vvv144dO7RhwwZduHBB/fv317lz5zzLLFy4UEuWLNGyZcu0c+dORUVFqV+/fjpz5owfKwcAADWFX29nsH79+mLPU1JSFBkZqd27d6tnz54yxmjp0qWaNWuWRo4cKUlatWqVmjZtqjVr1mjixIn+KBsAANQgNeqcmdOnT0uSGjduLElKS0vT8ePH1b9/f88yTqdTvXr10rZt2/xSIwAAqFlqzI0mjTGaMmWKbr31VrVr106SdPz4cUlS06ZNiy3btGlTpaenl7qevLw85eXleZ5nZWVVUcVA7eNyueR2u33qGxERcVXdydnXfZWamloF1QBXtxoTZh544AF9+umn+s9//lPiNYfDUey5MaZEW5F58+Zpzpw5VVIjUJu5XC4ltGmr3Jxsn/oHBYfowP7UqyLQVHRfAahcNSLMPPjgg3rrrbe0ZcsWNWvWzNMeFRUl6eIRmujoaE/7iRMnShytKTJjxgxNmTLF8zwrK0uxsbFVVDlQe7jdbuXmZCt8yFQFhHv3byY/M0OZby+W2+2+KsJMRfZVzje7dHrr6iqqDLg6+TXMGGP04IMP6s0339SmTZvUokWLYq+3aNFCUVFR2rBhgzp27ChJOn/+vDZv3qwFCxaUuk6n0ymn01nltQO1VUB4rJxRLf1dhhV82Vf5mRlVVA1w9fJrmLn//vu1Zs0a/f3vf1doaKjnHJmwsDAFBwfL4XBo8uTJmjt3rlq1aqVWrVpp7ty5CgkJ0R133OHP0gEAQA3h1zDz3HPPSZISExOLtaekpGjcuHGSpOnTpysnJ0eTJk3SyZMn1bVrV7377rsKDQ2t5moBAEBN5Pevma7E4XAoOTlZycnJVV8QAACwTo36PTMAAADeIswAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAq/k1zGzZskVDhw5VTEyMHA6H1q1bV+z1cePGyeFwFHt069bNP8UCAIAaya9h5ty5c+rQoYOWLVtW5jIDBgzQsWPHPI9//vOf1VghAACo6er5c+MDBw7UwIEDL7uM0+lUVFRUNVUEAABs49cwUx6bNm1SZGSkGjZsqF69eukPf/iDIiMjy1w+Ly9PeXl5nudZWVnVUSYA1Goul0tut9unvhEREYqLi6vkioD/U6PDzMCBA/Xzn/9c8fHxSktL02OPPaY+ffpo9+7dcjqdpfaZN2+e5syZU82VAkDt5XK5lNCmrXJzsn3qHxQcogP7Uwk0qDI1OsyMGjXK8/d27dqpS5cuio+P1zvvvKORI0eW2mfGjBmaMmWK53lWVpZiY2OrvFYAqK3cbrdyc7IVPmSqAsK9+3man5mhzLcXy+12E2ZQZWp0mLlUdHS04uPjdfDgwTKXcTqdZR61AQD4LiA8Vs6olv4uAyjBqt8zk5mZqYyMDEVHR/u7FAAAUEP49cjM2bNn9dVXX3mep6Wlae/evWrcuLEaN26s5ORk/fSnP1V0dLQOHTqkmTNnKiIiQiNGjPBj1QAAoCbxa5jZtWuXevfu7XledK5LUlKSnnvuOX322Wd66aWXdOrUKUVHR6t3795au3atQkND/VUyAACoYfwaZhITE2WMKfP1f//739VYDQAAsJFV58wAAABcijADAACsRpgBAABWI8wAAACrEWYAAIDVfAozaWlplV0HAACAT3wKMy1btlTv3r21evVq5ebmVnZNAAAA5eZTmPnkk0/UsWNHTZ06VVFRUZo4caL++9//VnZtAAAAV+RTmGnXrp2WLFmiI0eOKCUlRcePH9ett96qG2+8UUuWLNF3331X2XUCAACUqkInANerV08jRozQ3/72Ny1YsEBff/21pk2bpmbNmmns2LE6duxYZdUJAABQqgqFmV27dmnSpEmKjo7WkiVLNG3aNH399df64IMPdOTIEQ0bNqyy6gQAACiVT/dmWrJkiVJSUnTgwAENGjRIL730kgYNGqQ6dS5moxYtWuj5559XmzZtKrVYAACAS/kUZp577jndfffduuuuuxQVFVXqMnFxcVqxYkWFigMAALgSn8LMwYMHr7hMYGCgkpKSfFk9AABAufl0zkxKSopeffXVEu2vvvqqVq1aVeGiAAAAysunMDN//nxFRESUaI+MjNTcuXMrXBQAAEB5+RRm0tPT1aJFixLt8fHxcrlcFS4KAACgvHwKM5GRkfr0009LtH/yyScKDw+vcFEAAADl5VOY+eUvf6nf/OY32rhxowoKClRQUKAPPvhADz30kH75y19Wdo0AAABl8ulqpieeeELp6enq27ev6tW7uIrCwkKNHTuWc2YAAEC18inMBAYGau3atfr973+vTz75RMHBwWrfvr3i4+Mruz4AAIDL8inMFGndurVat25dWbUAAAB4zacwU1BQoJUrV+r999/XiRMnVFhYWOz1Dz74oFKKAwAAuBKfwsxDDz2klStXavDgwWrXrp0cDkdl1wUAAFAuPoWZV155RX/72980aNCgyq4HAADAKz5dmh0YGKiWLVtWdi0AAABe8ynMTJ06VU899ZSMMZVdDwAAgFd8+prpP//5jzZu3Kh//etfuvHGGxUQEFDs9TfeeKNSigMAALgSn8JMw4YNNWLEiMquBQAAwGs+hZmUlJTKrgMAAMAnPp0zI0kXLlzQe++9p+eff15nzpyRJB09elRnz56ttOIAAACuxKcjM+np6RowYIBcLpfy8vLUr18/hYaGauHChcrNzdXy5csru04AAIBS+XRk5qGHHlKXLl108uRJBQcHe9pHjBih999/v9KKAwAAuBKfr2b68MMPFRgYWKw9Pj5eR44cqZTCAAAAysOnIzOFhYUqKCgo0X748GGFhoZWuCgAAIDy8inM9OvXT0uXLvU8dzgcOnv2rGbPns0tDgAAQLXy6WumJ598Ur1799YNN9yg3Nxc3XHHHTp48KAiIiL08ssvV3aNAAAAZfIpzMTExGjv3r16+eWX9fHHH6uwsFDjx4/XnXfeWeyEYAAAgKrmU5iRpODgYN199926++67K7MeABZLTU2tlj6V6WqpWZIiIiIUFxdXydVULZfLJbfb7VNfG8cL3/gUZl566aXLvj527FifigFgp4KzJyWHQ6NHj/Z3KeV2NdYcFByiA/tTrfmAd7lcSmjTVrk52T71t2288J1PYeahhx4q9jw/P1/Z2dkKDAxUSEgIYQa4yhTmnZWMUfiQqQoIj/Wqb843u3R66+oqqqxsV1vN+ZkZynx7sdxutzUf7m63W7k52VfNeOE7n8LMyZMnS7QdPHhQ9913n377299WuCgAdgoIj5UzqqVXffIzM6qomvK5Wmq22dU2XnjP53szXapVq1aaP39+iaM2AAAAVanSwowk1a1bV0ePHq3MVQIAAFyWT18zvfXWW8WeG2N07NgxLVu2TD/+8Y8rpTAAAIDy8CnMDB8+vNhzh8OhJk2aqE+fPlq8eHFl1AUAAFAuPoWZwsLCyq4DAADAJ5V6zgwAAEB18+nIzJQpU8q97JIlS3zZBAAAQLn4FGb27Nmjjz/+WBcuXFBCQoIk6csvv1TdunXVqVMnz3IOh6NyqgQAACiDT2Fm6NChCg0N1apVq9SoUSNJF3+R3l133aXbbrtNU6dOrdQiAQAAyuLTOTOLFy/WvHnzPEFGkho1aqQnnniCq5kAAEC18inMZGVl6dtvvy3RfuLECZ05c6bCRQEAAJSXT2FmxIgRuuuuu/Taa6/p8OHDOnz4sF577TWNHz9eI0eOrOwaAQAAyuTTOTPLly/XtGnTNHr0aOXn519cUb16Gj9+vBYtWlSpBQIAAFyOT2EmJCREzz77rBYtWqSvv/5axhi1bNlS9evXr+z6AAAALqtCvzTv2LFjOnbsmFq3bq369evLGFNZdQEAAJSLT2EmMzNTffv2VevWrTVo0CAdO3ZMkjRhwgQuywYAANXKpzDz8MMPKyAgQC6XSyEhIZ72UaNGaf369ZVWHAAAwJX4dM7Mu+++q3//+99q1qxZsfZWrVopPT29UgoDAAAoD5+OzJw7d67YEZkibrdbTqezwkUBAACUl09hpmfPnnrppZc8zx0OhwoLC7Vo0SL17t270ooDAAC4Ep++Zlq0aJESExO1a9cunT9/XtOnT9e+ffv0/fff68MPP6zsGgEAAMrk05GZG264QZ9++qluueUW9evXT+fOndPIkSO1Z88eXX/99ZVdIwAAQJm8PjKTn5+v/v376/nnn9ecOXOqoiYAAIBy8/rITEBAgD7//HM5HI4Kb3zLli0aOnSoYmJi5HA4tG7dumKvG2OUnJysmJgYBQcHKzExUfv27avwdgEAQO3h09dMY8eO1YoVKyq88XPnzqlDhw5atmxZqa8vXLhQS5Ys0bJly7Rz505FRUWpX79+3JkbAAB4+HQC8Pnz5/XnP/9ZGzZsUJcuXUrck2nJkiXlWs/AgQM1cODAUl8zxmjp0qWaNWuW507cq1atUtOmTbVmzRpNnDjRl9IBAEAt41WY+eabb9S8eXN9/vnn6tSpkyTpyy+/LLZMZXz9JElpaWk6fvy4+vfv72lzOp3q1auXtm3bRpgBAACSvAwzrVq10rFjx7Rx40ZJF29f8Kc//UlNmzat9MKOHz8uSSXW3bRp08v+luG8vDzl5eV5nmdlZVV6bQAAO6SmpvrULyIiQnFxcZVcDaqKV2Hm0rti/+tf/9K5c+cqtaBLXXqkxxhz2aM/8+bN4yorALjKFZw9KTkcGj16tE/9g4JDdGB/KoHGEj6dM1Pk0nBTmaKioiRdPEITHR3taT9x4sRljwTNmDFDU6ZM8TzPyspSbGxsldUJAKh5CvPOSsYofMhUBYR79xmQn5mhzLcXy+12E2Ys4VWYcTgcJY6KVNY5Mpdq0aKFoqKitGHDBnXs2FHSxROPN2/erAULFpTZz+l0cn8oAIAkKSA8Vs6olv4uA1XM66+Zxo0b5wkLubm5uvfee0tczfTGG2+Ua31nz57VV1995XmelpamvXv3qnHjxoqLi9PkyZM1d+5ctWrVSq1atdLcuXMVEhKiO+64w5uyAQBALeZVmElKSir23NfvIovs2rWr2I0pi74eSkpK0sqVKzV9+nTl5ORo0qRJOnnypLp27ap3331XoaGhFdouAACoPbwKMykpKZW68cTExMued+NwOJScnKzk5ORK3S4AAKg9fPoNwAAAADUFYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBq9fxdAIDK5XK55Ha7ve6XmppaBdUAF/ny/uI9ifIizAC1iMvlUkKbtsrNyfZ3KYAkqeDsScnh0OjRo/1dCmoxwgxQi7jdbuXmZCt8yFQFhMd61Tfnm106vXV1FVWGq1Vh3lnJGN6TqFKEGaAWCgiPlTOqpVd98jMzqqgagPckqhYnAAMAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNW40CVyGy+WS2+32qW9ERITi4uIquSIAwKUIM0AZXC6XEtq0VW5Otk/9g4JDdGB/KoEGAKoYYQYog9vtVm5OtsKHTFVAeKxXffMzM5T59mK53W7CDABUMcIMcAUB4bFyRrX0dxkAgDJwAjAAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArMa9mYAayOVyye12e90vNTW1CqpBbeLLe+RqfV/5Ou6IiAhuMFvNCDNADeNyuZTQpq1yc7L9XQpqkYKzJyWHQ6NHj/Z3KTVeRfdVUHCIDuxPJdBUI8IMUMO43W7l5mQrfMhUBYTHetU355tdOr11dRVVBpsV5p2VjOF9VQ4V2Vf5mRnKfHux3G43YaYaEWaAGiogPFbOqJZe9cnPzKiialBb8L4qP1/2FfyDE4ABAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKvV6DCTnJwsh8NR7BEVFeXvsgAAQA1S438D8I033qj33nvP87xu3bp+rAYAANQ0NT7M1KtXj6MxAACgTDU+zBw8eFAxMTFyOp3q2rWr5s6dq+uuu67M5fPy8pSXl+d5npWVVR1lAgDgkZqa6lO/iIgIblDpgxodZrp27aqXXnpJrVu31rfffqsnnnhCPXr00L59+xQeHl5qn3nz5mnOnDnVXCkAAFLB2ZOSw6HRo0f71D8oOEQH9qcSaLxUo8PMwIEDPX9v3769unfvruuvv16rVq3SlClTSu0zY8aMYq9lZWUpNta7W7gDAOCLwryzkjEKHzJVAeHeffbkZ2Yo8+3FcrvdhBkv1egwc6n69eurffv2OnjwYJnLOJ1OOZ3OaqwKAIDiAsJj5Yxq6e8yrho1+tLsS+Xl5Sk1NVXR0dH+LgUAANQQNTrMTJs2TZs3b1ZaWpo++ugj/exnP1NWVpaSkpL8XRoAAKghavTXTIcPH9avfvUrud1uNWnSRN26ddOOHTsUHx/v79IAAEANUaPDzCuvvOLvEgAAQA1Xo79mAgAAuBLCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAajX6l+bZwOVyye12+9Q3IiKCO6OWU0X2c15enk83H01NTfVpewBQEb7+7LmaP1MIMxXgcrmU0KatcnOyfeofFByiA/tTr9o3X3lVdD/LUUcyhZVbFABUsoKzJyWHQ6NHj/ap/9X8mUKYqQC3263cnGyFD5mqgPBYr/rmZ2Yo8+3FcrvdV+UbzxsV2c853+zS6a2rK9QXAKpDYd5ZyRg+U3xAmKkEAeGxcka19HcZtZ4v+zk/M6PCfQGgOvGZ4j1OAAYAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1bg301XI5XLJ7Xb71PdqvsU8ANR0qampPvWz/Wc7YeYq43K5lNCmrXJzsn3qfzXfYh4AaqqCsyclh0OjR4/2qb/tP9sJM1cZt9ut3JxsbjEPALVIYd5ZyZir9mc7YeYqxS3mAaD2uVp/tnMCMAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1QgzAADAaoQZAABgNcIMAACwGmEGAABYjTADAACsxr2ZUG1cLpfcbrfX/Xy9pX1N4EvtNo8XAPyBMINq4XK5lNCmrXJzsv1dSrUoOHtScjg0evRof5cCALUeYQbVwu12Kzcn26fb0+d8s0unt66uosqqRmHeWcmYq2a8AOBPhBlUK19uT5+fmVFF1VS9q228AOAPnAAMAACsRpgBAABWI8wAAACrEWYAAIDVCDMAAMBqhBkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKtxbyY/S01N9alfXl6enE5ntW2vouuojO0CAKqOrz+nIyIiFBcXV8nVeIcw4ycFZ09KDodGjx7t2wocdSRTWLlFXUGFawYA1DgV/dkeFByiA/tT/RpoCDN+Uph3VjJG4UOmKiA81qu+Od/s0umtqyvU1xeVUTMAoGapyM/2/MwMZb69WG63mzBzNQsIj5UzqqVXffIzMyrctyL8tV0AQNXx5Wd7TcEJwAAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArEaYAQAAViPMAAAAqxFmAACA1awIM88++6xatGihoKAgde7cWVu3bvV3SQAAoIao8WFm7dq1mjx5smbNmqU9e/botttu08CBA+VyufxdGgAAqAFqfJhZsmSJxo8frwkTJqht27ZaunSpYmNj9dxzz/m7NAAAUAPU6DBz/vx57d69W/379y/W3r9/f23bts1PVQEAgJqkRt9o0u12q6CgQE2bNi3W3rRpUx0/frzUPnl5ecrLy/M8P336tCQpKyur0us7e/bsxW0e/0qF53O96lt040X60pe+9KUvfa3t+/1hSRc/Dyv7c7ZofcaYKy9sarAjR44YSWbbtm3F2p944gmTkJBQap/Zs2cbSTx48ODBgwePWvDIyMi4Yl6o0UdmIiIiVLdu3RJHYU6cOFHiaE2RGTNmaMqUKZ7nhYWFSk9P180336yMjAw1aNCgSmv2h6ysLMXGxjI+C9XmsUmMz3aMz161YWzGGJ05c0YxMTFXXLZGh5nAwEB17txZGzZs0IgRIzztGzZs0LBhw0rt43Q65XQ6i7XVqXPx1KAGDRpYO6nlwfjsVZvHJjE+2zE+e9k+trCwsHItV6PDjCRNmTJFY8aMUZcuXdS9e3e98MILcrlcuvfee/1dGgAAqAFqfJgZNWqUMjMz9fjjj+vYsWNq166d/vnPfyo+Pt7fpQEAgBqgxocZSZo0aZImTZrkc3+n06nZs2eX+PqptmB89qrNY5MYn+0Yn71q89hK4zCmPNc8AQAA1Ew1+pfmAQAAXAlhBgAAWI0wAwAArEaYAQAAVquVYebQoUMaP368WrRooeDgYF1//fWaPXu2zp8/f9l+48aNk8PhKPbo1q1bNVVdPr6OzRij5ORkxcTEKDg4WImJidq3b181Ve2dP/zhD+rRo4dCQkLUsGHDcvWxYe6K+DI+m+bv5MmTGjNmjMLCwhQWFqYxY8bo1KlTl+1Tk+fv2WefVYsWLRQUFKTOnTtr69atl11+8+bN6ty5s4KCgnTddddp+fLl1VSpb7wZ36ZNm0rMk8Ph0P79+6ux4vLZsmWLhg4dqpiYGDkcDq1bt+6KfWyaO2/HZ9Pc+aJWhpn9+/ersLBQzz//vPbt26cnn3xSy5cv18yZM6/Yd8CAATp27Jjn8c9//rMaKi4/X8e2cOFCLVmyRMuWLdPOnTsVFRWlfv366cyZM9VUefmdP39eP//5z3Xfffd51a+mz10RX8Zn0/zdcccd2rt3r9avX6/169dr7969GjNmzBX71cT5W7t2rSZPnqxZs2Zpz549uu222zRw4EC5XK5Sl09LS9OgQYN02223ac+ePZo5c6Z+85vf6PXXX6/mysvH2/EVOXDgQLG5atWqVTVVXH7nzp1Thw4dtGzZsnItb9vceTu+IjbMnU8qejNIWyxcuNC0aNHissskJSWZYcOGVU9BlehKYyssLDRRUVFm/vz5nrbc3FwTFhZmli9fXh0l+iQlJcWEhYWVa1kb566847Np/r744gsjyezYscPTtn37diPJ7N+/v8x+NXX+brnlFnPvvfcWa2vTpo155JFHSl1++vTppk2bNsXaJk6caLp161ZlNVaEt+PbuHGjkWROnjxZDdVVHknmzTffvOwyts3dD5VnfLbOXXnVyiMzpTl9+rQaN258xeU2bdqkyMhItW7dWr/+9a914sSJaqiuYq40trS0NB0/flz9+/f3tDmdTvXq1Uvbtm2rjhKrhY1zVx42zd/27dsVFhamrl27etq6deumsLCwK9Za0+bv/Pnz2r17d7H9Lkn9+/cvcyzbt28vsfztt9+uXbt2KT8/v8pq9YUv4yvSsWNHRUdHq2/fvtq4cWNVllltbJq7iqiNcyfV0q+ZLvX111/r6aefvuL9nAYOHKi//vWv+uCDD7R48WLt3LlTffr0UV5eXjVV6r3yjK3oruOX3mm8adOmJe5Ibisb5668bJq/48ePKzIyskR7ZGTkZWutifPndrtVUFDg1X4/fvx4qctfuHBBbre7ymr1hS/ji46O1gsvvKDXX39db7zxhhISEtS3b19t2bKlOkquUjbNnS9q89xJloWZ5OTkUk9g+uFj165dxfocPXpUAwYM0M9//nNNmDDhsusfNWqUBg8erHbt2mno0KH617/+pS+//FLvvPNOVQ5LUtWPTZIcDkex58aYEm1VxZfxecOfcydV/fgke+avtJquVKu/5+9yvN3vpS1fWntN4c34EhIS9Otf/1qdOnVS9+7d9eyzz2rw4MH64x//WB2lVjnb5s4btX3urLg3U5EHHnhAv/zlLy+7TPPmzT1/P3r0qHr37u2527a3oqOjFR8fr4MHD3rd11tVObaoqChJF//nER0d7Wk/ceJEif+JVBVvx1dR1Tl3UtWOz6b5+/TTT/Xtt9+WeO27777zqtbqnr/SREREqG7duiWOUlxuv0dFRZW6fL169RQeHl5ltfrCl/GVplu3blq9enVll1ftbJq7ylJb5k6yLMxEREQoIiKiXMseOXJEvXv3VufOnZWSkqI6dbw/CJWZmamMjIxiHyBVpSrH1qJFC0VFRWnDhg3q2LGjpIvfl2/evFkLFiyocO3l4c34KkN1zp1UteOzaf66d++u06dP67///a9uueUWSdJHH32k06dPq0ePHuXeXnXPX2kCAwPVuXNnbdiwQSNGjPC0b9iwQcOGDSu1T/fu3fWPf/yjWNu7776rLl26KCAgoErr9ZYv4yvNnj17/DpPlcWmuasstWXuJNXOq5mOHDliWrZsafr06WMOHz5sjh075nn8UEJCgnnjjTeMMcacOXPGTJ061Wzbts2kpaWZjRs3mu7du5trr73WZGVl+WMYpfJlbMYYM3/+fBMWFmbeeOMN89lnn5lf/epXJjo6ukaNrUh6errZs2ePmTNnjrnmmmvMnj17zJ49e8yZM2c8y9g4d0W8HZ8xds3fgAEDzE033WS2b99utm/fbtq3b2+GDBlSbBlb5u+VV14xAQEBZsWKFeaLL74wkydPNvXr1zeHDh0yxhjzyCOPmDFjxniW/+abb0xISIh5+OGHzRdffGFWrFhhAgICzGuvveavIVyWt+N78sknzZtvvmm+/PJL8/nnn5tHHnnESDKvv/66v4ZQpjNnznj+bUkyS5YsMXv27DHp6enGGPvnztvx2TR3vqiVYSYlJcVIKvXxQ5JMSkqKMcaY7Oxs079/f9OkSRMTEBBg4uLiTFJSknG5XH4YQdl8GZsxFy/vnT17tomKijJOp9P07NnTfPbZZ9VcffkkJSWVOr6NGzd6lrFx7op4Oz5j7Jq/zMxMc+edd5rQ0FATGhpq7rzzzhKXg9o0f88884yJj483gYGBplOnTmbz5s2e15KSkkyvXr2KLb9p0ybTsWNHExgYaJo3b26ee+65aq7YO96Mb8GCBeb66683QUFBplGjRubWW28177zzjh+qvrKiS5EvfSQlJRlj7J87b8dn09z5wmHM/z/DCQAAwEJWXc0EAABwKcIMAACwGmEGAABYjTADAACsRpgBAABWI8wAAACrEWYAAIDVCDMArJSYmKjJkyf7uwwANQBhBkC1Gzp0qH7yk5+U+tr27dvlcDj08ccfV3NVAGxFmAFQ7caPH68PPvhA6enpJV578cUXdfPNN6tTp05+qAyAjQgzAKrdkCFDFBkZqZUrVxZrz87O1tq1azV8+HD96le/UrNmzRQSEqL27dvr5Zdfvuw6HQ6H1q1bV6ytYcOGxbZx5MgRjRo1So0aNVJ4eLiGDRumQ4cOVc6gAPgNYQZAtatXr57Gjh2rlStX6oe3h3v11Vd1/vx5TZgwQZ07d9bbb7+tzz//XPfcc4/GjBmjjz76yOdtZmdnq3fv3rrmmmu0ZcsW/ec//9E111yjAQMG6Pz585UxLAB+QpgB4Bd33323Dh06pE2bNnnaXnzxRY0cOVLXXnutpk2bpptvvlnXXXedHnzwQd1+++169dVXfd7eK6+8ojp16ujPf/6z2rdvr7Zt2yolJUUul6tYDQDsU8/fBQC4OrVp00Y9evTQiy++qN69e+vrr7/W1q1b9e6776qgoEDz58/X2rVrdeTIEeXl5SkvL0/169f3eXu7d+/WV199pdDQ0GLtubm5+vrrrys6HAB+RJgB4Dfjx4/XAw88oGeeeUYpKSmKj49X3759tWjRIj355JNaunSp2rdvr/r162vy5MmX/TrI4XAU+8pKkvLz8z1/LywsVOfOnfXXv/61RN8mTZpU3qAAVDvCDAC/+cUvfqGHHnpIa9as0apVq/TrX/9aDodDW7du1bBhwzR69GhJF4PIwYMH1bZt2zLX1aRJEx07dszz/ODBg8rOzvY879Spk9auXavIyEg1aNCg6gYFoNpxzgwAv7nmmms0atQozZw5U0ePHtW4ceMkSS1bttSGDRu0bds2paamauLEiTp+/Phl19WnTx8tW7ZMH3/8sXbt2qV7771XAQEBntfvvPNORUREaNiwYdq6davS0tK0efNmPfTQQzp8+HBVDhNAFSPMAPCr8ePH6+TJk/rJT36iuLg4SdJjjz2mTp066fbbb1diYqKioqI0fPjwy65n8eLFio2NVc+ePXXHHXdo2rRpCgkJ8bweEhKiLVu2KC4uTiNHjlTbtm119913KycnhyM1gOUc5tIvmQEAACzCkRkAAGA1wgwAALAaYQYAAFiNMAMAAKxGmAEAAFYjzAAAAKsRZgAAgNUIMwAAwGqEGQAAYDXCDAAAsBphBgAAWI0wAwAArPb/AAfjC43YMBxoAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Generate some random data for the histogram\n",
    "#data = np.random.randn(1000) # 1000 random numbers from a standard normal distribution\n",
    "\n",
    "# Create the histogram\n",
    "plt.hist(real_latent.cpu().detach().numpy()[20], bins=30, edgecolor='black') # 'bins' defines the number of bins, 'edgecolor' adds borders to the bars\n",
    "\n",
    "# Add labels and title for clarity\n",
    "plt.xlabel('Value')\n",
    "plt.ylabel('Frequency')\n",
    "plt.title('Histogram of Random Data')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 294,
   "id": "add4e7a5-c1ef-4355-8a3a-3f2edc3919d2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkAAAAHFCAYAAAAaD0bAAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAQKVJREFUeJzt3Xl4VOXd//HPAGGSaAiQEJJAEhAwrCKLsogQpIAglEUEq2wCRQtaIFDaoBRQSwQEUVDRRyBYHhEta9VaQNkUsITN6hMpaGAiJOKEJYYsBHJ+f/hj6pAFMkwyMznv13Wdq5z73Pc53zNnnHx6lhmLYRiGAAAATKSKpwsAAACoaAQgAABgOgQgAABgOgQgAABgOgQgAABgOgQgAABgOgQgAABgOgQgAABgOgQgAABgOgQgwAVJSUmyWCxKTk4udnm/fv3UoEEDp7YGDRpo9OjRZdrOnj17NHv2bJ0/f961Qk1o7dq1atGihQICAmSxWHT48OFi++3YsUMWi8UxVa1aVXXq1FH//v1LPK4V4ep768SJEx6r4ZdOnDjh9Dr5+fkpJCREd911l6ZMmaKvv/7a5XXn5ORo9uzZ2rFjh/sKBm4QAQioIBs2bNDMmTPLNGbPnj2aM2cOAegG/fjjjxoxYoQaNWqkjz/+WHv37tXtt99e6pi5c+dq79692rFjh2bOnKk9e/aoW7duOnbsWAVV7Rueeuop7d27Vzt37tRf//pXDRw4UJs3b1br1q21YMECl9aZk5OjOXPmEIDgEdU8XQBgFm3atPF0CWVWUFAgi8WiatV846PiP//5jwoKCjR8+HB169bthsY0adJEHTt2lCTde++9qlmzpkaNGqXVq1drzpw55VmuT4mOjna8TpLUt29fxcfHa/DgwZo+fbpatmypPn36eLBCoGw4AwRUkGsvgRUWFur5559XbGysAgICVLNmTd1xxx16+eWXJUmzZ8/WH/7wB0lSw4YNHZcgrv6/5cLCQs2fP19NmzaV1WpVWFiYRo4cqe+//95pu4ZhaO7cuYqJiZG/v7/at2+vrVu3Ki4uTnFxcY5+Vy8J/fWvf9XUqVNVr149Wa1WHT9+XD/++KMmTJig5s2b69Zbb1VYWJjuu+8+7d6922lbVy+XLFiwQPPmzVODBg0UEBCguLg4Rzj505/+pMjISAUHB2vQoEE6c+bMDb1+mzdvVqdOnRQYGKigoCD17NlTe/fudSwfPXq0unTpIkkaNmyYLBaL0/7dqPbt20uSfvjhB6f2OXPmqEOHDqpdu7Zq1Kihtm3bavny5br296QbNGigfv366eOPP1bbtm0VEBCgpk2basWKFUW2tW/fPt1zzz3y9/dXZGSkEhISVFBQUKTfjR7ruLg4tWzZUnv37lXnzp0VEBCgBg0aaOXKlZKkDz/8UG3btlVgYKBatWqljz/+uMyvzy8FBARo+fLl8vPzczoLdCPvlxMnTqhOnTqSfn5tr76/r/43cvz4cT322GNq0qSJAgMDVa9ePfXv31///ve/b6pm4Crf+L91gJe6cuWKLl++XKT92j+KxZk/f75mz56tZ555Rl27dlVBQYG++eYbx+WucePG6ezZs1qyZInWr1+viIgISVLz5s0lSb/73e/05ptv6sknn1S/fv104sQJzZw5Uzt27NDBgwcVGhoqSXr66aeVmJio8ePHa/DgwUpLS9O4ceNUUFBQ7OWhhIQEderUScuWLVOVKlUUFhamH3/8UZI0a9YshYeHKzs7Wxs2bFBcXJw++eSTIkHj1Vdf1R133KFXX31V58+f19SpU9W/f3916NBBfn5+WrFihU6ePKlp06Zp3Lhx2rx5c6mv1TvvvKNHH31UvXr10po1a5Sfn6/58+c7tt+lSxfNnDlTd999tyZOnKi5c+eqe/fuqlGjxnWPw7VSU1Mlqchrc+LECT3++OOKjo6W9HN4eeqpp3Tq1Cn9+c9/dup75MgRTZ06VX/6059Ut25dvfXWWxo7dqwaN26srl27SpL+7//+Tz169FCDBg2UlJSkwMBAvfbaa3rnnXeK1HSjx1qSMjIy9Nhjj2n69OmqX7++lixZojFjxigtLU1/+9vfNGPGDAUHB+vZZ5/VwIED9d133ykyMrLMr9NVkZGRateunfbs2aPLly+rWrVqOnv2rKTS3y8RERH6+OOPdf/992vs2LEaN26cJDlC0enTpxUSEqIXXnhBderU0dmzZ7Vq1Sp16NBBhw4dUmxsrMs1A5IkA0CZrVy50pBU6hQTE+M0JiYmxhg1apRjvl+/fsadd95Z6nYWLFhgSDJSU1Od2lNSUgxJxoQJE5zav/jiC0OSMWPGDMMwDOPs2bOG1Wo1hg0b5tRv7969hiSjW7dujrbt27cbkoyuXbted/8vX75sFBQUGD169DAGDRrkaE9NTTUkGa1btzauXLniaF+8eLEhyfj1r3/ttJ7JkycbkowLFy6UuK0rV64YkZGRRqtWrZzW+dNPPxlhYWFG586di+zD+++/f919uNp37dq1RkFBgZGTk2N8/vnnRmxsrNG8eXPj3LlzpdZUUFBgPPvss0ZISIhRWFjoWBYTE2P4+/sbJ0+edLTl5uYatWvXNh5//HFH27Bhw4yAgAAjIyPD0Xb58mWjadOmTsf8Ro+1YRhGt27dDElGcnKyoy0zM9OoWrWqERAQYJw6dcrRfvjwYUOS8corr5T6Ol09pgsWLCixz7BhwwxJxg8//FDs8pLeLz/++KMhyZg1a1apNVxdx6VLl4wmTZoYU6ZMuW5/4Hq4BAbchLffflv79+8vMl29FFOau+++W0eOHNGECRP0z3/+U1lZWTe83e3bt0tSkafK7r77bjVr1kyffPKJpJ/PUuTn52vo0KFO/Tp27FjkKbWrHnzwwWLbly1bprZt28rf31/VqlWTn5+fPvnkE6WkpBTp27dvX1Wp8t+Pl2bNmkmSHnjgAad+V9ttNlsJeyodPXpUp0+f1ogRI5zWeeutt+rBBx/Uvn37lJOTU+L46xk2bJj8/PwUGBioe+65R1lZWfrwww9Vs2ZNp36ffvqpfvWrXyk4OFhVq1aVn5+f/vznPyszM7PIZbw777zTcaZIkvz9/XX77bfr5MmTjrbt27erR48eqlu3rqOtatWqGjZsmNO6bvRYXxUREaF27do55mvXrq2wsDDdeeedTmd6rr72v6zJVUYxZzzL8n4pzuXLlzV37lw1b95c1atXV7Vq1VS9enUdO3bshtcBlIYABNyEZs2aqX379kWm4ODg645NSEjQiy++qH379qlPnz4KCQlRjx49bugR7MzMTElyXBb7pcjISMfyq//7yz+yVxXXVtI6Fy1apN/97nfq0KGD1q1bp3379mn//v26//77lZubW6R/7dq1nearV69eanteXl6xtfxyH0ra18LCQp07d67E8dczb9487d+/Xzt37tTTTz+tH374QQMHDlR+fr6jz7/+9S/16tVLkvQ///M/+vzzz7V//349/fTTklTkNQgJCSmyHavV6tQvMzNT4eHhRfpd23ajx/qqa19j6efX2ZXX/kadPHlSVqvVsY2yvl+KEx8fr5kzZ2rgwIH6+9//ri+++EL79+9X69atb3gdQGm4BwjwkGrVqik+Pl7x8fE6f/68tm3bphkzZqh3795KS0tTYGBgiWOv/oFNT09X/fr1nZadPn3acU/I1X7X3tAr/XyvSHFngSwWS5G21atXKy4uTq+//rpT+08//VT6TrrBL/f1WqdPn1aVKlVUq1Ytl9d/2223OW587tq1qwICAvTMM89oyZIlmjZtmiTp3XfflZ+fnz744AP5+/s7xm7cuNHl7YaEhCgjI6NI+7VtN3qsPeXUqVM6cOCAunXr5nha0B3vl9WrV2vkyJGaO3euU7vdbi9ydg5wBWeAAC9Qs2ZNDRkyRBMnTtTZs2cdX4JntVolFT3DcN9990n6+Y/EL+3fv18pKSnq0aOHJKlDhw6yWq1au3atU799+/aV6dKHxWJx1HLVl19+6fQUVnmJjY1VvXr19M477zhdarl48aLWrVvneDLMXaZPn67GjRvrhRdecPzBvvpVAFWrVnX0y83N1V//+leXt9O9e3d98sknTuH0ypUrRY7VjR5rT8jNzdW4ceN0+fJlTZ8+3dF+o++Xkt7fJa3jww8/1KlTp9xVPkyOM0CAh/Tv318tW7ZU+/btVadOHZ08eVKLFy9WTEyMmjRpIklq1aqVJOnll1/WqFGj5Ofnp9jYWMXGxmr8+PFasmSJqlSpoj59+jieDIqKitKUKVMk/Xw5JD4+XomJiapVq5YGDRqk77//XnPmzFFERITTPTWl6devn5577jnNmjVL3bp109GjR/Xss8+qYcOGxT4F505VqlTR/Pnz9eijj6pfv356/PHHlZ+frwULFuj8+fN64YUX3Lo9Pz8/zZ07V0OHDtXLL7+sZ555Rg888IAWLVqkRx55ROPHj1dmZqZefPHFIn+gy+KZZ57R5s2bdd999+nPf/6zAgMD9eqrr+rixYtO/W70WJc3m82mffv2qbCwUBcuXNChQ4ccT/MtXLjQcYlQuvH3S1BQkGJiYrRp0yb16NFDtWvXVmhoqOOrBJKSktS0aVPdcccdOnDggBYsWFDkLBjgMk/fhQ34oqtPge3fv7/Y5Q888MB1nwJbuHCh0blzZyM0NNSoXr26ER0dbYwdO9Y4ceKE07iEhAQjMjLSqFKliiHJ2L59u2EYPz+JNG/ePOP22283/Pz8jNDQUGP48OFGWlqa0/jCwkLj+eefN+rXr29Ur17duOOOO4wPPvjAaN26tdMTOaU9QZWfn29MmzbNqFevnuHv72+0bdvW2LhxozFq1Cin/SzpiaGS1n291/GXNm7caHTo0MHw9/c3brnlFqNHjx7G559/fkPbKc71+nbo0MGoVauWcf78ecMwDGPFihVGbGysYbVajdtuu81ITEw0li9fXuQpvZiYGOOBBx4osr5u3bo5PXVnGIbx+eefGx07djSsVqsRHh5u/OEPfzDefPPNIuu80WPdrVs3o0WLFkW2XVJNkoyJEyeW9BIZhvHfY3p1qlq1qlGrVi2jXbt2xuTJk42vv/66yJgbfb8YhmFs27bNaNOmjWG1Wg1Jjv9Gzp07Z4wdO9YICwszAgMDjS5duhi7d+8u9nUEXGExjBv4whIAlUpqaqqaNm2qWbNmacaMGZ4uBwAqHAEIqOSOHDmiNWvWqHPnzqpRo4aOHj2q+fPnKysrS1999VWJT4MBQGXGPUBAJXfLLbcoOTlZy5cv1/nz5xUcHKy4uDj95S9/IfwAMC3OAAEAANPhMXgAAGA6BCAAAGA6BCAAAGA63ARdjMLCQp0+fVpBQUHF/iwAAADwPoZh6KefflJkZOR1v+iVAFSM06dPKyoqytNlAAAAF6SlpV33W8MJQMUICgqS9PMLWKNGDQ9XAwAAbkRWVpaioqIcf8dLQwAqxtXLXjVq1CAAAQDgY27k9hVuggYAAKZDAAIAAKZDAAIAAKZDAAIAAKZDAAIAAKZDAAIAAKZDAAIAAKZDAAIAAKZDAAIAAKZDAAIAAKZDAAIAAKZDAAIAAKZDAAIAAKZDAAIAAKZTzdMFADAvm80mu93u0tjQ0FBFR0e7uSIAZkEAAuARNptNsU2bKS83x6Xx/gGBOvpNCiEIgEsIQAA8wm63Ky83RyH9psovJKpMYwsy05T5wULZ7XYCEACXEIAAeJRfSJSs4Y09XQYAk+EmaAAAYDoEIAAAYDoEIAAAYDoEIAAAYDoEIAAAYDoEIAAAYDoEIAAAYDoEIAAAYDoEIAAAYDoeDUCJiYm66667FBQUpLCwMA0cOFBHjx516mMYhmbPnq3IyEgFBAQoLi5OX3/99XXXvW7dOjVv3lxWq1XNmzfXhg0byms3AACAj/FoANq5c6cmTpyoffv2aevWrbp8+bJ69eqlixcvOvrMnz9fixYt0tKlS7V//36Fh4erZ8+e+umnn0pc7969ezVs2DCNGDFCR44c0YgRIzR06FB98cUXFbFbAADAy1kMwzA8XcRVP/74o8LCwrRz50517dpVhmEoMjJSkydP1h//+EdJUn5+vurWrat58+bp8ccfL3Y9w4YNU1ZWlv7xj3842u6//37VqlVLa9asuW4dWVlZCg4O1oULF1SjRg337BwAJwcPHlS7du0UPmpxmX8LLD/juDJWTdaBAwfUtm3bcqoQgK8py99vr7oH6MKFC5Kk2rVrS5JSU1OVkZGhXr16OfpYrVZ169ZNe/bsKXE9e/fudRojSb179y5xTH5+vrKyspwmAABQeXlNADIMQ/Hx8erSpYtatmwpScrIyJAk1a1b16lv3bp1HcuKk5GRUaYxiYmJCg4OdkxRUVE3sysAAMDLeU0AevLJJ/Xll18We4nKYrE4zRuGUaTtZsYkJCTowoULjiktLa2M1QMAAF9SzdMFSNJTTz2lzZs3a9euXapfv76jPTw8XNLPZ3QiIiIc7WfOnClyhueXwsPDi5ztKW2M1WqV1Wq9mV0AAAA+xKNngAzD0JNPPqn169fr008/VcOGDZ2WN2zYUOHh4dq6dauj7dKlS9q5c6c6d+5c4no7derkNEaStmzZUuoYAABgHh49AzRx4kS988472rRpk4KCghxnbYKDgxUQECCLxaLJkydr7ty5atKkiZo0aaK5c+cqMDBQjzzyiGM9I0eOVL169ZSYmChJmjRpkrp27ap58+ZpwIAB2rRpk7Zt26bPPvvMI/sJAAC8i0cD0Ouvvy5JiouLc2pfuXKlRo8eLUmaPn26cnNzNWHCBJ07d04dOnTQli1bFBQU5Ohvs9lUpcp/T2Z17txZ7777rp555hnNnDlTjRo10tq1a9WhQ4dy3ycAAOD9PBqAbuQriCwWi2bPnq3Zs2eX2GfHjh1F2oYMGaIhQ4bcRHUAAKCy8pqnwAAAACoKAQgAAJgOAQgAAJgOAQgAAJgOAQgAAJgOAQgAAJgOAQgAAJgOAQgAAJgOAQgAAJiOV/waPADPstlsstvtLo0NDQ1VdHS0mysqXzezv/n5+bJarS6N9cXXCqisCECAydlsNsU2baa83ByXxvsHBOroNyk+84f9ZvdXliqSUejSUF97rYDKjAAEmJzdbldebo5C+k2VX0hUmcYWZKYp84OFstvtPvNH/Wb2N/e7ZF3Yvdo0rxVQmRGAAEiS/EKiZA1v7OkyKowr+1uQmebyWADehZugAQCA6RCAAACA6RCAAACA6RCAAACA6RCAAACA6RCAAACA6RCAAACA6RCAAACA6RCAAACA6RCAAACA6RCAAACA6RCAAACA6RCAAACA6fBr8EAlYbPZZLfbyzwuJSWlHKqpGK7U7sv7C8B9CEBAJWCz2RTbtJnycnM8XUqFuJJ9TrJYNHz4cE+XAsBHEYCASsButysvN0ch/abKLySqTGNzv0vWhd2ry6my8lGYny0Zhmn2F4D7EYCASsQvJErW8MZlGlOQmVZO1ZQ/s+0vAPfhJmgAAGA6Hg1Au3btUv/+/RUZGSmLxaKNGzc6LbdYLMVOCxYsKHGdSUlJxY7Jy8sr570BAAC+wqMB6OLFi2rdurWWLl1a7PL09HSnacWKFbJYLHrwwQdLXW+NGjWKjPX39y+PXQAAAD7Io/cA9enTR3369ClxeXh4uNP8pk2b1L17d912222lrtdisRQZCwAAcJXP3AP0ww8/6MMPP9TYsWOv2zc7O1sxMTGqX7+++vXrp0OHDpXaPz8/X1lZWU4TAACovHwmAK1atUpBQUEaPHhwqf2aNm2qpKQkbd68WWvWrJG/v7/uueceHTt2rMQxiYmJCg4OdkxRUWV7rBYAAPgWnwlAK1as0KOPPnrde3k6duyo4cOHq3Xr1rr33nv13nvv6fbbb9eSJUtKHJOQkKALFy44prQ0HpMFAKAy84nvAdq9e7eOHj2qtWvXlnlslSpVdNddd5V6Bshqtcpqtd5MiQAAwIf4xBmg5cuXq127dmrdunWZxxqGocOHDysiIqIcKgMAAL7Io2eAsrOzdfz4ccd8amqqDh8+rNq1ays6OlqSlJWVpffff18LFy4sdh0jR45UvXr1lJiYKEmaM2eOOnbsqCZNmigrK0uvvPKKDh8+rFdffbX8dwgAAPgEjwag5ORkde/e3TEfHx8vSRo1apSSkpIkSe+++64Mw9BvfvObYtdhs9lUpcp/T2SdP39e48ePV0ZGhoKDg9WmTRvt2rVLd999d/ntCAAA8CkeDUBxcXEyDKPUPuPHj9f48eNLXL5jxw6n+ZdeekkvvfSSO8oDAACVlE/cAwQAAOBOBCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA6BCAAAGA61TxdAADg+mw2m+x2u0tj8/PzZbVaXRobGhqq6Ohol8YC3owABABezmazKbZpM+Xl5ri2AksVySh0aah/QKCOfpNCCEKlQwACAC9nt9uVl5ujkH5T5RcSVaaxud8l68Lu1S6NLchMU+YHC2W32wlAqHQIQADgI/xComQNb1ymMQWZaS6PBSozboIGAACmQwACAACmQwACAACmQwACAACmQwACAACmQwACAACmQwACAACmQwACAACmQwACAACm49EAtGvXLvXv31+RkZGyWCzauHGj0/LRo0fLYrE4TR07drzuetetW6fmzZvLarWqefPm2rBhQzntAQAA8EUeDUAXL15U69attXTp0hL73H///UpPT3dMH330Uanr3Lt3r4YNG6YRI0boyJEjGjFihIYOHaovvvjC3eUDAAAf5dHfAuvTp4/69OlTah+r1arw8PAbXufixYvVs2dPJSQkSJISEhK0c+dOLV68WGvWrLmpegEAQOXg9T+GumPHDoWFhalmzZrq1q2b/vKXvygsLKzE/nv37tWUKVOc2nr37q3FixeXOCY/P1/5+fmO+aysrJuuGwCKk5KSUiFjAJTOqwNQnz599NBDDykmJkapqamaOXOm7rvvPh04cEBWq7XYMRkZGapbt65TW926dZWRkVHidhITEzVnzhy31g4Av3Ql+5xksWj48OGeLgWAvDwADRs2zPHvli1bqn379oqJidGHH36owYMHlzjOYrE4zRuGUaTtlxISEhQfH++Yz8rKUlRU1E1UDgDOCvOzJcNQSL+p8gsp2+dL7nfJurB7dTlVBpiTVwega0VERCgmJkbHjh0rsU94eHiRsz1nzpwpclbol6xWa4lnlADAnfxComQNb1ymMQWZaeVUDWBePvU9QJmZmUpLS1NERESJfTp16qStW7c6tW3ZskWdO3cu7/IAAICP8OgZoOzsbB0/ftwxn5qaqsOHD6t27dqqXbu2Zs+erQcffFARERE6ceKEZsyYodDQUA0aNMgxZuTIkapXr54SExMlSZMmTVLXrl01b948DRgwQJs2bdK2bdv02WefVfj+AQAA7+TRAJScnKzu3bs75q/ehzNq1Ci9/vrr+ve//623335b58+fV0REhLp37661a9cqKCjIMcZms6lKlf+eyOrcubPeffddPfPMM5o5c6YaNWqktWvXqkOHDhW3YwAAwKt5NADFxcXJMIwSl//zn/+87jp27NhRpG3IkCEaMmTIzZQGAAAqMZ+6BwgAAMAdCEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0XApAqamp7q4DAACgwrgUgBo3bqzu3btr9erVysvLc3dNAAAA5cqlAHTkyBG1adNGU6dOVXh4uB5//HH961//cndtAAAA5cKlANSyZUstWrRIp06d0sqVK5WRkaEuXbqoRYsWWrRokX788Ud31wkAAOA2N3UTdLVq1TRo0CC99957mjdvnr799ltNmzZN9evX18iRI5Wenl7q+F27dql///6KjIyUxWLRxo0bHcsKCgr0xz/+Ua1atdItt9yiyMhIjRw5UqdPny51nUlJSbJYLEUmLtUBAICrbioAJScna8KECYqIiNCiRYs0bdo0ffvtt/r000916tQpDRgwoNTxFy9eVOvWrbV06dIiy3JycnTw4EHNnDlTBw8e1Pr16/Wf//xHv/71r69bV40aNZSenu40+fv7u7yfAACgcqnmyqBFixZp5cqVOnr0qPr27au3335bffv2VZUqP+ephg0b6o033lDTpk1LXU+fPn3Up0+fYpcFBwdr69atTm1LlizR3XffLZvNpujo6BLXa7FYFB4eXsa9AgAAZuFSAHr99dc1ZswYPfbYYyUGjejoaC1fvvymirvWhQsXZLFYVLNmzVL7ZWdnKyYmRleuXNGdd96p5557Tm3atCmxf35+vvLz8x3zWVlZ7ioZAAB4IZcC0LFjx67bp3r16ho1apQrqy9WXl6e/vSnP+mRRx5RjRo1SuzXtGlTJSUlqVWrVsrKytLLL7+se+65R0eOHFGTJk2KHZOYmKg5c+a4rVYAAODdXLoHaOXKlXr//feLtL///vtatWrVTRd1rYKCAj388MMqLCzUa6+9Vmrfjh07avjw4WrdurXuvfdevffee7r99tu1ZMmSEsckJCTowoULjiktLc3duwAAALyISwHohRdeUGhoaJH2sLAwzZ0796aL+qWCggINHTpUqamp2rp1a6lnf4pTpUoV3XXXXaWetbJarapRo4bTBAAAKi+XAtDJkyfVsGHDIu0xMTGy2Ww3XdRVV8PPsWPHtG3bNoWEhJR5HYZh6PDhw4qIiHBbXQAAwLe5dA9QWFiYvvzySzVo0MCp/ciRI2UKKdnZ2Tp+/LhjPjU1VYcPH1bt2rUVGRmpIUOG6ODBg/rggw905coVZWRkSJJq166t6tWrS5JGjhypevXqKTExUZI0Z84cdezYUU2aNFFWVpZeeeUVHT58WK+++qoruwoAACohlwLQww8/rN///vcKCgpS165dJUk7d+7UpEmT9PDDD9/wepKTk9W9e3fHfHx8vCRp1KhRmj17tjZv3ixJuvPOO53Gbd++XXFxcZIkm83mePxeks6fP6/x48crIyNDwcHBatOmjXbt2qW7777blV0FAACVkEsB6Pnnn9fJkyfVo0cPVav28yoKCws1cuTIMt0DFBcXJ8MwSlxe2rKrduzY4TT/0ksv6aWXXrrhGgAAgPm4FICqV6+utWvX6rnnntORI0cUEBCgVq1aKSYmxt31AQAAuJ1LAeiq22+/Xbfffru7agHgo1JSUipkDAC4i0sB6MqVK0pKStInn3yiM2fOqLCw0Gn5p59+6pbiAHi3K9nnJItFw4cP93QpAFAmLgWgSZMmKSkpSQ888IBatmwpi8Xi7roA+IDC/GzJMBTSb6r8QqLKNDb3u2Rd2L26nCoDgNK5FIDeffddvffee+rbt6+76wHgg/xComQNb1ymMQWZfOM6AM9x6YsQq1evrsaNy/ZhBwAA4C1cCkBTp07Vyy+/fEOPqQMAAHgbly6BffbZZ9q+fbv+8Y9/qEWLFvLz83Navn79ercUBwAAUB5cCkA1a9bUoEGD3F0LAABAhXApAK1cudLddQAAAFQYl+4BkqTLly9r27ZteuONN/TTTz9Jkk6fPq3s7Gy3FQcAAFAeXDoDdPLkSd1///2y2WzKz89Xz549FRQUpPnz5ysvL0/Lli1zd50AAABu49IZoEmTJql9+/Y6d+6cAgICHO2DBg3SJ5984rbiAAAAyoPLT4F9/vnnql69ulN7TEyMTp065ZbCAAAAyotLZ4AKCwt15cqVIu3ff/+9goKCbrooAACA8uRSAOrZs6cWL17smLdYLMrOztasWbP4eQwAAOD1XLoE9tJLL6l79+5q3ry58vLy9Mgjj+jYsWMKDQ3VmjVr3F0jAACAW7kUgCIjI3X48GGtWbNGBw8eVGFhocaOHatHH33U6aZowFfZbDbZ7XaXxoaGhio6OtrNFQEA3MmlACRJAQEBGjNmjMaMGePOegCPs9lsim3aTHm5OS6N9w8I1NFvUghBAODFXApAb7/9dqnLR44c6VIxgDew2+3Ky81RSL+p8guJKtPYgsw0ZX6wUHa7nQAEAF7MpQA0adIkp/mCggLl5OSoevXqCgwMJAChUvALiZI1vLGnywAAlAOXngI7d+6c05Sdna2jR4+qS5cu3AQNAAC8nsu/BXatJk2a6IUXXihydggAAMDbuC0ASVLVqlV1+vRpd64SAADA7Vy6B2jz5s1O84ZhKD09XUuXLtU999zjlsIAAADKi0sBaODAgU7zFotFderU0X333aeFCxe6oy4AAIBy41IAKiwsdHcdAAAAFcat9wABAAD4ApfOAMXHx99w30WLFrmyCQAAgHLjUgA6dOiQDh48qMuXLys2NlaS9J///EdVq1ZV27ZtHf0sFot7qgQAAHAjlwJQ//79FRQUpFWrVqlWrVqSfv5yxMcee0z33nuvpk6d6tYiAQAA3Mmle4AWLlyoxMRER/iRpFq1aun5558v01Ngu3btUv/+/RUZGSmLxaKNGzc6LTcMQ7Nnz1ZkZKQCAgIUFxenr7/++rrrXbdunZo3by6r1armzZtrw4YNN1wTAACo/FwKQFlZWfrhhx+KtJ85c0Y//fTTDa/n4sWLat26tZYuXVrs8vnz52vRokVaunSp9u/fr/DwcPXs2bPUbezdu1fDhg3TiBEjdOTIEY0YMUJDhw7VF198ccN1AQCAys2lADRo0CA99thj+tvf/qbvv/9e33//vf72t79p7NixGjx48A2vp0+fPnr++eeLHWMYhhYvXqynn35agwcPVsuWLbVq1Srl5OTonXfeKXGdixcvVs+ePZWQkKCmTZsqISFBPXr00OLFi13ZVQAAUAm5dA/QsmXLNG3aNA0fPlwFBQU/r6haNY0dO1YLFixwS2GpqanKyMhQr169HG1Wq1XdunXTnj179Pjjjxc7bu/evZoyZYpTW+/evUsNQPn5+crPz3fMZ2Vl3VzxAADZbDbZ7XaXxoaGhio6OtrNFQH/5VIACgwM1GuvvaYFCxbo22+/lWEYaty4sW655Ra3FZaRkSFJqlu3rlN73bp1dfLkyVLHFTfm6vqKk5iYqDlz5txEtQCAX7LZbIpt2kx5uTkujfcPCNTRb1IIQSg3LgWgq9LT05Wenq6uXbsqICBAhmG4/dH3a9d3I9so65iEhASn7zbKyspSVFSUC9UCACTJbrcrLzdHIf2myi+kbJ+nBZlpyvxgoex2OwEI5calAJSZmamhQ4dq+/btslgsOnbsmG677TaNGzdONWvWdMvvgYWHh0v6+YxORESEo/3MmTNFzvBcO+7asz3XG2O1WmW1Wm+yYgDAtfxComQNb+zpMoAiXLoJesqUKfLz85PNZlNgYKCjfdiwYfr444/dUljDhg0VHh6urVu3OtouXbqknTt3qnPnziWO69Spk9MYSdqyZUupYwAAgLm4dAZoy5Yt+uc//6n69es7tTdp0qTU+3OulZ2drePHjzvmU1NTdfjwYdWuXVvR0dGaPHmy5s6dqyZNmqhJkyaaO3euAgMD9cgjjzjGjBw5UvXq1VNiYqIkadKkSeratavmzZunAQMGaNOmTdq2bZs+++wzV3YVAABUQi4FoIsXLzqd+bnKbreX6VJScnKyunfv7pi/eh/OqFGjlJSUpOnTpys3N1cTJkzQuXPn1KFDB23ZskVBQUGOMTabTVWq/PdEVufOnfXuu+/qmWee0cyZM9WoUSOtXbtWHTp0cGVXAQBAJeRSAOratavefvttPffcc5J+vum4sLBQCxYscAo01xMXFyfDMEpcbrFYNHv2bM2ePbvEPjt27CjSNmTIEA0ZMuSG6wAAAObiUgBasGCB4uLilJycrEuXLmn69On6+uuvdfbsWX3++efurhEAAMCtXLoJunnz5vryyy919913q2fPnrp48aIGDx6sQ4cOqVGjRu6uEQAAwK3KfAaooKBAvXr10htvvMGXBwIAAJ9U5jNAfn5++uqrr9z+hYcAAAAVxaVLYCNHjtTy5cvdXQsAAECFcOkm6EuXLumtt97S1q1b1b59+yK/AbZo0SK3FAcAAFAeyhSAvvvuOzVo0EBfffWV2rZtK0n6z3/+49SHS2MAAMDblSkANWnSROnp6dq+fbukn3/64pVXXin1d7YAAAC8TZnuAbr2Swv/8Y9/6OLFi24tCAAAoLy5dA/QVaV9izOAsrPZbLLb7WUel5KSUg7VAEDlVaYAZLFYitzjwz0/gHvYbDbFNm2mvNwcT5cCAJVemQKQYRgaPXq04wdP8/Ly9MQTTxR5Cmz9+vXuqxAwCbvdrrzcHIX0myq/kKgyjc39LlkXdq8up8oAoPIpUwAaNWqU0/zw4cPdWgwAyS8kStbwxmUaU5CZVk7VAEDlVKYAtHLlyvKqAwAAoMK49E3QAAAAvowABAAATIcABAAATIcABAAATIcABAAATIcABAAATIcABAAATIcABAAATIcABAAATIcABAAATIcABAAATIcABAAATKdMP4YK4MakpKRUyBigIvB+RmVEAALc6Er2Ocli0fDhwz1dCnDTeD+jMiMAAW5UmJ8tGYZC+k2VX0hUmcbmfpesC7tXl1NlQNnxfkZlRgACyoFfSJSs4Y3LNKYgM62cqgFuDu9nVEbcBA0AAEzH6wNQgwYNZLFYikwTJ04stv+OHTuK7f/NN99UcOUAAMBbef0lsP379+vKlSuO+a+++ko9e/bUQw89VOq4o0ePqkaNGo75OnXqlFuNAADAt3h9ALo2uLzwwgtq1KiRunXrVuq4sLAw1axZsxwrAwAAvsrrL4H90qVLl7R69WqNGTNGFoul1L5t2rRRRESEevTooe3bt5faNz8/X1lZWU4TAACovHwqAG3cuFHnz5/X6NGjS+wTERGhN998U+vWrdP69esVGxurHj16aNeuXSWOSUxMVHBwsGOKiirb454AAMC3eP0lsF9avny5+vTpo8jIyBL7xMbGKjY21jHfqVMnpaWl6cUXX1TXrl2LHZOQkKD4+HjHfFZWFiEIAIBKzGcC0MmTJ7Vt2zatX7++zGM7duyo1atL/kIuq9Uqq9V6M+UBAAAf4jOXwFauXKmwsDA98MADZR576NAhRURElENVAADAF/nEGaDCwkKtXLlSo0aNUrVqziUnJCTo1KlTevvttyVJixcvVoMGDdSiRQvHTdPr1q3TunXrPFE6AADwQj4RgLZt2yabzaYxY8YUWZaeni6bzeaYv3TpkqZNm6ZTp04pICBALVq00Icffqi+fftWZMkAAMCL+UQA6tWrlwzDKHZZUlKS0/z06dM1ffr0CqgKAAD4Kp+5BwgAAMBdCEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0qnm6AKC82Gw22e32Mo9LSUkph2oA+ApXPzskKTQ0VNHR0W6uCOWBAIRKyWazKbZpM+Xl5ni6FAA+5GY/O/wDAnX0mxRCkA8gAKFSstvtysvNUUi/qfILiSrT2NzvknVh9+pyqgyAN7uZz46CzDRlfrBQdrudAOQDCECo1PxComQNb1ymMQWZaeVUDQBf4cpnB3wLN0EDAADTIQABAADTIQABAADTIQABAADTIQABAADTIQABAADTIQABAADTIQABAADTIQABAADTIQABAADT8eoANHv2bFksFqcpPDy81DE7d+5Uu3bt5O/vr9tuu03Lli2roGoBAICv8PrfAmvRooW2bdvmmK9atWqJfVNTU9W3b1/99re/1erVq/X5559rwoQJqlOnjh588MGKKBcAAPgArw9A1apVu+5Zn6uWLVum6OhoLV68WJLUrFkzJScn68UXXyQAAQAAB68PQMeOHVNkZKSsVqs6dOiguXPn6rbbbiu27969e9WrVy+ntt69e2v58uUqKCiQn59fsePy8/OVn5/vmM/KynLfDgAAKpzNZpPdbi/zuJSUlHKoBt7IqwNQhw4d9Pbbb+v222/XDz/8oOeff16dO3fW119/rZCQkCL9MzIyVLduXae2unXr6vLly7Lb7YqIiCh2O4mJiZozZ0657AMAoGLZbDbFNm2mvNwcT5cCL+bVAahPnz6Of7dq1UqdOnVSo0aNtGrVKsXHxxc7xmKxOM0bhlFs+y8lJCQ4rS8rK0tRUVE3UzoAwEPsdrvycnMU0m+q/ELK9lme+12yLuxeXU6VwZt4dQC61i233KJWrVrp2LFjxS4PDw9XRkaGU9uZM2dUrVq1Ys8YXWW1WmW1Wt1aKwDAs/xComQNb1ymMQWZaeVUDbyNVz8Gf638/HylpKSUeCmrU6dO2rp1q1Pbli1b1L59+xLv/wEAAObj1QFo2rRp2rlzp1JTU/XFF19oyJAhysrK0qhRoyT9fOlq5MiRjv5PPPGETp48qfj4eKWkpGjFihVavny5pk2b5qldAAAAXsirL4F9//33+s1vfiO73a46deqoY8eO2rdvn2JiYiRJ6enpstlsjv4NGzbURx99pClTpujVV19VZGSkXnnlFR6BBwAATrw6AL377rulLk9KSirS1q1bNx08eLCcKgIAAJWBV18CAwAAKA8EIAAAYDoEIAAAYDoEIAAAYDoEIAAAYDoEIAAAYDoEIAAAYDoEIAAAYDoEIAAAYDpe/U3QgM1mk91uL/O4lJSUcqgGAFBZEIDgtWw2m2KbNlNebo6nSwEAVDIEIHgtu92uvNwchfSbKr+QqDKNzf0uWRd2ry6nygAAvo4ABK/nFxIla3jjMo0pyEwrp2oAAJUBN0EDAADTIQABAADTIQABAADTIQABAADTIQABAADTIQABAADTIQABAADTIQABAADTIQABAADTIQABAADTIQABAADTIQABAADT4cdQfYzNZpPdbndpbGhoqKKjoyt8u/n5+bJarWUel5KS4tL2AFQOrn4G8NmBG0EA8iE2m02xTZspLzfHpfH+AYE6+k1KmUPQzW5XliqSUejaWACmcyX7nGSxaPjw4Z4uBZUYAciH2O125eXmKKTfVPmFRJVpbEFmmjI/WCi73V7mAHQz2839LlkXdq++qbEAzKUwP1syDJc+NyQ+O3BjCEA+yC8kStbwxj6x3YLMtJseC8CcXP2s47MDN4KboAEAgOkQgAAAgOl4dQBKTEzUXXfdpaCgIIWFhWngwIE6evRoqWN27Nghi8VSZPrmm28qqGoAAODtvDoA7dy5UxMnTtS+ffu0detWXb58Wb169dLFixevO/bo0aNKT093TE2aNKmAigEAgC/w6pugP/74Y6f5lStXKiwsTAcOHFDXrl1LHRsWFqaaNWuWY3UAAMBXefUZoGtduHBBklS7du3r9m3Tpo0iIiLUo0cPbd++vdS++fn5ysrKcpoAAEDl5TMByDAMxcfHq0uXLmrZsmWJ/SIiIvTmm29q3bp1Wr9+vWJjY9WjRw/t2rWrxDGJiYkKDg52TFFRZf/eCQAA4Du8+hLYLz355JP68ssv9dlnn5XaLzY2VrGxsY75Tp06KS0tTS+++GKJl80SEhIUHx/vmM/KyiIEAQBQifnEGaCnnnpKmzdv1vbt21W/fv0yj+/YsaOOHTtW4nKr1aoaNWo4TQAAoPLy6jNAhmHoqaee0oYNG7Rjxw41bNjQpfUcOnRIERERbq4OAAD4Kq8OQBMnTtQ777yjTZs2KSgoSBkZGZKk4OBgBQQESPr58tWpU6f09ttvS5IWL16sBg0aqEWLFrp06ZJWr16tdevWad26dR7bDwAA4F28OgC9/vrrkqS4uDin9pUrV2r06NGSpPT0dNlsNseyS5cuadq0aTp16pQCAgLUokULffjhh+rbt29FlQ0AALycVwcgwzCu2ycpKclpfvr06Zo+fXo5VQQAACoDn7gJGgAAwJ0IQAAAwHQIQAAAwHQIQAAAwHQIQAAAwHQIQAAAwHQIQAAAwHQIQAAAwHQIQAAAwHQIQAAAwHQIQAAAwHQIQAAAwHQIQAAAwHQIQAAAwHQIQAAAwHQIQAAAwHQIQAAAwHSqeboAM7LZbLLb7WUel5KSUg7VAADcydXP6tDQUEVHR7s01tW/K766XXcgAFUwm82m2KbNlJeb4+lSAABudCX7nGSxaPjw4S6N9w8I1NFvUsocCm7274qvbdddCEAVzG63Ky83RyH9psovJKpMY3O/S9aF3avLqTIAwM0ozM+WDMOlz/eCzDRlfrBQdru9zIHgZv6u+OJ23YUA5CF+IVGyhjcu05iCzLRyqgYA4C6ufL6z3YrHTdAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0CEAAAMB0fCIAvfbaa2rYsKH8/f3Vrl077d69u9T+O3fuVLt27eTv76/bbrtNy5Ytq6BKAQCAL/D6ALR27VpNnjxZTz/9tA4dOqR7771Xffr0kc1mK7Z/amqq+vbtq3vvvVeHDh3SjBkz9Pvf/17r1q2r4MoBAIC38voAtGjRIo0dO1bjxo1Ts2bNtHjxYkVFRen1118vtv+yZcsUHR2txYsXq1mzZho3bpzGjBmjF198sYIrBwAA3sqrA9ClS5d04MAB9erVy6m9V69e2rNnT7Fj9u7dW6R/7969lZycrIKCgnKrFQAA+A6v/jV4u92uK1euqG7duk7tdevWVUZGRrFjMjIyiu1/+fJl2e12RUREFBmTn5+v/Px8x/yFCxckSVlZWTe7C0VkZ2f/vM2M4yq8lFemsVd/Dd6lsWe/lyQdOHDAUcONOnr0qOvbvZmaGctYxjK2jGM9uW2f/Iz28Hazs7Pd+rf26roMw7h+Z8OLnTp1ypBk7Nmzx6n9+eefN2JjY4sd06RJE2Pu3LlObZ999pkhyUhPTy92zKxZswxJTExMTExMTJVgSktLu27G8OozQKGhoapatWqRsz1nzpwpcpbnqvDw8GL7V6tWTSEhIcWOSUhIUHx8vGO+sLBQZ8+eVUhIiCwWy03uxc+ysrIUFRWltLQ01ahRwy3rhPtwfLwfx8j7cYy8mxmOj2EY+umnnxQZGXndvl4dgKpXr6527dpp69atGjRokKN969atGjBgQLFjOnXqpL///e9ObVu2bFH79u3l5+dX7Bir1Sqr1erUVrNmzZsrvgQ1atSotG+8yoDj4/04Rt6PY+TdKvvxCQ4OvqF+Xn0TtCTFx8frrbfe0ooVK5SSkqIpU6bIZrPpiSeekPTz2ZuRI0c6+j/xxBM6efKk4uPjlZKSohUrVmj58uWaNm2ap3YBAAB4Ga8+AyRJw4YNU2Zmpp599lmlp6erZcuW+uijjxQTEyNJSk9Pd/pOoIYNG+qjjz7SlClT9OqrryoyMlKvvPKKHnzwQU/tAgAA8DJeH4AkacKECZowYUKxy5KSkoq0devWTQcPHiznqsrGarVq1qxZRS61wTtwfLwfx8j7cYy8G8fHmcUwbuRZMQAAgMrD6+8BAgAAcDcCEAAAMB0CEAAAMB0CEAAAMB0CUAU7ceKExo4dq4YNGyogIECNGjXSrFmzdOnSJU+Xhv/vL3/5izp37qzAwMBy+0JMlM1rr72mhg0byt/fX+3atdPu3bs9XRJ+YdeuXerfv78iIyNlsVi0ceNGT5eEX0hMTNRdd92loKAghYWFaeDAgY7f8TIzAlAF++abb1RYWKg33nhDX3/9tV566SUtW7ZMM2bM8HRp+P8uXbqkhx56SL/73e88XQokrV27VpMnT9bTTz+tQ4cO6d5771WfPn2cvv8LnnXx4kW1bt1aS5cu9XQpKMbOnTs1ceJE7du3T1u3btXly5fVq1cvXbx40dOleRSPwXuBBQsW6PXXX9d3333n6VLwC0lJSZo8ebLOnz/v6VJMrUOHDmrbtq1ef/11R1uzZs00cOBAJSYmerAyFMdisWjDhg0aOHCgp0tBCX788UeFhYVp586d6tq1q6fL8RjOAHmBCxcuqHbt2p4uA/A6ly5d0oEDB9SrVy+n9l69emnPnj0eqgrwbRcuXJAk0//dIQB52LfffqslS5Y4ftsMwH/Z7XZduXJFdevWdWqvW7euMjIyPFQV4LsMw1B8fLy6dOmili1berocjyIAucns2bNlsVhKnZKTk53GnD59Wvfff78eeughjRs3zkOVm4Mrxwfew2KxOM0bhlGkDcD1Pfnkk/ryyy+1Zs0aT5ficT7xW2C+4Mknn9TDDz9cap8GDRo4/n369Gl1795dnTp10ptvvlnO1aGsxwfeITQ0VFWrVi1ytufMmTNFzgoBKN1TTz2lzZs3a9euXapfv76ny/E4ApCbhIaGKjQ09Ib6njp1St27d1e7du20cuVKVanCibjyVpbjA+9RvXp1tWvXTlu3btWgQYMc7Vu3btWAAQM8WBngOwzD0FNPPaUNGzZox44datiwoadL8goEoAp2+vRpxcXFKTo6Wi+++KJ+/PFHx7Lw8HAPVoarbDabzp49K5vNpitXrujw4cOSpMaNG+vWW2/1bHEmFB8frxEjRqh9+/aOM6Y2m4375rxIdna2jh8/7phPTU3V4cOHVbt2bUVHR3uwMkjSxIkT9c4772jTpk0KCgpynFENDg5WQECAh6vzHB6Dr2BJSUl67LHHil3GofAOo0eP1qpVq4q0b9++XXFxcRVfEPTaa69p/vz5Sk9PV8uWLfXSSy+Z+vFdb7Njxw517969SPuoUaOUlJRU8QXBSUn3y61cuVKjR4+u2GK8CAEIAACYDjefAAAA0yEAAQAA0yEAAQAA0yEAAQAA0yEAAQAA0yEAAQAA0yEAAQAA0yEAATCNuLg4TZ482dNlAPACBCAAPqF///761a9+VeyyvXv3ymKx6ODBgxVcFQBfRQAC4BPGjh2rTz/9VCdPniyybMWKFbrzzjvVtm1bD1QGwBcRgAD4hH79+iksLKzIb0vl5ORo7dq1GjhwoH7zm9+ofv36CgwMVKtWrbRmzZpS12mxWLRx40antpo1azpt49SpUxo2bJhq1aqlkJAQDRgwQCdOnHDPTgHwGAIQAJ9QrVo1jRw5UklJSU4/HPz+++/r0qVLGjdunNq1a6cPPvhAX331lcaPH68RI0boiy++cHmbOTk56t69u2699Vbt2rVLn332mW699Vbdf//9unTpkjt2C4CHEIAA+IwxY8boxIkT2rFjh6NtxYoVGjx4sOrVq6dp06bpzjvv1G233aannnpKvXv31vvvv+/y9t59911VqVJFb731llq1aqVmzZpp5cqVstlsTjUA8D3VPF0AANyopk2bqnPnzlqxYoW6d++ub7/9Vrt379aWLVt05coVvfDCC1q7dq1OnTql/Px85efn65ZbbnF5ewcOHNDx48cVFBTk1J6Xl6dvv/32ZncHgAcRgAD4lLFjx+rJJ5/Uq6++qpUrVyomJkY9evTQggUL9NJLL2nx4sVq1aqVbrnlFk2ePLnUS1UWi8XpcpokFRQUOP5dWFiodu3a6X//93+LjK1Tp477dgpAhSMAAfApQ4cO1aRJk/TOO+9o1apV+u1vfyuLxaLdu3drwIABGj58uKSfw8uxY8fUrFmzEtdVp04dpaenO+aPHTumnJwcx3zbtm21du1ahYWFqUaNGuW3UwAqHPcAAfApt956q4YNG6YZM2bo9OnTGj16tCSpcePG2rp1q/bs2aOUlBQ9/vjjysjIKHVd9913n5YuXaqDBw8qOTlZTzzxhPz8/BzLH330UYWGhmrAgAHavXu3UlNTtXPnTk2aNEnff/99ee4mgHJGAALgc8aOHatz587pV7/6laKjoyVJM2fOVNu2bdW7d2/FxcUpPDxcAwcOLHU9CxcuVFRUlLp27apHHnlE06ZNU2BgoGN5YGCgdu3apejoaA0ePFjNmjXTmDFjlJubyxkhwMdZjGsvgAMAAFRynAECAACmQwACAACmQwACAACmQwACAACmQwACAACmQwACAACmQwACAACmQwACAACmQwACAACmQwACAACmQwACAACmQwACAACm8/8Aniw0FENG9eQAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Generate some random data for the histogram\n",
    "#data = np.random.randn(1000) # 1000 random numbers from a standard normal distribution\n",
    "\n",
    "# Create the histogram\n",
    "plt.hist(fake.cpu().detach().numpy()[20], bins=30, edgecolor='black') # 'bins' defines the number of bins, 'edgecolor' adds borders to the bars\n",
    "\n",
    "# Add labels and title for clarity\n",
    "plt.xlabel('Value')\n",
    "plt.ylabel('Frequency')\n",
    "plt.title('Histogram of Random Data')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "1bf9769f-04f3-4644-9e40-2265a638b7bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▏         | 7/500 [00:00<00:06, 72.47it/s]\n"
     ]
    }
   ],
   "source": [
    "#DKGMmodel_generate.train()\n",
    "EVAEmodel.eval()\n",
    "netG.eval()\n",
    "#DKGMmodel_debiasing.eval()\n",
    "#diffusion = Diffusion(img_size=16, device=device)\n",
    "index_batch=torch.randperm(train_loader.batch_size)[0]\n",
    "index_sample=torch.randperm(train_loader.batch_size)[:64]\n",
    "# test_set=torchvision.datasets.CIFAR10(root='./',train=False,download=False,transform=transform)\n",
    "# test_loader=torch.utils.data.DataLoader(test_set,batch_size=batch_size,shuffle=False)\n",
    "train_set=torchvision.datasets.CIFAR10(root='./',train=True,download=False,transform=transform)\n",
    "train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=False)\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i,data in tqdm(enumerate(train_loader),total=int(len(train_set)/train_loader.batch_size)):\n",
    "        \n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        #optimizerVAE.zero_grad()\n",
    "        #optimizerEVAE.zero_grad()\n",
    "        if i==7:\n",
    "            #reconstruction,mean,log_r,z=EVAEmodel(data)\n",
    "\n",
    "     \n",
    "           \n",
    "            noise= torch.randn(64, 64,device=device)#torch.median(U_k,1)[0].reshape(point.size(0),3,32,32)\n",
    "\n",
    "            state1= netG(noise)#VAEmodel.decoder(noise)\n",
    "            #z=state1.view(-1,64,1,1)\n",
    "        #decoding\n",
    "            #decoding\n",
    "            reconstruction=EVAEmodel.decoder(state1)#K_z.view(-1,4,16,16))\n",
    "            #z=EVAEmodel.fully_connected_layer(state1.view(-1,256*4*4))\n",
    "      \n",
    "            #z=z.view(-1,256,4,4)\n",
    "        #decoding\n",
    "            #decoding\n",
    "            #reconstruction=EVAEmodel.decoder(state1)#K_z.view(-1,4,16,16))\n",
    "\n",
    "\n",
    "            \n",
    "            #debiased_state=DKGMmodel_debiasing(state2)\n",
    "            #reconstruction=EVAEmodel(state1,noise)\n",
    "            save_reconstructed_imagesVAE(data,0)\n",
    "            save_reconstructed_imagesEVAE(reconstruction,0)\n",
    "            #save_reconstructed_imagesEVAE((debiased_state+0),1)\n",
    "            #save_reconstructed_imagesREAL(real_images,0)\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "c8b32783-d87e-4435-af5a-29b9546f3f2a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 64])"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state1.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "b99a6ea1-4a3f-4337-9260-e1533ca3fddc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 3, 32, 32])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reconstruction.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "0a346abb-c6e5-4e16-aaf7-16848bdee53a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [03:04<00:00,  2.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FIDEVAE: 93.03988647460938\n"
     ]
    }
   ],
   "source": [
    "#VAEmodel.eval()\n",
    "EVAEmodel.eval()\n",
    "netG.eval()\n",
    "\n",
    "running_loss=0.0\n",
    "counter=0\n",
    "#test set\n",
    "tota_sharpVAE=0.0\n",
    "tota_sharpEVAE=0.0\n",
    "\n",
    "\n",
    "from torcheval.metrics import FrechetInceptionDistance\n",
    "\n",
    "fidEVAE = FrechetInceptionDistance(device=device)            \n",
    " \n",
    "with torch.no_grad():\n",
    "    for i,data in tqdm(enumerate(train_loader),total=int(len(train_set)/train_loader.batch_size)):\n",
    "        counter+=1\n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        #generation\n",
    "        state1=netG( torch.randn(data.size(0), 64, device=device)) \n",
    "        state1=EVAEmodel.decoder( state1)\n",
    "\n",
    "        fidEVAE.update(data, is_real=True)\n",
    "        fidEVAE.update(state1, is_real=False)\n",
    "\n",
    "lossEVAE=fidEVAE.compute()\n",
    "\n",
    "print(f\"FIDEVAE: {float(lossEVAE)}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f97eaddc",
   "metadata": {},
   "source": [
    "# Unconditional image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "8d80202a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]\n"
     ]
    }
   ],
   "source": [
    "EVAEmodel.eval()\n",
    "with torch.no_grad():\n",
    "    for i,data in tqdm(enumerate(test_loader),total=int(len(test_set)/test_loader.batch_size)):\n",
    "        \n",
    "        data=data[0]\n",
    "        data=data.to(device)\n",
    "        #optimizerVAE.zero_grad()\n",
    "        optimizerEVAE.zero_grad()\n",
    "        if i==0:#0index_batch.cpu():\n",
    "\n",
    "            U_p=torch.rand(64,64)\n",
    "            z=10*(2*U_p-1)\n",
    "            z=z.to(device)\n",
    "            #uniform_k=torch.rand(64*64,3).to(device)\n",
    "            #U_k=2*uniform_k-1\n",
    "            #z=self.prior(U_p)\n",
    "            #z=torch.tanh(torch.median(U_k,1)[0].reshape(64,64))/2     \n",
    "            z=EVAEmodel.fully_connected_layer(z)\n",
    "    \n",
    "            z=z.view(-1,256,1,1)\n",
    "            #decoding\n",
    "            generate_EVAE=EVAEmodel.decoder(z)\n",
    "            #save_reconstructed_imagesVAE(reconstruction_VAE,0)\n",
    "            save_reconstructed_imagesEVAE(generate_EVAE,0)\n",
    "            #save_reconstructed_imagesREAL(real_images,0)\n",
    "        break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d848ec-b89e-4217-b316-be70176fafbe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
