{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e067d815",
   "metadata": {},
   "outputs": [],
   "source": [
    "from MagGN_CelebA import *\n",
    "from MagGN_CelebA import GAN as MagGN_CelebAGAN\n",
    "from GANs_CelebA import *\n",
    "from GANs_CelebA import GAN as Standard_CelebAGAN\n",
    "from Inception_score import *\n",
    "from args_GAN_CelebA import get_parser\n",
    "from plot_gan_training import *\n",
    "import time\n",
    "import sys\n",
    "import os\n",
    "import torch\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76efa8d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = get_parser()\n",
    "opt = parser.parse_args(args=[])  \n",
    "\n",
    "img_size = opt.img_size\n",
    "channels = opt.channels\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "latent_dim = opt.latent_dim\n",
    "n_epochs = opt.n_epochs\n",
    "n_critic = opt.n_critic\n",
    "lambda_gp = opt.lambda_gp\n",
    "clip_value = opt.clip_value\n",
    "batch_size = opt.batch_size\n",
    "lr = opt.lr\n",
    "step = opt.step\n",
    "normalize = opt.normalize\n",
    "\n",
    "# Instantiate WGAN\n",
    "dataset_name='CELEBA'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06df53d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute metrics (you already have this)\n",
    "real_images = load_celeba_real_images(\n",
    "    num_images=10000,  # How many real images to use\n",
    "    img_size=64,       # Must match your generated images\n",
    "    batch_size=64,\n",
    "    device=device\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fcec277",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing CELEBA dataset\n",
      "Error preparing CelebA dataset: [Errno 45] Operation not supported: '/home/s2670758'\n",
      "Generator loaded from MagGN_multiscale/CELEBA_feature_space_NormMagGN_yogesh-t_multiscale-0.0004_0.01_0.08_0.3-epochs-1_81_201_281_with_normalized_loss_with_avgpool_8_hybridcoeff_0.2/400epochs/generator.pth.\n",
      "Loss and gradient norms loaded from MagGN_multiscale/CELEBA_feature_space_NormMagGN_yogesh-t_multiscale-0.0004_0.01_0.08_0.3-epochs-1_81_201_281_with_normalized_loss_with_avgpool_8_hybridcoeff_0.2/400epochs/loss_lists.h5.\n"
     ]
    }
   ],
   "source": [
    "# name='MagGN_improved'\n",
    "t_list = [0.01, 0.3, 2.0, 10.0]\n",
    "epoch_list = [1, 81, 201, 281]\n",
    "hybrid_coeff = 0\n",
    "name='MagGN_conv'\n",
    "maggn = MagGN_CelebAGAN(\n",
    "batch_size=batch_size,\n",
    "lr=lr,\n",
    "latent_dim=latent_dim,\n",
    "img_size=img_size,\n",
    "channels=channels,\n",
    "step=step,\n",
    "device=device,\n",
    "name=name,\n",
    "dataset_name=dataset_name\n",
    ")\n",
    "\n",
    "loss_G_list, generator_grad_norm_list_mag_t_scheduler, maggn_gen_data = maggn.train_MagGN_multiscale(\n",
    "n_epochs=n_epochs,\n",
    "batch_size=batch_size,\n",
    "normalize=normalize,\n",
    "t_list=t_list,\n",
    "epoch_list=epoch_list,\n",
    "loss_normalize=True,\n",
    "feature_space=True,\n",
    "avg_pool_size=8,\n",
    "hybrid_coeff =hybrid_coeff\n",
    ")\n",
    "\n",
    "print(f'MagGN t-list:{t_list}, avgpool:{8}, hybrid_coeff:{hybrid_coeff}')\n",
    "maggn_results = compute_inception_scores_CelebA(maggn_gen_data, real_images=real_images, device=device)\n",
    "print(f\"MagGN IS: {maggn_results['inception_score_mean']:.3f} ± {maggn_results['inception_score_std']:.3f}\")\n",
    "if 'fid' in maggn_results:\n",
    "    print(f\"MagGN FID: {maggn_results['fid']:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5f106ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# name='MagGN_improved'\n",
    "t_list = [0.001, 0.03, 0.2, 1.0]\n",
    "epoch_list = [1, 81, 201, 281]\n",
    "\n",
    "hybrid_coeff = 0.2\n",
    "name='MagGN_conv'\n",
    "maggn = MagGN_CelebAGAN(\n",
    "batch_size=batch_size,\n",
    "lr=lr,\n",
    "latent_dim=latent_dim,\n",
    "img_size=img_size,\n",
    "channels=channels,\n",
    "step=step,\n",
    "device=device,\n",
    "name=name,\n",
    "dataset_name=dataset_name\n",
    ")\n",
    "\n",
    "loss_G_list, generator_grad_norm_list_mag_t_scheduler, maggn_gen_data = maggn.train_MagGN_multiscale(\n",
    "n_epochs=n_epochs,\n",
    "batch_size=batch_size,\n",
    "normalize=normalize,\n",
    "t_list=t_list,\n",
    "epoch_list=epoch_list,\n",
    "loss_normalize=True,\n",
    "feature_space=True,\n",
    "avg_pool_size=8,\n",
    "hybrid_coeff =hybrid_coeff\n",
    ")\n",
    "\n",
    "print(f'MagGN t-list:{t_list}, avgpool:{8}, hybrid_coeff:{hybrid_coeff}')\n",
    "maggn_results = compute_inception_scores_CelebA(maggn_gen_data, real_images=real_images, device=device)\n",
    "print(f\"MagGN IS: {maggn_results['inception_score_mean']:.3f} ± {maggn_results['inception_score_std']:.3f}\")\n",
    "if 'fid' in maggn_results:\n",
    "    print(f\"MagGN FID: {maggn_results['fid']:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4306ddb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# name='MagGN_improved'\n",
    "t_list = [0.0003, 0.009, 0.06, 0.3]\n",
    "epoch_list = [1, 81, 201, 281]\n",
    "\n",
    "hybrid_coeff = 0.2\n",
    "name='MagGN_conv'\n",
    "maggn = MagGN_CelebAGAN(\n",
    "batch_size=batch_size,\n",
    "lr=lr,\n",
    "latent_dim=latent_dim,\n",
    "img_size=img_size,\n",
    "channels=channels,\n",
    "step=step,\n",
    "device=device,\n",
    "name=name,\n",
    "dataset_name=dataset_name\n",
    ")\n",
    "\n",
    "loss_G_list, generator_grad_norm_list_mag_t_scheduler, maggn_gen_data = maggn.train_MagGN_multiscale(\n",
    "n_epochs=n_epochs,\n",
    "batch_size=batch_size,\n",
    "normalize=normalize,\n",
    "t_list=t_list,\n",
    "epoch_list=epoch_list,\n",
    "loss_normalize=True,\n",
    "feature_space=True,\n",
    "avg_pool_size=8,\n",
    "hybrid_coeff =hybrid_coeff\n",
    ")\n",
    "\n",
    "print(f'MagGN t-list:{t_list}, avgpool:{8}, hybrid_coeff:{hybrid_coeff}')\n",
    "maggn_results = compute_inception_scores_CelebA(maggn_gen_data, real_images=real_images, device=device)\n",
    "print(f\"MagGN IS: {maggn_results['inception_score_mean']:.3f} ± {maggn_results['inception_score_std']:.3f}\")\n",
    "if 'fid' in maggn_results:\n",
    "    print(f\"MagGN FID: {maggn_results['fid']:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ec8e320",
   "metadata": {},
   "outputs": [],
   "source": [
    "# name='MagGN_improved'\n",
    "t_list = [0.0003, 0.009, 0.06, 0.3]\n",
    "epoch_list = [1, 81, 201, 281]\n",
    "\n",
    "hybrid_coeff = 0\n",
    "name='MagGN_conv'\n",
    "maggn = MagGN_CelebAGAN(\n",
    "batch_size=batch_size,\n",
    "lr=lr,\n",
    "latent_dim=latent_dim,\n",
    "img_size=img_size,\n",
    "channels=channels,\n",
    "step=step,\n",
    "device=device,\n",
    "name=name,\n",
    "dataset_name=dataset_name\n",
    ")\n",
    "\n",
    "loss_G_list, generator_grad_norm_list_mag_t_scheduler, maggn_gen_data = maggn.train_MagGN_multiscale(\n",
    "n_epochs=n_epochs,\n",
    "batch_size=batch_size,\n",
    "normalize=normalize,\n",
    "t_list=t_list,\n",
    "epoch_list=epoch_list,\n",
    "loss_normalize=True,\n",
    "feature_space=True,\n",
    "avg_pool_size=8,\n",
    "hybrid_coeff =hybrid_coeff\n",
    ")\n",
    "\n",
    "print(f'MagGN t-list:{t_list}, avgpool:{8}, hybrid_coeff:{hybrid_coeff}')\n",
    "maggn_results = compute_inception_scores_CelebA(maggn_gen_data, real_images=real_images, device=device)\n",
    "print(f\"MagGN IS: {maggn_results['inception_score_mean']:.3f} ± {maggn_results['inception_score_std']:.3f}\")\n",
    "if 'fid' in maggn_results:\n",
    "    print(f\"MagGN FID: {maggn_results['fid']:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "61e010a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77064344",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing CELEBA dataset\n",
      "Error preparing CelebA dataset: [Errno 45] Operation not supported: '/home/s2670758'\n",
      "Critic loaded from GAN/CELEBA_WGAN-GP_yogesh_tuned_parameteres/400epochs/critic.pth.\n",
      "Generator loaded from GAN/CELEBA_WGAN-GP_yogesh_tuned_parameteres/400epochs/generator.pth.\n",
      "Loss and gradient norms loaded from GAN/CELEBA_WGAN-GP_yogesh_tuned_parameteres/400epochs/loss_lists.h5.\n"
     ]
    }
   ],
   "source": [
    "name='WGAN-GP_conv_tuned_parameteres'\n",
    "\n",
    "wgan_gp = Standard_CelebAGAN(\n",
    "    batch_size=batch_size,\n",
    "    lr=lr,\n",
    "    latent_dim=latent_dim,\n",
    "    img_size=img_size,\n",
    "    channels=channels,\n",
    "    n_critic=n_critic,\n",
    "    step=step,\n",
    "    device=device,\n",
    "    name=name,\n",
    "    dataset_name=dataset_name\n",
    ")\n",
    "\n",
    "# Run training\n",
    "start_time = time.time()\n",
    "loss_C_list, loss_G_list, generator_grad_norm_list_wgan_gp, wgan_gp_gen_data = wgan_gp.train_WGAN_GP(\n",
    "    n_epochs=n_epochs,\n",
    "    n_critic=n_critic,\n",
    "    batch_size=batch_size,\n",
    "    lambda_gp=lambda_gp\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85c65afe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing CELEBA dataset\n",
      "Error preparing CelebA dataset: [Errno 45] Operation not supported: '/home/s2670758'\n",
      "Critic loaded from GAN/CELEBA_WGAN_yogesh_tuned_parameteres/400epochs/critic.pth.\n",
      "Generator loaded from GAN/CELEBA_WGAN_yogesh_tuned_parameteres/400epochs/generator.pth.\n",
      "Loss and gradient norms loaded from GAN/CELEBA_WGAN_yogesh_tuned_parameteres/400epochs/loss_lists.h5.\n"
     ]
    }
   ],
   "source": [
    "name='WGAN_conv_tuned_parameteres'\n",
    "\n",
    "wgan = Standard_CelebAGAN(\n",
    "    batch_size=batch_size,\n",
    "    lr=lr,\n",
    "    latent_dim=latent_dim,\n",
    "    img_size=img_size,\n",
    "    channels=channels,\n",
    "    n_critic=n_critic,\n",
    "    step=step,\n",
    "    device=device,\n",
    "    name=name,\n",
    "    dataset_name=dataset_name\n",
    ")\n",
    "\n",
    "# Run training\n",
    "start_time = time.time()\n",
    "loss_C_list, loss_G_list, generator_grad_norm_list_wgan, wgan_gen_data = wgan.train_WGAN(\n",
    "    n_epochs=n_epochs,\n",
    "    n_critic=n_critic,\n",
    "    batch_size=batch_size,\n",
    "    clip_value=clip_value\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4af20251",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "real_images = load_celeba_real_images(\n",
    "    num_images=10000,  # How many real images to use\n",
    "    img_size=64,       # Must match your generated images\n",
    "    batch_size=64,\n",
    "    device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "445ddb0e",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'list' object has no attribute 'shape'",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mAttributeError\u001b[39m                            Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m maggn_results = \u001b[43mcompute_inception_scores_CelebA\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m      2\u001b[39m \u001b[43m    \u001b[49m\u001b[43mmaggn_gen_data\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m      3\u001b[39m \u001b[43m    \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdevice\u001b[49m\n\u001b[32m      4\u001b[39m \u001b[43m)\u001b[49m\n\u001b[32m      5\u001b[39m wgan_gp_results = compute_inception_scores_CelebA(\n\u001b[32m      6\u001b[39m     wgan_gp_gen_data,\n\u001b[32m      7\u001b[39m     device=device\n\u001b[32m      8\u001b[39m )\n\u001b[32m      9\u001b[39m wgan_results = compute_inception_scores_CelebA(\n\u001b[32m     10\u001b[39m     wgan_gen_data,\n\u001b[32m     11\u001b[39m     device=device\n\u001b[32m     12\u001b[39m )\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/path/Magnitude-Distance/CelebA_Experiment/Inception_score.py:163\u001b[39m, in \u001b[36mcompute_inception_scores_CelebA\u001b[39m\u001b[34m(generated_images, real_images, device, batch_size)\u001b[39m\n\u001b[32m    160\u001b[39m     generated_images = torch.from_numpy(generated_images).float()\n\u001b[32m    162\u001b[39m \u001b[38;5;66;03m# Ensure 4D tensor\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m163\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[43mgenerated_images\u001b[49m\u001b[43m.\u001b[49m\u001b[43mshape\u001b[49m) == \u001b[32m5\u001b[39m:\n\u001b[32m    164\u001b[39m     generated_images = generated_images.reshape(-\u001b[32m1\u001b[39m, *generated_images.shape[\u001b[32m2\u001b[39m:])\n\u001b[32m    166\u001b[39m \u001b[38;5;66;03m# Load model\u001b[39;00m\n",
      "\u001b[31mAttributeError\u001b[39m: 'list' object has no attribute 'shape'"
     ]
    }
   ],
   "source": [
    "wgan_gp_results = compute_inception_scores_CelebA(wgan_gp_gen_data, real_images=real_images, device=device)\n",
    "wgan_results = compute_inception_scores_CelebA(wgan_gen_data, real_images=real_images, device=device)\n",
    "\n",
    "# Print results (use the _results variables, not _gen_data)\n",
    "print(f\"WGAN-GP IS: {wgan_gp_results['inception_score_mean']:.3f} ± {wgan_gp_results['inception_score_std']:.3f}\")\n",
    "print(f\"WGAN IS: {wgan_results['inception_score_mean']:.3f} ± {wgan_results['inception_score_std']:.3f}\")\n",
    "\n",
    "# If you computed FID:\n",
    "if 'fid' in maggn_results:\n",
    "    print(f\"WGAN-GP FID: {wgan_gp_results['fid']:.3f}\")\n",
    "    print(f\"WGAN FID: {wgan_results['fid']:.3f}\")\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_mag",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
