{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e067d815",
   "metadata": {},
   "outputs": [],
   "source": [
    "from GANs_CelebA import *\n",
    "from args_GAN_CelebA import get_parser\n",
    "from plot_gan_training import *\n",
    "import time\n",
    "import sys\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76efa8d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = get_parser()\n",
    "opt = parser.parse_args(args=[])  \n",
    "\n",
    "img_size = opt.img_size\n",
    "channels = opt.channels\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "latent_dim = opt.latent_dim\n",
    "n_critic = opt.n_critic\n",
    "n_epochs = opt.n_epochs\n",
    "batch_size = opt.batch_size\n",
    "# lr = opt.lr\n",
    "lr = 1e-4\n",
    "clip_value = opt.clip_value\n",
    "lambda_gp = opt.lambda_gp\n",
    "step = opt.step\n",
    "\n",
    "\n",
    "#Magnitude Overlap parameters\n",
    "max_t = 10\n",
    "min_t = 0\n",
    "steps = 100\n",
    "num_samples = 100\n",
    "overlap_normalize =  False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ce5766dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate WGAN\n",
    "dataset_name='CELEBA'\n",
    "# folder_name = f'GAN_{dataset_name}_results'\n",
    "folder_name = f'GAN_{dataset_name}_results/Paper_Results'\n",
    "os.makedirs(folder_name, exist_ok=True)\n",
    "epochs_to_plot = list(range(step, n_epochs + 1, step))  # e.g., [500, 1000, ..., 10000]\n",
    "# step_name = 'Epoch'\n",
    "step_name = 'Step'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fcec277",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing CELEBA dataset\n"
     ]
    },
    {
     "ename": "OSError",
     "evalue": "[Errno 45] Operation not supported: '/home/s2670758'",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mOSError\u001b[39m                                   Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m      1\u001b[39m name=\u001b[33m'\u001b[39m\u001b[33mWGAN_improved\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m wgan = \u001b[43mGAN\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m      4\u001b[39m \u001b[43m    \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m=\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      5\u001b[39m \u001b[43m    \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m=\u001b[49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      6\u001b[39m \u001b[43m    \u001b[49m\u001b[43mlatent_dim\u001b[49m\u001b[43m=\u001b[49m\u001b[43mlatent_dim\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      7\u001b[39m \u001b[43m    \u001b[49m\u001b[43mimg_size\u001b[49m\u001b[43m=\u001b[49m\u001b[43mimg_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      8\u001b[39m \u001b[43m    \u001b[49m\u001b[43mchannels\u001b[49m\u001b[43m=\u001b[49m\u001b[43mchannels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      9\u001b[39m \u001b[43m    \u001b[49m\u001b[43mn_critic\u001b[49m\u001b[43m=\u001b[49m\u001b[43mn_critic\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m     10\u001b[39m \u001b[43m    \u001b[49m\u001b[43mstep\u001b[49m\u001b[43m=\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m     11\u001b[39m \u001b[43m    \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m     12\u001b[39m \u001b[43m    \u001b[49m\u001b[43mname\u001b[49m\u001b[43m=\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m     13\u001b[39m \u001b[43m    \u001b[49m\u001b[43mdataset_name\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_name\u001b[49m\n\u001b[32m     14\u001b[39m \u001b[43m)\u001b[49m\n\u001b[32m     16\u001b[39m \u001b[38;5;66;03m# Run training\u001b[39;00m\n\u001b[32m     17\u001b[39m start_time = time.time()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/path/Magnitude-Distance/CelebA_Experiment/GANs_CelebA.py:42\u001b[39m, in \u001b[36mGAN.__init__\u001b[39m\u001b[34m(self, batch_size, sample_size, lr, weight_decay, beta1, beta2, n_cpu, latent_dim, img_size, channels, n_critic, step, device, name, dataset_name)\u001b[39m\n\u001b[32m     38\u001b[39m celeba_dir = os.environ.get(\u001b[33m\"\u001b[39m\u001b[33mDATA_ROOT\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33m/home/s2670758/celeba_data\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m     40\u001b[39m \u001b[38;5;66;03m# celeba_dir = os.path.join(os.path.dirname(__file__), '../celeba_data')\u001b[39;00m\n\u001b[32m     41\u001b[39m \u001b[38;5;66;03m# celeba_dir = os.path.abspath(celeba_dir)\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m42\u001b[39m \u001b[43mos\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmakedirs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mceleba_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexist_ok\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m     44\u001b[39m \u001b[38;5;66;03m# Check for images\u001b[39;00m\n\u001b[32m     45\u001b[39m img_dir = os.path.join(celeba_dir, \u001b[33m'\u001b[39m\u001b[33mceleba\u001b[39m\u001b[33m'\u001b[39m, \u001b[33m'\u001b[39m\u001b[33mimg_align_celeba\u001b[39m\u001b[33m'\u001b[39m)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m<frozen os>:215\u001b[39m, in \u001b[36mmakedirs\u001b[39m\u001b[34m(name, mode, exist_ok)\u001b[39m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m<frozen os>:225\u001b[39m, in \u001b[36mmakedirs\u001b[39m\u001b[34m(name, mode, exist_ok)\u001b[39m\n",
      "\u001b[31mOSError\u001b[39m: [Errno 45] Operation not supported: '/home/s2670758'"
     ]
    }
   ],
   "source": [
    "# name='WGAN_improved'\n",
    "# name='WGAN_conv'\n",
    "name='WGAN_conv_tuned_parameteres'\n",
    "\n",
    "\n",
    "wgan = GAN(\n",
    "    batch_size=batch_size,\n",
    "    lr=lr,\n",
    "    latent_dim=latent_dim,\n",
    "    img_size=img_size,\n",
    "    channels=channels,\n",
    "    n_critic=n_critic,\n",
    "    step=step,\n",
    "    device=device,\n",
    "    name=name,\n",
    "    dataset_name=dataset_name\n",
    ")\n",
    "\n",
    "# Run training\n",
    "start_time = time.time()\n",
    "loss_C_list, loss_G_list, generator_grad_norm_list_wgan, wgan_gen_data = wgan.train_WGAN(\n",
    "    n_epochs=n_epochs,\n",
    "    n_critic=n_critic,\n",
    "    batch_size=batch_size,\n",
    "    clip_value=clip_value\n",
    ")\n",
    "end_time = time.time()\n",
    "elapsed_time = end_time - start_time\n",
    "print(f\"{name} Training time: {elapsed_time:.2f} seconds\", flush=True)\n",
    "# wgan.compute_magnitude_overlap(n_epochs, max_t = max_t, min_t = min_t, steps = steps, num_samples = num_samples, normalize = overlap_normalize)\n",
    "\n",
    "# real_data = sample_real_data(data_dim = data_dim, batch_size = number_of_samples, mean = mean, std = std).to(device)\n",
    "test_name = f'{name}-test{n_epochs}epochs'\n",
    "pdf_path = f\"{folder_name}/{test_name}_plots.pdf\"\n",
    "\n",
    "# 1. Training Losses\n",
    "plot_training_losses(loss_G_list, loss_C_list, folder_name, pdf_path, name=test_name, step_name=step_name)\n",
    "\n",
    "# 2.a) Generator Gradient Norms\n",
    "plot_generator_grad_norms(\n",
    "    generator_grad_norm_list_wgan, folder_name, pdf_path, name=test_name, step_name=step_name)\n",
    "\n",
    "\n",
    "# 2.b) Generator Gradient Norms different visualizations\n",
    "plot_generator_grad_norms(\n",
    "    generator_grad_norm_list_wgan, folder_name, pdf_path, name=test_name, visualization='log', step_name=step_name)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ae0a37bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing CELEBA dataset\n",
      "Extracting CelebA dataset...\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m      1\u001b[39m name=\u001b[33m'\u001b[39m\u001b[33mWGAN-GP_improved\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m wgan_gp = \u001b[43mGAN\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m      4\u001b[39m \u001b[43m    \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m=\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      5\u001b[39m \u001b[43m    \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m=\u001b[49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      6\u001b[39m \u001b[43m    \u001b[49m\u001b[43mlatent_dim\u001b[49m\u001b[43m=\u001b[49m\u001b[43mlatent_dim\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      7\u001b[39m \u001b[43m    \u001b[49m\u001b[43mimg_size\u001b[49m\u001b[43m=\u001b[49m\u001b[43mimg_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      8\u001b[39m \u001b[43m    \u001b[49m\u001b[43mchannels\u001b[49m\u001b[43m=\u001b[49m\u001b[43mchannels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      9\u001b[39m \u001b[43m    \u001b[49m\u001b[43mn_critic\u001b[49m\u001b[43m=\u001b[49m\u001b[43mn_critic\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m     10\u001b[39m \u001b[43m    \u001b[49m\u001b[43mstep\u001b[49m\u001b[43m=\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m     11\u001b[39m \u001b[43m    \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m     12\u001b[39m \u001b[43m    \u001b[49m\u001b[43mname\u001b[49m\u001b[43m=\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m     13\u001b[39m \u001b[43m    \u001b[49m\u001b[43mdataset_name\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_name\u001b[49m\n\u001b[32m     14\u001b[39m \u001b[43m)\u001b[49m\n\u001b[32m     16\u001b[39m \u001b[38;5;66;03m# Run training\u001b[39;00m\n\u001b[32m     17\u001b[39m start_time = time.time()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/path/Magnitude-Distance/CelebA_Experiment/GANs_CelebA.py:45\u001b[39m, in \u001b[36mGAN.__init__\u001b[39m\u001b[34m(self, batch_size, sample_size, lr, weight_decay, beta1, beta2, n_cpu, latent_dim, img_size, channels, n_critic, step, device, name, dataset_name)\u001b[39m\n\u001b[32m     43\u001b[39m     zip_path = os.path.join(celeba_dir, \u001b[33m'\u001b[39m\u001b[33mceleba\u001b[39m\u001b[33m'\u001b[39m, \u001b[33m'\u001b[39m\u001b[33mimg_align_celeba.zip\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m     44\u001b[39m     \u001b[38;5;28;01mwith\u001b[39;00m zipfile.ZipFile(zip_path, \u001b[33m'\u001b[39m\u001b[33mr\u001b[39m\u001b[33m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m zip_ref:\n\u001b[32m---> \u001b[39m\u001b[32m45\u001b[39m         \u001b[43mzip_ref\u001b[49m\u001b[43m.\u001b[49m\u001b[43mextractall\u001b[49m\u001b[43m(\u001b[49m\u001b[43mos\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m.\u001b[49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mceleba_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mceleba\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     47\u001b[39m \u001b[38;5;66;03m# Check for required annotation files\u001b[39;00m\n\u001b[32m     48\u001b[39m anno_dir = os.path.join(celeba_dir, \u001b[33m'\u001b[39m\u001b[33mceleba\u001b[39m\u001b[33m'\u001b[39m)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/zipfile.py:1702\u001b[39m, in \u001b[36mZipFile.extractall\u001b[39m\u001b[34m(self, path, members, pwd)\u001b[39m\n\u001b[32m   1699\u001b[39m     path = os.fspath(path)\n\u001b[32m   1701\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m zipinfo \u001b[38;5;129;01min\u001b[39;00m members:\n\u001b[32m-> \u001b[39m\u001b[32m1702\u001b[39m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_extract_member\u001b[49m\u001b[43m(\u001b[49m\u001b[43mzipinfo\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpwd\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/zipfile.py:1757\u001b[39m, in \u001b[36mZipFile._extract_member\u001b[39m\u001b[34m(self, member, targetpath, pwd)\u001b[39m\n\u001b[32m   1753\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m targetpath\n\u001b[32m   1755\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m.open(member, pwd=pwd) \u001b[38;5;28;01mas\u001b[39;00m source, \\\n\u001b[32m   1756\u001b[39m      \u001b[38;5;28mopen\u001b[39m(targetpath, \u001b[33m\"\u001b[39m\u001b[33mwb\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m target:\n\u001b[32m-> \u001b[39m\u001b[32m1757\u001b[39m     \u001b[43mshutil\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcopyfileobj\u001b[49m\u001b[43m(\u001b[49m\u001b[43msource\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1759\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m targetpath\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/opt/anaconda3/envs/env_mag/lib/python3.11/shutil.py:200\u001b[39m, in \u001b[36mcopyfileobj\u001b[39m\u001b[34m(fsrc, fdst, length)\u001b[39m\n\u001b[32m    198\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m buf:\n\u001b[32m    199\u001b[39m     \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m200\u001b[39m fdst_write(buf)\n",
      "\u001b[31mKeyboardInterrupt\u001b[39m: "
     ]
    }
   ],
   "source": [
    "# name='WGAN-GP_improved'\n",
    "\n",
    "# wgan_gp = GAN(\n",
    "#     batch_size=batch_size,\n",
    "#     lr=lr,\n",
    "#     latent_dim=latent_dim,\n",
    "#     img_size=img_size,\n",
    "#     channels=channels,\n",
    "#     n_critic=n_critic,\n",
    "#     step=step,\n",
    "#     device=device,\n",
    "#     name=name,\n",
    "#     dataset_name=dataset_name\n",
    "# )\n",
    "\n",
    "# # Run training\n",
    "# start_time = time.time()\n",
    "# loss_C_list, loss_G_list, generator_grad_norm_list_wgan_gp, wgan_gp_gen_data = wgan_gp.train_WGAN_GP(\n",
    "#     n_epochs=n_epochs,\n",
    "#     n_critic=n_critic,\n",
    "#     batch_size=batch_size,\n",
    "#     lambda_gp=lambda_gp\n",
    "# )\n",
    "# end_time = time.time()\n",
    "# elapsed_time = end_time - start_time\n",
    "# print(f\"{name} Training time: {elapsed_time:.2f} seconds\", flush=True)\n",
    "# # real_data = sample_real_data(data_dim = data_dim, batch_size = number_of_samples, mean = mean, std = std).to(device)\n",
    "# test_name = f'{name}-test{n_epochs}epochs'\n",
    "# pdf_path = f\"{folder_name}/{test_name}_plots.pdf\"\n",
    "# # wgan_gp.compute_magnitude_overlap(n_epochs, max_t = max_t, min_t = min_t, steps = steps, num_samples = num_samples, normalize = overlap_normalize)\n",
    "\n",
    "\n",
    "# # 1. Training Losses\n",
    "# plot_training_losses(loss_G_list, loss_C_list, folder_name, pdf_path, name=test_name, step_name=step_name)\n",
    "\n",
    "# # 2.a) Generator Gradient Norms\n",
    "# plot_generator_grad_norms(\n",
    "#     generator_grad_norm_list_wgan_gp, folder_name, pdf_path, name=test_name, step_name=step_name)\n",
    "\n",
    "\n",
    "# # 2.b) Generator Gradient Norms different visualizations\n",
    "# plot_generator_grad_norms(\n",
    "#     generator_grad_norm_list_wgan_gp, folder_name, pdf_path, name=test_name, visualization='log', step_name=step_name)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_mag",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
