{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import h5py\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "\n",
    "def to_float(arr, data_type):\n",
    "    if data_type == \"int8\":\n",
    "        res = np.clip(arr / 255.0, 0.0, 1.0)\n",
    "    elif data_type == \"int16\":\n",
    "        res = np.clip(arr / 4096.0, 0.0, 1.0)\n",
    "    else:\n",
    "        raise ValueError(\"Select an appropriate data type.\")\n",
    "    return res\n",
    "\n",
    "def handle_labels(arr, key_txt):\n",
    "    key_array = np.loadtxt(key_txt)\n",
    "    trans_arr = arr\n",
    "\n",
    "    for translation in key_array:\n",
    "        # translation is (src label, dst label)\n",
    "        scr_l, dst_l = translation\n",
    "        if scr_l != dst_l:\n",
    "            trans_arr[trans_arr == scr_l] = dst_l\n",
    "\n",
    "    # translated array\n",
    "    return trans_arr\n",
    "\n",
    "def read_a_patch(fn):\n",
    "    input_size = 240\n",
    "    hr_label_index = 8\n",
    "    hr_label_key = \"data/cheaseapeake_to_hr_labels.txt\"\n",
    "\n",
    "    if fn.endswith(\".npz\"):\n",
    "        dl = np.load(fn)\n",
    "        data = dl[\"arr_0\"].squeeze()\n",
    "        dl.close()\n",
    "    else:\n",
    "        data = np.load(fn).squeeze()\n",
    "\n",
    "    # do a random crop if input_size is less than the prescribed size\n",
    "    assert data.shape[1] == data.shape[2]\n",
    "    data_size = data.shape[1]\n",
    "    if input_size < data_size:\n",
    "        x_idx = np.random.randint(0, data_size - input_size)\n",
    "        y_idx = np.random.randint(0, data_size - input_size)\n",
    "        data = data[:, y_idx: y_idx + input_size, x_idx: x_idx + input_size]\n",
    "    x = to_float(data[0:4, :, :], \"int8\").astype(np.float32)\n",
    "    y_hr = handle_labels(data[hr_label_index:hr_label_index+1, :, :], hr_label_key).astype(np.float32)\n",
    "    \n",
    "    return np.concatenate((x, y_hr), axis=0)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Processed  1000 / 50000 patches.\n",
      "Processed  2000 / 50000 patches.\n",
      "Processed  3000 / 50000 patches.\n",
      "Processed  4000 / 50000 patches.\n",
      "Processed  5000 / 50000 patches.\n",
      "Processed  6000 / 50000 patches.\n",
      "Processed  7000 / 50000 patches.\n",
      "Processed  8000 / 50000 patches.\n",
      "Processed  9000 / 50000 patches.\n",
      "Processed 10000 / 50000 patches.\n",
      "Processed 11000 / 50000 patches.\n",
      "Processed 12000 / 50000 patches.\n",
      "Processed 13000 / 50000 patches.\n",
      "Processed 14000 / 50000 patches.\n",
      "Processed 15000 / 50000 patches.\n",
      "Processed 16000 / 50000 patches.\n",
      "Processed 17000 / 50000 patches.\n",
      "Processed 18000 / 50000 patches.\n",
      "Processed 19000 / 50000 patches.\n",
      "Processed 20000 / 50000 patches.\n",
      "Processed 21000 / 50000 patches.\n",
      "Processed 22000 / 50000 patches.\n",
      "Processed 23000 / 50000 patches.\n",
      "Processed 24000 / 50000 patches.\n",
      "Processed 25000 / 50000 patches.\n",
      "Processed 26000 / 50000 patches.\n",
      "Processed 27000 / 50000 patches.\n",
      "Processed 28000 / 50000 patches.\n",
      "Processed 29000 / 50000 patches.\n",
      "Processed 30000 / 50000 patches.\n",
      "Processed 31000 / 50000 patches.\n",
      "Processed 32000 / 50000 patches.\n",
      "Processed 33000 / 50000 patches.\n",
      "Processed 34000 / 50000 patches.\n",
      "Processed 35000 / 50000 patches.\n",
      "Processed 36000 / 50000 patches.\n",
      "Processed 37000 / 50000 patches.\n",
      "Processed 38000 / 50000 patches.\n",
      "Processed 39000 / 50000 patches.\n",
      "Processed 40000 / 50000 patches.\n",
      "Processed 41000 / 50000 patches.\n",
      "Processed 42000 / 50000 patches.\n",
      "Processed 43000 / 50000 patches.\n",
      "Processed 44000 / 50000 patches.\n",
      "Processed 45000 / 50000 patches.\n",
      "Processed 46000 / 50000 patches.\n",
      "Processed 47000 / 50000 patches.\n",
      "Processed 48000 / 50000 patches.\n",
      "Processed 49000 / 50000 patches.\n",
      "Processed 50000 / 50000 patches.\n",
      "Processed  1000 /  2500 patches.\n",
      "Processed  2000 /  2500 patches.\n"
     ]
    }
   ],
   "source": [
    "state = 'md_1m_2013'\n",
    "\n",
    "dataset_dir = Path(\"/scratch/forest/datasets/chesapeake_data\")\n",
    "output_dir = Path(\"/scratch/forest/datasets/chesapeake_data_hdf5\")\n",
    "output_dir.mkdir(exist_ok=True)\n",
    "\n",
    "# Process training data\n",
    "dataset_type = 'train'\n",
    "csv_path = Path(f\"{state}_extended-{dataset_type}_patches.csv\")\n",
    "patches = []\n",
    "for fn in pd.read_csv((dataset_dir / csv_path).__str__())[\"patch_fn\"].values:\n",
    "    patches.append((dataset_dir / Path(fn)).__str__())\n",
    "n_patch = len(patches)\n",
    "with h5py.File(f\"/scratch/forest/datasets/chesapeake_data_hdf5/{state}_{dataset_type}.hdf5\", \"w\") as f:\n",
    "    dset = f.create_dataset(\"dataset\", (n_patch,5,240,240), dtype=np.float32)\n",
    "    for idx, fn in enumerate(patches):\n",
    "        data = read_a_patch(fn)\n",
    "        dset[idx, :, :, :] = data\n",
    "        if (idx + 1) % 1000 == 0:\n",
    "            print(f\"Processed {idx+1:5d} / {len(patches):5d} patches.\")\n",
    "\n",
    "# Process validation data\n",
    "dataset_type = 'val'\n",
    "csv_path = Path(f\"{state}_extended-{dataset_type}_patches.csv\")\n",
    "patches = []\n",
    "for fn in pd.read_csv((dataset_dir / csv_path).__str__())[\"patch_fn\"].values:\n",
    "    patches.append((dataset_dir / Path(fn)).__str__())\n",
    "n_patch = len(patches)\n",
    "with h5py.File(f\"/scratch/forest/datasets/chesapeake_data_hdf5/{state}_{dataset_type}.hdf5\", \"w\") as f:\n",
    "    dset = f.create_dataset(\"dataset\", (n_patch,5,240,240), dtype=np.float32)\n",
    "    for idx, fn in enumerate(patches):\n",
    "        data = read_a_patch(fn)\n",
    "        dset[idx, :, :, :] = data\n",
    "        if (idx + 1) % 1000 == 0:\n",
    "            print(f\"Processed {idx+1:5d} / {len(patches):5d} patches.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "f = h5py.File(f\"/scratch/forest/datasets/chesapeake_data_hdf5/{state}_train.hdf5\", 'r')\n",
    "dset = f[\"dataset\"]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.0 64-bit ('forest-segmentation': conda)",
   "metadata": {
    "interpreter": {
     "hash": "ea8ccbf35921feb25ac6076a977d10ed228a087a1890f38b57c754fe65c733fe"
    }
   }
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}