{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%load_ext line_profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from time import  time\n",
    "\n",
    "from torchvision.utils import make_grid\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.datasets import MNIST\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import random\n",
    "\n",
    "from bhsgan import DiscriminatorBhsMnist, GeneratorBhsMnist\n",
    "from ipmbhsgan import DiscriminatorIpmMnist, GeneratorIpmMnist\n",
    "from trainer import *\n",
    "from utils import *\n",
    "#get_device, get_noise, init_weights, plot_tensor_images, plot_losses, Positive, RevKlActivation\n",
    "from wgan import DiscriminatorWassersteinMnist, GeneratorWassersteinMnist\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from fid import InceptionV3, calculate_frechet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create dataloader\n",
    "batch_size = 128\n",
    "device = get_device()\n",
    "\n",
    "train_transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "])\n",
    "\n",
    "dataloader = DataLoader(\n",
    "    MNIST('.', download=True, transform=train_transform),\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time()\n",
    "dataiter = iter(dataloader)\n",
    "images,labels = dataiter._next_data()\n",
    "print ('Time is {} sec'.format(time()-start))\n",
    "\n",
    "plt.figure(figsize=(8,8))\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Training Images\")\n",
    "plt.imshow(np.transpose(make_grid(images.to(device), padding=2, normalize=True).cpu(),(1,2,0)))\n",
    "\n",
    "print('Shape of loading one batch:', images.shape)\n",
    "print('Total no. of batches present in trainloader:', len(dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_dim = 100\n",
    "latent_dim_bhs = 100\n",
    "batch_size_bhs = 128\n",
    "test_noise = get_noise(25, latent_dim, device)\n",
    "test_noise_bhs = get_noise(25, latent_dim_bhs, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## First I train a Wasserstein GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(96)\n",
    "random.seed(96)\n",
    "\n",
    "training_params = TrainingParams(lr_dis=0.0002,\n",
    "                                 lr_gen=0.0002,\n",
    "                                 num_epochs=50,\n",
    "                                 num_dis_updates=4,\n",
    "                                 num_gen_updates=3,\n",
    "                                 beta_1=0.5,\n",
    "                                 batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator_wasserstein = GeneratorWassersteinMnist(latent_dim).apply(init_weights)\n",
    "discriminator_wasserstein = DiscriminatorWassersteinMnist(28*28).apply(init_weights)\n",
    "trainer_wgan = Trainer(training_params, generator_wasserstein, discriminator_wasserstein, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_wgan = trainer_wgan.train_gan(dataloader, get_dis_loss_wasserstein, get_gen_loss_wasserstein, True, flatten_dim=28*28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show generated images\n",
    "torch.manual_seed(12)\n",
    "test_noise = get_noise(25, latent_dim, device)\n",
    "generated_images_wasserstein = trained_wgan.generator(test_noise)\n",
    "plot_tensor_images(generated_images_wasserstein, num_images=24)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plot_losses(trained_wgan.generator_losses, trained_wgan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save model Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(generator_wasserstein.state_dict(), \"C:/Users/Rick/Desktop/bhsgan/State_Dicts/WS_Gen.pt\")\n",
    "torch.save(discriminator_wasserstein.state_dict(), \"C:/Users/Rick/Desktop/bhsgan/State_Dicts/WS_Disc.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Universal f-GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(96)\n",
    "random.seed(96)\n",
    "\n",
    "training_params = TrainingParams(lr_dis=0.0002,\n",
    "                                 lr_gen=0.0002,\n",
    "                                 num_epochs=5,\n",
    "                                 num_dis_updates=5,\n",
    "                                 num_gen_updates=1,\n",
    "                                 beta_1=0.5,\n",
    "                                 beta_2=0.999,\n",
    "                                 weight_decay=0,\n",
    "                                 batch_size=batch_size,\n",
    "                                 lr_annealing=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator_uf = GeneratorBhsMnist(latent_dim)\n",
    "discriminator_uf = DiscriminatorBhsMnist(UniversalActivation, 28*28)\n",
    "trainer_uf = Trainer(training_params, generator_uf, discriminator_uf, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_ufgan = trainer_uf.train_gan(dataloader, get_dis_loss_uf, get_gen_loss_uf, True, flatten_dim=28*28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show generated images\n",
    "torch.manual_seed(15)\n",
    "test_noise_uf = get_noise(24, latent_dim, device)\n",
    "generated_images_uf = trained_ufgan.generator(test_noise_uf)\n",
    "plot_tensor_images(generated_images_uf, num_images=24)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plot_losses(trained_ufgan.generator_losses, trained_ufgan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BHS Gan\n",
    "$f^*(x) = 2(-1+\\sqrt{1+x})\\exp(\\sqrt{1+x})$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(96)\n",
    "random.seed(96)\n",
    "\n",
    "training_params_bhs = TrainingParams(lr_dis=0.0002, lr_gen=0.0002, num_epochs=50, num_dis_updates=4, num_gen_updates=3, beta_1=0.5, batch_size=batch_size_bhs)\n",
    "generator_uf = GeneratorBhsMnist(latent_dim_bhs)\n",
    "discriminator_uf = DiscriminatorBhsMnist(Positive, 28*28)\n",
    "trainer_uf = Trainer(training_params_bhs, generator_bhs, discriminator_bhs, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trained_bhsgan = trainer_bhs.train_gan(dataloader, get_dis_loss_bhs, get_gen_loss_bhs, False, flatten_dim=28*28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show generated images\n",
    "torch.manual_seed(15)\n",
    "test_noise_bhs = get_noise(24, latent_dim_bhs, device)\n",
    "generated_images_bhs = trained_bhsgan.generator(test_noise_bhs)\n",
    "plot_tensor_images(generated_images_bhs, num_images=24)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plot_losses(trained_bhsgan.generator_losses, trained_bhsgan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(generator_bhs.state_dict(), \"C:/Users/Rick/Desktop/bhsgan/State_Dicts/BHS_Gen.pt\")\n",
    "torch.save(discriminator_bhs.state_dict(), \"C:/Users/Rick/Desktop/bhsgan/State_Dicts/BHS_Disc.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## KL GAN\n",
    "$f^*(x) = \\exp(x-1)$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(96)\n",
    "random.seed(96)\n",
    "\n",
    "training_params_KL = TrainingParams(lr_dis=0.0002, lr_gen=0.0002, num_epochs=2, num_dis_updates=1, num_gen_updates=1, beta_1=0.5, beta_2=0.999,\n",
    "                                 weight_decay=0,\n",
    "                                 batch_size=batch_size,\n",
    "                                 lr_annealing=False)\n",
    "generator_KL = GeneratorBhsMnist(latent_dim)\n",
    "discriminator_KL = DiscriminatorBhsMnist(UniversalActivation, 28*28)\n",
    "trainer_kl = Trainer(training_params_KL, generator_KL, discriminator_KL, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trained_klgan = trainer_kl.train_gan(dataloader, get_dis_loss_kl, get_gen_loss_kl, True, flatten_dim=28*28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "plot_losses(trained_klgan.generator_losses, trained_klgan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(12)\n",
    "test_noise = get_noise(24, latent_dim, device)\n",
    "generated_images_kl = trained_klgan.generator(test_noise)\n",
    "plot_tensor_images(generated_images_kl, num_images=25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(generator_KL.state_dict(), \"State_Dicts/KL_Gen.pt\")\n",
    "torch.save(discriminator_KL.state_dict(), \"State_Dicts/KL_Disc.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Rev KL GAN\n",
    "$f^*(x) = -1 - \\log(x)$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(96)\n",
    "random.seed(96)\n",
    "\n",
    "training_params_RKL = TrainingParams(lr_dis=0.0002,\n",
    "                                     lr_gen=0.0002,\n",
    "                                     num_epochs=50,\n",
    "                                     num_dis_updates=4,\n",
    "                                     num_gen_updates=3,\n",
    "                                     beta_1=0.5,\n",
    "                                     batch_size=batch_size_bhs)\n",
    "generator_RKL = GeneratorBhsMnist(latent0ßßpppppppp_dim)\n",
    "discriminator_RKL = DiscriminatorBhsMnist(RevKlActivation, 28 * 28)\n",
    "trainer_rkl = Trainer(training_params_RKL, generator_RKL, discriminator_RKL, device=device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trained_rklgan = trainer_rkl.train_gan(dataloader, get_dis_loss_rkl, get_gen_loss_rkl, False, flatten_dim=28 * 28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "plot_losses(trained_rklgan.generator_losses, trained_rklgan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(12)\n",
    "test_noise = get_noise(24, latent_dim, device)\n",
    "generated_images_rkl = trained_rklgan.generator(test_noise)\n",
    "plot_tensor_images(generated_images_rkl, num_images=25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(generator_RKL.state_dict(), \"State_Dicts/RKL_Gen.pt\")\n",
    "torch.save(discriminator_RKL.state_dict(), \"State_Dicts/RKL_Disc.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Pearson GAN\n",
    "$f^*(x) = \\frac{1}{4} x^2 + x$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(96)\n",
    "random.seed(96)\n",
    "\n",
    "training_params_pgan = TrainingParams(lr_dis=0.0002,\n",
    "                                      lr_gen=0.002,\n",
    "                                      num_epochs=50,\n",
    "                                      num_dis_updates=4,\n",
    "                                      num_gen_updates=3,\n",
    "                                      beta_1=0.5,\n",
    "                                      batch_size=batch_size_bhs)\n",
    "generator_pgan = GeneratorBhsMnist(latent_dim)\n",
    "discriminator_pgan = DiscriminatorBhsMnist(nn.Identity,28 * 28)\n",
    "trainer_kl = Trainer(training_params_pgan, generator_pgan, discriminator_pgan, device=device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trained_pgan = trainer_kl.train_gan(dataloader, get_dis_loss_p, get_gen_loss_p, False, flatten_dim=28 * 28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(12)\n",
    "test_noise = get_noise(24, latent_dim, device)\n",
    "\n",
    "generated_images_pgan = trained_pgan.generator(test_noise)\n",
    "plot_tensor_images(generated_images_pgan, num_images=24)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "plot_losses(trained_pgan.generator_losses, trained_pgan.discriminator_losses)\n",
    "plt.plot(trained_pgan.generator_losses[0:25])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(generator_pgan.state_dict(), \"State_Dicts/P_Gen.pt\")\n",
    "torch.save(discriminator_pgan.state_dict(), \"State_Dicts/P_Disc.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## GAN GAN\n",
    "$f^*(x) = -\\log(1-\\exp(x))$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(95)\n",
    "random.seed(95)\n",
    "\n",
    "training_params_GAN = TrainingParams(lr_dis=0.0002, lr_gen=0.0002, num_epochs=50, num_dis_updates=4, num_gen_updates=3,\n",
    "                                    beta_1=0.5, batch_size=batch_size_bhs)\n",
    "generator_GAN = GeneratorBhsMnist(latent_dim)\n",
    "discriminator_GAN = DiscriminatorBhsMnist(nn.Sigmoid, 28 * 28)\n",
    "trainer_gan = Trainer(training_params_GAN, generator_GAN, discriminator_GAN, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trained_gan = trainer_gan.train_gan(dataloader, get_dis_loss_gan, get_gen_loss_gan, False, flatten_dim=28 * 28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(4)\n",
    "test_noise = get_noise(24, latent_dim, device)\n",
    "generated_images_gan = trained_gan.generator(test_noise)\n",
    "plot_tensor_images(generated_images_gan, num_images=24)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "plot_losses(trained_gan.generator_losses, trained_gan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(generator_GAN.state_dict(), \"State_Dicts/GAN_Gen.pt\")\n",
    "torch.save(discriminator_GAN.state_dict(), \"State_Dicts/GAN_Disc.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## IPM BHS GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(96)\n",
    "random.seed(96)\n",
    "\n",
    "training_params_ipm = TrainingParams(lr_dis=0.0002, lr_gen=0.0002, num_epochs=50, num_dis_updates=4, num_gen_updates=3, beta_1=0.5, batch_size=batch_size)\n",
    "generator_ipm = GeneratorIpmMnist(latent_dim).apply(init_weights)\n",
    "discriminator_ipm = DiscriminatorIpmMnist(28*28).apply(init_weights)\n",
    "trainer_ipm = Trainer(training_params_ipm, generator_ipm, discriminator_ipm, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_ipmgan = trainer_ipm.train_gan(dataloader, get_dis_loss_ipm, get_gen_loss_ipm, False, flatten_dim=28*28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show generated images\n",
    "torch.manual_seed(14)\n",
    "test_noise = get_noise(24, latent_dim, device)\n",
    "generated_images_ipm = trained_ipmgan.generator(test_noise)\n",
    "plot_tensor_images(generated_images_ipm, num_images=24)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plot_losses(trained_ipmgan.generator_losses, trained_ipmgan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(generator_i.state_dict(), \"C:/Users/Rick/Desktop/bhsgan/State_Dicts/RKL_Gen.pt\")\n",
    "torch.save(discriminator_RKL.state_dict(), \"C:/Users/Rick/Desktop/bhsgan/State_Dicts/RKL_Disc.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate with FID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]\n",
    "model = InceptionV3([block_idx], normalize_input=False)\n",
    "model=model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample onr batch of real data\n",
    "real_images = iter(dataloader)._next_data()[0]\n",
    "# generate noise\n",
    "fid_noise = get_noise(128, latent_dim, device)\n",
    "fid_noise_bhs = get_noise(128, latent_dim_bhs, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get generators\n",
    "generator_wasserstein.load_state_dict(torch.load(\"State_Dicts/WS_Gen.pt\"))\n",
    "generator_bhs.load_state_dict(torch.load(\"State_Dicts/BHS_Gen.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate images\n",
    "generated_images_wasserstein = generator_wasserstein(fid_noise)\n",
    "generated_images_bhs = generator_bhs(fid_noise_bhs)\n",
    "generated_images_kl = trained_klgan.generator(fid_noise)\n",
    "generated_images_rvkl = trained_rklgan.generator(fid_noise)\n",
    "generated_images_gan = trained_gan.generator(fid_noise)\n",
    "generated_images_pearson = trained_pgan.generator(fid_noise)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_wasserstein = calculate_frechet(real_images , generated_images_wasserstein, model)\n",
    "fid_wasserstein"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_bhs = calculate_frechet(real_images , generated_images_bhs, model)\n",
    "fid_bhs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_kl = calculate_frechet(real_images , generated_images_kl, model)\n",
    "fid_kl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_rvkl = calculate_frechet(real_images , generated_images_rvkl, model)\n",
    "fid_rvkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_gan = calculate_frechet(real_images , generated_images_gan, model)\n",
    "fid_gan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_pearson = calculate_frechet(real_images , generated_images_pearson, model)\n",
    "fid_pearson"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bhsgan",
   "language": "python",
   "name": "bhsgan"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
