{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e067d815",
   "metadata": {},
   "outputs": [],
   "source": [
    "from MagGN_Cifar import *\n",
    "from args_GAN_Cifar import get_parser\n",
    "from plot_gan_training import *\n",
    "import time\n",
    "import sys\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "76efa8d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = get_parser()\n",
    "opt = parser.parse_args(args=[])  \n",
    "\n",
    "img_size = opt.img_size\n",
    "channels = opt.channels\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "latent_dim = opt.latent_dim\n",
    "n_epochs = opt.n_epochs\n",
    "batch_size = opt.batch_size\n",
    "lr = opt.lr\n",
    "opt_betas = opt.opt_betas\n",
    "step = opt.step\n",
    "normalize = opt.normalize\n",
    "\n",
    "#Magnitude Overlap parameters\n",
    "max_t = 10\n",
    "min_t = 0\n",
    "steps = 100\n",
    "num_samples = 100\n",
    "overlap_normalize =  False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ce5766dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate WGAN\n",
    "dataset_name='Cifar10'\n",
    "# folder_name = f'GAN_{dataset_name}_results'\n",
    "folder_name = f'GAN_{dataset_name}_results/Paper_Results'\n",
    "os.makedirs(folder_name, exist_ok=True)\n",
    "epochs_to_plot = list(range(step, n_epochs + 1, step))  # e.g., [500, 1000, ..., 10000]\n",
    "# step_name = 'Epoch'\n",
    "step_name = 'Step'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97686a27",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_list = [0.4, 0.6, 1.2, 2.4] #Updated\n",
    "epoch_list = [1, 201, 501, 701]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6fcec277",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing Cifar10\n",
      "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /Users/admin/path/Magnitude-Distance/cifar_data/cifar-10-python.tar.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 170498071/170498071 [00:16<00:00, 10575553.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /Users/admin/path/Magnitude-Distance/cifar_data/cifar-10-python.tar.gz to /Users/admin/path/Magnitude-Distance/cifar_data\n",
      "CIFAR-10 dataloader created with 5000 images\n"
     ]
    }
   ],
   "source": [
    "# name='MagGN_improved'\n",
    "name='MagGN_conv'\n",
    "\n",
    "maggn = GAN(\n",
    "batch_size=batch_size,\n",
    "lr=lr,\n",
    "opt_betas=opt_betas,\n",
    "latent_dim=latent_dim,\n",
    "img_size=img_size,\n",
    "channels=channels,\n",
    "step=step,\n",
    "device=device,\n",
    "name=name,\n",
    "dataset_name=dataset_name\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7a5c890",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generator path: MagGN_multiscale/Cifar10_feature_space_NormMagGN_conv-t_multiscale-0.01_0.3_2.0_10.0-epochs-1_201_501_701_with_normalized_loss_with_avgpool_4/1000epochs/generator.pth, exist: False\n",
      "Loss and gradient norms path: MagGN_multiscale/Cifar10_feature_space_NormMagGN_conv-t_multiscale-0.01_0.3_2.0_10.0-epochs-1_201_501_701_with_normalized_loss_with_avgpool_4/1000epochs/loss_lists.h5, exist: False\n",
      "Starting training feature_space_NormMagGN_conv-t_multiscale-0.01_0.3_2.0_10.0-epochs-1_201_501_701_with_normalized_loss_with_avgpool_4 for 1000 epochs with batch size 64..\n",
      "Interval for adding t values: [1, 201, 501, 701] epochs.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
      "  warnings.warn(msg)\n",
      "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /Users/admin/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n",
      "100%|██████████| 44.7M/44.7M [00:03<00:00, 12.0MB/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Feature extractor initialized.\n",
      "Epoch 1: Adding magnitude with 1 t values [0.01] for loss computation.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"<string>\", line 1, in <module>\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/multiprocessing/spawn.py\", line 122, in spawn_main\n",
      "    exitcode = _main(fd, parent_sentinel)\n",
      "               ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/multiprocessing/spawn.py\", line 132, in _main\n",
      "    self = reduction.pickle.load(from_parent)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/__init__.py\", line 1247, in <module>\n",
      "    import torch.backends.mps\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/backends/mps/__init__.py\", line 30, in <module>\n",
      "    from ..._refs import var_mean as _var_mean, native_group_norm as _native_group_norm\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/_refs/__init__.py\", line 4265, in <module>\n",
      "    @register_decomposition(aten.arange)\n",
      "     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/_decomp/__init__.py\", line 131, in decomposition_decorator\n",
      "    tree_map(register, aten_op)\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/utils/_pytree.py\", line 196, in tree_map\n",
      "    return tree_unflatten([fn(i) for i in flat_args], spec)\n",
      "                          ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/utils/_pytree.py\", line 196, in <listcomp>\n",
      "    return tree_unflatten([fn(i) for i in flat_args], spec)\n",
      "                           ^^^^^\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/_decomp/__init__.py\", line 128, in register\n",
      "    _add_op_to_registry(registry, op, fn)\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/_decomp/__init__.py\", line 44, in _add_op_to_registry\n",
      "    overloads.append(getattr(op, ol))\n",
      "                     ^^^^^^^^^^^^^^^\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/_ops.py\", line 482, in __getattr__\n",
      "    overload = OpOverload(self, op_, op_dk_, schema, tags)\n",
      "               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/_ops.py\", line 253, in __init__\n",
      "    self._schema.name.split(\"::\")[1], self._overloadname\n",
      "    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "KeyboardInterrupt\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 4\u001b[39m\n\u001b[32m      1\u001b[39m \u001b[38;5;66;03m# Run training\u001b[39;00m\n\u001b[32m      2\u001b[39m start_time = time.time()\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m loss_G_list, generator_grad_norm_list_mag_t_scheduler, maggn_gen_data = \u001b[43mmaggn\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtrain_MagGN_multiscale\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m      5\u001b[39m \u001b[43mn_epochs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mn_epochs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      6\u001b[39m \u001b[43mbatch_size\u001b[49m\u001b[43m=\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      7\u001b[39m \u001b[43mnormalize\u001b[49m\u001b[43m=\u001b[49m\u001b[43mnormalize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      8\u001b[39m \u001b[43mt_list\u001b[49m\u001b[43m=\u001b[49m\u001b[43mt_list\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      9\u001b[39m \u001b[43mepoch_list\u001b[49m\u001b[43m=\u001b[49m\u001b[43mepoch_list\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m     10\u001b[39m \u001b[43mloss_normalize\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m     11\u001b[39m \u001b[43mfeature_space\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m     12\u001b[39m \u001b[43mavg_pool_size\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m4\u001b[39;49m\n\u001b[32m     13\u001b[39m \u001b[43m)\u001b[49m\n\u001b[32m     14\u001b[39m end_time = time.time()\n\u001b[32m     15\u001b[39m elapsed_time = end_time - start_time\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/path/Magnitude-Distance/Cifar10_Experiment/MagGN_Cifar.py:132\u001b[39m, in \u001b[36mGAN.train_MagGN_multiscale\u001b[39m\u001b[34m(self, n_epochs, batch_size, normalize, t_list, epoch_list, loss_normalize, feature_space, avg_pool_size, hybrid_coeff)\u001b[39m\n\u001b[32m    129\u001b[39m     \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m: Adding magnitude with \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_t\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m t values \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msorted_t_list[:num_t]\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m for loss computation.\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m    130\u001b[39m     t_schedule_list = sorted_t_list[:num_t]\n\u001b[32m--> \u001b[39m\u001b[32m132\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i, (imgs, _) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m.dataloader):\n\u001b[32m    133\u001b[39m     real_imgs = imgs.to(\u001b[38;5;28mself\u001b[39m.device)\n\u001b[32m    134\u001b[39m     z = torch.randn(imgs.shape[\u001b[32m0\u001b[39m], \u001b[38;5;28mself\u001b[39m.latent_dim, \u001b[32m1\u001b[39m, \u001b[32m1\u001b[39m).to(\u001b[38;5;28mself\u001b[39m.device)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/utils/data/dataloader.py:441\u001b[39m, in \u001b[36mDataLoader.__iter__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m    439\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._iterator\n\u001b[32m    440\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m441\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_get_iterator\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/utils/data/dataloader.py:388\u001b[39m, in \u001b[36mDataLoader._get_iterator\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m    386\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m    387\u001b[39m     \u001b[38;5;28mself\u001b[39m.check_worker_number_rationality()\n\u001b[32m--> \u001b[39m\u001b[32m388\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_MultiProcessingDataLoaderIter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1042\u001b[39m, in \u001b[36m_MultiProcessingDataLoaderIter.__init__\u001b[39m\u001b[34m(self, loader)\u001b[39m\n\u001b[32m   1035\u001b[39m w.daemon = \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m   1036\u001b[39m \u001b[38;5;66;03m# NB: Process.start() actually take some time as it needs to\u001b[39;00m\n\u001b[32m   1037\u001b[39m \u001b[38;5;66;03m#     start a process and pass the arguments over via a pipe.\u001b[39;00m\n\u001b[32m   1038\u001b[39m \u001b[38;5;66;03m#     Therefore, we only add a worker to self._workers list after\u001b[39;00m\n\u001b[32m   1039\u001b[39m \u001b[38;5;66;03m#     it started, so that we do not call .join() if program dies\u001b[39;00m\n\u001b[32m   1040\u001b[39m \u001b[38;5;66;03m#     before it starts, and __del__ tries to join but will get:\u001b[39;00m\n\u001b[32m   1041\u001b[39m \u001b[38;5;66;03m#     AssertionError: can only join a started process.\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1042\u001b[39m \u001b[43mw\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstart\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1043\u001b[39m \u001b[38;5;28mself\u001b[39m._index_queues.append(index_queue)\n\u001b[32m   1044\u001b[39m \u001b[38;5;28mself\u001b[39m._workers.append(w)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/multiprocessing/process.py:121\u001b[39m, in \u001b[36mBaseProcess.start\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m    118\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _current_process._config.get(\u001b[33m'\u001b[39m\u001b[33mdaemon\u001b[39m\u001b[33m'\u001b[39m), \\\n\u001b[32m    119\u001b[39m        \u001b[33m'\u001b[39m\u001b[33mdaemonic processes are not allowed to have children\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m    120\u001b[39m _cleanup()\n\u001b[32m--> \u001b[39m\u001b[32m121\u001b[39m \u001b[38;5;28mself\u001b[39m._popen = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_Popen\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m    122\u001b[39m \u001b[38;5;28mself\u001b[39m._sentinel = \u001b[38;5;28mself\u001b[39m._popen.sentinel\n\u001b[32m    123\u001b[39m \u001b[38;5;66;03m# Avoid a refcycle if the target function holds an indirect\u001b[39;00m\n\u001b[32m    124\u001b[39m \u001b[38;5;66;03m# reference to the process object (see bpo-30775)\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/multiprocessing/context.py:224\u001b[39m, in \u001b[36mProcess._Popen\u001b[39m\u001b[34m(process_obj)\u001b[39m\n\u001b[32m    222\u001b[39m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[32m    223\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_Popen\u001b[39m(process_obj):\n\u001b[32m--> \u001b[39m\u001b[32m224\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_context\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_context\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mProcess\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_Popen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprocess_obj\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/multiprocessing/context.py:288\u001b[39m, in \u001b[36mSpawnProcess._Popen\u001b[39m\u001b[34m(process_obj)\u001b[39m\n\u001b[32m    285\u001b[39m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[32m    286\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_Popen\u001b[39m(process_obj):\n\u001b[32m    287\u001b[39m     \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mpopen_spawn_posix\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Popen\n\u001b[32m--> \u001b[39m\u001b[32m288\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mPopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprocess_obj\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/multiprocessing/popen_spawn_posix.py:32\u001b[39m, in \u001b[36mPopen.__init__\u001b[39m\u001b[34m(self, process_obj)\u001b[39m\n\u001b[32m     30\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, process_obj):\n\u001b[32m     31\u001b[39m     \u001b[38;5;28mself\u001b[39m._fds = []\n\u001b[32m---> \u001b[39m\u001b[32m32\u001b[39m     \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[34;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mprocess_obj\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/multiprocessing/popen_fork.py:19\u001b[39m, in \u001b[36mPopen.__init__\u001b[39m\u001b[34m(self, process_obj)\u001b[39m\n\u001b[32m     17\u001b[39m \u001b[38;5;28mself\u001b[39m.returncode = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m     18\u001b[39m \u001b[38;5;28mself\u001b[39m.finalizer = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_launch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprocess_obj\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/multiprocessing/popen_spawn_posix.py:62\u001b[39m, in \u001b[36mPopen._launch\u001b[39m\u001b[34m(self, process_obj)\u001b[39m\n\u001b[32m     60\u001b[39m     \u001b[38;5;28mself\u001b[39m.sentinel = parent_r\n\u001b[32m     61\u001b[39m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(parent_w, \u001b[33m'\u001b[39m\u001b[33mwb\u001b[39m\u001b[33m'\u001b[39m, closefd=\u001b[38;5;28;01mFalse\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[32m---> \u001b[39m\u001b[32m62\u001b[39m         f.write(fp.getbuffer())\n\u001b[32m     63\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m     64\u001b[39m     fds_to_close = []\n",
      "\u001b[31mKeyboardInterrupt\u001b[39m: "
     ]
    }
   ],
   "source": [
    "# Run training\n",
    "start_time = time.time()\n",
    "\n",
    "loss_G_list, generator_grad_norm_list_mag_t_scheduler, maggn_gen_data = maggn.train_MagGN_multiscale(\n",
    "n_epochs=n_epochs,\n",
    "batch_size=batch_size,\n",
    "normalize=normalize,\n",
    "t_list=t_list,\n",
    "epoch_list=epoch_list,\n",
    "loss_normalize=True,\n",
    "feature_space=False\n",
    ")\n",
    "end_time = time.time()\n",
    "elapsed_time = end_time - start_time\n",
    "print(f\"{name} Training time: {elapsed_time:.2f} seconds\", flush=True)\n",
    "\n",
    "# maggn.compute_magnitude_overlap(n_epochs, max_t = max_t, min_t = min_t, steps = steps, num_samples = num_samples, normalize = overlap_normalize)\n",
    "# # test_name = f'{name}_with_multiscale_t_{\"_\".join([str(t) for t in t_list])}'\n",
    "\n",
    "# # if normalize:\n",
    "# #     test_name = f'Normalized{test_name}'\n",
    "# # t_name = f'multiscale_t_{\"_\".join([str(t) for t in t_list])}'\n",
    "# test_name = maggn.model_name\n",
    "\n",
    "# # 1. Training Losses\n",
    "# plot_training_losses(loss_G_list, folder_name = folder_name, name=test_name, step_name=step_name)\n",
    "\n",
    "# # 2.a) Generator Gradient Norms\n",
    "# plot_generator_grad_norms(\n",
    "#     generator_grad_norm_list_mag_t_scheduler, folder_name, name=test_name, step_name=step_name)\n",
    "\n",
    "\n",
    "# # 2.b) Generator Gradient Norms different visualizations\n",
    "# plot_generator_grad_norms(\n",
    "#     generator_grad_norm_list_mag_t_scheduler, folder_name, name=test_name, visualization='log', step_name=step_name)\n",
    "\n",
    "  "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_mag",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
