{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "10886849-81af-49c7-adce-895290141663",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Intro\n",
    "\n",
    "This notebook will create the semi-synthetic Grassy MNIST dataset used in the experiments for \"Feature Selection in the Contrastive Analysis Setting\". Much of this code was adapated from that of \"Exploring Patterns Enriched in a Dataset with Contrastive Principal Component Analysis\", Nature Communications (2018) (https://github.com/abidlabs/contrastive).\n",
    "\n",
    "The notebook will automatically download any necessary files."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d99842b-54c5-4f17-8a21-925159010af2",
   "metadata": {},
   "source": [
    "### First, we download the images of grass from imagenet\n",
    "\n",
    "There are a lot of links to query here, so this step can take a while (~15-20 minutes in my case.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bb9335b-f2d5-42ca-b6d9-e3403b25f045",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▊            | 1143/1275 [40:24<01:25,  1.55it/s]"
     ]
    }
   ],
   "source": [
    "import urllib\n",
    "import requests\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "def url_to_image(url):\n",
    "\t# Download the image, convert it to a NumPy array, and then read\n",
    "\t# it into OpenCV format\n",
    "\tresp = urllib.request.urlopen(url)\n",
    "\timage = np.asarray(bytearray(resp.read()), dtype=\"uint8\")\n",
    "\timage = cv2.imdecode(image, cv2.IMREAD_COLOR)\n",
    " \n",
    "\t# return the image\n",
    "\treturn image\n",
    "\n",
    "\n",
    "# URL that contains a list of links to images identified by the ImageNet as team as \n",
    "# containing grass\n",
    "grass_url = \"https://image-net.org/api/imagenet.synset.geturls?wnid=n12102133\"\n",
    "page = requests.get(grass_url)\n",
    "img_url_list = page.content.decode('UTF-8').split(\"\\r\\n\")\n",
    "\n",
    "\n",
    "# Create a new folder to store our grassy images\n",
    "grassy_img_path = \"./grassy_imgs\"\n",
    "os.makedirs(grassy_img_path, exist_ok=True)\n",
    "\n",
    "\n",
    "for i, img_url in enumerate(tqdm(img_url_list)):\n",
    "    # Unfortunately some of the links in our list are broken (ImageNet has been around a long time!)\n",
    "    # so we'll need to watch out for any exceptions when attempting to download them.\n",
    "    try:\n",
    "        I = url_to_image(img_url)\n",
    "        save_path = \"./grassy_imgs/{}.jpg\".format(i)#create a name of each image\n",
    "        cv2.imwrite(save_path,I)\n",
    "\n",
    "    except:\n",
    "        None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd79498e-d6b3-4472-8f20-7e6bde579a3b",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Next, we load our downloaded images into memory and display some as a sanity check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "966ba776-f24a-48b0-9aef-5d4a118305a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "\n",
    "def resize_and_crop(img, size=(100,100), crop_type='middle'):\n",
    "    # If height is higher we resize vertically, if not we resize horizontally\n",
    "    # Get current and desired ratio for the images\n",
    "    img_ratio = img.size[0] / float(img.size[1])\n",
    "    ratio = size[0] / float(size[1])\n",
    "    # The image is scaled/cropped vertically or horizontally\n",
    "    # depending on the ratio\n",
    "    if ratio > img_ratio:\n",
    "        img = img.resize((\n",
    "            size[0],\n",
    "            int(round(size[0] * img.size[1] / img.size[0]))),\n",
    "            Image.ANTIALIAS)\n",
    "        # Crop in the top, middle or bottom\n",
    "        if crop_type == 'top':\n",
    "            box = (0, 0, img.size[0], size[1])\n",
    "        elif crop_type == 'middle':\n",
    "            box = (\n",
    "                0,\n",
    "                int(round((img.size[1] - size[1]) / 2)),\n",
    "                img.size[0],\n",
    "                int(round((img.size[1] + size[1]) / 2)))\n",
    "        elif crop_type == 'bottom':\n",
    "            box = (0, img.size[1] - size[1], img.size[0], img.size[1])\n",
    "        else:\n",
    "            raise ValueError('ERROR: invalid value for crop_type')\n",
    "        img = img.crop(box)\n",
    "    elif ratio < img_ratio:\n",
    "        img = img.resize((\n",
    "            int(round(size[1] * img.size[0] / img.size[1])),\n",
    "            size[1]),\n",
    "            Image.ANTIALIAS)\n",
    "        # Crop in the top, middle or bottom\n",
    "        if crop_type == 'top':\n",
    "            box = (0, 0, size[0], img.size[1])\n",
    "        elif crop_type == 'middle':\n",
    "            box = (\n",
    "                int(round((img.size[0] - size[0]) / 2)),\n",
    "                0,\n",
    "                int(round((img.size[0] + size[0]) / 2)),\n",
    "                img.size[1])\n",
    "        elif crop_type == 'bottom':\n",
    "            box = (\n",
    "                img.size[0] - size[0],\n",
    "                0,\n",
    "                img.size[0],\n",
    "                img.size[1])\n",
    "        else:\n",
    "            raise ValueError('ERROR: invalid value for crop_type')\n",
    "        img = img.crop(box)\n",
    "    else:\n",
    "        img = img.resize((\n",
    "            size[0],\n",
    "            size[1]),\n",
    "            Image.ANTIALIAS)\n",
    "    # If the scale is the same, we do not need to crop\n",
    "    return img\n",
    "\n",
    "\n",
    "\n",
    "natural_images = list() # Dictionary of pictures indexed by the pic # and each value is 100x100 image\n",
    "for filename in tqdm(os.listdir(grassy_img_path)):\n",
    "    if filename.endswith(\".JPEG\") or filename.endswith(\".JPG\") or filename.endswith(\".jpg\"):\n",
    "        im = Image.open(os.path.join(grassy_img_path, filename))\n",
    "        im = im.convert(mode=\"L\") #convert to grayscale\n",
    "        im = resize_and_crop(im) #resize and crop each picture to be 100px by 100px\n",
    "        natural_images.append(np.reshape(im, [10000])) \n",
    "            \n",
    "natural_images=np.asarray(natural_images,dtype=float)\n",
    "natural_images/=255 #rescale to be 0-1\n",
    "print(\"Array of grass images:\",natural_images.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a2e81f4-db75-493d-b106-ad2a499f7607",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(natural_images[0].reshape(100, 100), cmap='gray')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c7bb04e-f7df-4189-8243-a57ef0f2a98b",
   "metadata": {},
   "source": [
    "### Next, we download the MNIST dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e8f7e4a-adb0-45fe-aa69-bc49985e3ad6",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_folder = \"./raw/\"\n",
    "os.makedirs(raw_folder, exist_ok=True)\n",
    "\n",
    "mirrors = [\n",
    "    'http://yann.lecun.com/exdb/mnist/',\n",
    "    'https://ossci-datasets.s3.amazonaws.com/mnist/',\n",
    "]\n",
    "\n",
    "resources = [\n",
    "    (\"train-images-idx3-ubyte.gz\", \"f68b3c2dcbeaaa9fbdd348bbdeb94873\"),\n",
    "    (\"train-labels-idx1-ubyte.gz\", \"d53e105ee54ea40749a09fcbcd1e9432\"),\n",
    "    (\"t10k-images-idx3-ubyte.gz\", \"9fb629c4189551a2d022fa330f9573f3\"),\n",
    "    (\"t10k-labels-idx1-ubyte.gz\", \"ec29112dd5afa0611ce80d1b7f02629c\")\n",
    "]\n",
    "\n",
    "# download files\n",
    "for filename, md5 in resources:\n",
    "    for mirror in mirrors:\n",
    "        url = \"{}{}\".format(mirror, filename)\n",
    "        try:\n",
    "            print(\"Downloading {}\".format(url))\n",
    "            r = requests.get(url)\n",
    "\n",
    "            with open(raw_folder + filename, 'wb') as f:\n",
    "                f.write(r.content)\n",
    "            \n",
    "        except URLError as error:\n",
    "            print(\n",
    "                \"Failed to download (trying next):\\n{}\".format(error)\n",
    "            )\n",
    "            continue\n",
    "        finally:\n",
    "            print()\n",
    "        break\n",
    "    else:\n",
    "        raise RuntimeError(\"Error downloading {}\".format(filename))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3e3f538-5c68-40dd-b468-1aefa54680fc",
   "metadata": {},
   "source": [
    "### Load MNIST into memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c10a2ae-5232-4ae1-82aa-b183f8d89b72",
   "metadata": {},
   "outputs": [],
   "source": [
    "import struct\n",
    "import gzip\n",
    "\n",
    "def _load_uint8(f):\n",
    "    idx_dtype, ndim = struct.unpack('BBBB', f.read(4))[2:]\n",
    "    shape = struct.unpack('>' + 'I' * ndim, f.read(4 * ndim))\n",
    "    buffer_length = int(np.prod(shape))\n",
    "    data = np.frombuffer(f.read(buffer_length), dtype=np.uint8).reshape(shape)\n",
    "    return data\n",
    "\n",
    "def load_idx(path: str) -> np.ndarray:\n",
    "    \"\"\"Reads an array in IDX format from disk.\n",
    "    Parameters\n",
    "    ----------\n",
    "    path : str\n",
    "        Path of the input file. Will uncompress with `gzip` if path ends in '.gz'.\n",
    "    Returns\n",
    "    -------\n",
    "    np.ndarray\n",
    "        Output array of dtype ``uint8``.\n",
    "    References\n",
    "    ----------\n",
    "    http://yann.lecun.com/exdb/mnist/\n",
    "    \"\"\"\n",
    "    open_fcn = gzip.open if path.endswith('.gz') else open\n",
    "    with open_fcn(path, 'rb') as f:\n",
    "        return _load_uint8(f)\n",
    "    \n",
    "data = load_idx(raw_folder + \"./train-images-idx3-ubyte.gz\").astype(\"float32\")\n",
    "data = data / 255\n",
    "\n",
    "labels = load_idx(raw_folder + \"./train-labels-idx1-ubyte.gz\").astype(\"long\")\n",
    "\n",
    "print(labels[0])\n",
    "plt.imshow(data[0], cmap='gray')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61ce0eaf-0dfe-47f1-b03d-f56571ea64e6",
   "metadata": {},
   "source": [
    "### Now we create Grassy MNIST by superimposing the MNIST digits onto our images of grass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d81e121c-e8d3-41e6-93c8-6b3f3bdca9d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0) # for reproducibility\n",
    "\n",
    "rand_indices =  np.random.permutation(natural_images.shape[0]) # just shuffles the indices\n",
    "split = int(len(rand_indices)/2)\n",
    "target_indices = rand_indices[0:split] # choose the first half of images to be superimposed on target\n",
    "background_indices = rand_indices[split:] # choose the second half of images to be background dataset\n",
    "\n",
    "target = np.zeros(foreground.shape)\n",
    "background = np.zeros(foreground.shape)\n",
    "\n",
    "for i in range(target.shape[0]):\n",
    "    idx = np.random.choice(target_indices) # randomly pick a image \n",
    "    loc = np.random.randint(70,size=(2)) # randomly pick a region in the image\n",
    "    superimposed_patch = np.reshape(np.reshape(natural_images[idx,:],[100,100])[loc[0]:loc[0]+28,:][:,loc[1]:loc[1]+28] ,[1,784])    \n",
    "    target[i] = 0.5*foreground[i] + superimposed_patch\n",
    "    \n",
    "    idx = np.random.choice(background_indices) # randomly pick a image \n",
    "    loc = np.random.randint(70,size=(2)) # randomly pick a region in the image\n",
    "    background_patch = np.reshape(np.reshape(natural_images[idx,:],[100,100])[loc[0]:loc[0]+28,:][:,loc[1]:loc[1]+28] ,[1,784])    \n",
    "    background[i] = background_patch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cddfcd69-abf1-4f57-a7e6-5ae875194fec",
   "metadata": {},
   "source": [
    "### We visualize some samples from our dataset as a sanity check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "686133bf-8b9d-467b-98a6-69db191394ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_show=8\n",
    "\n",
    "plt.figure(figsize=[12,12])\n",
    "\n",
    "# Target samples\n",
    "for i in range(n_show):\n",
    "    plt.subplot(1, n_show, i+1)\n",
    "    idx=np.random.randint(5000)\n",
    "    plt.imshow(np.reshape(target[i,:],[28,28]),cmap='gray', interpolation=\"bicubic\")\n",
    "    plt.axis('off')\n",
    "\n",
    "# Background samples\n",
    "plt.figure(figsize=[12,12])\n",
    "for i in range(n_show):\n",
    "    plt.subplot(1, n_show, i+1)\n",
    "    idx=np.random.randint(5000)\n",
    "    plt.imshow(np.reshape(background[i,:],[28,28]),cmap='gray', interpolation=\"bicubic\")\n",
    "    plt.axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cfea1ac-f4c0-42ef-bf26-8957e0f8db84",
   "metadata": {},
   "source": [
    "### Finally, we save the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e997e8f7-0da1-4a58-9c93-2200f5288b77",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs('data/Grassy_MNIST', exist_ok=True)\n",
    "\n",
    "np.save(\"data/Grassy_MNIST/background.npy\", background)\n",
    "np.save(\"data/Grassy_MNIST/target.npy\", target)\n",
    "np.save(\"data/Grassy_MNIST/target_labels.npy\", foreground_labels)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
