{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mRunning cells with 'base (Python 3.12.4)' requires the ipykernel package.\n",
      "\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
      "\u001b[1;31mCommand: 'conda install -n base ipykernel --update-deps --force-reinstall'"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/sgan/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'display'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_32806/4135931243.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/workspace/iclr2025allimportant/ours-25000pkl/mnist-ours-25000-fid8.74.pkl'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0mG\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'G_ema'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# 加载生成器模型\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'display'"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import display\n",
    "with open('/workspace/iclr2025allimportant/ours-25000pkl/mnist-ours-25000-fid8.74.pkl', 'rb') as f:\n",
    "    G = pickle.load(f)['G_ema'].cuda()  # 加载生成器模型\n",
    "\n",
    "z = torch.randn([1, G.z_dim]).cuda()    # 随机潜在向量\n",
    "\n",
    "c = torch.tensor([[0,1]], dtype=torch.int64).cuda()  # Shape: [1, 1]\n",
    "\n",
    "print(z.shape)\n",
    "print(c.shape)\n",
    "\n",
    "# 生成图像\n",
    "img = G(z, c)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 512])\n",
      "torch.Size([1, 2])\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAIw0lEQVR4nO3cv45cZxnA4Tn7x8maRE4UCEVuAKWkoEd0iCYdHTcA3AyCnpKGLi3QcgEIcQMpTKLYBLyxszuH7kcRRZw39qxnd5+nfvfTp9mZ+c1p3mVd13UHALvd7uR1XwCA4yEKAEQUAIgoABBRACCiAEBEAYCIAgA52zp4eno6Oni/348vA8B2FxcXo/lnz5793xlPCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAkGVd13XT4LIc+i4AHNCWr3tPCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoA5Ox1X4BXZBnOr5M/WIeHA7eVJwUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIjdR3fFeD2RfUZHbfJzbbj3arne/ger98m940kBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgCx5uJGDfYRnAzXC1zPxrlDhmsuJqsrltPZ4eu1tRi3nScFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACI3Uc3avtemJP97OThOHfJcN3Q+z9+b/Ps4798NrwMt50nBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxJqLG3Q2aPDV76anW3TBNo//vH11xfm7b4zO/urJ8+l1ODKeFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIHYf3aCT3bp9+IeD2WPzyfbRdz54ODr6ye7Z9uFldPRuGb7k618Hwz+anX0spruMlsGLvk4+D9wYTwoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoAJBlXddNC0iWZbhIBu64d9/77ubZzz/99IA3OaDpdrTrg9yCV2TL170nBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxJoL7qyHbz8czT/717MD3eQWOx3O7w9yC14Ray4AGBEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgDE7iNeuTd2DzbPPl9fHPAm98Pp8KN5fb599uRkePiL7fP71aKkm2b3EQAjogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQM5e9wV4PU4Hs9ebFqH8z/Od1RU36Y1ffH80//z3/9w8e306XEUx2YoxfF9xMzwpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBAlnVdN20gWZbJUhNu3Plw3noiNjj/5cVo/uS3l5tnn0+/UuxKemlbvu49KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAINZc3FdWBnAIHw9mfzY8e/oTdj+cvwesuQBgRBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABC7j+4ru4++5mJ5MJq/PH+xefbNizdHZ3/59MvR/K10Ppy/Osgt7hW7jwAYEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAMTuozvidDh//cfB8EfDww/o0QfvbZ59+slns8PHL+JwfuB8t/3z9tW2j/Dx8ZVy4+w+AmBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiDUXx2r6cg83HUyOv6VLFOY/efYHucW3cjL4D+2n/6Fj+YdO/z/Hcu9bzJoLAEZEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAOXvdF+AbHHjPy+j46bvkajh/KEe0y2hqvwz+Q+O9ZJYI8c08KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAINZc3BFv//T90fwXHz8+0E12u7M3LzbPXj2/PNg9Tnaz9Q/74fqHh4PfVM/+Nty58eFk+Jaurbil177rPCkAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAECWdV03bSBZltkeGV7Og2GvX6zD3ToH9OBs+0qtF9dXs8MnL8vwJZn+Qtrb3fN1/9g+uvxgdrSX++Vt+br3pABARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgGzfR8BLmxT4xa+/Mzz9i+H84by4Gqyu+M1bs8N/9e/Z/MDxLAo5Hm9/+O5o/j9//3zz7P50eJnr4TzfiicFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFADIsq7rumlwWQ59lzvvfLf9Nfxq278F5v4wmP358Oyzwe/MK9umbtqWr3tPCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiDUXN2r7a/joe2+NTn76+IvpZbgr/jQbv/jJ9vfh5elw3cr1bJybZc0FACOiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGA2H10T73z6NHm2SdPnh7wJrfY1fbRs/Ph0RfbZ5fL2dnr5KM8XH3EcbP7CIARUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQOw+OlLT13vZzeb3635y+MxgX85B31XT9+wB9/ysp7PDT663/17bL8OLb/vIcwfZfQTAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBrLgDuCWsuABgRBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoAJCzrYMXFxejgy8vL8eXAWC7B+cPXvmZnhQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACDLuq7r674EAMfBkwIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCAPkvKhMtbXmweqcAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Load the generator model\n",
    "with open('/workspace/iclr2025allimportant/ours-25000pkl/mnist-ours-25000-fid8.74.pkl', 'rb') as f:\n",
    "    G = pickle.load(f)['G_ema'].cuda()  # 加载生成器模型\n",
    "\n",
    "# Generate latent vector z and conditional input c\n",
    "z = torch.randn([1, G.z_dim]).cuda()    # 随机潜在向量\n",
    "c = torch.tensor([[1,0]], dtype=torch.int64).cuda()  # Shape: [1, 1]\n",
    "\n",
    "# Print shapes\n",
    "print(z.shape)\n",
    "print(c.shape)\n",
    "\n",
    "# Generate image\n",
    "img = G(z, c)\n",
    "\n",
    "# Convert tensor to image format\n",
    "img = img.detach().cpu().numpy()  # Move to CPU and convert to numpy\n",
    "img = (img + 1) / 2  # Normalize to [0, 1] if the output is in range [-1, 1]\n",
    "\n",
    "# If the image is in (batch, channels, height, width) format, remove the batch dimension\n",
    "img = np.transpose(img[0], (1, 2, 0))  # Convert from CHW to HWC format\n",
    "\n",
    "# Display the image\n",
    "plt.imshow(img)\n",
    "plt.axis('off')  # Hide axis\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image generation complete.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import numpy as np\n",
    "import os\n",
    "from PIL import Image\n",
    "\n",
    "# Load the generator model\n",
    "with open('/workspace/iclr2025allimportant/ours-25000pkl/mnist-ours-25000-fid8.74.pkl', 'rb') as f:\n",
    "    G = pickle.load(f)['G_ema'].cuda()  # 加载生成器模型\n",
    "\n",
    "# Directories to save the images\n",
    "red_dir = '/workspace/images/red'\n",
    "green_dir = '/workspace/images/green'\n",
    "\n",
    "# Create directories if they don't exist\n",
    "os.makedirs(red_dir, exist_ok=True)\n",
    "os.makedirs(green_dir, exist_ok=True)\n",
    "\n",
    "# Function to process and save the image\n",
    "def process_image(img, lo=-1, hi=1):\n",
    "    \"\"\"\n",
    "    Process the image using the specified lo and hi values.\n",
    "    \"\"\"\n",
    "    img = np.asarray(img, dtype=np.float32)\n",
    "    img = (img - lo) * (255 / (hi - lo))  # Scale to [0, 255]\n",
    "    img = np.rint(img).clip(0, 255).astype(np.uint8)  # Clip values and convert to uint8\n",
    "    return img\n",
    "\n",
    "# Function to generate and save images\n",
    "def generate_and_save_images(num_images, condition, save_dir, label):\n",
    "    for i in range(num_images):\n",
    "        # Generate latent vector z and conditional input c\n",
    "        z = torch.randn([1, G.z_dim]).cuda()    # 随机潜在向量\n",
    "        c = torch.tensor([condition], dtype=torch.int64).cuda()  # Conditional vector\n",
    "\n",
    "        # Generate image\n",
    "        img = G(z, c)\n",
    "\n",
    "        # Convert tensor to image format\n",
    "        img = img.detach().cpu().numpy()  # Move to CPU and convert to numpy\n",
    "        img = np.transpose(img[0], (1, 2, 0))  # Convert from CHW to HWC format\n",
    "\n",
    "        # Process the image using the lo and hi values\n",
    "        processed_img = process_image(img, lo=-1, hi=1)\n",
    "\n",
    "        # Convert to PIL image\n",
    "        pil_img = Image.fromarray(processed_img)\n",
    "\n",
    "        # Save the image\n",
    "        pil_img.save(os.path.join(save_dir, f\"{label}_{i+1:05d}.png\"))\n",
    "\n",
    "# Generate 60,000 images for each category\n",
    "num_images = 60000\n",
    "\n",
    "# Generate red images (condition: c = [[0,1]])\n",
    "generate_and_save_images(num_images, [0, 1], red_dir, \"red\")\n",
    "\n",
    "# Generate green images (condition: c = [[1,0]])\n",
    "generate_and_save_images(num_images, [1, 0], green_dir, \"green\")\n",
    "\n",
    "print(\"Image generation complete.\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sgan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
