{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Wasserstein GAN (WGAN)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%load_ext line_profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "import time\n",
    "import glob\n",
    "import imageio\n",
    "from IPython import display\n",
    "import cv2\n",
    "import pathlib\n",
    "import zipfile\n",
    "import torch\n",
    "import sys\n",
    "import pandas as pd \n",
    "\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import Dataset, DataLoader, ConcatDataset, TensorDataset\n",
    "from torchvision.utils import make_grid\n",
    "import torch.optim as optim\n",
    "from torchvision.datasets import MNIST\n",
    "\n",
    "from skimage import io, transform\n",
    "\n",
    "from torchsummary import summary\n",
    "\n",
    "from torch.utils.tensorboard import SummaryWriter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a range of x values\n",
    "x = torch.linspace(-10, 100, 1000)\n",
    "\n",
    "# Compute the softplus values for the range of x\n",
    "softplus_x = torch.nn.functional.elu(x) + 1\n",
    "\n",
    "# Plot the softplus function\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.plot(x, softplus_x, label='Softplus(x)', color='blue')\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('Softplus(x)')\n",
    "plt.title('Softplus Function')\n",
    "plt.grid(True)\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Device Mode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Decide which device we want to run on\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Generator(nn.Module):\n",
    "\n",
    "    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):\n",
    "        super(Generator, self).__init__()\n",
    "        \n",
    "        self.z_dim = z_dim\n",
    "        \n",
    "        self.gen = nn.Sequential(\n",
    "            \n",
    "            self.get_generator_block(z_dim, \n",
    "                                     hidden_dim * 4,\n",
    "                                     kernel_size=3, \n",
    "                                     stride=2),\n",
    "            \n",
    "            self.get_generator_block(hidden_dim * 4, \n",
    "                                     hidden_dim * 2,\n",
    "                                     kernel_size=4,\n",
    "                                     stride = 1),\n",
    "            \n",
    "            self.get_generator_block(hidden_dim * 2,\n",
    "                                     hidden_dim ,\n",
    "                                     kernel_size=3,\n",
    "                                     stride = 2,\n",
    "                                    ),\n",
    "\n",
    "            self.get_generator_final_block(hidden_dim,\n",
    "                                           im_chan,\n",
    "                                           kernel_size=4,\n",
    "                                           stride=2)\n",
    "            \n",
    "\n",
    "        )\n",
    "        \n",
    "        \n",
    "    def get_generator_block(self, input_channel, output_channel, kernel_size, stride = 1, padding = 0):\n",
    "        return nn.Sequential(\n",
    "                nn.ConvTranspose2d(input_channel, output_channel, kernel_size, stride, padding),\n",
    "                nn.BatchNorm2d(output_channel),\n",
    "                nn.ReLU(inplace=True),\n",
    "        )\n",
    "    \n",
    "    \n",
    "    def get_generator_final_block(self, input_channel, output_channel, kernel_size, stride = 1, padding = 0):\n",
    "        return  nn.Sequential(\n",
    "                nn.ConvTranspose2d(input_channel, output_channel, kernel_size, stride, padding),\n",
    "                nn.Tanh()\n",
    "            )\n",
    "    \n",
    "    \n",
    "    def forward(self, noise):\n",
    "        x = noise.view(len(noise), self.z_dim, 1, 1)\n",
    "        return self.gen(x)\n",
    "    \n",
    "    \n",
    "    \n",
    "summary(Generator(100).to(device), (100,))\n",
    "print(Generator(100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Critic / Discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Critic(nn.Module):\n",
    "\n",
    "    def __init__(self, im_chan=1, hidden_dim=16):\n",
    "        super(Critic, self).__init__()\n",
    "        self.disc = nn.Sequential(\n",
    "            self.get_critic_block(im_chan,\n",
    "                                         hidden_dim * 4,\n",
    "                                         kernel_size=4,\n",
    "                                         stride=2),\n",
    "            \n",
    "            self.get_critic_block(hidden_dim * 4,\n",
    "                                         hidden_dim * 8,\n",
    "                                         kernel_size=4,\n",
    "                                         stride=2,),\n",
    "            \n",
    "            self.get_critic_final_block(hidden_dim * 8,\n",
    "                                               1,\n",
    "                                               kernel_size=4,\n",
    "                                               stride=2,),\n",
    "\n",
    "        )\n",
    "\n",
    "        \n",
    "    def get_critic_block(self, input_channel, output_channel, kernel_size, stride = 1, padding = 0):\n",
    "        return nn.Sequential(\n",
    "                nn.Conv2d(input_channel, output_channel, kernel_size, stride, padding),\n",
    "                nn.BatchNorm2d(output_channel),\n",
    "                nn.LeakyReLU(0.2, inplace=True)\n",
    "        )\n",
    "    \n",
    "    \n",
    "    def get_critic_final_block(self, input_channel, output_channel, kernel_size, stride = 1, padding = 0):\n",
    "        return  nn.Sequential(\n",
    "                nn.Conv2d(input_channel, output_channel, kernel_size, stride, padding),\n",
    "            )\n",
    "    \n",
    "    def forward(self, image):\n",
    "        return self.disc(image)\n",
    "    \n",
    "summary(Critic().to(device) , (1,28,28))\n",
    "print(Critic())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Noise Creator Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_noise(n_samples, z_dim, device='cpu'):\n",
    "    return torch.randn(n_samples,z_dim,device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST Dataset Load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_dim = 100\n",
    "batch_size = 128\n",
    "\n",
    "fixed_noise = get_noise(batch_size, z_dim, device=device)\n",
    "\n",
    "train_transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "])\n",
    "\n",
    "dataloader = DataLoader(\n",
    "    MNIST('.', download=True, transform=train_transform),\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loaded Data Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "dataiter = iter(dataloader)\n",
    "images,labels = dataiter._next_data()\n",
    "print ('Time is {} sec'.format(time.time()-start))\n",
    "\n",
    "plt.figure(figsize=(8,8))\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Training Images\")\n",
    "plt.imshow(np.transpose(make_grid(images.to(device), padding=2, normalize=True).cpu(),(1,2,0)))\n",
    "\n",
    "print('Shape of loading one batch:', images.shape)\n",
    "print('Total no. of batches present in trainloader:', len(dataloader))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 0.0002\n",
    "beta_1 = 0.5 \n",
    "beta_2 = 0.999\n",
    "\n",
    "def weights_init(m):\n",
    "    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):\n",
    "        torch.nn.init.normal_(m.weight, 0.0, 0.02)\n",
    "    if isinstance(m, nn.BatchNorm2d):\n",
    "        torch.nn.init.normal_(m.weight, 0.0, 0.02)\n",
    "        torch.nn.init.constant_(m.bias, 0)\n",
    "\n",
    "        \n",
    "gen = Generator(z_dim).to(device)\n",
    "gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))\n",
    "\n",
    "crit  = Critic().to(device) \n",
    "crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))\n",
    "\n",
    "gen = gen.apply(weights_init)\n",
    "crit = crit.apply(weights_init)        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test = get_noise(batch_size, z_dim, device=device)\n",
    "# test_gan = gen(test)\n",
    "# grid_img = np.transpose(make_grid(test_gan[0].to(device), padding=2, normalize=True).cpu().detach(),(1,2,0))\n",
    "\n",
    "# plt.figure(figsize=(8,8))\n",
    "# plt.axis(\"off\")\n",
    "# plt.title(\"Training Images\")\n",
    "# plt.imshow(grid_img)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gradient Penalty"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gradient_penalty(gradient):\n",
    "    gradient = gradient.view(len(gradient), -1)\n",
    "\n",
    "    gradient_norm = gradient.norm(2, dim=1)\n",
    "    \n",
    "    penalty = torch.mean((gradient_norm - 1)**2)\n",
    "    return penalty"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_gen_loss(crit_fake_pred):\n",
    "    gen_loss = -1. * torch.mean(crit_fake_pred)\n",
    "    return gen_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):\n",
    "    crit_loss = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda * gp\n",
    "    return crit_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Training Process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_gradient(crit, real, fake, epsilon):\n",
    "\n",
    "    mixed_images = real * epsilon + fake * (1 - epsilon)\n",
    "\n",
    "    mixed_scores = crit(mixed_images)\n",
    "    \n",
    "    gradient = torch.autograd.grad(\n",
    "        inputs=mixed_images,\n",
    "        outputs=mixed_scores,\n",
    "        grad_outputs=torch.ones_like(mixed_scores), \n",
    "        create_graph=True,\n",
    "        retain_graph=True,\n",
    "        \n",
    "    )[0]\n",
    "    return gradient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28), show_fig=False, epoch=0):\n",
    "    image_unflat = image_tensor.detach().cpu().view(-1, *size)\n",
    "    image_grid = make_grid(image_unflat[:num_images], nrow=5)\n",
    "    plt.axis('off')\n",
    "    plt.imshow(image_grid.permute(1, 2, 0).squeeze())\n",
    "    if show_fig:\n",
    "        plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n",
    "        \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_gan(crit, gen, dataloader):\n",
    "    n_epochs = 5\n",
    "    cur_step = 0\n",
    "    total_steps = 0\n",
    "    start_time = time.time()\n",
    "    cur_step = 0\n",
    "\n",
    "    generator_losses = []\n",
    "    critic_losses = []\n",
    "\n",
    "    C_mean_losses = []\n",
    "    G_mean_losses = []\n",
    "\n",
    "    c_lambda = 10\n",
    "    crit_repeats = 5\n",
    "    display_step = 50\n",
    "\n",
    "    for epoch in range(n_epochs):\n",
    "        cur_step = 0\n",
    "        start = time.time()\n",
    "        for real, _ in dataloader:\n",
    "            cur_batch_size = len(real)\n",
    "            real = real.to(device)\n",
    "\n",
    "            mean_iteration_critic_loss = 0\n",
    "            for _ in range(crit_repeats):\n",
    "                ### Update critic ###\n",
    "                crit_opt.zero_grad()\n",
    "                fake_noise = get_noise(cur_batch_size, z_dim, device=device)\n",
    "                fake = gen(fake_noise)\n",
    "                crit_fake_pred = crit(fake.detach())\n",
    "                crit_real_pred = crit(real)\n",
    "\n",
    "                epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)\n",
    "                gradient = get_gradient(crit, real, fake.detach(), epsilon)\n",
    "                gp = gradient_penalty(gradient)\n",
    "                crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)\n",
    "\n",
    "                # Keep track of the average critic loss in this batch\n",
    "                mean_iteration_critic_loss += crit_loss.item() / crit_repeats\n",
    "                # Update gradients\n",
    "                crit_loss.backward(retain_graph=True)\n",
    "                # Update optimizer\n",
    "                crit_opt.step()\n",
    "            critic_losses += [mean_iteration_critic_loss]\n",
    "\n",
    "            ### Update generator ###\n",
    "            gen_opt.zero_grad()\n",
    "            fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)\n",
    "            fake_2 = gen(fake_noise_2)\n",
    "            crit_fake_pred = crit(fake_2)\n",
    "\n",
    "            gen_loss = get_gen_loss(crit_fake_pred)\n",
    "            gen_loss.backward()\n",
    "\n",
    "            # Update the weights\n",
    "            gen_opt.step()\n",
    "\n",
    "            # Keep track of the average generator loss\n",
    "            generator_losses += [gen_loss.item()]\n",
    "\n",
    "            cur_step += 1\n",
    "            total_steps += 1\n",
    "\n",
    "            print_val = f\"Epoch: {epoch}/{n_epochs} Steps:{cur_step}/{len(dataloader)}\\t\"\n",
    "            print_val += f\"Epoch_Run_Time: {(time.time()-start):.6f}\\t\"\n",
    "            print_val += f\"Loss_C : {mean_iteration_critic_loss:.6f}\\t\"\n",
    "            print_val += f\"Loss_G : {gen_loss:.6f}\\t\"  \n",
    "            print(print_val, end='\\r',flush = True)\n",
    "\n",
    "            ### Visualization code ###\n",
    "    #         if cur_step % display_step == 0 and cur_step > 0:\n",
    "    #             print()\n",
    "    #             gen_mean = sum(generator_losses[-display_step:]) / display_step\n",
    "    #             crit_mean = sum(critic_losses[-display_step:]) / display_step\n",
    "    #             print(f\"Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}\")\n",
    "    #             show_tensor_images(fake)\n",
    "    #             show_tensor_images(real)\n",
    "    #             step_bins = 20\n",
    "    #             num_examples = (len(generator_losses) // step_bins) * step_bins\n",
    "    #             plt.plot(\n",
    "    #                 range(num_examples // step_bins), \n",
    "    #                 torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),\n",
    "    #                 label=\"Generator Loss\"\n",
    "    #             )\n",
    "    #             plt.plot(\n",
    "    #                 range(num_examples // step_bins), \n",
    "    #                 torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),\n",
    "    #                 label=\"Critic Loss\"\n",
    "    #             )\n",
    "    #             plt.legend()\n",
    "    #             plt.show()\n",
    "\n",
    "        print()\n",
    "        gen_mean = sum(generator_losses[-cur_step:]) / cur_step\n",
    "        crit_mean = sum(critic_losses[-cur_step:]) / cur_step\n",
    "\n",
    "        C_mean_losses.append(crit_mean)\n",
    "        G_mean_losses.append(gen_mean)\n",
    "\n",
    "        print_val = f\"Epoch: {epoch}/{n_epochs} Total Steps:{total_steps}\\t\"\n",
    "        print_val += f\"Total_Time : {(time.time() - start_time):.6f}\\t\"\n",
    "        print_val += f\"Loss_C : {mean_iteration_critic_loss:.6f}\\t\"\n",
    "        print_val += f\"Loss_G : {gen_loss:.6f}\\t\"\n",
    "        print_val += f\"Loss_C_Mean : {crit_mean:.6f}\\t\"\n",
    "        print_val += f\"Loss_G_Mean : {gen_mean:.6f}\\t\"\n",
    "        print(print_val)\n",
    "\n",
    "        fake_noise = fixed_noise\n",
    "        fake = gen(fake_noise)\n",
    "\n",
    "        show_tensor_images(fake, show_fig=True,epoch=epoch)\n",
    "\n",
    "        cur_step = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 1\n",
    "cur_step = 0\n",
    "total_steps = 0\n",
    "start_time = time.time()\n",
    "cur_step = 0\n",
    "\n",
    "generator_losses = []\n",
    "critic_losses = []\n",
    "\n",
    "C_mean_losses = []\n",
    "G_mean_losses = []\n",
    "\n",
    "c_lambda = 10\n",
    "crit_repeats = 5\n",
    "display_step = 50\n",
    "\n",
    "for epoch in range(n_epochs):\n",
    "    cur_step = 0\n",
    "    start = time.time()\n",
    "    for real, _ in dataloader:\n",
    "        cur_batch_size = len(real)\n",
    "        real = real.to(device)\n",
    "\n",
    "        mean_iteration_critic_loss = 0\n",
    "        for _ in range(crit_repeats):\n",
    "            ### Update critic ###\n",
    "            crit_opt.zero_grad()\n",
    "            fake_noise = get_noise(cur_batch_size, z_dim, device=device)\n",
    "            fake = gen(fake_noise)\n",
    "            crit_fake_pred = crit(fake.detach())\n",
    "            crit_real_pred = crit(real)\n",
    "\n",
    "            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)\n",
    "            gradient = get_gradient(crit, real, fake.detach(), epsilon)\n",
    "            gp = gradient_penalty(gradient)\n",
    "            crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)\n",
    "\n",
    "            # Keep track of the average critic loss in this batch\n",
    "            mean_iteration_critic_loss += crit_loss.item() / crit_repeats\n",
    "            # Update gradients\n",
    "            crit_loss.backward(retain_graph=True)\n",
    "            # Update optimizer\n",
    "            crit_opt.step()\n",
    "        critic_losses += [mean_iteration_critic_loss]\n",
    "\n",
    "        ### Update generator ###\n",
    "        gen_opt.zero_grad()\n",
    "        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)\n",
    "        fake_2 = gen(fake_noise_2)\n",
    "        crit_fake_pred = crit(fake_2)\n",
    "        \n",
    "        gen_loss = get_gen_loss(crit_fake_pred)\n",
    "        gen_loss.backward()\n",
    "\n",
    "        # Update the weights\n",
    "        gen_opt.step()\n",
    "\n",
    "        # Keep track of the average generator loss\n",
    "        generator_losses += [gen_loss.item()]\n",
    "        \n",
    "        cur_step += 1\n",
    "        total_steps += 1\n",
    "        \n",
    "        print_val = f\"Epoch: {epoch}/{n_epochs} Steps:{cur_step}/{len(dataloader)}\\t\"\n",
    "        print_val += f\"Epoch_Run_Time: {(time.time()-start):.6f}\\t\"\n",
    "        print_val += f\"Loss_C : {mean_iteration_critic_loss:.6f}\\t\"\n",
    "        print_val += f\"Loss_G : {gen_loss:.6f}\\t\"  \n",
    "        print(print_val, end='\\r',flush = True)\n",
    "\n",
    "        ### Visualization code ###\n",
    "#         if cur_step % display_step == 0 and cur_step > 0:\n",
    "#             print()\n",
    "#             gen_mean = sum(generator_losses[-display_step:]) / display_step\n",
    "#             crit_mean = sum(critic_losses[-display_step:]) / display_step\n",
    "#             print(f\"Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}\")\n",
    "#             show_tensor_images(fake)\n",
    "#             show_tensor_images(real)\n",
    "#             step_bins = 20\n",
    "#             num_examples = (len(generator_losses) // step_bins) * step_bins\n",
    "#             plt.plot(\n",
    "#                 range(num_examples // step_bins), \n",
    "#                 torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),\n",
    "#                 label=\"Generator Loss\"\n",
    "#             )\n",
    "#             plt.plot(\n",
    "#                 range(num_examples // step_bins), \n",
    "#                 torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),\n",
    "#                 label=\"Critic Loss\"\n",
    "#             )\n",
    "#             plt.legend()\n",
    "#             plt.show()\n",
    "\n",
    "    print()\n",
    "    gen_mean = sum(generator_losses[-cur_step:]) / cur_step\n",
    "    crit_mean = sum(critic_losses[-cur_step:]) / cur_step\n",
    "    \n",
    "    C_mean_losses.append(crit_mean)\n",
    "    G_mean_losses.append(gen_mean)\n",
    "    \n",
    "    print_val = f\"Epoch: {epoch}/{n_epochs} Total Steps:{total_steps}\\t\"\n",
    "    print_val += f\"Total_Time : {(time.time() - start_time):.6f}\\t\"\n",
    "    print_val += f\"Loss_C : {mean_iteration_critic_loss:.6f}\\t\"\n",
    "    print_val += f\"Loss_G : {gen_loss:.6f}\\t\"\n",
    "    print_val += f\"Loss_C_Mean : {crit_mean:.6f}\\t\"\n",
    "    print_val += f\"Loss_G_Mean : {gen_mean:.6f}\\t\"\n",
    "    print(print_val)\n",
    "    \n",
    "    fake_noise = fixed_noise\n",
    "    fake = gen(fake_noise)\n",
    "    \n",
    "    show_tensor_images(fake, show_fig=True,epoch=epoch)\n",
    "    \n",
    "    cur_step = 0\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# After Tranning Loss Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,5))\n",
    "plt.title(\"Generator and Discriminator Loss During Training\")\n",
    "plt.plot(generator_losses,label=\"G-Loss\")\n",
    "plt.plot(critic_losses,label=\"C-Loss\")\n",
    "plt.xlabel(\"iterations\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,5))\n",
    "plt.title(\"Generator and Discriminator Loss During Training\")\n",
    "plt.plot(G_mean_losses,label=\"G-Loss\")\n",
    "plt.plot(C_mean_losses,label=\"C-Loss\")\n",
    "plt.xlabel(\"iterations\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Animated GIF Create & Show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "anim_file = 'WGAN-GAN.gif'\n",
    "\n",
    "with imageio.get_writer(anim_file, mode='I') as writer:\n",
    "  filenames = glob.glob('image*.png')\n",
    "  filenames = sorted(filenames)\n",
    "  for filename in filenames:\n",
    "    image = imageio.imread(filename)\n",
    "    writer.append_data(image)\n",
    "  image = imageio.imread(filename)\n",
    "  writer.append_data(image)\n",
    "\n",
    "\n",
    "!pip install -q git+https://github.com/tensorflow/docs\n",
    "import tensorflow_docs.vis.embed as embed\n",
    "embed.embed_file(anim_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Testing WGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_new_gen_images(tensor_img, num_img=25):\n",
    "    tensor_img = (tensor_img + 1) / 2\n",
    "    unflat_img = tensor_img.detach().cpu()\n",
    "    img_grid = make_grid(unflat_img[:num_img], nrow=5)\n",
    "    plt.imshow(img_grid.permute(1, 2, 0).squeeze(),cmap='gray')\n",
    "    plt.show()\n",
    "\n",
    "num_image = 25\n",
    "noise = get_noise(num_image, z_dim, device=device)\n",
    "with torch.no_grad():\n",
    "    fake_img = gen(noise)\n",
    "\n",
    "show_new_gen_images(fake_img.reshape(num_image,1,28,28))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Resources\n",
    "\n",
    "[ProteinGAN](https://colab.research.google.com/github/https-deeplearning-ai/GANs-Public/blob/master/ProteinGAN.ipynb)\n",
    "\n",
    "[GAN to WGAN ](https://lilianweng.github.io/lil-log/2017/08/20/from-GAN-to-WGAN.html)\n",
    "\n",
    "[Improved Training of Wasserstein GANs (Gulrajani et al., 2017)](https://arxiv.org/abs/1704.00028)\n",
    "\n",
    "[Wasserstein GAN (Arjovsky, Chintala, and Bottou, 2017))](https://arxiv.org/abs/1701.07875)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bhsgan",
   "language": "python",
   "name": "bhsgan"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
