{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "import pathlib\n",
    "import pytorch_lightning as pl\n",
    "import numpy as np\n",
    "from data.data_transforms import ImageDataTransform\n",
    "from data.image_data import CelebaDataset, ImageNetDataset\n",
    "from data.operators import GaussianBlurOperator, InpaintingOperator, NoiseScheduler\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib\n",
    "import torch\n",
    "from torch import nn\n",
    "import scipy.ndimage\n",
    "from skimage.metrics import peak_signal_noise_ratio as psnr\n",
    "from pl_modules.ncsn_module import NCSN_Module\n",
    "from data.metrics import psnr, mse, LPIPS\n",
    "from data.operators import create_operator, create_noise_schedule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example for ImageNet inpainting degradation scheduling\n",
    "operator_config = {\n",
    "        'type': 'inpainting', \n",
    "        'mask_type': 'gaussian',\n",
    "        'mask_min_std': 0.1,\n",
    "        'mask_max_std': 30,\n",
    "        'mask_pow': 4,\n",
    "        'scheduling': 'linear',\n",
    "    }\n",
    "\n",
    "fwd_operator = create_operator(operator_config)\n",
    "\n",
    "noise_config = {\n",
    "        'sigma_min': 0.01, \n",
    "        'sigma_max': 0.05\n",
    "    }\n",
    "        \n",
    "train_transform = ImageDataTransform(is_train=False, operator_config=operator_config, noise_config=noise_config, dt=0.001)\n",
    "\n",
    "dataset = ImageNetDataset(\n",
    "    root='/data/imagenet/', # Change ImageNet directory here \n",
    "    split='train',\n",
    "    transform=train_transform,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Functions\n",
    "lpips = LPIPS('vgg')\n",
    "device = 'cuda:0'\n",
    "lpips = lpips.loss.to(device)\n",
    "def edge_distance(dataset_clean, t_i, t_j, fwd_operator, metric):\n",
    "    assert t_i < t_j\n",
    "    if metric == 'psnr':\n",
    "        metric_fn = lambda x, y: -psnr(x, y)\n",
    "    elif metric == 'mse':\n",
    "        metric_fn = lambda x, y: mse(x, y)\n",
    "    elif metric == 'lpips':\n",
    "        metric_fn = lambda x, y: lpips(x, y)\n",
    "    else:\n",
    "        raise NotImplementedError(\"Other metrics not implemented.\")\n",
    "        \n",
    "    dist = 0\n",
    "    b = len(dataset_clean)\n",
    "    for d in dataset_clean:\n",
    "        dist += metric_fn(fwd_operator(d, t_i), fwd_operator(d, t_j))\n",
    "    dist /= b\n",
    "    return dist\n",
    "\n",
    "def sorted_insert(dicts, new_item, key):\n",
    "    for i, d in enumerate(dicts):\n",
    "        if new_item[key] > d[key]:\n",
    "            dicts.insert(i, new_item)\n",
    "            return dicts\n",
    "    dicts.append(new_item)\n",
    "    return dicts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "N = 2000 # Number of discretization points\n",
    "m = 100 # Number of desired steps\n",
    "num_data_samples = 30\n",
    "metric = 'lpips'\n",
    "\n",
    "with torch.no_grad():\n",
    "    sample_clean = [dataset[i]['clean'].unsqueeze(0).to(device) for i in range(num_data_samples)]\n",
    "    t = torch.linspace(0.0, 1.0, N)\n",
    "    t = t.to(device)\n",
    "    total_dist = edge_distance(sample_clean, t[0], t[N-1], fwd_operator, metric)\n",
    "    edges = [{'start': 0, 'end': N-1, 'd': total_dist}]\n",
    "    for i in range(m):\n",
    "        max_edge = edges.pop(0)\n",
    "        max_val =  max_edge['d']\n",
    "        print('Adding point ', i+1, ' over interval ', max_edge['start'], ' - ', max_edge['end'], ' with distance ', max_edge['d'])\n",
    "        found = False\n",
    "        for j in range(max_edge['start']+1, max_edge['end']):\n",
    "            dist_1 = edge_distance(sample_clean, t[max_edge['start']], t[j], fwd_operator, metric)\n",
    "            dist_2 = edge_distance(sample_clean, t[j], t[max_edge['end']], fwd_operator, metric)\n",
    "            new_edges_max = max(dist_1, dist_2)\n",
    "            if new_edges_max < max_val:\n",
    "                found = True\n",
    "                max_val = new_edges_max\n",
    "                new_edge_1 = {'start': max_edge['start'], 'end': j, 'd': dist_1}\n",
    "                new_edge_2 = {'start': j, 'end': max_edge['end'], 'd': dist_2}\n",
    "                print('New possible split found, because ', new_edges_max, ' < ', max_val)\n",
    "                print('New edges: {} - {}, {} - {}'.format(max_edge['start'], j, j, max_edge['end']))\n",
    "            else:\n",
    "                print('Not a good split, because ', new_edges_max, ' > ', max_val)\n",
    "                pass\n",
    "        if found:\n",
    "            edges = sorted_insert(edges, new_edge_1, 'd')\n",
    "            edges = sorted_insert(edges, new_edge_2, 'd')\n",
    "        else:\n",
    "            print('Stopping, no other split reduces max.')\n",
    "            break\n",
    "\n",
    "        t_schedule = [t[d['start']].cpu() for d in edges] + [1.0]\n",
    "        t_schedule.sort()\n",
    "        plt.plot(t_schedule)\n",
    "        plt.title(str(i+1)+' points added')\n",
    "        plt.show()\n",
    "        \n",
    "t_schedule = [t[d['end']].cpu() for d in edges] + [0.0]\n",
    "t_schedule.sort()\n",
    "times = np.linspace(0.0, 1.0, len(mask_factor))\n",
    "print(times, t_schedule) # Output: uniform times and their mapped value in the schedule."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ddpm",
   "language": "python",
   "name": "ddpm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
