{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction\n",
    "\n",
    "In this notebook we perform the comparison experiments with a numpy implementation of HIO which can also be used to use the  Gerchberg-Saxton and Input-Output. In the follwing we compute all three algorithms for each dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "Loaded util functions\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from dataset import *\n",
    "from train_abs import *\n",
    "from test_abs import *\n",
    "from utils import *\n",
    "from skimage.transform import rotate\n",
    "from skimage.metrics import structural_similarity, mean_squared_error, peak_signal_noise_ratio\n",
    "from skimage.util import crop\n",
    "from skimage.registration import phase_cross_correlation\n",
    "from tqdm import tqdm\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded MNIST dataset: x_train(100000, 32, 32), x_valid(140000, 32, 32)\n",
      "Loaded EMNIST dataset: x_train(100000, 32, 32), x_valid(24800, 32, 32)\n",
      "Loaded FashionMNIST gray dataset: x_train(60000, 32, 32), x_valid(10000, 32, 32)\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "using Gray image\n",
      "Loaded CIFAR10 gray dataset: x_train(50000, 32, 32), x_valid(10000, 32, 32)\n",
      "Using downloaded and verified file: ./data/train_32x32.mat\n",
      "Using downloaded and verified file: ./data/test_32x32.mat\n",
      "using Gray image\n",
      "Loaded SVHN gray dataset: x_train(73257, 32, 32), x_valid(26032, 32, 32)\n"
     ]
    }
   ],
   "source": [
    "x_train_mnist, x_test_mnist = load_MNIST(32)\n",
    "x_train_emnist, x_test_emnist = load_EMNIST(32)\n",
    "x_train_fmnist, x_test_fmnist = load_FMNIST(32)\n",
    "x_train_cifar, x_test_cifar = load_CIFAR10(32, channel = -1)\n",
    "x_train_svhn, x_test_svhn = load_SVHN(32,channel = -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ATTENTION: In requires a lot of computational time to rerun this notebook, if the whole datasets are used"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_test_mnist = 10000 # as the authors use also only 10000\n",
    "n_test_emnist = len(x_test_emnist)\n",
    "n_test_fmnist = len(x_test_fmnist)\n",
    "n_test_cifar = len(x_test_cifar)\n",
    "n_test_svhn = len(x_test_svhn)\n",
    "\n",
    "x_test_mnist = x_test_mnist[:n_test_mnist]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import methods for experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import augmentation method\n",
    "np.random.seed(417)\n",
    "def image_magnitudes_oversample(image, pad):\n",
    "    image_padded = np.pad(image, pad, 'constant')\n",
    "    magnitudes_oversampled = np.abs(np.fft.fft2(image_padded))\n",
    "    mask = np.pad(np.ones_like(image), pad, 'constant')\n",
    "    return magnitudes_oversampled, mask\n",
    "\n",
    "# psnr without register\n",
    "def psnr_crop(img, gt, pad):\n",
    "    img = crop(img, pad)\n",
    "    img[img>1]=1\n",
    "    mini = min(np.min(img), np.min(gt))\n",
    "    maxi = max(np.max(img), np.max(gt))\n",
    "    dst = maxi - mini\n",
    "    psnr = peak_signal_noise_ratio(gt, img, data_range= dst)\n",
    "    return psnr\n",
    "\n",
    "# psnr with register\n",
    "def psnr_register_flip(img, gt, pad):\n",
    "    img = crop(img, pad)\n",
    "    img[img>1]=1\n",
    "    img = rotate(img, 180)\n",
    "    s,_,_ = phase_cross_correlation(gt, img)\n",
    "    shift = (int(s[0]) , int(s[1]))\n",
    "    img = np.roll(img, shift, axis=0)\n",
    "    mini = min(np.min(img), np.min(gt))\n",
    "    maxi = max(np.max(img), np.max(gt))\n",
    "    dst = maxi - mini\n",
    "    psnr = peak_signal_noise_ratio(gt, img, data_range= dst)\n",
    "    return psnr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import retrieval algorithm\n",
    "def fienup_phase_retrieval(mag, mask=None, beta=0.8, \n",
    "                           steps=200, mode='hybrid', verbose=True):\n",
    "    \"\"\"\n",
    "    Implementation of Fienup's phase-retrieval methods. This function\n",
    "    implements the input-output, the output-output and the hybrid method.\n",
    "    \n",
    "    Note: Mode 'output-output' and beta=1 results in \n",
    "    the Gerchberg-Saxton algorithm.\n",
    "    \n",
    "    Parameters:\n",
    "        mag: Measured magnitudes of Fourier transform\n",
    "        mask: Binary array indicating where the image should be\n",
    "              if padding is known\n",
    "        beta: Positive step size\n",
    "        steps: Number of iterations\n",
    "        mode: Which algorithm to use\n",
    "              (can be 'input-output', 'output-output' or 'hybrid')\n",
    "        verbose: If True, progress is shown\n",
    "    \n",
    "    Returns:\n",
    "        x: Reconstructed image\n",
    "    \n",
    "    Author: Tobias Uelwer\n",
    "    Date: 30.12.2018\n",
    "    \n",
    "    References:\n",
    "    [1] E. Osherovich, Numerical methods for phase retrieval, 2012,\n",
    "        https://arxiv.org/abs/1203.4756\n",
    "    [2] J. R. Fienup, Phase retrieval algorithms: a comparison, 1982,\n",
    "        https://www.osapublishing.org/ao/abstract.cfm?uri=ao-21-15-2758\n",
    "    [3] https://github.com/cwg45/Image-Reconstruction\n",
    "    \"\"\"\n",
    "    \n",
    "    assert beta > 0, 'step size must be a positive number'\n",
    "    assert steps > 0, 'steps must be a positive number'\n",
    "    assert mode == 'input-output' or mode == 'output-output'\\\n",
    "        or mode == 'hybrid',\\\n",
    "    'mode must be \\'input-output\\', \\'output-output\\' or \\'hybrid\\''\n",
    "    \n",
    "    if mask is None:\n",
    "        mask = np.ones(mag.shape)\n",
    "        \n",
    "    assert mag.shape == mask.shape, 'mask and mag must have same shape'\n",
    "    \n",
    "    # sample random phase and initialize image x \n",
    "    y_hat = mag*np.exp(1j*2*np.pi*np.random.rand(*mag.shape))\n",
    "    x = np.zeros(mag.shape)\n",
    "    \n",
    "    # previous iterate\n",
    "    x_p = None\n",
    "        \n",
    "    # main loop\n",
    "    for i in range(1, steps+1):\n",
    "        # show progress\n",
    "        if i % 100 == 0 and verbose: \n",
    "            print(\"step\", i, \"of\", steps)\n",
    "        \n",
    "        # inverse fourier transform\n",
    "        y = np.real(np.fft.ifft2(y_hat))\n",
    "        \n",
    "        # previous iterate\n",
    "        if x_p is None:\n",
    "            x_p = y\n",
    "        else:\n",
    "            x_p = x \n",
    "        \n",
    "        # updates for elements that satisfy object domain constraints\n",
    "        if mode == \"output-output\" or mode == \"hybrid\":\n",
    "            x = y\n",
    "            \n",
    "        # find elements that violate object domain constraints \n",
    "        # or are not masked\n",
    "        indices = np.logical_or(np.logical_and(y<0, mask), \n",
    "                                np.logical_not(mask))\n",
    "        \n",
    "        # updates for elements that violate object domain constraints\n",
    "        if mode == \"hybrid\" or mode == \"input-output\":\n",
    "            x[indices] = x_p[indices]-beta*y[indices] \n",
    "        elif mode == \"output-output\":\n",
    "            x[indices] = y[indices]-beta*y[indices] \n",
    "        \n",
    "        # fourier transform\n",
    "        x_hat = np.fft.fft2(x)\n",
    "        \n",
    "        # satisfy fourier domain constraints\n",
    "        # (replace magnitude with input magnitude)\n",
    "        y_hat = mag*np.exp(1j*np.angle(x_hat))\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HIO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_mnist.shape\n",
    "pad = height//2\n",
    "x_test_mnist_prep = np.array([image_magnitudes_oversample(x, pad) for x in x_test_mnist])\n",
    "print(x_test_mnist_prep.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [09:09<00:00, 18.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.533217767885693 3.8119663142058\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_mnist\n",
    "psnr_list_max = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_mnist_prep[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100, mask=m, verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_mnist[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_mnist[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max.append(psnr)\n",
    "    \n",
    "psnr_list_max = np.array(psnr_list_max)\n",
    "print(np.mean(psnr_list_max), np.std(psnr_list_max))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gerchberg-Saxton"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_mnist.shape\n",
    "pad = height//2\n",
    "x_test_mnist_prep_u = np.array([image_magnitudes_oversample(x, pad) for x in x_test_mnist])\n",
    "print(x_test_mnist_prep_u.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [08:58<00:00, 18.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.817387072364363 2.4410112407637867\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_mnist\n",
    "psnr_list_max_u = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_mnist_prep_u[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100, beta=1, mask=m, mode='output-output', verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_mnist[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_mnist[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_u.append(psnr)\n",
    "    \n",
    "psnr_list_max_u = np.array(psnr_list_max_u)\n",
    "print(np.mean(psnr_list_max_u), np.std(psnr_list_max_u))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input-Output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_mnist.shape\n",
    "pad = height//2\n",
    "x_test_mnist_prep_io = np.array([image_magnitudes_oversample(x, pad) for x in x_test_mnist])\n",
    "print(x_test_mnist_prep_io.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [08:45<00:00, 19.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.799093051755992 1.3515794111192863\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_mnist\n",
    "psnr_list_max_io = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_mnist_prep_io[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100, mask=m, mode='input-output', verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_mnist[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_mnist[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_io.append(psnr)\n",
    "    \n",
    "psnr_list_max_io = np.array(psnr_list_max_io)\n",
    "print(np.mean(psnr_list_max_io), np.std(psnr_list_max_io))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# EMNSIT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HIO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(24800, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_emnist.shape\n",
    "pad = height//2\n",
    "x_test_emnist_prep = np.array([image_magnitudes_oversample(x, pad) for x in x_test_emnist])\n",
    "print(x_test_emnist_prep.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 24800/24800 [22:44<00:00, 18.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.809815104812078 3.934378529537481\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_emnist\n",
    "psnr_list_max_emnist = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_emnist_prep[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100, mask=m, verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_emnist[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_emnist[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_emnist.append(psnr)\n",
    "    \n",
    "psnr_list_max_emnist = np.array(psnr_list_max_emnist)\n",
    "print(np.mean(psnr_list_max_emnist), np.std(psnr_list_max_emnist))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gerchberg-Saxton"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(24800, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_emnist.shape\n",
    "pad = height//2\n",
    "x_test_emnist_prep_u = np.array([image_magnitudes_oversample(x, pad) for x in x_test_emnist])\n",
    "print(x_test_emnist_prep_u.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 24800/24800 [22:40<00:00, 18.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.985711223304026 2.4161293402906083\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_emnist\n",
    "psnr_list_max_emnist_u = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_emnist_prep_u[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100,beta=1, mask=m, mode='output-output', verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_emnist[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_emnist[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_emnist_u.append(psnr)\n",
    "    \n",
    "psnr_list_max_emnist_u = np.array(psnr_list_max_emnist_u)\n",
    "print(np.mean(psnr_list_max_emnist_u), np.std(psnr_list_max_emnist_u))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input-Output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(24800, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_emnist.shape\n",
    "pad = height//2\n",
    "x_test_emnist_prep_io = np.array([image_magnitudes_oversample(x, pad) for x in x_test_emnist])\n",
    "print(x_test_emnist_prep_io.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 24800/24800 [22:03<00:00, 18.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.852396506159906 1.4600540026659212\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_emnist\n",
    "psnr_list_max_emnist_io = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_emnist_prep_io[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100, mask=m, mode='input-output', verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_emnist[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_emnist[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_emnist_io.append(psnr)\n",
    "    \n",
    "psnr_list_max_emnist_io = np.array(psnr_list_max_emnist_io)\n",
    "print(np.mean(psnr_list_max_emnist_io), np.std(psnr_list_max_emnist_io))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FMNIST"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HIO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_fmnist.shape\n",
    "pad = height//2\n",
    "x_test_fmnist_prep = np.array([image_magnitudes_oversample(x, pad) for x in x_test_fmnist])\n",
    "print(x_test_fmnist_prep.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [09:14<00:00, 18.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "14.061361823805145 8.536005081683598\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_fmnist\n",
    "psnr_list_max_fmnist = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_fmnist_prep[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100, mask=m, verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_fmnist[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_fmnist[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_fmnist.append(psnr)\n",
    "    \n",
    "psnr_list_max_fmnist = np.array(psnr_list_max_fmnist)\n",
    "print(np.mean(psnr_list_max_fmnist), np.std(psnr_list_max_fmnist))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gerchberg-Saxton"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_fmnist.shape\n",
    "pad = height//2\n",
    "x_test_fmnist_prep_u = np.array([image_magnitudes_oversample(x, pad) for x in x_test_fmnist])\n",
    "print(x_test_fmnist_prep_u.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [09:06<00:00, 18.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11.245463183268203 3.632738956327884\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_fmnist\n",
    "psnr_list_max_fmnist_u = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_fmnist_prep_u[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100, beta=1, mask=m, mode='output-output', verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_fmnist[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_fmnist[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_fmnist_u.append(psnr)\n",
    "    \n",
    "psnr_list_max_fmnist = np.array(psnr_list_max_fmnist_u)\n",
    "print(np.mean(psnr_list_max_fmnist_u), np.std(psnr_list_max_fmnist_u))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input-Output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_fmnist.shape\n",
    "pad = height//2\n",
    "x_test_fmnist_prep_io = np.array([image_magnitudes_oversample(x, pad) for x in x_test_fmnist])\n",
    "print(x_test_fmnist_prep_io.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [08:56<00:00, 18.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8.737355971097722 2.625021556657955\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_fmnist\n",
    "psnr_list_max_fmnist_io = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_fmnist_prep_io[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100, mask=m, mode='input-output', verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_fmnist[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_fmnist[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_fmnist_io.append(psnr)\n",
    "    \n",
    "psnr_list_max_fmnist_io = np.array(psnr_list_max_fmnist_io)\n",
    "print(np.mean(psnr_list_max_fmnist_io), np.std(psnr_list_max_fmnist_io))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SVHN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HIO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(26032, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_svhn.shape\n",
    "pad = height//2\n",
    "x_test_svhn_prep = np.array([image_magnitudes_oversample(x, pad) for x in x_test_svhn])\n",
    "print(x_test_svhn_prep.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 26032/26032 [23:41<00:00, 18.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "31.89992718180047 16.449739338830508\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_svhn\n",
    "psnr_list_max_svhn = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_svhn_prep[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100, mask=m, verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_svhn[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_svhn[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_svhn.append(psnr)\n",
    "    \n",
    "psnr_list_max_svhn = np.array(psnr_list_max_svhn)\n",
    "print(np.mean(psnr_list_max_svhn), np.std(psnr_list_max_svhn))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gerchberg-Saxton"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(26032, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_svhn.shape\n",
    "pad = height//2\n",
    "x_test_svhn_prep_u = np.array([image_magnitudes_oversample(x, pad) for x in x_test_svhn])\n",
    "print(x_test_svhn_prep_u.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 26032/26032 [23:14<00:00, 18.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17.888926745678102 3.77238934951962\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_svhn\n",
    "psnr_list_max_svhn_u = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_svhn_prep_u[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100,beta=1, mask=m,mode='output-output', verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_svhn[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_svhn[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_svhn_u.append(psnr)\n",
    "    \n",
    "psnr_list_max_svhn_u = np.array(psnr_list_max_svhn_u)\n",
    "print(np.mean(psnr_list_max_svhn_u), np.std(psnr_list_max_svhn_u))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input-Output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(26032, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_svhn.shape\n",
    "pad = height//2\n",
    "x_test_svhn_prep_io = np.array([image_magnitudes_oversample(x, pad) for x in x_test_svhn])\n",
    "print(x_test_svhn_prep_io.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 26032/26032 [22:39<00:00, 19.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6.680730588725579 1.8482517180509195\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_svhn\n",
    "psnr_list_max_svhn_io = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_svhn_prep_io[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100, mask=m,mode='input-output', verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_svhn[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_svhn[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_svhn_io.append(psnr)\n",
    "    \n",
    "psnr_list_max_svhn_io = np.array(psnr_list_max_svhn_io)\n",
    "print(np.mean(psnr_list_max_svhn_io), np.std(psnr_list_max_svhn_io))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CIFAR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HIO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_cifar.shape\n",
    "pad = height//2\n",
    "x_test_cifar_prep = np.array([image_magnitudes_oversample(x, pad) for x in x_test_cifar])\n",
    "print(x_test_cifar_prep.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [08:53<00:00, 18.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28.337736982769762 13.922134641919158\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_cifar\n",
    "psnr_list_max_cifar = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_cifar_prep[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100, mask=m, verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_cifar[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_cifar[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_cifar.append(psnr)\n",
    "    \n",
    "psnr_list_max_cifar = np.array(psnr_list_max_cifar)\n",
    "print(np.mean(psnr_list_max_cifar), np.std(psnr_list_max_cifar))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gerchberg-Saxton"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_cifar.shape\n",
    "pad = height//2\n",
    "x_test_cifar_prep_u = np.array([image_magnitudes_oversample(x, pad) for x in x_test_cifar])\n",
    "print(x_test_cifar_prep_u.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [08:48<00:00, 18.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16.384498253497863 3.081373773670048\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_cifar\n",
    "psnr_list_max_cifar_u = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_cifar_prep_u[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100, beta=1, mask=m, mode='output-output', verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_cifar[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_cifar[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_cifar_u.append(psnr)\n",
    "    \n",
    "psnr_list_max_cifar_u = np.array(psnr_list_max_cifar_u)\n",
    "print(np.mean(psnr_list_max_cifar_u), np.std(psnr_list_max_cifar_u))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input-Output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 2, 64, 64)\n"
     ]
    }
   ],
   "source": [
    "N, height, width = x_test_cifar.shape\n",
    "pad = height//2\n",
    "x_test_cifar_prep_io = np.array([image_magnitudes_oversample(x, pad) for x in x_test_cifar])\n",
    "print(x_test_cifar_prep_io.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [08:44<00:00, 19.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7.794638984232505 1.7285389535306375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(417)\n",
    "n_test = n_test_cifar\n",
    "psnr_list_max_cifar_io = []\n",
    "\n",
    "for i in tqdm(range(n_test)):\n",
    "    mag, m = x_test_cifar_prep_io[i]\n",
    "    rec = fienup_phase_retrieval(mag, steps=100, mask=m, mode='input-output', verbose=False)\n",
    "    psnr1 = psnr_crop(rec, x_test_cifar[i], pad)\n",
    "    psnr2 = psnr_register_flip(rec, x_test_cifar[i], pad)\n",
    "    psnr = max(psnr1, psnr2)\n",
    "    psnr_list_max_cifar_io.append(psnr)\n",
    "    \n",
    "psnr_list_max_cifar_io = np.array(psnr_list_max_cifar_io)\n",
    "print(np.mean(psnr_list_max_cifar_io), np.std(psnr_list_max_cifar_io))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
