{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "import os\n",
    "\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "from pathlib import Path\n",
    "import yaml\n",
    "# import dnnlib.util as du\n",
    "import json\n",
    "import pickle\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "from glob import glob\n",
    "import io\n",
    "import torch_utils\n",
    "import dnnlib\n",
    "from torch_utils import misc\n",
    "import torch.autograd.forward_ad as fwAD\n",
    "from functorch import jvp\n",
    "\n",
    "device = 'cuda'\n",
    "\n",
    "\n",
    "\n",
    "has_labels = False\n",
    "\n",
    "with open('/path/to/.cache/dnnlib/downloads/c3238fd63e57b9ea1562999e0c55f411_https___nvlabs-fi-cdn.nvidia.com_edm_pretrained_edm-afhqv2-64x64-uncond-vp.pkl', 'rb') as f:\n",
    "\n",
    "    data = pickle.load(f)\n",
    "\n",
    "net = data['ema'].to(device)\n",
    "\n",
    "batch_size = 10\n",
    "\n",
    "c = dnnlib.EasyDict()\n",
    "path_to_dataset = \"/path/to/datasets/afhqv2/afhqv2-64x64.zip\"\n",
    "# c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=path_to_dataset, use_labels=False, xflip=False, cache=True)\n",
    "c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=path_to_dataset, use_labels=has_labels, xflip=False, cache=True)\n",
    "\n",
    "c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=1, prefetch_factor=2)\n",
    "\n",
    "dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs)\n",
    "\n",
    "dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=0, num_replicas=1, seed=0)\n",
    "dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_size, **c.data_loader_kwargs))\n",
    "\n",
    "images, labels = next(dataset_iterator)\n",
    "images = images.to(device).to(torch.float32) / 127.5 - 1\n",
    "labels = labels.to(device)\n",
    "\n",
    "img_size = 64\n",
    "\n",
    "def display_image(image_pt, normalize=False, ax=None):\n",
    "    assert image_pt.shape == (3, img_size, img_size)\n",
    "    if normalize:\n",
    "        image_pt /= torch.max(torch.abs(image_pt))\n",
    "    image_np = ((image_pt.cpu().detach().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)\n",
    "    if ax is None:\n",
    "        plt.imshow(image_np)\n",
    "        plt.show()\n",
    "    else:\n",
    "        ax.imshow(image_np) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "sigmas = 0.1 * torch.ones((batch_size,), device=device)\n",
    "noisy_images = images + torch.randn_like(images) * sigmas[:, None, None, None]\n",
    "\n",
    "if has_labels:\n",
    "    cleaned_images = net(noisy_images, sigmas, labels)\n",
    "else:\n",
    "    cleaned_images = net(noisy_images, sigmas)\n",
    "\n",
    "display_image(images[0])\n",
    "display_image(noisy_images[0], normalize=True)\n",
    "display_image(cleaned_images[0])\n",
    "display_image(cleaned_images[0], normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_sigma = 0.002\n",
    "max_sigma = 80\n",
    "# num_timesteps = [30, 100, 300, 1000]\n",
    "# for num_timestep in num_timesteps:\n",
    "\n",
    "log_sigmas = np.linspace(np.log(min_sigma), np.log(max_sigma), 100)\n",
    "sigmas = np.exp(log_sigmas)\n",
    "# sigmas = np.linspace(min_sigma, max_sigma, 1000)\n",
    "\n",
    "num_repeats = 10\n",
    "score_diffs = np.zeros((num_repeats, len(sigmas)-1, batch_size))\n",
    "for repeat_idx in tqdm(range(num_repeats)):\n",
    "    images, labels = next(dataset_iterator)\n",
    "    images = images.to(device).to(torch.float32) / 127.5 - 1\n",
    "    labels = labels.to(device)\n",
    "    # log_sigmas = np.arange(np.log(min_sigma), np.log(max_sigma), 0.1)\n",
    "\n",
    "    for i in range(len(sigmas)-1):\n",
    "        left_sigma = sigmas[i]\n",
    "        right_sigma = sigmas[i+1]\n",
    "        \n",
    "        # sample at the higher noise level\n",
    "        right_noisy_images = images + torch.randn_like(images) * right_sigma\n",
    "\n",
    "        left_sigmas_pt = torch.ones((batch_size,), device=device) * left_sigma\n",
    "        right_sigmas_pt = torch.ones((batch_size,), device=device) * right_sigma\n",
    "\n",
    "        right_to_left_cleaned_images = net(right_noisy_images, left_sigmas_pt, labels)\n",
    "        right_to_right_cleaned_images = net(right_noisy_images, right_sigmas_pt, labels)\n",
    "\n",
    "        right_to_left_score = (right_to_left_cleaned_images - right_noisy_images) / (left_sigmas_pt[:, None, None, None]**2)\n",
    "        right_to_right_score = (right_to_right_cleaned_images - right_noisy_images) / (right_sigmas_pt[:, None, None, None]**2)\n",
    "\n",
    "        score_diffs[repeat_idx, i, :] = torch.sum( right_sigmas_pt[:, None, None, None]**2 * (right_to_right_score - right_to_left_score)**2, dim=(1,2,3)).mean().sqrt().cpu().detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_cum_sum = np.cumsum(np.mean(score_diffs, axis=0)[:, 0])\n",
    "plt.plot(sigmas[:-1], avg_cum_sum)\n",
    "plt.xscale('log')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import interpolate\n",
    "f = interpolate.PchipInterpolator(x=avg_cum_sum, y=np.log(sigmas[1:]))\n",
    "inv_f = interpolate.PchipInterpolator(x=np.log(sigmas[1:]), y=avg_cum_sum)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for num_steps in [40, 100]:\n",
    "\n",
    "    max_culm_sum = avg_cum_sum[-1]\n",
    "    linear_spacing = np.linspace(0, max_culm_sum, num_steps)\n",
    "\n",
    "    spaced_log_sigmas = f(linear_spacing)\n",
    "    print(np.exp(spaced_log_sigmas))\n",
    "\n",
    "    plt.plot(sigmas[1:], avg_cum_sum)\n",
    "    plt.scatter(np.exp(spaced_log_sigmas), linear_spacing)\n",
    "    plt.xscale('log')\n",
    "    plt.title(f'{num_steps}')\n",
    "    plt.show()\n",
    "\n",
    "    regularized_linear_spacing = np.flip(np.exp(spaced_log_sigmas))\n",
    "    regularized_linear_spacing[-1] = 0.002\n",
    "    regularized_linear_spacing[0] = 80\n",
    "    regularized_linear_spacing = np.concatenate([regularized_linear_spacing, [0.0]])\n",
    "    print(regularized_linear_spacing)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
