{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e067d815",
   "metadata": {},
   "outputs": [],
   "source": [
    "from MagGN_Cifar import *\n",
    "from MagGN_Cifar import GAN as MagGN_CifarGAN\n",
    "from GANs_Cifar import *\n",
    "from GANs_Cifar import GAN as Standard_CifarGAN\n",
    "from Inception_score import *\n",
    "from args_GAN_Cifar import get_parser\n",
    "from plot_gan_training import *\n",
    "\n",
    "sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))\n",
    "from model_loader import ModelLoader\n",
    "\n",
    "import time\n",
    "import sys\n",
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76efa8d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = get_parser()\n",
    "opt = parser.parse_args(args=[])  \n",
    "\n",
    "img_size = opt.img_size\n",
    "channels = opt.channels\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "latent_dim = opt.latent_dim\n",
    "n_epochs = opt.n_epochs\n",
    "n_critic = opt.n_critic\n",
    "lambda_gp = opt.lambda_gp\n",
    "clip_value = opt.clip_value\n",
    "batch_size = opt.batch_size\n",
    "lr = opt.lr\n",
    "opt_betas = opt.opt_betas\n",
    "step = opt.step\n",
    "normalize = opt.normalize\n",
    "\n",
    "# Instantiate WGAN\n",
    "dataset_name='Cifar10'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fcec277",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing CELEBA dataset\n",
      "Error preparing CelebA dataset: [Errno 45] Operation not supported: '/home/s2670758'\n",
      "Generator loaded from MagGN_multiscale/CELEBA_feature_space_NormMagGN_conv-t_multiscale-0.0004_0.01_0.08_0.3-epochs-1_81_201_281_with_normalized_loss_with_avgpool_8_hybridcoeff_0.2/400epochs/generator.pth.\n",
      "Loss and gradient norms loaded from MagGN_multiscale/CELEBA_feature_space_NormMagGN_conv-t_multiscale-0.0004_0.01_0.08_0.3-epochs-1_81_201_281_with_normalized_loss_with_avgpool_8_hybridcoeff_0.2/400epochs/loss_lists.h5.\n"
     ]
    }
   ],
   "source": [
    "# Compute metrics (you already have this)\n",
    "real_images = load_cifar10_real_images(\n",
    "    num_images=3000,  # How many real images to use\n",
    "    img_size=img_size,       # Must match your generated images\n",
    "    batch_size=batch_size,\n",
    "    device=device,\n",
    "    class_label=3\n",
    ")\n",
    "\n",
    "\n",
    "name='WGAN-GP_conv'\n",
    "pretrained_generator_path = f'./GAN/Cifar10_WGAN-GP_conv_pretrained_with_freezing_Cifar10_NormMagGN_conv-t_multiscale-0.5_1.5_3.0_5.0-epochs-1_201_501_701_with_normalized_loss_1000epochs/500epochs/generator.pth'\n",
    "lr = 5e-5  # Half of original 1e-4\n",
    "\n",
    "wgan_gp_pretrained = Standard_CifarGAN(\n",
    "    batch_size=batch_size,\n",
    "    lr=lr,\n",
    "    opt_betas=opt_betas,\n",
    "    latent_dim=latent_dim,\n",
    "    img_size=img_size,\n",
    "    channels=channels,\n",
    "    n_critic=n_critic,\n",
    "    step=step,\n",
    "    device=device,\n",
    "    name=name,\n",
    "    dataset_name=dataset_name,\n",
    "    pretrained_generator_path = pretrained_generator_path\n",
    ")\n",
    "\n",
    "# wgan_gp_pretrained.model_name = 'Cifar10_WGAN-GP_conv_pretrained_with_freezing_Cifar10_NormMagGN_conv-t_multiscale-0.5_1.5_3.0_5.0-epochs-1_201_501_701_with_normalized_loss_1000epochs'\n",
    "wgan_gp_pretrained.load_pretrained_generator(pretrained_generator_path, freeze_generator=True)\n",
    "wgan_gp_pretrained.model_loader = ModelLoader(dataset_name, wgan_gp_pretrained.model_name, device, step)\n",
    "wgan_gp_pretrained.model_loader.set_paths(n_epochs)\n",
    "wgan_gp_pretrained_gen_data = wgan_gp_pretrained.model_loader.generated_images(wgan_gp_pretrained.G, wgan_gp_pretrained.latent_dim, wgan_gp_pretrained.device, epoch = 500, num_samples=3000, save=False, nrow=8, save_img=False)\n",
    "\n",
    "\n",
    "\n",
    "print(f'Pretrained with 0.5_1.5_3.0_5.0-epochs-1_201_501_701')\n",
    "results = compute_inception_scores_CIFAR10(\n",
    "        wgan_gp_pretrained_gen_data, \n",
    "        real_images=real_images,  # Or use 'auto' with class_label=3\n",
    "        device=device,\n",
    "        class_label=3  # Only used if real_images='auto'\n",
    "    )\n",
    "    \n",
    "print(f\"MagGN IS: {results['inception_score_mean']:.3f} ± {results['inception_score_std']:.3f}\")\n",
    "if 'fid' in results:\n",
    "    print(f\"MagGN FID: {results['fid']:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77064344",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing CELEBA dataset\n",
      "Error preparing CelebA dataset: [Errno 45] Operation not supported: '/home/s2670758'\n",
      "Critic loaded from GAN/CELEBA_WGAN-GP_conv_tuned_parameteres/400epochs/critic.pth.\n",
      "Generator loaded from GAN/CELEBA_WGAN-GP_conv_tuned_parameteres/400epochs/generator.pth.\n",
      "Loss and gradient norms loaded from GAN/CELEBA_WGAN-GP_conv_tuned_parameteres/400epochs/loss_lists.h5.\n"
     ]
    }
   ],
   "source": [
    "name='WGAN-GP_conv'\n",
    "\n",
    "wgan_gp = Standard_CifarGAN(\n",
    "    batch_size=batch_size,\n",
    "    lr=lr,\n",
    "    opt_betas=opt_betas,\n",
    "    latent_dim=latent_dim,\n",
    "    img_size=img_size,\n",
    "    channels=channels,\n",
    "    n_critic=n_critic,\n",
    "    step=step,\n",
    "    device=device,\n",
    "    name=name,\n",
    "    dataset_name=dataset_name\n",
    ")\n",
    "n_epochs = 1600\n",
    "# Run training\n",
    "start_time = time.time()\n",
    "loss_C_list, loss_G_list, generator_grad_norm_list_wgan_gp, wgan_gp_gen_data = wgan_gp.train_WGAN_GP(\n",
    "    n_epochs=n_epochs,\n",
    "    n_critic=n_critic,\n",
    "    batch_size=batch_size,\n",
    "    lambda_gp=lambda_gp\n",
    ")\n",
    "\n",
    "wgan_gp_results = compute_inception_scores_CIFAR10(\n",
    "        wgan_gp_gen_data, \n",
    "        real_images=real_images,  # Or use 'auto' with class_label=3\n",
    "        device=device,\n",
    "        class_label=3  # Only used if real_images='auto'\n",
    "    )\n",
    "\n",
    "# Print results (use the _results variables, not _gen_data)\n",
    "print(f\"WGAN-GP IS: {wgan_gp_results['inception_score_mean']:.3f} ± {wgan_gp_results['inception_score_std']:.3f}\")\n",
    "\n",
    "# If you computed FID:\n",
    "if 'fid' in wgan_gp_results:\n",
    "    print(f\"WGAN-GP FID: {wgan_gp_results['fid']:.3f}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_mag",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
