{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5080fbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import libraries\n",
    "import os, sys\n",
    "import h5py\n",
    "import zipfile\n",
    "\n",
    "import torch\n",
    "from torch.functional import F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch import nn\n",
    "from torchvision import transforms\n",
    "from torchvision import datasets\n",
    "\n",
    "from copy import deepcopy\n",
    "from tqdm.notebook import tqdm\n",
    "from matplotlib import pyplot as plt\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "import numpy as np\n",
    "import random\n",
    "import lpips\n",
    "from IPython.display import clear_output\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "from utils.utils import (downsample, upsample)\n",
    "from utils.tb_utils import prepare_imgs_for_plotting\n",
    "from utils.utils import unfreeze, freeze\n",
    "from utils.distributions import DatasetSampler\n",
    "from utils.fid_score import (get_generated_inception_stats, get_hr_inception_stats, \n",
    "                             calculate_frechet_distance)\n",
    "\n",
    "from models.edsr_G import EDSR\n",
    "from models.upsample_plus_unet import UNet\n",
    "\n",
    "from aim19_datasets import AugDataset, TestDataset\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da0aaccc",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 1\n",
    "NUM_WORKERS = 20\n",
    "SCALE_FACTOR = 4\n",
    "CROP_SIZE = None # 128\n",
    "DATASET =  'CelebA' #'AIM19'\n",
    "G_ARCH = 'UNet' # EDSR'\n",
    "UPSAMPLE = 'bilinear' # for UNet only\n",
    "BASE_FACTOR = 64"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b662617b",
   "metadata": {},
   "source": [
    "## Load pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f13cdd7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eebf1bf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "if G_ARCH == 'UNet':\n",
    "    G = UNet(3, 3, scale_factor=SCALE_FACTOR, base_factor=BASE_FACTOR, upsample=UPSAMPLE)\n",
    "elif G_ARCH == 'EDSR':\n",
    "    G = EDSR(scale_factor=SCALE_FACTOR, device='cuda').cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eae110f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "G.load_state_dict(torch.load('path_to_state_dict'))\n",
    "        \n",
    "G.cuda();\n",
    "freeze(G);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0a92a3a",
   "metadata": {},
   "source": [
    "## Load Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af3a4baf",
   "metadata": {},
   "outputs": [],
   "source": [
    "if DATASET == 'CelebA':\n",
    "    transform = transforms.Compose([\n",
    "            transforms.Resize((64, 64)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "        ])\n",
    "    dataset = datasets.ImageFolder('path_to_test_dataset', \n",
    "                                   transform=transform)\n",
    "elif DATASET == 'AIM19':\n",
    "    dataset = TestDataset(hr_dir='path_to_hr_test', \n",
    "                          lr_dir='path_to_lr_test',\n",
    "                          scale_factor=SCALE_FACTOR, crop_size=CROP_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66e6b240",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f7750cc",
   "metadata": {},
   "source": [
    "## LPIPS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f48cb06b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# IMPORTANT: Inputs must be in [-1, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30cc0a88",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn_alex = lpips.LPIPS(net='alex').cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4593fb7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(\"===> Calculate LPIPS.\")\n",
    "if DATASET == 'CelebA':\n",
    "    losses = np.zeros((len(dataset)))\n",
    "    for i, (X, _) in tqdm(enumerate(dataloader)):\n",
    "        X = X.cuda()\n",
    "        Y = downsample(X).cuda()\n",
    "        G_Y = G(Y)\n",
    "        X = torch.clamp(X, -1, 1)\n",
    "        G_Y = torch.clamp(G_Y, -1, 1).cuda()\n",
    "        \n",
    "        loss = loss_fn_alex(X, G_Y).squeeze()\n",
    "        losses[i] = loss.detach().cpu()\n",
    "        del X, Y, G_Y, loss\n",
    "        torch.cuda.empty_cache();\n",
    "    out = losses.mean()\n",
    "elif DATASET == 'AIM19':\n",
    "    losses = np.zeros((len(dataset)))\n",
    "    for i, (X, Y) in tqdm(enumerate(dataloader)):\n",
    "        X = X.cuda()\n",
    "        Y = Y.cuda()\n",
    "        G_Y = G(Y)\n",
    "        X = torch.clamp(X, -1, 1)\n",
    "        G_Y = torch.clamp(G_Y, -1, 1).cuda()\n",
    "        \n",
    "        loss = loss_fn_alex(X, G_Y).squeeze()\n",
    "        losses[i] = loss.item()\n",
    "        del X, Y, G_Y, loss\n",
    "        torch.cuda.empty_cache();\n",
    "    out = np.mean(losses)\n",
    "print('mean LPIPS = %f'%out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0abb5d5e",
   "metadata": {},
   "source": [
    "## FID"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "353c8319",
   "metadata": {},
   "source": [
    "For **CelebA**:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb67711d",
   "metadata": {},
   "outputs": [],
   "source": [
    "mu1, sigma1 = get_hr_inception_stats(verbose=True, batch_size=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55057a37",
   "metadata": {},
   "outputs": [],
   "source": [
    "mu2, sigma2 = get_generated_inception_stats(upsample, verbose=True, batch_size=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94b4e290",
   "metadata": {},
   "outputs": [],
   "source": [
    "calculate_frechet_distance(mu1, sigma1, mu2, sigma2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "347e23d5",
   "metadata": {},
   "source": [
    "For **AIM19**:\n",
    "- Datasets of random crops for FID calculation are stored as h5 files and prepared using utils/aim19_prepare_data.py file.\n",
    "- h5dataset from aim19_h5_datasets.py is used to extract images from h5 format.\n",
    "- h5dataset outputs LR in $[-1, 1]$ and HR in $[0, 1]$, channels first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1684680",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert DATASET == 'AIM19'\n",
    "from aim19_h5_datasets import h5dataset\n",
    "from utils.aim19_fid_score import (get_hr_inception_stats, get_generated_inception_stats, \n",
    "                                   calculate_frechet_distance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "255d0872",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Prepare datasets for test partition.')\n",
    "try:\n",
    "    stats = np.load('path_to_hr_test_inception_stats in .npz format')\n",
    "    mu, sigma = stats['mu'], stats['sigma']\n",
    "except:\n",
    "    d = h5dataset(partition='test', mode='hr')\n",
    "    mu, sigma = get_hr_inception_stats(dataset=d, batch_size=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30a37293",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert DATASET == 'AIM19'\n",
    "d = h5dataset(partition='test', mode='lr')\n",
    "m, s = get_generated_inception_stats(G=G, dataset=d, batch_size=50, verbose=True)\n",
    "fid = calculate_frechet_distance(m, s, mu, sigma)\n",
    "print('Test FID = %f'%fid)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07da11ce",
   "metadata": {},
   "source": [
    "## SSIM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18b9dcba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage.metrics import structural_similarity as compare_ssim"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39af4bb5",
   "metadata": {},
   "source": [
    "**IMPORTANT:** Inputs must be in [0, 255]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4332b717",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(\"===> Calculate SSIM.\")\n",
    "if DATASET == 'CelebA':\n",
    "    losses = np.zeros(len(dataset))\n",
    "    for i, (X, _) in tqdm(enumerate(dataloader)):\n",
    "        X = X.cuda()\n",
    "        Y = downsample(X).cuda()\n",
    "        G_Y = G(Y)\n",
    "\n",
    "        X = X[0].mul(0.5).add(0.5)\n",
    "        G_Y = G_Y[0].mul(0.5).add(0.5)\n",
    "        X = torch.clamp(X, 0, 1)\n",
    "        G_Y = torch.clamp(G_Y, 0, 1)\n",
    "        \n",
    "        X = (X.permute(1, 2, 0).detach().cpu().numpy() * 255.0).round().astype(np.uint8)\n",
    "        G_Y = (G_Y.permute(1, 2, 0).detach().cpu().numpy() * 255.0).round().astype(np.uint8)\n",
    "        loss = compare_ssim(X, G_Y, multichannel=True).squeeze()\n",
    "\n",
    "        losses[i] = loss.item()\n",
    "        del X, Y, G_Y, loss\n",
    "elif DATASET == 'AIM19':\n",
    "    losses = np.zeros(len(dataset))\n",
    "    for i, (X, Y) in tqdm(enumerate(dataloader)):\n",
    "        Y = Y.cuda()\n",
    "        G_Y = G(Y)\n",
    "\n",
    "        X = X[0].mul(0.5).add(0.5)\n",
    "        G_Y = G_Y[0].mul(0.5).add(0.5)\n",
    "        X = torch.clamp(X, 0, 1)\n",
    "        G_Y = torch.clamp(G_Y, 0, 1)\n",
    "        \n",
    "        X = (X.permute(1, 2, 0).detach().cpu().numpy() * 255.0).round().astype(np.uint8)\n",
    "        G_Y = (G_Y.permute(1, 2, 0).detach().cpu().numpy() * 255.0).round().astype(np.uint8)\n",
    "        loss = compare_ssim(X, G_Y, multichannel=True).squeeze()\n",
    "\n",
    "        losses[i] = loss.item()\n",
    "        del X, Y, G_Y, loss\n",
    "        torch.cuda.empty_cache();\n",
    "out = losses.mean()\n",
    "print('mean SSIM = %f'%out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ab2ea71",
   "metadata": {},
   "source": [
    "# PSNR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68a5dd2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from piq import psnr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aed9d73d",
   "metadata": {},
   "source": [
    "**IMPORTANT:** Inputs must be in [0, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "071f878c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"===> Calculate PSNR.\")\n",
    "if DATASET == 'CelebA':\n",
    "    losses = np.zeros((len(dataset)))\n",
    "    for i, (X, _) in tqdm(enumerate(dataloader)):\n",
    "        X = X.cuda()\n",
    "        Y = downsample(X).cuda()\n",
    "        G_Y = G(Y)\n",
    "\n",
    "        X = torch.clamp(X.mul(0.5).add(0.5), 0, 1)\n",
    "        G_Y = torch.clamp(G_Y.mul(0.5).add(0.5), 0, 1)\n",
    "\n",
    "        loss = psnr(X, G_Y, reduction='none').squeeze()\n",
    "        losses[i] = loss.detach().cpu()\n",
    "elif DATASET == 'AIM19':\n",
    "    losses = np.zeros((len(dataset)))\n",
    "    for i, (X, Y) in tqdm(enumerate(dataloader)):\n",
    "        X = X.cuda()\n",
    "        Y = Y.cuda()\n",
    "        G_Y = G(Y)\n",
    "\n",
    "        X = torch.clamp(X.mul(0.5).add(0.5), 0, 1)\n",
    "        G_Y = torch.clamp(G_Y.mul(0.5).add(0.5), 0, 1)\n",
    "        \n",
    "        loss = psnr(X, G_Y).squeeze()\n",
    "        losses[i] = loss.item()\n",
    "        del X, Y, G_Y, loss\n",
    "        torch.cuda.empty_cache();\n",
    "        \n",
    "out = losses.mean()\n",
    "print('mean PSNR = %f'%out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87cf8411",
   "metadata": {},
   "source": [
    "# AIM19 Color Palettes (visualization & variance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b3b1920",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cycle(iterable):\n",
    "    while True:\n",
    "        for x in iterable:\n",
    "            yield x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80691239",
   "metadata": {},
   "outputs": [],
   "source": [
    "HR_dataset = AugDataset(datadir='path_to_aim19_hr_test', crop_size=128, \n",
    "                        flips=True, rotations=True)\n",
    "LR_dataset = AugDataset(datadir='path_to_aim19_lr_test', crop_size=32, \n",
    "                        flips=True, rotations=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa409406",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_dataloader = DataLoader(HR_dataset, batch_size=100, num_workers=20, shuffle=False)\n",
    "Y_dataloader = DataLoader(LR_dataset, batch_size=100, num_workers=20, shuffle=False)\n",
    "X_iter = iter(cycle(X_dataloader))\n",
    "Y_iter = iter(cycle(Y_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0feed933",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_rgb_cloud(cloud, ax):\n",
    "    colors = np.clip(cloud, 0, 1)\n",
    "    ax.set_yticks([])\n",
    "    ax.set_xticks([])\n",
    "    ax.set_zticks([])\n",
    "    ax.scatter(cloud[:, 0], cloud[:, 1], cloud[:, 2], c=colors)\n",
    "    ax.set_xlabel('Red', labelpad=-10); ax.set_ylabel('Green', labelpad=-10); ax.set_zlabel('Blue', labelpad=-10);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a953ae96",
   "metadata": {},
   "outputs": [],
   "source": [
    "SIZE = 128*8\n",
    "s = 100\n",
    "pc_var_OTS = np.zeros((s))\n",
    "\n",
    "for k in tqdm(range(s)):\n",
    "\n",
    "    fig = plt.figure(figsize=(4, 4), dpi=100)\n",
    "    \n",
    "    Y = next(iter(Y_dataloader))\n",
    "\n",
    "    G = EDSR(scale_factor=SCALE_FACTOR, device='cuda').cuda()\n",
    "    G.cuda();\n",
    "    freeze(G);\n",
    "    G.load_state_dict(torch.load('path_to_state_dict'))\n",
    "\n",
    "    ax = fig.add_subplot(111, projection='3d')\n",
    "    for i in range(Y.shape[0] // 20):\n",
    "        Y_push = G(\n",
    "            torch.tensor(Y[i*20:(i+1)*20, :, :, :], device='cuda', dtype=torch.float32, requires_grad=True)\n",
    "        ).add(1).div(2).permute(0, 2, 3, 1).flatten(start_dim=0, end_dim=2)\n",
    "        if i==0:\n",
    "            Y_pushed = Y_push.detach().cpu()\n",
    "        else:\n",
    "            Y_pushed = torch.cat((Y_pushed, Y_push.detach().cpu()), dim=1)\n",
    "            del Y_push\n",
    "            torch.cuda.empty_cache()\n",
    "    Y_0 = np.random.choice(Y_pushed[:, 0].cpu().detach().numpy(), size=SIZE)\n",
    "    Y_1 = np.random.choice(Y_pushed[:, 1].cpu().detach().numpy(), size=SIZE)\n",
    "    Y_2 = np.random.choice(Y_pushed[:, 2].cpu().detach().numpy(), size=SIZE)\n",
    "    Y_pushed = np.stack((Y_0, Y_1, Y_2), axis=1)\n",
    "    pc_var_OTS[k] = np.sum(np.var(Y_pushed, axis=0))\n",
    "    plot_rgb_cloud(Y_pushed, ax)\n",
    "    ax.set_xlim(0, 1); ax.set_ylim(0, 1); ax.set_zlim(0, 1); ax.title.set_text('OTS (ours)')\n",
    "    del G\n",
    "    \n",
    "    clear_output(wait=True)\n",
    "    plt.show(); plt.close(fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ade0628",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Variance of OTS (ours) color palette = %.2f +- %.2f'%(pc_var_OTS.mean(), pc_var_OTS.std()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
