{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6734f8f1-d299-4e7e-a3ae-5da0fa64ac02",
   "metadata": {},
   "source": [
    " In order to first evaluate the bhdsgan I will use the following procedure:\n",
    " 1. Sample from a truncated normal distribution with high variance --> gan should perform poorly\n",
    " 2. Train gan on sampled data\n",
    " 3. Generate new data and estimate density\n",
    " 4. Compare actual and estimated density\n",
    " 5. Sample from a truncated normal distribution with low variance --> gan should perform quite well\n",
    " 6.Train gan on sampled data\n",
    " 7. Generate new data and estimate density\n",
    " 8. Compare actual and estimated density"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d81cccca-a527-4dfc-a371-0fa611edcce5",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b9314c0-cae7-4c51-be3e-02d7c570cde5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import torch\n",
    "from torch import nn\n",
    "from bhsgan import DiscriminatorBhsSim, GeneratorBhsSim,GeneratorBhsSimNormal\n",
    "from dataset import SampleDataset\n",
    "from ipmbhsgan import DiscriminatorIpmSim, GeneratorIpmSim\n",
    "from trainer import (Trainer, TrainingParams, get_dis_loss_bhs,\n",
    "                     get_dis_loss_ipm, get_dis_loss_wasserstein,\n",
    "                     get_gen_loss_bhs, get_gen_loss_ipm,\n",
    "                     get_gen_loss_wasserstein,get_dis_loss_kl, get_gen_loss_kl, get_dis_loss_rkl, get_gen_loss_rkl,\n",
    "                     get_dis_loss_gan, get_gen_loss_gan, get_dis_loss_p, get_gen_loss_p)\n",
    "from utils import get_device, get_noise, init_weights, plot_tensor_images, plot_losses, Positive, save_models_state_dict, load_model_state_dict, RevKlActivation, GanGanActivation\n",
    "from wgan import DiscriminatorWassersteinSim, GeneratorWassersteinSim, GeneratorWassersteinSimNormal\n",
    "import random\n",
    "\n",
    "torch.set_default_dtype(torch.float64)\n",
    "torch.manual_seed(96)\n",
    "random.seed(96)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0613cb97-6864-48ed-9196-6427702ea9ec",
   "metadata": {},
   "source": [
    "## sample from beta distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "185c5974-45de-49ad-a4a2-3c475a144b8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample from beta distributions\n",
    "mean_1_low = 0\n",
    "sd_1_low = 1\n",
    "mean_2_low = 5\n",
    "sd_2_low = 1\n",
    "mean_1_high = 0\n",
    "sd_1_high = 4\n",
    "mean_2_high = 5\n",
    "sd_2_high = 4\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f3f44f2-8b76-4771-9926-685e76e9c0fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# beta_sample_high = np.random.beta(a_high, b_high, 10000)\n",
    "normal_sample_1_high = np.random.choice(np.random.normal(mean_1_high, sd_1_high, 10000), 5000)\n",
    "normal_sample_2_high = np.random.choice(np.random.normal(mean_2_high, sd_2_high, 10000), 5000)\n",
    "sample_high = np.concatenate([normal_sample_1_high, normal_sample_2_high])\n",
    "beta_sample_high = np.reshape(sample_high, (5000, 2))\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(beta_sample_high)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f6134cb-ff21-420a-b4b4-e00a3733966a",
   "metadata": {},
   "outputs": [],
   "source": [
    "normal_sample_1_low = np.random.choice(np.random.normal(mean_1_low, sd_1_low, 10000), 5000)\n",
    "normal_sample_2_low = np.random.choice(np.random.normal(mean_2_low, sd_2_low, 10000), 5000)\n",
    "sample_low = np.concatenate([normal_sample_1_low, normal_sample_2_low])\n",
    "beta_sample_low = np.reshape(sample_low, (5000, 2))\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(beta_sample_low)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5fd9e87-e7a0-4ca4-bb40-febfc477eff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_set_high = SampleDataset(beta_sample_high)\n",
    "training_set_low = SampleDataset(beta_sample_low)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "507f402b-ff2a-4a2f-920e-d66656b14515",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_noise = get_noise(10000, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "154f5dac-219a-4725-aa4f-62f9d622e817",
   "metadata": {},
   "source": [
    "## Train BHS GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1021e08-f96c-4164-9409-8fc42fd2f55d",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_params = TrainingParams(lr_dis=0.0002, lr_gen=0.0002, num_epochs=15, num_dis_updates=4, num_gen_updates = 3, beta_1=0.5, batch_size=128)\n",
    " \n",
    "# get device to train on\n",
    "device = \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3994ec9-b43e-4bc0-8d5a-0d50b164fe55",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the dataloaders\n",
    "dataloader_high = torch.utils.data.DataLoader(training_set_high, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "dataloader_low = torch.utils.data.DataLoader(training_set_low, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ab2132c-6fc6-48d0-8c37-3be2ccfefac5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the dataloaders\n",
    "dataloader_high = torch.utils.data.DataLoader(training_set_high, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "dataloader_low = torch.utils.data.DataLoader(training_set_low, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "# initialize nets\n",
    "final_activation = Positive\n",
    "generator_high = GeneratorBhsSimNormal()\n",
    "discriminator_high = DiscriminatorBhsSim(final_activation)\n",
    "generator_low = GeneratorBhsSimNormal()\n",
    "discriminator_low = DiscriminatorBhsSim(final_activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07c8e5b2-9c1f-4285-969c-f9223dada271",
   "metadata": {},
   "outputs": [],
   "source": [
    "# init Trainer\n",
    "trainer_high = Trainer(training_params, generator_high, discriminator_high)\n",
    "trainer_low = Trainer(training_params, generator_low, discriminator_low)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f082927-0828-4586-bb1e-2b5e99345cdf",
   "metadata": {},
   "source": [
    "### Train on high variance samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d30094a3-373e-4500-8b4c-1a991c786147",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_bhs_high = trainer_high.train_gan(dataloader_high, get_dis_loss_bhs, get_gen_loss_bhs, False, print_intermediate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a274af86-cbf1-4c49-b681-e21faf180268",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data = trained_bhs_high.generator(test_noise)\n",
    "generated_sample = torch.reshape(generated_data, (1, 20000)).detach().numpy().ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f49a3a3-6f2f-4105-afe3-372e28c59715",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot resulting density\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(generated_sample, bins=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b414599-f38c-4126-b960-8044a4cd5bc0",
   "metadata": {},
   "source": [
    "### Train on low variance samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1881c63-5518-4a52-9c32-0dfe5d7d6c8c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_bhs_low = trainer_low.train_gan(dataloader_low, get_dis_loss_bhs, get_gen_loss_bhs, False, print_intermediate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e4e4052-5aae-486e-966f-4dc10b76460a",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data_low = trained_bhs_low.generator(test_noise)\n",
    "generated_sample_low = torch.reshape(generated_data_low, (1, 20000)).detach().numpy().ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8d7efd3-e8b6-4ef1-bd1a-7e70d297fd45",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot resulting density\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(generated_sample_low, bins=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b898b706-0e03-4a02-9b0a-b788df8ec052",
   "metadata": {},
   "source": [
    "## Wasserstein GAN "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1630b6c0-5962-4942-840f-16b2f1b8c760",
   "metadata": {},
   "source": [
    "### High Variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "429e81e2-215e-45fc-8ebf-448a7150dfd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader_wasserstein_high = torch.utils.data.DataLoader(training_set_high, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "generator_wasserstein_high = GeneratorWassersteinSimNormal()\n",
    "discriminator_wasserstein_high = DiscriminatorWassersteinSim()\n",
    "trainer_wgan_high = Trainer(training_params, generator_wasserstein_high, discriminator_wasserstein_high)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85d8cc4e-c7f6-48d3-9207-93fbad1fbf04",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_wgan_high = trainer_wgan_high.train_gan(dataloader_wasserstein_high, get_dis_loss_wasserstein, get_gen_loss_wasserstein, True, print_intermediate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13123456-9972-4a9b-842e-4cbb21259cd8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "generated_data_wasserstein_high = trained_wgan_high.generator(test_noise)\n",
    "generated_sample_wasserstein_high = torch.reshape(generated_data_wasserstein_high, (1, 20000)).detach().numpy().ravel()\n",
    "# plot resulting density\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(generated_sample_wasserstein_high, bins=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be2b5ad8-3a6b-46a4-ae27-8d4d0c410ae7",
   "metadata": {},
   "source": [
    "### Low Variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e49b863a-60ab-4534-a2a0-d4749e51136d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader_wasserstein_low = torch.utils.data.DataLoader(training_set_low, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "generator_wasserstein_low = GeneratorWassersteinSimNormal()\n",
    "discriminator_wasserstein_low = DiscriminatorWassersteinSim()\n",
    "trainer_wgan_low = Trainer(training_params, generator_wasserstein_low, discriminator_wasserstein_low)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6845be2-a440-4b4d-8e7b-cfce2074832f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_wgan_low = trainer_wgan_low.train_gan(dataloader_wasserstein_low, get_dis_loss_wasserstein, get_gen_loss_wasserstein, True, print_intermediate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd1a0122-3e4f-41f9-8f7e-0c5b385d655b",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data_wasserstein_low = trained_wgan_low.generator(test_noise)\n",
    "generated_sample_wasserstein_low = torch.reshape(generated_data_wasserstein_low, (1, 20000)).detach().numpy().ravel()\n",
    "# plot resulting density\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(generated_sample_wasserstein_low, bins=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9071f1ff-b397-438f-9768-b1b28743fa3c",
   "metadata": {},
   "source": [
    "## GAN GAN "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64b04606-0564-40e2-afdb-2c51ee53a964",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the dataloaders\n",
    "dataloader_high = torch.utils.data.DataLoader(training_set_high, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "dataloader_low = torch.utils.data.DataLoader(training_set_low, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "# initialize nets\n",
    "final_activation = nn.Sigmoid\n",
    "generator_gan_high = GeneratorBhsSimNormal()\n",
    "discriminator_gan_high = DiscriminatorBhsSim(final_activation)\n",
    "generator_gan_low = GeneratorBhsSimNormal()\n",
    "discriminator_gan_low = DiscriminatorBhsSim(final_activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc6055df-f4ac-406f-a4c6-6cb5d28f7fe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "### High Variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7053ac38-5a90-48f3-b9bf-b183924abe36",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_gan_high = Trainer(training_params, generator_gan_high, discriminator_gan_high)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd6bf1f4-81e9-4731-a45a-7229ed5617ad",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_gan_high = trainer_gan_high.train_gan(dataloader_high, get_dis_loss_gan, get_gen_loss_gan, False, print_intermediate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5634783f-f017-4c97-a44e-5f57e8a393e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data_gan_high = trained_gan_high.generator(test_noise)\n",
    "generated_sample_gan_high = torch.reshape(generated_data_gan_high, (1, 20000)).detach().numpy().ravel()\n",
    "# plot resulting density\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(generated_sample_gan_high, bins=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4729fb3-1474-4fb7-9cd6-76542f68d6ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Low Variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9dcdef9-ff2c-452f-9299-5bcb2a9ce2b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_gan_low = Trainer(training_params, generator_gan_low, discriminator_gan_low)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "287dc1a4-0a54-4583-a145-fe5db03b3b59",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_gan_low = trainer_gan_low.train_gan(dataloader_low, get_dis_loss_gan, get_gen_loss_gan, False, print_intermediate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7872b510-8422-40ab-bd05-dc8160e567fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data_gan_low = trained_gan_low.generator(test_noise)\n",
    "generated_sample_gan_low = torch.reshape(generated_data_gan_low, (1, 20000)).detach().numpy().ravel()\n",
    "# plot resulting density\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(generated_sample_gan_low, bins=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43cd61ca-fb98-4fbd-a800-3550de6ad721",
   "metadata": {},
   "source": [
    "## Pearson GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8a1816b-91e4-425a-9e8e-8c0cc4307dc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the dataloaders\n",
    "dataloader_high = torch.utils.data.DataLoader(training_set_high, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "dataloader_low = torch.utils.data.DataLoader(training_set_low, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "# initialize nets\n",
    "final_activation = nn.Identity\n",
    "generator_pearson_high = GeneratorBhsSimNormal()\n",
    "discriminator_pearson_high = DiscriminatorBhsSim(final_activation)\n",
    "generator_pearson_low = GeneratorBhsSimNormal()\n",
    "discriminator_pearson_low = DiscriminatorBhsSim(final_activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6393bf1f-25d8-40ee-ac09-882b408a17a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "### High Variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62e96d01-e2c1-4fd5-8d8e-407f509b0b8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_pearson_high = Trainer(training_params, generator_pearson_high, discriminator_pearson_high)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "346064b6-28bc-4505-9050-199a33a79d72",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_pearson_high = trainer_pearson_high.train_gan(dataloader_high, get_dis_loss_p, get_gen_loss_p, False, print_intermediate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce42a82c-f667-4f5e-8718-cbe74562b147",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data_pearson_high = trained_pearson_high.generator(test_noise)\n",
    "generated_sample_pearson_high = torch.reshape(generated_data_pearson_high, (1, 20000)).detach().numpy().ravel()\n",
    "# plot resulting density\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(generated_sample_pearson_high, bins=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ef66c41-40e7-40bb-8e37-84e1422e5f2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Low Variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7dea231-9e54-4943-8434-de6195b5cb29",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_pearson_low = Trainer(training_params, generator_pearson_low, discriminator_pearson_low)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ef4f659-dd41-446d-98aa-fd7828c86503",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_pearson_low = trainer_pearson_low.train_gan(dataloader_low, get_dis_loss_p, get_gen_loss_p, False, print_intermediate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be422007-d076-4257-bdad-50309b2529b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data_pearson_low = trained_pearson_low.generator(test_noise)\n",
    "generated_sample_pearson_low = torch.reshape(generated_data_pearson_low, (1, 20000)).detach().numpy().ravel()\n",
    "# plot resulting density\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(generated_sample_pearson_low, bins=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d58eb419-5416-42f1-af68-d27a60b85443",
   "metadata": {},
   "source": [
    "## KL GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1890094d-6095-4d48-b299-ce4978e2e83e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the dataloaders\n",
    "dataloader_high = torch.utils.data.DataLoader(training_set_high, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "dataloader_low = torch.utils.data.DataLoader(training_set_low, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "# initialize nets\n",
    "final_activation = nn.Identity\n",
    "generator_kl_high = GeneratorBhsSimNormal()\n",
    "discriminator_kl_high = DiscriminatorBhsSim(final_activation)\n",
    "generator_kl_low = GeneratorBhsSimNormal()\n",
    "discriminator_kl_low = DiscriminatorBhsSim(final_activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "277ee51d-f795-4340-ba43-ef1891614acd",
   "metadata": {},
   "outputs": [],
   "source": [
    "### High Variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeb7ec83-1f1e-4478-aca9-4b3157a19f32",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_kl_high = Trainer(training_params, generator_kl_high, discriminator_kl_high)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df3e9c0a-9e95-49fb-8bf0-ceb5da6997b5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_kl_high = trainer_kl_high.train_gan(dataloader_high, get_dis_loss_kl, get_gen_loss_kl, False, print_intermediate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a54d28a-48b8-49c4-bfc7-618431a20827",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data_kl_high = trained_kl_high.generator(test_noise)\n",
    "generated_sample_kl_high = torch.reshape(generated_data_kl_high, (1, 20000)).detach().numpy().ravel()\n",
    "# plot resulting density\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(generated_sample_kl_high, bins=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b54ae430-14ca-4943-a473-a7717b820ab3",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Low Variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "607f9c45-d203-49c3-b78a-4940b59bbc37",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_kl_low = Trainer(training_params, generator_kl_low, discriminator_kl_low)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42790624-8578-49cf-88dc-978e2c33def4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_kl_low = trainer_kl_low.train_gan(dataloader_low, get_dis_loss_kl, get_gen_loss_kl, False, print_intermediate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "949a9f44-03b3-41dc-a1f5-0c97b18ee0f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data_kl_low = trained_kl_low.generator(test_noise)\n",
    "generated_sample_kl_low = torch.reshape(generated_data_kl_low, (1, 20000)).detach().numpy().ravel()\n",
    "# plot resulting density\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(generated_sample_kl_low, bins=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "262a0307-a4c3-4a7e-a90c-3e99be1760b4",
   "metadata": {},
   "source": [
    "## RV KL GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "523d5c2a-1beb-47bb-8edb-00edf6deecd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the dataloaders\n",
    "dataloader_high = torch.utils.data.DataLoader(training_set_high, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "dataloader_low = torch.utils.data.DataLoader(training_set_low, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "# initialize nets\n",
    "final_activation = RevKlActivation\n",
    "generator_rvkl_high = GeneratorBhsSimNormal()\n",
    "discriminator_rvkl_high = DiscriminatorBhsSim(final_activation)\n",
    "generator_rvkl_low = GeneratorBhsSimNormal()\n",
    "discriminator_rvkl_low = DiscriminatorBhsSim(final_activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3eed8ec5-d599-4ed9-a6a4-fd9571c77b56",
   "metadata": {},
   "outputs": [],
   "source": [
    "### High Variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2c3d57b-0e1c-4c63-a132-af6064b05597",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_rvkl_high = Trainer(training_params, generator_rvkl_high, discriminator_rvkl_high)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e41e9bc-21ca-4b02-a643-ff29ee5a1f99",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_rvkl_high = trainer_rvkl_high.train_gan(dataloader_high, get_dis_loss_rkl, get_gen_loss_rkl, False, print_intermediate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86785b6f-e3c1-474e-aa9c-f124645813ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data_rvkl_high = trained_rvkl_high.generator(test_noise)\n",
    "generated_sample_rvkl_high = torch.reshape(generated_data_rvkl_high, (1, 20000)).detach().numpy().ravel()\n",
    "# plot resulting density\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(generated_sample_rvkl_high, bins=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ba130e8-b145-4446-a52b-5c857cf980c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Low Variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49445071-6a39-465c-8715-895569590fa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_rvkl_low = Trainer(training_params, generator_rvkl_low, discriminator_rvkl_low)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e595c29-74b9-477f-8064-8027b56bef96",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_rvkl_low = trainer_rvkl_low.train_gan(dataloader_low, get_dis_loss_rkl, get_gen_loss_rkl, False, print_intermediate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02625f84-031b-483e-a9ee-690bca8b646d",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data_rvkl_low = trained_rvkl_low.generator(test_noise)\n",
    "generated_sample_rvkl_low = torch.reshape(generated_data_rvkl_low, (1, 20000)).detach().numpy().ravel()\n",
    "# plot resulting density\n",
    "sns.set_style('whitegrid')\n",
    "sns.histplot(generated_sample_rvkl_low, bins=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d822d82",
   "metadata": {},
   "source": [
    "## Now I train the IPM Version for the BHS GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e30257dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize nets\n",
    "generator_ipm = GeneratorIpmSim()\n",
    "discriminator_ipm = DiscriminatorIpmSim()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da9ecbce",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader_ipm = torch.utils.data.DataLoader(training_set_low, batch_size=training_params.batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "# init Trainer\n",
    "trainer_ipm_gan = Trainer(training_params, generator_ipm, discriminator_ipm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3f36156",
   "metadata": {},
   "outputs": [],
   "source": [
    "# training loop\n",
    "trained_ipm_gan = trainer_ipm_gan.train_gan(dataloader_ipm, get_dis_loss_ipm, get_gen_loss_ipm, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "326f4a78",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_data_ipm = trained_ipm_gan.generator(test_noise)\n",
    "generated_sample_ipm = torch.reshape(generated_data_ipm, (1, 1000)).detach().numpy().ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "732beb70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot resulting density\n",
    "sns.set_style('whitegrid')\n",
    "sns.kdeplot(generated_sample_ipm, bw=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4dd7aae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plt.figure(figsize=(10,5))\n",
    "plt.title(\"Generator and Discriminator Loss During Training\")\n",
    "plt.plot(trained_ipm_gan.generator_losses,label=\"G-Loss\")\n",
    "plt.plot(trained_ipm_gan.discriminator_losses,label=\"D-Loss\")\n",
    "plt.xlabel(\"iterations\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da20e06f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# check mean and varaince of generated data\n",
    "np.mean(generated_sample_ipm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e46f0bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.std(generated_sample_ipm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "575f63e0-b79d-48e9-a2f3-7ea92c8e60a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_points = np.linspace(-1, 2, 100)\n",
    "def conjugate(points):\n",
    "    conjugate_1 = 2 * (-1 + np.sqrt(1 + points)) * np.exp(-1 + np.sqrt(1 + points)) \n",
    "    conjugate_2 = 2 * (-1 - np.sqrt(1 + points)) * np.exp(-1 - np.sqrt(1 + points)) \n",
    "    return np.where(points >= 0, conjugate_1, conjugate_2)\n",
    "plt.plot(data_points, conjugate(data_points))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "876da1fd-fd50-49a2-9788-6816f69fb2a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(points):\n",
    "    f_1 = points * np.log(points)**2\n",
    "    f_2 = -points * np.log(points)**2\n",
    "    return np.where(points >= np.exp(-1), f_1, f_2)\n",
    "plt.plot(data_points, f(data_points))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35e55f5e-3e57-437b-af39-3dc8b5023fcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "2 * (-1 - np.sqrt(1 + -1)) * np.exp(-1 - np.sqrt(1 + -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86b7f303-f3ac-4edd-8b33-d23a1cd55e4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "conjugate(data_points)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bhsgan",
   "language": "python",
   "name": "bhsgan"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
