{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e067d815",
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mMagGN_CelebA\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m *\n\u001b[32m      2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mCelebA_Experiment\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01margs_MagGN_CelebA\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get_parser\n\u001b[32m      3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mplot_gan_training\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m *\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/path/Magnitude-Distance/CelebA_Experiment/MagGN_CelebA.py:4\u001b[39m\n\u001b[32m      2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mos\u001b[39;00m\n\u001b[32m      3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmath\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorchvision\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtransforms\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtransforms\u001b[39;00m\n\u001b[32m      5\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorchvision\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m datasets\n\u001b[32m      6\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torchvision/__init__.py:5\u001b[39m\n\u001b[32m      2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mwarnings\u001b[39;00m\n\u001b[32m      3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmodulefinder\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Module\n\u001b[32m----> \u001b[39m\u001b[32m5\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n\u001b[32m      6\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorchvision\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m datasets, io, models, ops, transforms, utils\n\u001b[32m      8\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mextension\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m _HAS_OPS\n",
      "\u001b[36mFile \u001b[39m\u001b[32m<frozen importlib._bootstrap>:1176\u001b[39m, in \u001b[36m_find_and_load\u001b[39m\u001b[34m(name, import_)\u001b[39m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m<frozen importlib._bootstrap>:1138\u001b[39m, in \u001b[36m_find_and_load_unlocked\u001b[39m\u001b[34m(name, import_)\u001b[39m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m<frozen importlib._bootstrap>:1078\u001b[39m, in \u001b[36m_find_spec\u001b[39m\u001b[34m(name, path, target)\u001b[39m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m<frozen importlib._bootstrap_external>:1507\u001b[39m, in \u001b[36mfind_spec\u001b[39m\u001b[34m(cls, fullname, path, target)\u001b[39m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m<frozen importlib._bootstrap_external>:1479\u001b[39m, in \u001b[36m_get_spec\u001b[39m\u001b[34m(cls, fullname, path, target)\u001b[39m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m<frozen importlib._bootstrap_external>:1615\u001b[39m, in \u001b[36mfind_spec\u001b[39m\u001b[34m(self, fullname, target)\u001b[39m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m<frozen importlib._bootstrap_external>:147\u001b[39m, in \u001b[36m_path_stat\u001b[39m\u001b[34m(path)\u001b[39m\n",
      "\u001b[31mKeyboardInterrupt\u001b[39m: "
     ]
    }
   ],
   "source": [
    "from MagGN_CelebA import *\n",
    "from args_GAN_CelebA import get_parser\n",
    "from plot_gan_training import *\n",
    "import time\n",
    "import sys\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76efa8d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = get_parser()\n",
    "opt = parser.parse_args(args=[])  \n",
    "\n",
    "img_size = opt.img_size\n",
    "channels = opt.channels\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "latent_dim = opt.latent_dim\n",
    "n_epochs = opt.n_epochs\n",
    "batch_size = opt.batch_size\n",
    "lr = opt.lr\n",
    "step = opt.step\n",
    "normalize = opt.normalize\n",
    "\n",
    "#Magnitude Overlap parameters\n",
    "max_t = 10\n",
    "min_t = 0\n",
    "steps = 100\n",
    "num_samples = 100\n",
    "overlap_normalize =  False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce5766dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate WGAN\n",
    "dataset_name='CELEBA'\n",
    "# folder_name = f'GAN_{dataset_name}_results'\n",
    "folder_name = f'GAN_{dataset_name}_results/Paper_Results'\n",
    "os.makedirs(folder_name, exist_ok=True)\n",
    "epochs_to_plot = list(range(step, n_epochs + 1, step))  # e.g., [500, 1000, ..., 10000]\n",
    "# step_name = 'Epoch'\n",
    "step_name = 'Step'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97686a27",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_list = [0.03, 0.9, 6.0, 10.0]\n",
    "epoch_list = [1, 81, 201, 281]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fcec277",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing CELEBA dataset\n",
      "Critic path: GAN/CELEBA_WGAN_improved/100epochs/critic.pth, exist: False\n",
      "Generator path: GAN/CELEBA_WGAN_improved/100epochs/generator.pth, exist: False\n",
      "Loss and gradient norms path: GAN/CELEBA_WGAN_improved/100epochs/loss_lists.h5, exist: False\n",
      "Starting training WGAN_improved for 100 epochs with batch size 32 and 5 critic updates.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "libc++abi: libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe\n",
      "terminating due to uncaught exception of type std::__1::system_error: Broken pipelibc++abi: libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe\n",
      "terminating due to uncaught exception of type std::__1::system_error: Broken pipe\n",
      "\n",
      "Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x129ae4860>\n",
      "Traceback (most recent call last):\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/utils/data/dataloader.py\", line 1478, in __del__\n",
      "    self._shutdown_workers()\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/utils/data/dataloader.py\", line 1442, in _shutdown_workers\n",
      "    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/multiprocessing/process.py\", line 149, in join\n",
      "    res = self._popen.wait(timeout)\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/multiprocessing/popen_fork.py\", line 43, in wait\n",
      "    return self.poll(os.WNOHANG if timeout == 0.0 else 0)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/multiprocessing/popen_fork.py\", line 27, in poll\n",
      "    pid, sts = os.waitpid(self.pid, flag)\n",
      "               ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/utils/data/_utils/signal_handling.py\", line 66, in handler\n",
      "    _error_if_any_worker_fails()\n",
      "RuntimeError: DataLoader worker (pid 50957) is killed by signal: Abort trap: 6. \n"
     ]
    }
   ],
   "source": [
    "# name='MagGN_improved'\n",
    "name='MagGN_conv'\n",
    "\n",
    "maggn = GAN(\n",
    "batch_size=batch_size,\n",
    "lr=lr,\n",
    "latent_dim=latent_dim,\n",
    "img_size=img_size,\n",
    "channels=channels,\n",
    "step=step,\n",
    "device=device,\n",
    "name=name,\n",
    "dataset_name=dataset_name\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7a5c890",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run training\n",
    "start_time = time.time()\n",
    "\n",
    "loss_G_list, generator_grad_norm_list_mag_t_scheduler, maggn_gen_data = maggn.train_MagGN_multiscale(\n",
    "n_epochs=n_epochs,\n",
    "batch_size=batch_size,\n",
    "normalize=normalize,\n",
    "t_list=t_list,\n",
    "epoch_list=epoch_list,\n",
    "loss_normalize=True,\n",
    "feature_space=True,\n",
    "avg_pool_size=8\n",
    "# hybrid_coeff = 0.3\n",
    ")\n",
    "end_time = time.time()\n",
    "elapsed_time = end_time - start_time\n",
    "print(f\"{name} Training time: {elapsed_time:.2f} seconds\", flush=True)\n",
    "\n",
    "maggn.compute_magnitude_overlap(n_epochs, max_t = max_t, min_t = min_t, steps = steps, num_samples = num_samples, normalize = overlap_normalize)\n",
    "# test_name = f'{name}_with_multiscale_t_{\"_\".join([str(t) for t in t_list])}'\n",
    "\n",
    "# if normalize:\n",
    "#     test_name = f'Normalized{test_name}'\n",
    "# t_name = f'multiscale_t_{\"_\".join([str(t) for t in t_list])}'\n",
    "test_name = maggn.model_name\n",
    "\n",
    "# 1. Training Losses\n",
    "plot_training_losses(loss_G_list, folder_name = folder_name, name=test_name, step_name=step_name)\n",
    "\n",
    "# 2.a) Generator Gradient Norms\n",
    "plot_generator_grad_norms(\n",
    "    generator_grad_norm_list_mag_t_scheduler, folder_name, name=test_name, step_name=step_name)\n",
    "\n",
    "\n",
    "# 2.b) Generator Gradient Norms different visualizations\n",
    "plot_generator_grad_norms(\n",
    "    generator_grad_norm_list_mag_t_scheduler, folder_name, name=test_name, visualization='log', step_name=step_name)\n",
    "\n",
    "  "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_mag",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
