{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yksKKJyWuRlm"
      },
      "outputs": [],
      "source": [
        "!pip install faiss-gpu-cu11[fix-cuda]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kfQlxxojwvvz"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "# Disable Colab's cv2 import hook to avoid recursion error\n",
        "if 'google.colab' in sys.modules:\n",
        "    import os\n",
        "    os.environ['OPENCV_AVOID_COLORMAP'] = '1'\n",
        "    import importlib\n",
        "    import importlib.util\n",
        "    import sys\n",
        "    for hook in list(sys.meta_path):\n",
        "        if 'cv2' in str(hook).lower():\n",
        "            sys.meta_path.remove(hook)\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import faiss\n",
        "print(\"FAISS version:\", faiss.__version__)\n",
        "print(\"Number of GPUs detected:\", faiss.get_num_gpus())"
      ],
      "metadata": {
        "id": "wnlaM37PkG2A"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V7OMDVTduLBS"
      },
      "outputs": [],
      "source": [
        "!pip install \"private-evolution[image] @ git+https://github.com/microsoft/DPSDA.git\""
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!nvidia-smi -L"
      ],
      "metadata": {
        "id": "ZYVFKph9Ywo1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T7W9ZRtgvYQy"
      },
      "outputs": [],
      "source": [
        "# from pe.data.image import Cifar10\n",
        "from pe.logging import setup_logging\n",
        "#from pe.runner import PE\n",
        "from pe.population import Population\n",
        "#from pe.api.image import ImprovedDiffusion270M\n",
        "from pe.embedding.image import Inception\n",
        "from pe.histogram import NearestNeighbors\n",
        "from pe.callback import SaveCheckpoints\n",
        "from pe.callback import SampleImages\n",
        "from pe.callback import ComputeFID\n",
        "from pe.logger import ImageFile\n",
        "from pe.logger import CSVPrint\n",
        "from pe.logger import LogPrint\n",
        "\n",
        "import pandas as pd\n",
        "import os\n",
        "import numpy as np\n",
        "\n",
        "pd.options.mode.copy_on_write = True"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# PE CIFAR data class"
      ],
      "metadata": {
        "id": "87hmYptkj5_m"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "changes how we get the data"
      ],
      "metadata": {
        "id": "6YXF57-_xVJg"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qoS3r6GCxIo7"
      },
      "outputs": [],
      "source": [
        "#new class for cifar to get less samples\n",
        "import torch\n",
        "import torchvision\n",
        "import tempfile\n",
        "import pandas as pd\n",
        "\n",
        "from pe.data import Data\n",
        "from pe.constant.data import LABEL_ID_COLUMN_NAME\n",
        "from pe.constant.data import IMAGE_DATA_COLUMN_NAME\n",
        "\n",
        "CIFAR10_LABEL_NAMES = [\"plane\"]#,\n",
        "#     \"car\",\n",
        "#     \"bird\",\n",
        "#     \"cat\",\n",
        "#     \"deer\",\n",
        "#     \"dog\",\n",
        "#     \"frog\",\n",
        "#     \"horse\",\n",
        "#     \"ship\",\n",
        "#     \"truck\",\n",
        "# ]\n",
        "\n",
        "\n",
        "class Cifar10(Data):\n",
        "    \"\"\"The CIFAR10 dataset with num_per_class images per class.\"\"\"\n",
        "\n",
        "    def __init__(self, cifar10class = [\"plane\"], split=\"train\", num_per_class=50):\n",
        "        if split not in [\"train\", \"test\"]:\n",
        "            raise ValueError(f\"Invalid split: {split}\")\n",
        "        train = split == \"train\"\n",
        "\n",
        "        with tempfile.TemporaryDirectory() as tmp_dir:\n",
        "            dataset = torchvision.datasets.CIFAR10(root=tmp_dir, train=train, download=True)\n",
        "            all_data = []\n",
        "            all_targets = []\n",
        "\n",
        "            targets = torch.tensor(dataset.targets)\n",
        "            for class_idx in range(len(cifar10class)):\n",
        "                class_indices = torch.where(targets == class_idx)[0][:num_per_class]\n",
        "                all_data.append(dataset.data[class_indices])\n",
        "                all_targets.extend([class_idx] * len(class_indices))\n",
        "\n",
        "            final_data = torch.cat([torch.tensor(d) for d in all_data], dim=0).numpy()\n",
        "\n",
        "        data_frame = pd.DataFrame({\n",
        "            IMAGE_DATA_COLUMN_NAME: list(final_data),\n",
        "            LABEL_ID_COLUMN_NAME: all_targets,\n",
        "        })\n",
        "\n",
        "        metadata = {\"label_info\": [{\"name\": n} for n in cifar10class]}\n",
        "        super().__init__(data_frame=data_frame, metadata=metadata)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# PE variations"
      ],
      "metadata": {
        "id": "qnhRgHZn0Y35"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Instead of fixing the amount of variation degree in an iteration, we'll let the variation_degrees be a list of lists, where there are many variation degrees on each iteration."
      ],
      "metadata": {
        "id": "QrxSXcNrkSu-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import tempfile\n",
        "import os\n",
        "\n",
        "from pe.api import API\n",
        "from pe.logging import execution_logger\n",
        "from pe.data import Data\n",
        "from pe.constant.data import IMAGE_DATA_COLUMN_NAME\n",
        "from pe.constant.data import IMAGE_MODEL_LABEL_COLUMN_NAME\n",
        "from pe.constant.data import LABEL_ID_COLUMN_NAME\n",
        "from pe.constant.data import PARENT_SYN_DATA_INDEX_COLUMN_NAME\n",
        "from pe.api.util import ConstantList\n",
        "from pe.util import download\n",
        "\n",
        "from improved_diffusion.script_util import NUM_CLASSES\n",
        "from pe.api.image.improved_diffusion_lib.unet import create_model\n",
        "from pe.api.image.improved_diffusion_lib.gaussian_diffusion import create_gaussian_diffusion\n",
        "\n",
        "\n",
        "class ImprovedDiffusion(API):\n",
        "    \"\"\"The image API that utilizes improved diffusion models from https://arxiv.org/abs/2102.09672.\"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        variation_degrees,\n",
        "        model_path,\n",
        "        model_image_size=64,\n",
        "        num_channels=192,\n",
        "        num_res_blocks=3,\n",
        "        learn_sigma=True,\n",
        "        class_cond=True,\n",
        "        use_checkpoint=False,\n",
        "        attention_resolutions=\"16,8\",\n",
        "        num_heads=4,\n",
        "        num_heads_upsample=-1,\n",
        "        use_scale_shift_norm=True,\n",
        "        dropout=0.0,\n",
        "        diffusion_steps=4000,\n",
        "        sigma_small=False,\n",
        "        noise_schedule=\"cosine\",\n",
        "        use_kl=False,\n",
        "        predict_xstart=False,\n",
        "        rescale_timesteps=False,\n",
        "        rescale_learned_sigmas=False,\n",
        "        timestep_respacing=\"100\",\n",
        "        batch_size=2000,\n",
        "        use_ddim=True,\n",
        "        clip_denoised=True,\n",
        "        use_data_parallel=True,\n",
        "    ):\n",
        "        \"\"\"Constructor.\n",
        "        See https://github.com/openai/improved-diffusion for the explanation of the parameters not listed here.\n",
        "\n",
        "        :param variation_degrees: The variation degrees utilized at each PE iteration. If a single int is provided, the\n",
        "            same variation degree will be used for all iterations.\n",
        "        :type variation_degrees: int or list[int]\n",
        "        :param model_path: The path of the model checkpoint\n",
        "        :type model_path: str\n",
        "        :param diffusion_steps: The total number of diffusion steps, defaults to 4000\n",
        "        :type diffusion_steps: int, optional\n",
        "        :param timestep_respacing: The step configurations for image generation utilized at each PE iteration. If a\n",
        "            single str is provided, the same step configuration will be used for all iterations. Defaults to \"100\"\n",
        "        :type timestep_respacing: str or list[str], optional\n",
        "        :param batch_size: The batch size for image generation, defaults to 2000\n",
        "        :type batch_size: int, optional\n",
        "        :param use_data_parallel: Whether to use data parallel during image generation, defaults to True\n",
        "        :type use_data_parallel: bool, optional\n",
        "        \"\"\"\n",
        "        super().__init__()\n",
        "        self._model = create_model(\n",
        "            image_size=model_image_size,\n",
        "            num_channels=num_channels,\n",
        "            num_res_blocks=num_res_blocks,\n",
        "            learn_sigma=learn_sigma,\n",
        "            class_cond=class_cond,\n",
        "            use_checkpoint=use_checkpoint,\n",
        "            attention_resolutions=attention_resolutions,\n",
        "            num_heads=num_heads,\n",
        "            num_heads_upsample=num_heads_upsample,\n",
        "            use_scale_shift_norm=use_scale_shift_norm,\n",
        "            dropout=dropout,\n",
        "        )\n",
        "        self._device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "        self._model.load_state_dict(torch.load(model_path, map_location=\"cpu\"))\n",
        "        self._model.to(self._device)\n",
        "        self._model.eval()\n",
        "        all_timestep_respacing = (\n",
        "            set(timestep_respacing) if isinstance(timestep_respacing, list) else {timestep_respacing}\n",
        "        )\n",
        "        self._timestep_respacing_to_diffusion = {}\n",
        "        self._timestep_respacing_to_sampler = {}\n",
        "        for sub_timestep_respacing in all_timestep_respacing:\n",
        "            self._timestep_respacing_to_diffusion[sub_timestep_respacing] = create_gaussian_diffusion(\n",
        "                steps=diffusion_steps,\n",
        "                learn_sigma=learn_sigma,\n",
        "                sigma_small=sigma_small,\n",
        "                noise_schedule=noise_schedule,\n",
        "                use_kl=use_kl,\n",
        "                predict_xstart=predict_xstart,\n",
        "                rescale_timesteps=rescale_timesteps,\n",
        "                rescale_learned_sigmas=rescale_learned_sigmas,\n",
        "                timestep_respacing=sub_timestep_respacing,\n",
        "            )\n",
        "            self._timestep_respacing_to_sampler[sub_timestep_respacing] = Sampler(\n",
        "                model=self._model, diffusion=self._timestep_respacing_to_diffusion[sub_timestep_respacing]\n",
        "            )\n",
        "            if use_data_parallel:\n",
        "                self._timestep_respacing_to_sampler[sub_timestep_respacing] = torch.nn.DataParallel(\n",
        "                    self._timestep_respacing_to_sampler[sub_timestep_respacing]\n",
        "                )\n",
        "        if isinstance(timestep_respacing, str):\n",
        "            self._timestep_respacing = ConstantList(timestep_respacing)\n",
        "        else:\n",
        "            self._timestep_respacing = timestep_respacing\n",
        "        self._batch_size = batch_size\n",
        "        self._use_ddim = use_ddim\n",
        "        self._image_size = model_image_size\n",
        "        self._clip_denoised = clip_denoised\n",
        "        self._class_cond = class_cond\n",
        "        if isinstance(variation_degrees, int):\n",
        "            self._variation_degrees = ConstantList(variation_degrees)\n",
        "        else:\n",
        "            self._variation_degrees = variation_degrees\n",
        "\n",
        "    def random_api(self, label_info, num_samples):\n",
        "        \"\"\"Generating random synthetic data.\n",
        "\n",
        "        :param label_info: The info of the label, not utilized in this API\n",
        "        :type label_info: omegaconf.dictconfig.DictConfig\n",
        "        :param num_samples: The number of random samples to generate\n",
        "        :type num_samples: int\n",
        "        :return: The data object of the generated synthetic data\n",
        "        :rtype: :py:class:`pe.data.Data`\n",
        "        \"\"\"\n",
        "        label_name = label_info.name\n",
        "        execution_logger.info(f\"RANDOM API: creating {num_samples} samples for label {label_name}\")\n",
        "        samples, labels = sample(\n",
        "            sampler=self._timestep_respacing_to_sampler[self._timestep_respacing[0]],\n",
        "            start_t=0,\n",
        "            num_samples=num_samples,\n",
        "            batch_size=self._batch_size,\n",
        "            use_ddim=self._use_ddim,\n",
        "            image_size=self._image_size,\n",
        "            clip_denoised=self._clip_denoised,\n",
        "            class_cond=self._class_cond,\n",
        "            device=self._device,\n",
        "        )\n",
        "        samples = _round_to_uint8((samples + 1.0) * 127.5)\n",
        "        samples = samples.transpose(0, 2, 3, 1)\n",
        "        torch.cuda.empty_cache()\n",
        "        data_frame = pd.DataFrame(\n",
        "            {\n",
        "                IMAGE_DATA_COLUMN_NAME: list(samples),\n",
        "                IMAGE_MODEL_LABEL_COLUMN_NAME: list(labels),\n",
        "                LABEL_ID_COLUMN_NAME: 0,\n",
        "            }\n",
        "        )\n",
        "        metadata = {\"label_info\": [label_info]}\n",
        "        execution_logger.info(f\"RANDOM API: finished creating {num_samples} samples for label {label_name}\")\n",
        "        return Data(data_frame=data_frame, metadata=metadata)\n",
        "\n",
        "    def variation_api(self, syn_data):\n",
        "        \"\"\"Generating variations of the synthetic data.\n",
        "\n",
        "        :param syn_data: The data object of the synthetic data\n",
        "        :type syn_data: :py:class:`pe.data.Data`\n",
        "        :return: The data object of the variation of the input synthetic data\n",
        "        :rtype: :py:class:`pe.data.Data`\n",
        "        \"\"\"\n",
        "        execution_logger.info(f\"VARIATION API: creating variations for {len(syn_data.data_frame)} samples\")\n",
        "        images = np.stack(syn_data.data_frame[IMAGE_DATA_COLUMN_NAME].values)\n",
        "        labels = np.array(syn_data.data_frame[IMAGE_MODEL_LABEL_COLUMN_NAME].values)\n",
        "        iteration = getattr(syn_data.metadata, \"iteration\", -1)\n",
        "        variation_degrees = self._variation_degrees[iteration + 1]\n",
        "        timestep_respacing = self._timestep_respacing[iteration + 1]\n",
        "\n",
        "        execution_logger.info(\n",
        "            f\"VARIATION API parameters: variation_degrees={variation_degrees}, timestep_respacing={timestep_respacing}, \"\n",
        "            f\"iteration={iteration}\"\n",
        "        )\n",
        "\n",
        "        final_vars = []\n",
        "        final_labels = []\n",
        "\n",
        "        images = images.astype(np.float32) / 127.5 - 1.0\n",
        "        images = images.transpose(0, 3, 1, 2)\n",
        "\n",
        "        for variation_degree in variation_degrees:\n",
        "\n",
        "          variations, _ = sample(\n",
        "              sampler=self._timestep_respacing_to_sampler[timestep_respacing],\n",
        "              start_t=variation_degree,\n",
        "              start_image=torch.Tensor(images).to(self._device),\n",
        "              labels=(None if not self._class_cond else torch.LongTensor(labels).to(self._device)),\n",
        "              num_samples= images.shape[0],\n",
        "              batch_size=self._batch_size,\n",
        "              use_ddim=self._use_ddim,\n",
        "              image_size=self._image_size,\n",
        "              clip_denoised=self._clip_denoised,\n",
        "              class_cond=self._class_cond,\n",
        "              device=self._device,\n",
        "          )\n",
        "          variations = _round_to_uint8((variations + 1.0) * 127.5)\n",
        "          variations = variations.transpose(0, 2, 3, 1)\n",
        "          torch.cuda.empty_cache()\n",
        "          final_vars+= list(variations)\n",
        "          final_labels += list(labels)\n",
        "\n",
        "\n",
        "        data_frame = pd.DataFrame(\n",
        "            {\n",
        "                IMAGE_DATA_COLUMN_NAME: final_vars,\n",
        "                IMAGE_MODEL_LABEL_COLUMN_NAME: final_labels,\n",
        "                LABEL_ID_COLUMN_NAME: np.tile(syn_data.data_frame[LABEL_ID_COLUMN_NAME].values,len(variation_degrees)),\n",
        "                PARENT_SYN_DATA_INDEX_COLUMN_NAME: np.tile(syn_data.data_frame[PARENT_SYN_DATA_INDEX_COLUMN_NAME].values,len(variation_degrees))\n",
        "            }\n",
        "        )\n",
        "        execution_logger.info(f\"VARIATION API: finished creating variations for {len(syn_data.data_frame)} samples\")\n",
        "        return Data(data_frame=data_frame, metadata=syn_data.metadata)\n",
        "\n",
        "\n",
        "def sample(\n",
        "    sampler,\n",
        "    num_samples,\n",
        "    start_t,\n",
        "    batch_size,\n",
        "    use_ddim,\n",
        "    image_size,\n",
        "    clip_denoised,\n",
        "    class_cond,\n",
        "    device,\n",
        "    start_image=None,\n",
        "    labels=None,\n",
        "):\n",
        "    all_images = []\n",
        "    all_labels = []\n",
        "    batch_cnt = 0\n",
        "    cnt = 0\n",
        "    while cnt < num_samples:\n",
        "        current_batch_size = (\n",
        "            batch_size if start_image is None else min(batch_size, start_image.shape[0] - batch_cnt * batch_size)\n",
        "        )\n",
        "        current_batch_size = min(num_samples - cnt, current_batch_size)\n",
        "        shape = (current_batch_size, 3, image_size, image_size)\n",
        "        model_kwargs = {}\n",
        "        if class_cond:\n",
        "            if labels is None:\n",
        "                classes = torch.randint(\n",
        "                    low=0,\n",
        "                    high=NUM_CLASSES,\n",
        "                    size=(current_batch_size,),\n",
        "                    device=device,\n",
        "                )\n",
        "            else:\n",
        "                classes = labels[batch_cnt * batch_size : (batch_cnt + 1) * batch_size]\n",
        "            model_kwargs[\"y\"] = classes\n",
        "        sample = sampler(\n",
        "            clip_denoised=clip_denoised,\n",
        "            model_kwargs=model_kwargs,\n",
        "            start_t=max(start_t, 0),\n",
        "            start_image=(\n",
        "                None if start_image is None else start_image[batch_cnt * batch_size : (batch_cnt + 1) * batch_size]\n",
        "            ),\n",
        "            use_ddim=use_ddim,\n",
        "            noise=torch.randn(*shape, device=device),\n",
        "            image_size=image_size,\n",
        "        )\n",
        "        batch_cnt += 1\n",
        "\n",
        "        all_images.append(sample.detach().cpu().numpy())\n",
        "\n",
        "        if class_cond:\n",
        "            all_labels.append(classes.detach().cpu().numpy())\n",
        "\n",
        "        cnt += sample.shape[0]\n",
        "        execution_logger.info(f\"Created {cnt} samples\")\n",
        "\n",
        "    all_images = np.concatenate(all_images, axis=0)\n",
        "    all_images = all_images[:num_samples]\n",
        "    if class_cond:\n",
        "        all_labels = np.concatenate(all_labels, axis=0)\n",
        "        all_labels = all_labels[:num_samples]\n",
        "    else:\n",
        "        all_labels = np.zeros(shape=(num_samples,))\n",
        "    return all_images, all_labels\n",
        "\n",
        "\n",
        "class Sampler(torch.nn.Module):\n",
        "    \"\"\"A wrapper around the model and diffusion modules that handles the entire\n",
        "    sampling process, so as to reduce the communiation rounds between GPUs when\n",
        "    using DataParallel.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, model, diffusion):\n",
        "        super().__init__()\n",
        "        self._model = model\n",
        "        self._diffusion = diffusion\n",
        "\n",
        "    def forward(\n",
        "        self,\n",
        "        clip_denoised,\n",
        "        model_kwargs,\n",
        "        start_t,\n",
        "        start_image,\n",
        "        use_ddim,\n",
        "        noise,\n",
        "        image_size,\n",
        "    ):\n",
        "        sample_fn = self._diffusion.p_sample_loop if not use_ddim else self._diffusion.ddim_sample_loop\n",
        "        sample = sample_fn(\n",
        "            self._model,\n",
        "            (noise.shape[0], 3, image_size, image_size),\n",
        "            clip_denoised=clip_denoised,\n",
        "            model_kwargs=model_kwargs,\n",
        "            start_t=max(start_t, 0),\n",
        "            start_image=start_image,\n",
        "            noise=noise,\n",
        "            device=noise.device,\n",
        "        )\n",
        "        return sample\n",
        "\n",
        "\n",
        "def _round_to_uint8(image):\n",
        "    return np.around(np.clip(image, a_min=0, a_max=255)).astype(np.uint8)\n",
        "\n",
        "\n",
        "class ImprovedDiffusion270M(ImprovedDiffusion):\n",
        "    #: The URL of the checkpoint path\n",
        "    CHECKPOINT_URL = \"https://openaipublic.blob.core.windows.net/diffusion/march-2021/imagenet64_cond_270M_250K.pt\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        variation_degrees,\n",
        "        model_path=None,\n",
        "        batch_size=2000,\n",
        "        timestep_respacing=\"100\",\n",
        "        use_data_parallel=True,\n",
        "    ):\n",
        "        \"\"\"The \"Class-conditional ImageNet-64 model (270M parameters, trained for 250K iterations)\" model from the\n",
        "        Improved Diffusion paper.\n",
        "\n",
        "        :param variation_degrees: The variation degrees utilized at each PE iteration\n",
        "        :type variation_degrees: list[int]\n",
        "        :param model_path: The path of the model checkpoint. If not provided, the checkpoint will be downloaded from\n",
        "            the `CHECKPOINT_URL`\n",
        "        :type model_path: str\n",
        "        :param batch_size: The batch size for image generation, defaults to 2000\n",
        "        :type batch_size: int, optional\n",
        "        :param timestep_respacing: The step configuration for image generation, defaults to \"100\"\n",
        "        :type timestep_respacing: str, optional\n",
        "        :param use_data_parallel: Whether to use data parallel during image generation, defaults to True\n",
        "        :type use_data_parallel: bool, optional\n",
        "        \"\"\"\n",
        "        if model_path is None or not os.path.exists(model_path):\n",
        "            model_path = self._download_checkpoint(model_path)\n",
        "        super().__init__(\n",
        "            variation_degrees=variation_degrees,\n",
        "            model_path=model_path,\n",
        "            model_image_size=64,\n",
        "            num_channels=192,\n",
        "            num_res_blocks=3,\n",
        "            learn_sigma=True,\n",
        "            class_cond=True,\n",
        "            use_checkpoint=False,\n",
        "            attention_resolutions=\"16,8\",\n",
        "            num_heads=4,\n",
        "            num_heads_upsample=-1,\n",
        "            use_scale_shift_norm=True,\n",
        "            dropout=0.0,\n",
        "            diffusion_steps=4000,\n",
        "            sigma_small=False,\n",
        "            noise_schedule=\"cosine\",\n",
        "            use_kl=False,\n",
        "            predict_xstart=False,\n",
        "            rescale_timesteps=False,\n",
        "            rescale_learned_sigmas=False,\n",
        "            timestep_respacing=timestep_respacing,\n",
        "            batch_size=batch_size,\n",
        "            use_ddim=True,\n",
        "            clip_denoised=True,\n",
        "            use_data_parallel=use_data_parallel,\n",
        "        )\n",
        "\n",
        "    def _download_checkpoint(self, model_path):\n",
        "        execution_logger.info(f\"Downloading ImprovedDiffusion checkpoint from {self.CHECKPOINT_URL}\")\n",
        "        if model_path is None:\n",
        "            model_path = tempfile.mktemp(suffix=\".pt\")\n",
        "        download(url=self.CHECKPOINT_URL, fname=model_path)\n",
        "        execution_logger.info(f\"Finished downloading ImprovedDiffusion checkpoint to {model_path}\")\n",
        "        return model_path\n"
      ],
      "metadata": {
        "id": "B3BLUEQoj89c"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# PE population function"
      ],
      "metadata": {
        "id": "JPw82Chk0TzW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "\n",
        "from pe.population import Population\n",
        "from pe.data import Data\n",
        "from pe.constant.data import DP_HISTOGRAM_COLUMN_NAME\n",
        "from pe.constant.data import POST_PROCESSED_DP_HISTOGRAM_COLUMN_NAME\n",
        "from pe.constant.data import PARENT_SYN_DATA_INDEX_COLUMN_NAME\n",
        "from pe.constant.data import FROM_LAST_FLAG_COLUMN_NAME\n",
        "from pe.constant.data import VARIATION_API_FOLD_ID_COLUMN_NAME\n",
        "from pe.logging import execution_logger\n",
        "\n",
        "\n",
        "class PEPopulation(Population):\n",
        "    \"\"\"The default population algorithm for Private Evolution.\"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        api,\n",
        "        histogram_threshold=None,\n",
        "        initial_variation_api_fold=0,\n",
        "        next_variation_api_fold=1,\n",
        "        keep_selected=False,\n",
        "        selection_mode=\"sample\",\n",
        "    ):\n",
        "        \"\"\"Constructor.\n",
        "\n",
        "        :param api: The API object that contains the random and variation APIs\n",
        "        :type api: :py:class:`pe.api.API`\n",
        "        :param histogram_threshold: The threshold for clipping the histogram. None means no clipping. Defaults to None\n",
        "        :type histogram_threshold: float, optional\n",
        "        :param initial_variation_api_fold: The number of variations to apply to the initial synthetic data, defaults to\n",
        "            0\n",
        "        :type initial_variation_api_fold: int, optional\n",
        "        :param next_variation_api_fold: The number of variations to apply to the next synthetic data, defaults to 1\n",
        "        :type next_variation_api_fold: int, optional\n",
        "        :param keep_selected: Whether to keep the selected data in the next synthetic data, defaults to False\n",
        "        :type keep_selected: bool, optional\n",
        "        :param selection_mode: The selection mode for selecting the data. It should be one of the following: \"sample\" (\n",
        "            random sampling proportional to the histogram), \"rank\" (select the top samples according to the histogram).\n",
        "            Defaults to \"sample\"\n",
        "        :type selection_mode: str, optional\n",
        "        :raises ValueError: If next_variation_api_fold is 0 and keep_selected is False\n",
        "        \"\"\"\n",
        "        super().__init__()\n",
        "        self._api = api\n",
        "        self._histogram_threshold = histogram_threshold\n",
        "        self._initial_variation_api_fold = initial_variation_api_fold\n",
        "        self._next_variation_api_fold = next_variation_api_fold\n",
        "        self._keep_selected = keep_selected\n",
        "        self._selection_mode = selection_mode\n",
        "        if self._next_variation_api_fold == 0 and not self._keep_selected:\n",
        "            raise ValueError(\n",
        "                \"next_variation_api_fold should be greater than 0 or keep_selected should be True. Otherwise, next \"\n",
        "                \"synthetic data will be empty.\"\n",
        "            )\n",
        "\n",
        "    def initial(self, label_info, num_samples):\n",
        "        \"\"\"Generate the initial synthetic data.\n",
        "\n",
        "        :param label_info: The label info\n",
        "        :type label_info: omegaconf.dictconfig.DictConfig\n",
        "        :param num_samples: The number of samples to generate\n",
        "        :type num_samples: int\n",
        "        :return: The initial synthetic data\n",
        "        :rtype: :py:class:`pe.data.Data`\n",
        "        \"\"\"\n",
        "        execution_logger.info(\n",
        "            f\"Population: generating {num_samples}*{self._initial_variation_api_fold + 1} initial \"\n",
        "            f\"synthetic samples for label {label_info.name}\"\n",
        "        )\n",
        "        random_data = self._api.random_api(label_info=label_info, num_samples=num_samples)\n",
        "        random_data.data_frame[VARIATION_API_FOLD_ID_COLUMN_NAME] = -1\n",
        "        variation_data_list = []\n",
        "        for variation_api_fold_id in range(self._initial_variation_api_fold):\n",
        "            variation_data = self._api.variation_api(syn_data=random_data)\n",
        "            variation_data.data_frame[VARIATION_API_FOLD_ID_COLUMN_NAME] = variation_api_fold_id\n",
        "            variation_data_list.append(variation_data)\n",
        "        data = Data.concat([random_data] + variation_data_list)\n",
        "        execution_logger.info(\n",
        "            f\"Population: finished generating {num_samples}*{self._initial_variation_api_fold + 1} initial \"\n",
        "            f\"synthetic samples for label {label_info.name}\"\n",
        "        )\n",
        "        return data\n",
        "\n",
        "    def _post_process_histogram(self, syn_data):\n",
        "        \"\"\"Post process the histogram of synthetic data (e.g., clipping).\n",
        "\n",
        "        :param syn_data: The synthetic data\n",
        "        :type syn_data: :py:class:`pe.data.Data`\n",
        "        :return: The synthetic data with post-processed histogram in the column\n",
        "            :py:const:`pe.constant.data.POST_PROCESSED_DP_HISTOGRAM_COLUMN_NAME`\n",
        "        :rtype: :py:class:`pe.data.Data`\n",
        "        \"\"\"\n",
        "        count = syn_data.data_frame[DP_HISTOGRAM_COLUMN_NAME].to_numpy()\n",
        "        if self._histogram_threshold is not None:\n",
        "            clipped_count = np.clip(count, a_min=self._histogram_threshold, a_max=None)\n",
        "            clipped_count -= self._histogram_threshold\n",
        "        else:\n",
        "            clipped_count = count\n",
        "        syn_data.data_frame[POST_PROCESSED_DP_HISTOGRAM_COLUMN_NAME] = clipped_count\n",
        "        return syn_data\n",
        "\n",
        "    def _select_data(self, syn_data, num_samples):\n",
        "        \"\"\"Select data from the synthetic data according to `selection_mode`.\n",
        "\n",
        "        :param syn_data: The synthetic data\n",
        "        :type syn_data: :py:class:`pe.data.Data`\n",
        "        :param num_samples: The number of samples to select\n",
        "        :type num_samples: int\n",
        "        :raises ValueError: If the selection mode is not supported\n",
        "        :return: The selected data\n",
        "        :rtype: :py:class:`pe.data.Data`\n",
        "        \"\"\"\n",
        "        if self._selection_mode == \"sample\":\n",
        "            count = syn_data.data_frame[POST_PROCESSED_DP_HISTOGRAM_COLUMN_NAME].to_numpy()\n",
        "            prob = count / count.sum()\n",
        "            indices = np.random.choice(len(syn_data.data_frame), size=num_samples, p=prob)\n",
        "            new_data_frame = syn_data.data_frame.iloc[indices]\n",
        "            new_data_frame[PARENT_SYN_DATA_INDEX_COLUMN_NAME] = syn_data.data_frame.index[indices]\n",
        "            return Data(data_frame=new_data_frame, metadata=syn_data.metadata)\n",
        "        elif self._selection_mode == \"rank\":\n",
        "            count = syn_data.data_frame[POST_PROCESSED_DP_HISTOGRAM_COLUMN_NAME].to_numpy()\n",
        "            indices = np.argsort(count)[::-1][:num_samples]\n",
        "            new_data_frame = syn_data.data_frame.iloc[indices]\n",
        "            new_data_frame[PARENT_SYN_DATA_INDEX_COLUMN_NAME] = syn_data.data_frame.index[indices]\n",
        "            return Data(data_frame=new_data_frame, metadata=syn_data.metadata)\n",
        "        else:\n",
        "            raise ValueError(f\"Selection mode {self._selection_mode} is not supported\")\n",
        "\n",
        "    def next(self, syn_data, num_samples):\n",
        "        \"\"\"Generate the next synthetic data.\n",
        "\n",
        "        :param syn_data: The synthetic data\n",
        "        :type syn_data: :py:class:`pe.data.Data`\n",
        "        :param num_samples: The number of samples to generate\n",
        "        :type num_samples: int\n",
        "        :return: The next synthetic data\n",
        "        :rtype: :py:class:`pe.data.Data`\n",
        "        \"\"\"\n",
        "        execution_logger.info(\n",
        "            f\"Population: generating {num_samples}*{self._next_variation_api_fold} \" \"next synthetic samples\"\n",
        "        )\n",
        "        syn_data = self._post_process_histogram(syn_data)\n",
        "        selected_data = self._select_data(syn_data, num_samples)\n",
        "        selected_data.data_frame[FROM_LAST_FLAG_COLUMN_NAME] = 1\n",
        "        selected_data.data_frame[VARIATION_API_FOLD_ID_COLUMN_NAME] = -1\n",
        "        variation_data_list = []\n",
        "        for variation_api_fold_id in range(self._next_variation_api_fold):\n",
        "            variation_data = self._api.variation_api(syn_data=selected_data)\n",
        "            # variation_data.data_frame[PARENT_SYN_DATA_INDEX_COLUMN_NAME] = selected_data.data_frame[\n",
        "            #     PARENT_SYN_DATA_INDEX_COLUMN_NAME\n",
        "            # ].values\n",
        "            variation_data.data_frame[FROM_LAST_FLAG_COLUMN_NAME] = 0\n",
        "            variation_data.data_frame[VARIATION_API_FOLD_ID_COLUMN_NAME] = variation_api_fold_id\n",
        "            variation_data_list.append(variation_data)\n",
        "        new_syn_data = Data.concat(variation_data_list + ([selected_data] if self._keep_selected else []))\n",
        "        execution_logger.info(\n",
        "            f\"Population: finished generating {num_samples}*{self._next_variation_api_fold} \" \"next synthetic samples\"\n",
        "        )\n",
        "        return new_syn_data"
      ],
      "metadata": {
        "id": "nUsVMGsCxgPe"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# PE run function"
      ],
      "metadata": {
        "id": "kgWnuu010RfJ"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "changes the PE run"
      ],
      "metadata": {
        "id": "oA_BnskBxelu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "\n",
        "from pe.dp import Gaussian\n",
        "from pe.data import Data\n",
        "from pe.constant.data import LABEL_ID_COLUMN_NAME\n",
        "from pe.logging import execution_logger\n",
        "\n",
        "\n",
        "class new_PE(object):\n",
        "    \"\"\"The class that runs the PE algorithm.\"\"\"\n",
        "\n",
        "    def __init__(self, priv_data, population, histogram, dp=None, loggers=[], callbacks=[]):\n",
        "        \"\"\"Constructor.\n",
        "\n",
        "        :param priv_data: The private data\n",
        "        :type priv_data: :py:class:`pe.data.Data`\n",
        "        :param population: The population algorithm\n",
        "        :type population: :py:class:`pe.population.Population`\n",
        "        :param histogram: The histogram algorithm\n",
        "        :type histogram: :py:class:`pe.histogram.Histogram`\n",
        "        :param dp: The DP algorithm, defaults to None, in which case the Gaussian mechanism\n",
        "            :py:class:`pe.dp.Gaussian` is used\n",
        "        :type dp: :py:class:`pe.dp.DP`, optional\n",
        "        :param loggers: The list of loggers, defaults to []\n",
        "        :type loggers: list[:py:class:`pe.logger.Logger`], optional\n",
        "        :param callbacks: The list of callbacks, defaults to []\n",
        "        :type callbacks: list[Callable or :py:class:`pe.callback.Callback`], optional\n",
        "        \"\"\"\n",
        "        super().__init__()\n",
        "        self._priv_data = priv_data\n",
        "        self._population = population\n",
        "        self._histogram = histogram\n",
        "        if dp is None:\n",
        "            dp = Gaussian()\n",
        "        self._dp = dp\n",
        "        self._loggers = loggers\n",
        "        self._callbacks = callbacks\n",
        "\n",
        "    def load_checkpoint(self, checkpoint_path):\n",
        "        \"\"\"Load a checkpoint.\n",
        "\n",
        "        :param checkpoint_path: The path to the checkpoint\n",
        "        :type checkpoint_path: str\n",
        "        :return: The synthetic data\n",
        "        :rtype: :py:class:`pe.data.Data` or None\n",
        "        \"\"\"\n",
        "        syn_data = Data()\n",
        "        if not syn_data.load_checkpoint(checkpoint_path):\n",
        "            return None\n",
        "        return syn_data\n",
        "\n",
        "    def _log_metrics(self, syn_data):\n",
        "        \"\"\"Log metrics.\n",
        "\n",
        "        :param syn_data: The synthetic data\n",
        "        :type syn_data: :py:class:`pe.data.Data`\n",
        "        \"\"\"\n",
        "        if not self._callbacks:\n",
        "            return\n",
        "        metric_items = []\n",
        "        for callback in self._callbacks:\n",
        "            metric_items.extend(callback(syn_data) or [])\n",
        "        for logger in self._loggers:\n",
        "            logger.log(iteration=syn_data.metadata.iteration, metric_items=metric_items)\n",
        "        for metric_item in metric_items:\n",
        "            metric_item.clean_up()\n",
        "\n",
        "    def _get_num_samples_per_label_id(self, num_samples, fraction_per_label_id):\n",
        "        \"\"\"Get the number of samples per label id given the total number of samples\n",
        "\n",
        "        :param num_samples: The total number of samples\n",
        "        :type num_samples: int\n",
        "        :param fraction_per_label_id: The fraction of samples for each label id. The fraction does not have to be\n",
        "            normalized. When it is None, the fraction is assumed to be the same as the fraction of label ids in the\n",
        "            private data. Defaults to None\n",
        "        :type fraction_per_label_id: list[float], optional\n",
        "        :raises ValueError: If the length of fraction_per_label_id is not the same as the number of labels\n",
        "        :raises ValueError: If the number of samples is so small that the number of samples for some label ids is zero\n",
        "        :return: The number of samples per label id\n",
        "        :rtype: np.ndarray\n",
        "        \"\"\"\n",
        "        if fraction_per_label_id is None:\n",
        "            execution_logger.warning(\n",
        "                \"fraction_per_label_id is not provided. Assuming the fraction of label ids in private data is public \"\n",
        "                \"information.\"\n",
        "            )\n",
        "            fraction_per_label_id = self._priv_data.data_frame[LABEL_ID_COLUMN_NAME].value_counts().to_dict()\n",
        "            fraction_per_label_id = [\n",
        "                0 if i not in fraction_per_label_id else fraction_per_label_id[i]\n",
        "                for i in range(len(self._priv_data.metadata.label_info))\n",
        "            ]\n",
        "        if len(fraction_per_label_id) != len(self._priv_data.metadata.label_info):\n",
        "            raise ValueError(\"fraction_per_label_id should have the same length as the number of labels.\")\n",
        "        fraction_per_label_id = np.array(fraction_per_label_id)\n",
        "        fraction_per_label_id = fraction_per_label_id / np.sum(fraction_per_label_id)\n",
        "\n",
        "        target_num_samples_per_label_id = fraction_per_label_id * num_samples\n",
        "        num_samples_per_label_id = np.floor(target_num_samples_per_label_id).astype(int)\n",
        "        num_samples_left = num_samples - np.sum(num_samples_per_label_id)\n",
        "        ids = np.argsort(target_num_samples_per_label_id - num_samples_per_label_id)[::-1]\n",
        "        num_samples_per_label_id[ids[:num_samples_left]] += 1\n",
        "        assert np.sum(num_samples_per_label_id) == num_samples\n",
        "        if np.any(num_samples_per_label_id == 0):\n",
        "            raise ValueError(\"num_samples is so small that the number of samples for some label ids is zero.\")\n",
        "        return num_samples_per_label_id\n",
        "\n",
        "    def _clean_up_loggers(self):\n",
        "        \"\"\"Clean up loggers.\"\"\"\n",
        "        for logger in self._loggers:\n",
        "            logger.clean_up()\n",
        "\n",
        "    def evaluate(self, checkpoint_path):\n",
        "        \"\"\"Evaluate the synthetic data.\n",
        "\n",
        "        :param checkpoint_path: The path to the checkpoint\n",
        "        :type checkpoint_path: str\n",
        "        \"\"\"\n",
        "        syn_data = self.load_checkpoint(checkpoint_path)\n",
        "        execution_logger.info(f\"Loaded checkpoint from {checkpoint_path}, iteration={syn_data.metadata.iteration}\")\n",
        "        self._log_metrics(syn_data)\n",
        "\n",
        "    def run(\n",
        "        self,\n",
        "        delta,\n",
        "        epsilon=None,\n",
        "        noise_multiplier=None,\n",
        "        checkpoint_path=None,\n",
        "        save_checkpoint=True,\n",
        "        fraction_per_label_id=None,\n",
        "    ):\n",
        "        \"\"\"Run the PE algorithm.\n",
        "        :param delta: The delta value of DP\n",
        "        :type delta: float\n",
        "        :param epsilon: The epsilon value of DP, defaults to None\n",
        "        :type epsilon: float, optional\n",
        "        :param noise_multiplier: The noise multiplier of the DP mechanism, defaults to None\n",
        "        :type noise_multiplier: float, optional\n",
        "        :param checkpoint_path: The path to load and save the checkpoint, defaults to None\n",
        "        :type checkpoint_path: str, optional\n",
        "        :param save_checkpoint: Whether to save the checkpoint, defaults to True\n",
        "        :type save_checkpoint: bool, optional\n",
        "        :param fraction_per_label_id: The fraction of samples for each label id. The fraction does not have to be\n",
        "            normalized. When it is None, the fraction is assumed to be the same as the fraction of label ids in the\n",
        "            private data. Defaults to None\n",
        "        :type fraction_per_label_id: list[float], optional\n",
        "        :return: The synthetic data\n",
        "        :rtype: :py:class:`pe.data.Data`\n",
        "        \"\"\"\n",
        "        try:\n",
        "            num_priv_samples = self._priv_data.data_frame.shape[0] #changed\n",
        "            T = int(np.log(num_priv_samples*epsilon)) + 1\n",
        "\n",
        "            # Set privacy budget.\n",
        "            self._dp.set_epsilon_and_delta(\n",
        "                num_iterations= T,\n",
        "                epsilon=epsilon,\n",
        "                delta=delta,\n",
        "                noise_multiplier = None)\n",
        "\n",
        "            print(self._dp._noise_multiplier)\n",
        "\n",
        "            num_syn_samples = np.max([10,int(num_priv_samples/self._dp._noise_multiplier*5**(1/2048 - 1))])#according to our theory\n",
        "\n",
        "            # Generate or load initial data.\n",
        "            if checkpoint_path is not None and (syn_data := self.load_checkpoint(checkpoint_path)):\n",
        "                execution_logger.info(\n",
        "                    f\"Loaded checkpoint from {checkpoint_path}, iteration={syn_data.metadata.iteration}\"\n",
        "                )\n",
        "            else:\n",
        "                # num_samples_per_label_id = self._get_num_samples_per_label_id(\n",
        "                #     num_samples=num_syn_samples,\n",
        "                #     fraction_per_label_id=fraction_per_label_id,\n",
        "                # )#changed\n",
        "                syn_data_list = []\n",
        "                for label_id, label_info in enumerate(self._priv_data.metadata.label_info):\n",
        "                    syn_data = self._population.initial(\n",
        "                        label_info=label_info,\n",
        "                        num_samples=num_syn_samples,#changed\n",
        "                    )\n",
        "                    syn_data.set_label_id(label_id)\n",
        "                    syn_data_list.append(syn_data)\n",
        "                syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata)\n",
        "                syn_data.data_frame.reset_index(drop=True, inplace=True)\n",
        "                syn_data.metadata.iteration = 0\n",
        "                self._log_metrics(syn_data)\n",
        "\n",
        "            #make the variations degrees for the last iteration very small and increase the amount of vars so the synthetic data is close to num_samples.\n",
        "            last_iter_variation_apis_needed = num_priv_samples//num_syn_samples\n",
        "            last_iter_variation_degrees = [2, 4, 6, 8] * (last_iter_variation_apis_needed // 4) + [2, 4, 6, 8][:last_iter_variation_apis_needed % 4]\n",
        "            self._population._api._variation_degrees = self._population._api._variation_degrees[:T] + [last_iter_variation_degrees]\n",
        "\n",
        "            # Run PE iterations.\n",
        "            for iteration in range(syn_data.metadata.iteration + 1, T+1):\n",
        "                print('number of variation degrees: ',self._population._api._variation_degrees[iteration])\n",
        "                execution_logger.info(f\"PE iteration {iteration}\")\n",
        "                # num_samples_per_label_id = self._get_num_samples_per_label_id(\n",
        "                #     num_samples=num_samples_schedule[iteration],\n",
        "                #     fraction_per_label_id=fraction_per_label_id,\n",
        "                # )\n",
        "                syn_data_list = []\n",
        "                priv_data_list = []\n",
        "\n",
        "                # Generate synthetic data for each label.\n",
        "                for label_id in range(len(self._priv_data.metadata.label_info)):\n",
        "                    execution_logger.info(f\"Label {label_id}\")\n",
        "                    sub_priv_data = self._priv_data.filter_label_id(label_id=label_id)\n",
        "                    sub_syn_data = syn_data.filter_label_id(label_id=label_id)\n",
        "\n",
        "                    # DP NN histogram.\n",
        "                    sub_priv_data, sub_syn_data = self._histogram.compute_histogram(\n",
        "                        priv_data=sub_priv_data, syn_data=sub_syn_data\n",
        "                    )\n",
        "                    priv_data_list.append(sub_priv_data)\n",
        "                    sub_syn_data = self._dp.add_noise(syn_data=sub_syn_data)\n",
        "\n",
        "                    # Generate variations ###this selects the next synthetic data according to noisy hist and also runs var_API\n",
        "                    sub_syn_data = self._population.next(\n",
        "                        syn_data=sub_syn_data,\n",
        "                        num_samples= num_syn_samples, #changed\n",
        "                    )\n",
        "                    sub_syn_data.set_label_id(label_id)\n",
        "                    syn_data_list.append(sub_syn_data)\n",
        "\n",
        "                syn_data = Data.concat(syn_data_list)\n",
        "                syn_data.data_frame.reset_index(drop=True, inplace=True)\n",
        "                syn_data.metadata.iteration = iteration\n",
        "\n",
        "                new_priv_data = Data.concat(priv_data_list)\n",
        "                self._priv_data = self._priv_data.merge(new_priv_data)\n",
        "\n",
        "                if save_checkpoint:\n",
        "                    syn_data.save_checkpoint(checkpoint_path)\n",
        "                self._log_metrics(syn_data)\n",
        "        finally:\n",
        "            self._clean_up_loggers()\n",
        "\n",
        "        return syn_data\n",
        "\n",
        "    def run_variable_T(\n",
        "        self,\n",
        "        num_iters,\n",
        "        delta,\n",
        "        epsilon=None,\n",
        "        noise_multiplier=None,\n",
        "        checkpoint_path=None,\n",
        "        save_checkpoint=True,\n",
        "        fraction_per_label_id=None,\n",
        "    ):\n",
        "        \"\"\"Run the PE algorithm.\n",
        "        :param delta: The delta value of DP\n",
        "        :type delta: float\n",
        "        :param epsilon: The epsilon value of DP, defaults to None\n",
        "        :type epsilon: float, optional\n",
        "        :param noise_multiplier: The noise multiplier of the DP mechanism, defaults to None\n",
        "        :type noise_multiplier: float, optional\n",
        "        :param checkpoint_path: The path to load and save the checkpoint, defaults to None\n",
        "        :type checkpoint_path: str, optional\n",
        "        :param save_checkpoint: Whether to save the checkpoint, defaults to True\n",
        "        :type save_checkpoint: bool, optional\n",
        "        :param fraction_per_label_id: The fraction of samples for each label id. The fraction does not have to be\n",
        "            normalized. When it is None, the fraction is assumed to be the same as the fraction of label ids in the\n",
        "            private data. Defaults to None\n",
        "        :type fraction_per_label_id: list[float], optional\n",
        "        :return: The synthetic data\n",
        "        :rtype: :py:class:`pe.data.Data`\n",
        "        \"\"\"\n",
        "        try:\n",
        "            num_priv_samples = self._priv_data.data_frame.shape[0] #changed\n",
        "            T = num_iters\n",
        "\n",
        "            # Set privacy budget.\n",
        "            self._dp.set_epsilon_and_delta(\n",
        "                num_iterations= T,#changed\n",
        "                epsilon=epsilon,\n",
        "                delta=delta,\n",
        "                noise_multiplier = None)\n",
        "\n",
        "            num_syn_samples = int(num_priv_samples/self._dp._noise_multiplier*5**(1/2048 - 1)) #max([10, int(num_priv_samples/20), 1 + int(/self._dp._noise_multiplier)])#changed\n",
        "\n",
        "            # Generate or load initial data.\n",
        "            if checkpoint_path is not None and (syn_data := self.load_checkpoint(checkpoint_path)):\n",
        "                execution_logger.info(\n",
        "                    f\"Loaded checkpoint from {checkpoint_path}, iteration={syn_data.metadata.iteration}\"\n",
        "                )\n",
        "            else:\n",
        "                # num_samples_per_label_id = self._get_num_samples_per_label_id(\n",
        "                #     num_samples=num_syn_samples,\n",
        "                #     fraction_per_label_id=fraction_per_label_id,\n",
        "                # )#changed\n",
        "                syn_data_list = []\n",
        "                for label_id, label_info in enumerate(self._priv_data.metadata.label_info):\n",
        "                    syn_data = self._population.initial(\n",
        "                        label_info=label_info,\n",
        "                        num_samples=num_syn_samples,#changed\n",
        "                    )\n",
        "                    syn_data.set_label_id(label_id)\n",
        "                    syn_data_list.append(syn_data)\n",
        "                syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata)\n",
        "                syn_data.data_frame.reset_index(drop=True, inplace=True)\n",
        "                syn_data.metadata.iteration = 0\n",
        "                self._log_metrics(syn_data)\n",
        "\n",
        "            #make the variations degrees for the last iteration very small and increase the amount of vars so the synthetic data is close to num_samples.\n",
        "            last_iter_variation_apis_needed = num_priv_samples//num_syn_samples\n",
        "            last_iter_variation_degrees = [2, 4, 6, 8] * (last_iter_variation_apis_needed // 4) + [2, 4, 6, 8][:last_iter_variation_apis_needed % 4]\n",
        "            self._population._api._variation_degrees = self._population._api._variation_degrees[:T] + [last_iter_variation_degrees]\n",
        "\n",
        "            # Run PE iterations.\n",
        "            for iteration in range(syn_data.metadata.iteration + 1, T+1):\n",
        "                print('number of variation degrees: ',self._population._api._variation_degrees[iteration])\n",
        "                execution_logger.info(f\"PE iteration {iteration}\")\n",
        "                # num_samples_per_label_id = self._get_num_samples_per_label_id(\n",
        "                #     num_samples=num_samples_schedule[iteration],\n",
        "                #     fraction_per_label_id=fraction_per_label_id,\n",
        "                # )\n",
        "                syn_data_list = []\n",
        "                priv_data_list = []\n",
        "\n",
        "                # Generate synthetic data for each label.\n",
        "                for label_id in range(len(self._priv_data.metadata.label_info)):\n",
        "                    execution_logger.info(f\"Label {label_id}\")\n",
        "                    sub_priv_data = self._priv_data.filter_label_id(label_id=label_id)\n",
        "                    sub_syn_data = syn_data.filter_label_id(label_id=label_id)\n",
        "\n",
        "                    # DP NN histogram.\n",
        "                    sub_priv_data, sub_syn_data = self._histogram.compute_histogram(\n",
        "                        priv_data=sub_priv_data, syn_data=sub_syn_data\n",
        "                    )\n",
        "                    priv_data_list.append(sub_priv_data)\n",
        "                    sub_syn_data = self._dp.add_noise(syn_data=sub_syn_data)\n",
        "\n",
        "                    # Generate variations ###this selects the next synthetic data according to noisy hist and also runs var_API\n",
        "                    sub_syn_data = self._population.next(\n",
        "                        syn_data=sub_syn_data,\n",
        "                        num_samples= num_syn_samples, #changed\n",
        "                    )\n",
        "                    sub_syn_data.set_label_id(label_id)\n",
        "                    syn_data_list.append(sub_syn_data)\n",
        "\n",
        "                syn_data = Data.concat(syn_data_list)\n",
        "                syn_data.data_frame.reset_index(drop=True, inplace=True)\n",
        "                syn_data.metadata.iteration = iteration\n",
        "\n",
        "                new_priv_data = Data.concat(priv_data_list)\n",
        "                self._priv_data = self._priv_data.merge(new_priv_data)\n",
        "\n",
        "                if save_checkpoint:\n",
        "                    syn_data.save_checkpoint(checkpoint_path)\n",
        "                self._log_metrics(syn_data)\n",
        "        finally:\n",
        "            self._clean_up_loggers()\n",
        "\n",
        "        return syn_data"
      ],
      "metadata": {
        "id": "dDsgVgAaVE48"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Experiment design"
      ],
      "metadata": {
        "id": "A3VPXbAO2hmm"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "fixed T:"
      ],
      "metadata": {
        "id": "zi2CA7vF2pVl"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class Experiment():\n",
        "\n",
        "  def __init__(self, num_experiments, num_samples, epsilon, delta, cifar10class):\n",
        "    self.num_experiments = num_experiments #should be a positive integer\n",
        "    self.num_samples = num_samples #should be a list of increasing positive integers\n",
        "    self.epsilon = epsilon #should be positive real\n",
        "    self.delta = delta #should be in (0,1)\n",
        "    self.cifar10class = cifar10class #should be e.g plane or something like that\n",
        "\n",
        "  def run(self):\n",
        "    for num_samples in self.num_samples:\n",
        "      api = ImprovedDiffusion270M(\n",
        "            variation_degrees= [list(range(2, 23, 5))]*20,\n",
        "            timestep_respacing=\"100\")\n",
        "      for num_exp in range(1, self.num_experiments + 1):\n",
        "        print(f'-------------------------------- Experiment {num_exp} with {num_samples} samples has started--------------------------------')\n",
        "        exp_folder = \"/content/drive/MyDrive/newPE_experiments/fixed_T/experiment_\" + str(num_exp) + \"eps_\" + str(self.epsilon)+\"_delta_\" + str(self.delta)+\"_class_\" + str(self.cifar10class) + \"_num_samples_\" + str(num_samples)\n",
        "        setup_logging(log_file=os.path.join(exp_folder, \"log.txt\"))\n",
        "        data = Cifar10(cifar10class = self.cifar10class, num_per_class = num_samples)\n",
        "        embedding = Inception(res=32, batch_size=100)\n",
        "        histogram = NearestNeighbors(\n",
        "            embedding=embedding,\n",
        "            mode=\"L2\",\n",
        "            lookahead_degree=0,\n",
        "            api=api,\n",
        "        )\n",
        "        population = PEPopulation(api=api, histogram_threshold=0)\n",
        "        save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, \"checkpoint\"))\n",
        "        sample_images = SampleImages()\n",
        "        compute_fid = ComputeFID(priv_data=data, embedding=embedding)\n",
        "        image_file = ImageFile(output_folder=exp_folder)\n",
        "        csv_print = CSVPrint(output_folder=exp_folder)\n",
        "        log_print = LogPrint()\n",
        "        pe_runner = new_PE(\n",
        "            priv_data=data,\n",
        "            population=population,\n",
        "            histogram=histogram,\n",
        "            callbacks=[save_checkpoints, sample_images, compute_fid],\n",
        "            loggers=[image_file, csv_print, log_print],\n",
        "        )\n",
        "        pe_runner.run(\n",
        "            delta=self.delta,\n",
        "            epsilon = self.epsilon,\n",
        "            checkpoint_path=os.path.join(exp_folder, \"checkpoint\"),\n",
        "        )"
      ],
      "metadata": {
        "id": "6ZvTuuDfiNRO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "vary T:"
      ],
      "metadata": {
        "id": "pYcdmtGs2rH9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class Experiment_variable_T():\n",
        "\n",
        "  def __init__(self, num_experiments, num_iters, epsilon, delta, cifar10class):\n",
        "    self.num_experiments = num_experiments #should be a positive integer\n",
        "    self.num_iters = num_iters #should be a list of increasing positive integers\n",
        "    self.epsilon = epsilon #should be positive real\n",
        "    self.delta = delta #should be in (0,1)\n",
        "    self.cifar10class = cifar10class #should be e.g plane or something like that\n",
        "\n",
        "  def run(self):\n",
        "    for num_iter in self.num_iters:\n",
        "      api = ImprovedDiffusion270M(\n",
        "            variation_degrees= [list(range(2, 23, 5))]*20,\n",
        "            timestep_respacing=\"100\")\n",
        "      for num_exp in range(1, self.num_experiments + 1):\n",
        "        print(f'-------------------------------- Experiment {num_exp} with {num_iter} iterations has started--------------------------------')\n",
        "        exp_folder = \"/content/drive/MyDrive/newPE_experiments/vary_T/experiment_\" + str(num_exp) + \"eps_\" + str(self.epsilon)+\"_delta_\" + str(self.delta)+\"_class_\" + str(self.cifar10class) + \"_num_iters_\" + str(num_iter)\n",
        "        setup_logging(log_file=os.path.join(exp_folder, \"log.txt\"))\n",
        "        data = Cifar10(cifar10class = self.cifar10class, num_per_class = 300) #sets 300 samples\n",
        "        embedding = Inception(res=32, batch_size=100)\n",
        "        histogram = NearestNeighbors(\n",
        "            embedding=embedding,\n",
        "            mode=\"L2\",\n",
        "            lookahead_degree=0,\n",
        "            api=api,\n",
        "        )\n",
        "        population = PEPopulation(api=api, histogram_threshold=0)\n",
        "        save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, \"checkpoint\"))\n",
        "        sample_images = SampleImages()\n",
        "        compute_fid = ComputeFID(priv_data=data, embedding=embedding)\n",
        "        image_file = ImageFile(output_folder=exp_folder)\n",
        "        csv_print = CSVPrint(output_folder=exp_folder)\n",
        "        log_print = LogPrint()\n",
        "        pe_runner = new_PE(\n",
        "            priv_data=data,\n",
        "            population=population,\n",
        "            histogram=histogram,\n",
        "            callbacks=[save_checkpoints, sample_images, compute_fid],\n",
        "            loggers=[image_file, csv_print, log_print],\n",
        "        )\n",
        "        pe_runner.run_variable_T(\n",
        "            num_iters = num_iter,\n",
        "            delta=self.delta,\n",
        "            epsilon = self.epsilon,\n",
        "            checkpoint_path=os.path.join(exp_folder, \"checkpoint\"),\n",
        "        )"
      ],
      "metadata": {
        "id": "XJDhYW3JHprQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Experiments with class plane"
      ],
      "metadata": {
        "id": "BbpnNBg22TRr"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "fixed T"
      ],
      "metadata": {
        "id": "1ZWE7xyJiN6_"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "num_experiments = 3\n",
        "num_samples = [50,150,250,350,450,550,650]\n",
        "epsilon = 5\n",
        "delta = 10**-4\n",
        "cifar10class = [\"plane\"]\n",
        "experiment = Experiment(num_experiments, num_samples, epsilon, delta, cifar10class)\n",
        "experiment.run()"
      ],
      "metadata": {
        "id": "KLFw9qGljMdj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Experiment variable T"
      ],
      "metadata": {
        "id": "cuNzHTwY3BBS"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "num_experiments = 3\n",
        "num_iters = [4,8,12,16,20]\n",
        "epsilon = 5\n",
        "delta = 10**-4\n",
        "cifar10class = [\"plane\"]\n",
        "experiment = Experiment_variable_T(num_experiments, num_iters, epsilon, delta, cifar10class)\n",
        "experiment.run()"
      ],
      "metadata": {
        "id": "NtxfWOvRLyOe"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Experiments class dog"
      ],
      "metadata": {
        "id": "2kBXU59D25Ej"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "fixed T"
      ],
      "metadata": {
        "id": "dz1D_T3GH9G_"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "num_experiments = 3\n",
        "num_samples = [50,150,250,350,450,550,650]\n",
        "epsilon = 5\n",
        "delta = 10**-4\n",
        "cifar10class = [\"dog\"]\n",
        "experiment = Experiment(num_experiments, num_samples, epsilon, delta, cifar10class)\n",
        "experiment.run()"
      ],
      "metadata": {
        "id": "jLroiLO-wc2W"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "vary T"
      ],
      "metadata": {
        "id": "uzO-cI8VICF3"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "num_experiments = 3\n",
        "num_iters = [4,8,12,16,20]\n",
        "epsilon = 5\n",
        "delta = 10**-4\n",
        "cifar10class = [\"dog\"]\n",
        "experiment = Experiment_variable_T(num_experiments, num_iters, epsilon, delta, cifar10class)\n",
        "experiment.run()"
      ],
      "metadata": {
        "id": "M0WnL0ge1Vqd"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "A100",
      "machine_shape": "hm",
      "collapsed_sections": [
        "kgWnuu010RfJ"
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}