{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e067d815",
   "metadata": {},
   "outputs": [],
   "source": [
    "from GANs_mnist import *\n",
    "import os\n",
    "from args_mnist import get_parser\n",
    "from plot_gan_training import *\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76efa8d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = get_parser()\n",
    "opt = parser.parse_args(args=[])  \n",
    "\n",
    "img_size = opt.img_size\n",
    "channels = opt.channels\n",
    "# device = opt.device\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "# device = torch.device(\"cuda\")\n",
    "\n",
    "minimax = opt.minimax\n",
    "# minimax = False\n",
    "latent_dim = opt.latent_dim\n",
    "n_critic = 5\n",
    "n_epochs = 2500\n",
    "batch_size = opt.batch_size\n",
    "# lr_G = opt.lr_G\n",
    "# lr_C = opt.lr_C\n",
    "lr = opt.lr\n",
    "# n_cpu = opt.n_cpu\n",
    "clip_value = opt.clip_value\n",
    "lambda_gp = opt.lambda_gp\n",
    "# step = opt.step\n",
    "step = 250\n",
    "t = opt.t\n",
    "normalize = opt.normalize\n",
    "beta = opt.beta\n",
    "# beta = [0.5, 0.1]\n",
    "img_rows = opt.img_rows\n",
    "img_cols = opt.img_cols\n",
    "\n",
    "#Magnitude Overlap parameters\n",
    "max_t = 10\n",
    "min_t = 0\n",
    "steps = 100\n",
    "num_samples = 100\n",
    "overlap_normalize =  False\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ce5766dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate WGAN\n",
    "dataset_name='MNIST'\n",
    "# folder_name = f'GAN_{dataset_name}_results'\n",
    "folder_name = f'GAN_{dataset_name}_results/Paper_Results'\n",
    "os.makedirs(folder_name, exist_ok=True)\n",
    "epochs_to_plot = list(range(step, n_epochs + 1, step))  # e.g., [500, 1000, ..., 10000]\n",
    "# step_name = 'Epoch'\n",
    "step_name = 'Step'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fcec277",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing MNIST\n",
      "Critic loaded from GAN/MNIST_WGAN_mlp/500epochs/critic.pth.\n",
      "Generator loaded from GAN/MNIST_WGAN_mlp/500epochs/generator.pth.\n",
      "Loss and gradient norms loaded from GAN/MNIST_WGAN_mlp/500epochs/loss_lists.h5.\n",
      "WGAN_mlp Training time: 0.05 seconds\n"
     ]
    }
   ],
   "source": [
    "name='WGAN-GP_mlp'\n",
    "\n",
    "wgan_gp = GAN(\n",
    "    batch_size=batch_size,\n",
    "    lr=lr,\n",
    "    latent_dim=latent_dim,\n",
    "    img_size=img_size,\n",
    "    channels=channels,\n",
    "    n_critic=n_critic,\n",
    "    step=step,\n",
    "    device=device,\n",
    "    name=name,\n",
    "    dataset_name=dataset_name\n",
    ")\n",
    "\n",
    "# Run training\n",
    "start_time = time.time()\n",
    "loss_C_list, loss_G_list, generator_grad_norm_list_wgan_gp, wgan_gp_gen_data = wgan_gp.train_WGAN_GP(\n",
    "    n_epochs=n_epochs,\n",
    "    n_critic=n_critic,\n",
    "    batch_size=batch_size,\n",
    "    lambda_gp=lambda_gp\n",
    ")\n",
    "end_time = time.time()\n",
    "elapsed_time = end_time - start_time\n",
    "print(f\"{name} Training time: {elapsed_time:.2f} seconds\", flush=True)\n",
    "# real_data = sample_real_data(data_dim = data_dim, batch_size = number_of_samples, mean = mean, std = std).to(device)\n",
    "test_name = f'{name}-test{n_epochs}epochs'\n",
    "pdf_path = f\"{folder_name}/{test_name}_plots.pdf\"\n",
    "# wgan_gp.compute_magnitude_overlap(n_epochs, max_t = max_t, min_t = min_t, steps = steps, num_samples = num_samples, normalize = overlap_normalize)\n",
    "\n",
    "\n",
    "# 1. Training Losses\n",
    "plot_training_losses(loss_G_list, loss_C_list, folder_name, pdf_path, name=test_name, step_name=step_name)\n",
    "\n",
    "# 2.a) Generator Gradient Norms\n",
    "plot_generator_grad_norms(\n",
    "    generator_grad_norm_list_wgan_gp, folder_name, pdf_path, name=test_name, step_name=step_name)\n",
    "\n",
    "\n",
    "# 2.b) Generator Gradient Norms different visualizations\n",
    "plot_generator_grad_norms(\n",
    "    generator_grad_norm_list_wgan_gp, folder_name, pdf_path, name=test_name, visualization='log', step_name=step_name)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_mag",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
