{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba98f1e0-a5a9-4103-9ccf-a08857562199",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%load_ext line_profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8582bc3b-aba1-4afb-b964-f6c2bd8fcb24",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from time import  time\n",
    "\n",
    "from torchvision.utils import make_grid\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.datasets import MNIST, LSUN, CelebA\n",
    "from torchvision import datasets\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import random\n",
    "from pathlib import Path\n",
    "\n",
    "from bhsgan import DiscriminatorBhsMnist, GeneratorBhsMnist, DiscriminatorBhsCifar, GeneratorBhsCifar\n",
    "from fgan import DiscriminatorfCifar, GeneratorfCifar\n",
    "from wgan import GeneratorWassersteinCifar, DiscriminatorWassersteinCifar\n",
    "from ipmbhsgan import DiscriminatorIpmMnist, GeneratorIpmMnist\n",
    "from trainer import (Trainer, TrainingParams, get_dis_loss_bhs,\n",
    "                     get_dis_loss_ipm, get_dis_loss_wasserstein,\n",
    "                     get_gen_loss_bhs, get_gen_loss_ipm, get_dis_loss_bhs_2,\n",
    "                     get_gen_loss_wasserstein,get_dis_loss_kl, get_gen_loss_kl, get_dis_loss_rkl, get_gen_loss_rkl,\n",
    "                     get_dis_loss_gan, get_gen_loss_gan, get_dis_loss_p, get_gen_loss_p)\n",
    "from utils import get_device, get_noise, init_weights, plot_tensor_images, plot_losses, BhsActivation, save_models_state_dict, load_model_state_dict, RevKlActivation, GanGanActivation\n",
    "from wgan import DiscriminatorWassersteinCifar, GeneratorWassersteinCifar\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from torchvision.datasets import CIFAR10\n",
    "from fid import InceptionV3\n",
    "from dataset import SubCIFAR10\n",
    "from fid import calculate_frechet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "448ed295-c197-4ef8-9dfe-9910900acb18",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create dataloader\n",
    "batch_size = 64\n",
    "device = get_device()\n",
    "image_size = 64\n",
    "num_epochs = 50\n",
    "\n",
    "train_transform = transforms.Compose([\n",
    "        transforms.Resize(image_size),\n",
    "        #transforms.CenterCrop(image_size),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "])\n",
    "\n",
    "dataset = SubCIFAR10(root='cifar10/', transform=train_transform, download=True, include_list=[])\n",
    "\n",
    "dataloader = DataLoader(\n",
    "    dataset,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    "    num_workers = 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da898a4f-3cdd-47b1-b562-b3514d3fcece",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70f847a5-4ecf-494b-950d-71b59cfe44c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_dim = 100\n",
    "latent_dim_bhs = 100\n",
    "batch_size_bhs = 128\n",
    "test_noise = torch.reshape(get_noise(25, latent_dim, device), (25, latent_dim, 1, 1))\n",
    "test_noise_bhs = torch.reshape(get_noise(25, latent_dim_bhs, device), (25, latent_dim_bhs, 1, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c21ab2c-e28a-47b6-a906-2742ff65d94b",
   "metadata": {},
   "source": [
    "### Wasserstein GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b436f40-a302-4855-9d27-b5e7429be321",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_params_wasserstein = TrainingParams(lr_dis=0.0002,\n",
    "                                 lr_gen=0.0002,\n",
    "                                 num_epochs=num_epochs,\n",
    "                                 num_dis_updates=2,\n",
    "                                 num_gen_updates=2,\n",
    "                                 beta_1=0.5,\n",
    "                                 beta_2=0.999,\n",
    "                                 weight_decay=0,\n",
    "                                 batch_size=batch_size,\n",
    "                                 lr_annealing=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b11f523-4f0a-4e06-ad02-38a17fec2093",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path_wgan = \"WGAN_CIFAR\"\n",
    "generator_wasserstein = GeneratorWassersteinCifar(z_dim=latent_dim).apply(init_weights)\n",
    "discriminator_wasserstein = DiscriminatorWassersteinCifar().apply(init_weights)\n",
    "trainer_wgan = Trainer(training_params_wasserstein, generator_wasserstein, discriminator_wasserstein, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7097946-8068-4b23-8305-df3492f5c1a6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_wgan = trainer_wgan.train_gan(dataloader, get_dis_loss_wasserstein, get_gen_loss_wasserstein, True, is_color_picture=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b152f22-71e6-4f1a-9c1e-bbc07fd12d9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_models_state_dict(trained_wgan, save_path_wgan)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f18433da-a86f-43f0-bbc0-d12d8f5d51ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# show generated images\n",
    "generated_images_wasserstein = generator_wasserstein(test_noise)\n",
    "plot_tensor_images(generated_images_wasserstein, num_images=8, size=(1, 64, 64), unflat=False, tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9235c88-8128-48be-869c-cbc63f24018b",
   "metadata": {},
   "outputs": [],
   "source": [
    "load_model_state_dict(generator_wasserstein, Path(save_path_wgan) / \"generator.pt\", map_location=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e4a272a-2c8c-4861-a83b-7c15f22c837e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plot_losses(trained_wgan.generator_losses, trained_wgan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "410d37bc-569e-486d-9e03-d2c0f6729cf5",
   "metadata": {},
   "source": [
    "### BHS GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d693581-b91d-44af-8277-72df293c9088",
   "metadata": {},
   "outputs": [],
   "source": [
    "# old seed 97\n",
    "torch.manual_seed(121)\n",
    "random.seed(121)\n",
    "training_params_bhs = TrainingParams(lr_dis=0.00005,\n",
    "                                 lr_gen=0.00005,\n",
    "                                 num_epochs=50,\n",
    "                                 num_dis_updates=2,\n",
    "                                 num_gen_updates=2,\n",
    "                                 beta_1=0.5,\n",
    "                                 beta_2=0.999,\n",
    "                                 weight_decay=0,\n",
    "                                 batch_size=batch_size, \n",
    "                                 lr_annealing=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a21490a-351c-4084-b425-715509c786f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path_bhsgan = \"BHSGAN_CIFAR_5\"\n",
    "final_activation = BhsActivation\n",
    "generator_bhs = GeneratorBhsCifar(z_dim=latent_dim_bhs, image_size=image_size).apply(init_weights)\n",
    "discriminator_bhs = DiscriminatorBhsCifar(final_activation, image_size=image_size).apply(init_weights)\n",
    "trainer_bhsgan = Trainer(training_params_bhs, generator_bhs, discriminator_bhs, device=device, calculate_fid=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a4dc9f-262e-4ab5-82ad-66e7bef07fe3",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_bhsgan = Trainer(training_params_bhs, generator_bhs, discriminator_bhs, device=device, calculate_fid=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab4f5d0-0635-48e5-9efc-a71ddac95c8c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "##### training loop\n",
    "torch.autograd.set_detect_anomaly(False)\n",
    "trained_bhsgan = trainer_bhsgan.train_gan(dataloader, get_dis_loss_bhs_2, get_gen_loss_bhs, True, is_color_picture=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "971b8a31-9d83-48e3-986d-e4aea200279c",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_models_state_dict(trained_bhsgan, save_path_bhsgan)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "648492da-5d54-4efd-83c3-ebbdea5ee004",
   "metadata": {},
   "outputs": [],
   "source": [
    "# show generated images\n",
    "generated_images_bhs = generator_bhs(test_noise_bhs)\n",
    "plot_tensor_images(generated_images_bhs, num_images=24, unflat=False, tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ec1d607-4145-4239-a984-06a16c6518aa",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plot_losses(trained_bhsgan.generator_losses, trained_bhsgan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6da9af0-0260-4aa2-b1e8-06156330817d",
   "metadata": {},
   "source": [
    "## KL GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46e4a855-391e-4f5e-b4a3-fbcd0c6de799",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_params_kl = TrainingParams(lr_dis=0.0002,\n",
    "                                 lr_gen=0.0002,\n",
    "                                 num_epochs=num_epochs,\n",
    "                                 num_dis_updates=2,\n",
    "                                 num_gen_updates=2,\n",
    "                                 beta_1=0.5,\n",
    "                                 beta_2=0.999,\n",
    "                                 weight_decay=0,\n",
    "                                 batch_size=batch_size,\n",
    "                                 lr_annealing=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe1eec8b-1d01-4fb8-99af-9b6392ca70ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path_klgan = \"KLGAN_CIFAR_2\"\n",
    "final_activation = nn.Identity\n",
    "generator_kl = GeneratorfCifar(z_dim=latent_dim).apply(init_weights)\n",
    "discriminator_kl = DiscriminatorfCifar(final_activation).apply(init_weights)\n",
    "trainer_klgan = Trainer(training_params_kl, generator_kl, discriminator_kl, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a79cd78b-e0d0-45d9-80cb-ae2d6b792609",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_klgan = trainer_klgan.train_gan(dataloader, get_dis_loss_kl, get_gen_loss_kl, True, is_color_picture=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93a11721-9655-429a-ad1d-a873d7654afd",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_models_state_dict(trained_klgan, save_path_klgan)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cf7cea5-a193-4745-9c31-7e8754412346",
   "metadata": {},
   "outputs": [],
   "source": [
    "# show generated images\n",
    "generated_images_kl = trained_klgan.generator(test_noise)\n",
    "plot_tensor_images(generated_images_kl, num_images=8, unflat=False, tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "406a3046-37ff-4468-9712-826ab80274fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plot_losses(trained_klgan.generator_losses, trained_klgan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4b17e4f-ddef-43c1-90b4-e13d9f49a697",
   "metadata": {},
   "source": [
    "### RV KL GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ecc8e64-691c-4626-bc26-28a4addcc8d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(96)\n",
    "random.seed(96)\n",
    "training_params_rkl = TrainingParams(lr_dis=0.00005,\n",
    "                                 lr_gen=0.00005,\n",
    "                                 num_epochs=num_epochs,\n",
    "                                 num_dis_updates=2,\n",
    "                                 num_gen_updates=2,\n",
    "                                 beta_1=0.5,\n",
    "                                 beta_2=0.999,\n",
    "                                 weight_decay=0,\n",
    "                                 batch_size=batch_size,\n",
    "                                 lr_annealing=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18130719-5361-42d9-8cf6-781f08f272ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path_rvklgan = \"RVKLGAN_CIFAR_2\"\n",
    "final_activation = RevKlActivation\n",
    "generator_rvkl = GeneratorfCifar(z_dim=latent_dim).apply(init_weights)\n",
    "discriminator_rvkl = DiscriminatorfCifar(final_activation).apply(init_weights)\n",
    "trainer_rvklgan = Trainer(training_params_rkl, generator_rvkl, discriminator_rvkl, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c52bb4ef-48bf-4b65-8492-54a6193caf14",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_rvklgan = trainer_rvklgan.train_gan(dataloader, get_dis_loss_rkl, get_gen_loss_rkl, False, is_color_picture=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f407128e-5171-4c8a-9000-af23a879e2c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_models_state_dict(trained_rvklgan, save_path_rvklgan)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4154f109-32fc-4a73-80f1-db6f33cbce20",
   "metadata": {},
   "outputs": [],
   "source": [
    "# show generated images\n",
    "generated_images_rvkl = trained_rvklgan.generator(test_noise)\n",
    "plot_tensor_images(generated_images_rvkl, num_images=8, unflat=False, tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9d19557-c5a3-4a09-80d2-99c50730d83b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plot_losses(trained_rvklgan.generator_losses, trained_rvklgan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2246d09f-1337-41d2-ac56-cac0e9b04295",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Standard GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15e9f4dd-c20c-4c1f-8259-67b1a465dcce",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(96)\n",
    "random.seed(96)\n",
    "training_params_gan = TrainingParams(lr_dis=0.0002,\n",
    "                                 lr_gen=0.0002,\n",
    "                                 num_epochs=num_epochs,\n",
    "                                 num_dis_updates=2,\n",
    "                                 num_gen_updates=2,\n",
    "                                 beta_1=0.5,\n",
    "                                 beta_2=0.999,\n",
    "                                 weight_decay=0,\n",
    "                                 batch_size=batch_size,\n",
    "                                 lr_annealing=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a766b40-3a36-4e7d-a404-b8b1bcef45a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path_gan = \"GAN_CIFAR_2\"\n",
    "final_activation = nn.Sigmoid\n",
    "generator_gan = GeneratorfCifar(z_dim=latent_dim).apply(init_weights)\n",
    "discriminator_gan = DiscriminatorfCifar(final_activation).apply(init_weights)\n",
    "trainer_gan = Trainer(training_params_gan, generator_gan, discriminator_gan, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5700b1f8-150b-4242-8a19-a98e4ca078b9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_gan = trainer_gan.train_gan(dataloader, get_dis_loss_gan, get_gen_loss_gan, False, is_color_picture=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb670001-b2a9-4ca5-aa34-4f0a4f72bef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_models_state_dict(trained_gan, save_path_gan)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2921c2-80f0-47fa-8d77-79a1bb0103c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# show generated images\n",
    "generated_images_gan = trained_gan.generator(test_noise)\n",
    "plot_tensor_images(generated_images_gan, num_images=8, unflat=False, tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0ec76b4-c65b-4cf5-b1fc-5df2706b2c09",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plot_losses(trained_gan.generator_losses, trained_gan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2c0243e-6cb3-45f5-8131-0210e7132db9",
   "metadata": {},
   "source": [
    "### Pearson GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97a82f59-4ed5-474d-8135-29eab5a815a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(96)\n",
    "random.seed(96)\n",
    "training_params_pearson = TrainingParams(lr_dis=0.0002,\n",
    "                                 lr_gen=0.0002,\n",
    "                                 num_epochs=num_epochs,\n",
    "                                 num_dis_updates=2,\n",
    "                                 num_gen_updates=2,\n",
    "                                 beta_1=0.5,\n",
    "                                 beta_2=0.999,\n",
    "                                 weight_decay=0,\n",
    "                                 batch_size=batch_size,\n",
    "                                 lr_annealing=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6bf8ae3-c781-4e52-95f5-e5e6105a80f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path_pearsongan = \"PEARSONGAN_CIFAR\"\n",
    "final_activation = nn.Identity\n",
    "generator_pearson = GeneratorfCifar(z_dim=latent_dim).apply(init_weights)\n",
    "discriminator_pearson = DiscriminatorfCifar(final_activation).apply(init_weights)\n",
    "trainer_pearsongan = Trainer(training_params_pearson, generator_pearson, discriminator_pearson, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3421e010-d62c-48f6-94d9-22c7887a2e0d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_pearsongan = trainer_pearsongan.train_gan(dataloader, get_dis_loss_p, get_gen_loss_p, True, is_color_picture=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08ef7bd4-5a51-485d-b218-86724d2fad2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_models_state_dict(trained_pearsongan, save_path_pearsongan)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c2ba6b5-6073-463d-b76b-157874836b15",
   "metadata": {},
   "outputs": [],
   "source": [
    "# show generated images\n",
    "generated_images_pearson = trained_pearsongan.generator(test_noise)\n",
    "plot_tensor_images(generated_images_pearson, num_images=8, unflat=False, tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4819dd32-128a-41e1-a94b-0962ce325b7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plot_losses(trained_pearsongan.generator_losses, trained_pearsongan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "818e7edb-59b7-4582-accd-3dc0172e74e1",
   "metadata": {},
   "source": [
    "## IPMBHSGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df8054ce-e0fe-4f23-b0fb-9d268327dabd",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(96)\n",
    "random.seed(96)\n",
    "training_params_ipm = TrainingParams(lr_dis=0.0002,\n",
    "                                 lr_gen=0.0002,\n",
    "                                 num_epochs=num_epochs,\n",
    "                                 num_dis_updates=5,\n",
    "                                 num_gen_updates=1,\n",
    "                                 beta_1=0.5,\n",
    "                                 beta_2=0.999,\n",
    "                                 weight_decay=0,\n",
    "                                 batch_size=batch_size,\n",
    "                                 lr_annealing=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d525f3ad-5e84-4c3d-a27c-40bdf33eddc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path_ipmngan = \"IPMGAN_CIFAR\"\n",
    "generator_ipm = GeneratorWassersteinCifar(z_dim=latent_dim).apply(init_weights)\n",
    "discriminator_ipm = DiscriminatorWassersteinCifar().apply(init_weights)\n",
    "trainer_ipmgan = Trainer(training_params_ipm, generator_ipm, discriminator_ipm, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f229337-a0c6-4bd3-95c0-8b803f7517fd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_ipmgan = trainer_ipmgan.train_gan(dataloader, get_dis_loss_ipm, get_gen_loss_ipm, False, is_color_picture=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5273463-b14b-4f8c-9e24-065fee6a770a",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_models_state_dict(trained_ipmgan, save_path_ipmgan)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a743f055-9ccb-459a-85cb-1c2edd06a4ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# show generated images\n",
    "generated_images_ipm = trained_ipmgan.generator(test_noise)\n",
    "plot_tensor_images(generated_images_ipm, num_images=8, unflat=False, tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c499274b-5d8e-4ff3-b929-c864ea439f7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plot_losses(trained_ipmgan.generator_losses, trained_ipmgan.discriminator_losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65c20edc-c74b-4acf-9006-f86845dbd476",
   "metadata": {},
   "source": [
    "### Evaluate with FID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40a9f2d6-0e9e-4dd6-ab05-d6fb6a58b5db",
   "metadata": {},
   "outputs": [],
   "source": [
    "block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]\n",
    "model = InceptionV3([block_idx])\n",
    "model=model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91a31c3b-6691-4f30-b394-22a5cdbfafcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_transform = transforms.Compose([\n",
    "        transforms.Resize(image_size),\n",
    "        #transforms.CenterCrop(image_size),\n",
    "        transforms.ToTensor()\n",
    "])\n",
    "dataset_fid = SubCIFAR10(root='cifar10/', transform=train_transform, download=True, include_list=[])\n",
    "\n",
    "dataloader_fid = DataLoader(\n",
    "    dataset_fid,\n",
    "    batch_size=200,\n",
    "    shuffle=True,\n",
    "    num_workers = 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24c95b41-d263-4225-9926-36d7445b0db3",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(97)\n",
    "random.seed(97)\n",
    "# sample onr batch of real data\n",
    "real_images = iter(dataloader_fid)._next_data()[0]\n",
    "# generate noise\n",
    "fid_noise = torch.reshape(get_noise(16, latent_dim, device), (16, latent_dim, 1, 1))\n",
    "fid_noise_bhs = torch.reshape(get_noise(100, latent_dim_bhs, device), (100, latent_dim_bhs, 1, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d51692a7-5db6-4541-8fbd-879c78535d24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get generators\n",
    "generator_wasserstein.load_state_dict(torch.load(\"WGAN_CIFAR/generator.pt\"))\n",
    "generator_bhs.load_state_dict(torch.load(\"BHSGAN_CIFAR_3/generator.pt\"))\n",
    "generator_kl.load_state_dict(torch.load(\"KLGAN_CIFAR_2/generator.pt\"))\n",
    "generator_rvkl.load_state_dict(torch.load(\"RVKLGAN_CIFAR_2/generator.pt\"))\n",
    "generator_gan.load_state_dict(torch.load(\"GAN_CIFAR_2/generator.pt\"))\n",
    "generator_pearson.load_state_dict(torch.load(\"PEARSONGAN_CIFAR/generator.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cda3b985-9c1b-445c-8c3e-9f1932960693",
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate images\n",
    "# generate images\n",
    "generated_images_wasserstein = generator_wasserstein(fid_noise)\n",
    "generated_images_bhs = generator_bhs(fid_noise_bhs)\n",
    "generated_images_kl = generator_kl(fid_noise)\n",
    "generated_images_rvkl = generator_rvkl(fid_noise)\n",
    "generated_images_gan = generator_gan(fid_noise)\n",
    "generated_images_pearson = generator_pearson(fid_noise)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5041d277-c113-41fc-a102-9a0ced28c1c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_tensor_images(\n",
    "    image_tensor,\n",
    "    num_images=16,\n",
    "    size=(1, 28, 28),\n",
    "    save_fig=False,\n",
    "    epoch=0,\n",
    "    unflat=True,\n",
    "    tanh_activation=False,\n",
    "):\n",
    "    images_to_plot = image_tensor.detach().cpu()\n",
    "\n",
    "    if unflat:\n",
    "        images_to_plot = images_to_plot.view(-1, *size)\n",
    "\n",
    "    if tanh_activation:\n",
    "        #images_to_plot = images_to_plot * 0.5 + 0.5\n",
    "        pass\n",
    "\n",
    "    image_grid = make_grid(images_to_plot[:num_images], nrow=4, normalize=True, value_range=(-1,1))\n",
    "    plt.axis(\"off\")\n",
    "    #plt.tight_layout()\n",
    "    plt.imshow(image_grid.permute(1, 2, 0).squeeze())\n",
    "    if save_fig:\n",
    "        plt.savefig(\"image_at_epoch_{:04d}.png\".format(epoch))\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdf5ed0f-c23a-4b49-85fa-e278be27c74a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_tensor_images(real_images, size=(3,64,64),  tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9f48103-81e9-4ed7-bbcf-a24c952f8a9a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#generated_images_wasserstein = trained_wgan.generator(fid_noise)\n",
    "plot_tensor_images(generated_images_wasserstein, size=(3,64,64),  tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "179bf882-a2c1-423b-909e-bcae8b276a93",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_tensor_images(generated_images_bhs, size=(3,64,64), tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dee2b67a-de0b-4c58-b463-979266cee4e3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plot_tensor_images(generated_images_kl, size=(3,64,64), tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcb5a0a8-248c-4ac7-baef-5bba7d84aba2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plot_tensor_images(generated_images_rvkl, size=(3,64,64), tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59b2bd4e-9c51-4a07-ac0d-6b4566181058",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plot_tensor_images(generated_images_gan, size=(3,64,64), tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69f421d0-54e5-4bef-9ca2-a5ac0c943ddb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plot_tensor_images(generated_images_pearson, size=(3,64,64), tanh_activation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17e7befb-78b3-4a68-934a-286b4965017d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_ipm = calculate_frechet(dataloader_fid , trained_ipmgan.generator, model, device, 200)\n",
    "fid_ipm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a738d8d3-2874-4821-befa-b53a840a1be9",
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_w = calculate_frechet(dataloader_fid , trained_wgan.generator, model, device, 200)\n",
    "fid_w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "838934a3-9f56-4e2e-8753-28dc46b8e05e",
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_bhs = calculate_frechet(dataloader_fid , trained_bhsgan.generator, model, device, 200) #np.mean([calculate_frechet(real_images , generated_images_bhs, model) for real_images, _ in dataloader])\n",
    "fid_bhs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8955d2a3-167d-40a7-aa7a-3184f22c9d91",
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_bhs2 = calculate_frechet(dataloader_fid , trained_bhsgan.generator, model, device, 200) #np.mean([calculate_frechet(real_images , generated_images_bhs, model) for real_images, _ in dataloader])\n",
    "fid_bhs2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "226fd139-22f2-4496-a077-a043d3eaf162",
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_kl = calculate_frechet(dataloader_fid , generator_kl, model, device, 200)\n",
    "fid_kl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14e37eae-b25f-470c-89ab-6176bc3532f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_rvkl = calculate_frechet(dataloader_fid , generator_rvkl, model, device, 200)\n",
    "fid_rvkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f6b1ede-df2d-4fd4-b9e0-767cb383c114",
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_gan =  calculate_frechet(dataloader_fid , generator_gan, model, device, 200)\n",
    "fid_gan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3259fb4-f390-4962-beff-4d00ff05e66c",
   "metadata": {},
   "outputs": [],
   "source": [
    "fid_pearson = calculate_frechet(dataloader_fid , generator_pearson, model, device, 200)\n",
    "fid_pearson"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "173e1fde-1b42-472c-b5cd-239c1ee40b2d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#fid_scores_wasserstein = []\n",
    "#fid_scores_bhs = []\n",
    "#fid_scores_kl = []\n",
    "#fid_scores_rvkl = []\n",
    "#fid_scores_gan = []\n",
    "#fid_scores_pearson = []\n",
    "\n",
    "for seed in range(79, 101):\n",
    "    torch.manual_seed(seed)\n",
    "    random.seed(seed)\n",
    "    print(f\"iteration {seed}\")\n",
    "    fid_scores_wasserstein.append(calculate_frechet(dataloader_fid , generator_wasserstein, model, device, 200))\n",
    "    fid_scores_bhs.append(calculate_frechet(dataloader_fid , generator_bhs, model, device, 200))\n",
    "    fid_scores_kl.append(calculate_frechet(dataloader_fid , generator_kl, model, device, 200))\n",
    "    fid_scores_rvkl.append(calculate_frechet(dataloader_fid , generator_rvkl, model, device, 200))\n",
    "    fid_scores_gan.append(calculate_frechet(dataloader_fid , generator_gan, model, device, 200))\n",
    "    fid_scores_pearson.append(calculate_frechet(dataloader_fid , generator_pearson, model, device, 200))\n",
    "\n",
    "results_dict = {\"mean_wasserstein\": np.mean(fid_scores_wasserstein),\n",
    "                \"two_std_wasserstein\": np.std(fid_scores_wasserstein),\n",
    "                \"mean_bhs\": np.mean(fid_scores_bhs),\n",
    "                \"two_std_bhs\": np.std(fid_scores_bhs),\n",
    "                \"mean_kl\": np.mean(fid_scores_kl),\n",
    "                \"two_std_kl\": np.std(fid_scores_kl),\n",
    "                \"mean_rvkl\": np.mean(fid_scores_rvkl),\n",
    "                \"two_std_rvkl\": np.std(fid_scores_rvkl),\n",
    "                \"mean_gan\": np.mean(fid_scores_gan),\n",
    "                \"two_std_gan\": np.std(fid_scores_gan),\n",
    "                \"mean_pearson\": np.mean(fid_scores_gan),\n",
    "                \"two_std_pearson\": np.std(fid_scores_pearson)\n",
    "               }\n",
    "\n",
    "results_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42995a9c-5946-46e3-8b05-013eda0a4c41",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dict = {\"mean_wasserstein\": np.mean(fid_scores_wasserstein),\n",
    "                \"two_std_wasserstein\": np.std(fid_scores_wasserstein),\n",
    "                \"mean_bhs\": np.mean(fid_scores_bhs),\n",
    "                \"two_std_bhs\": np.std(fid_scores_bhs),\n",
    "                \"mean_kl\": np.mean(fid_scores_kl),\n",
    "                \"two_std_kl\": np.std(fid_scores_kl),\n",
    "                \"mean_rvkl\": np.mean(fid_scores_rvkl),\n",
    "                \"two_std_rvkl\": np.std(fid_scores_rvkl),\n",
    "                \"mean_gan\": np.mean(fid_scores_gan),\n",
    "                \"two_std_gan\": np.std(fid_scores_gan),\n",
    "                \"mean_pearson\": np.mean(fid_scores_gan),\n",
    "                \"two_std_pearson\": np.std(fid_scores_pearson)\n",
    "               }\n",
    "\n",
    "results_dict"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bhsgan",
   "language": "python",
   "name": "bhsgan"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
