{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Stare.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "Lz4Hj26nCn8R",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "! nvidia-smi"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3Bj1aK7kFBy9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import os\n",
        "import glob\n",
        "import math\n",
        "import numpy as np\n",
        "import tensorflow as tf"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s5CsOJmvCsyd",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from src import dataset\n",
        "from src import preprocess\n",
        "from src import generator\n",
        "from src import evaluate\n",
        "from src import sprunet\n",
        "from src import metrics\n",
        "from src import losses"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ws_cFnkeFuSj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def shuffle(img_files, msk_files, seed):\n",
        "  for i in range(10):\n",
        "    np.random.seed(seed*i)\n",
        "    np.random.shuffle(img_files)\n",
        "    np.random.seed(seed*i)\n",
        "    np.random.shuffle(msk_files)\n",
        "  return img_files, msk_files"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cTsWbpBfoRBA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "dataset.prepare_stare_dataset()\n",
        "\n",
        "stare_img_files = sorted(glob.glob('data/stare_dataset/images/*'))\n",
        "stare_msk_files = sorted(glob.glob('data/stare_dataset/masks/*'))\n",
        "\n",
        "seed = 50\n",
        "stare_img_files, stare_msk_files =  shuffle(stare_img_files, stare_msk_files, seed)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xcGG_d4PGnbr",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "IMG_DIM = 512\n",
        "\n",
        "img_files = stare_img_files\n",
        "msk_files = stare_msk_files"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FHhMsvqcHKiC",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "imgs = []\n",
        "msks = []\n",
        "\n",
        "for i, j in zip(img_files, msk_files):\n",
        "  img = preprocess.read_img(i, (IMG_DIM, IMG_DIM))\n",
        "  msk = preprocess.read_msk(j, (IMG_DIM, IMG_DIM))\n",
        "  imgs.append(img)\n",
        "  msks.append(msk)\n",
        "\n",
        "imgs = np.array(imgs)\n",
        "msks = np.array(msks)\n",
        "\n",
        "print(imgs.shape)\n",
        "print(msks.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "J4FQjzKvHroJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "enhanced_imgs = np.array([preprocess.enhance_image_v3(i) for i in imgs])\n",
        "print(enhanced_imgs.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2oodFuw8H7yL",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_samples = 16\n",
        "\n",
        "train_imgs = enhanced_imgs[:train_samples]\n",
        "train_msks = msks[:train_samples]\n",
        "\n",
        "valid_imgs = enhanced_imgs[train_samples:]\n",
        "valid_msks = msks[train_samples:]\n",
        "\n",
        "print(train_imgs.shape)\n",
        "print(valid_imgs.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RyLQyplVJPuv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "true_thresh = 50\n",
        "batch_size = 5"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GFkNtQqiIWvm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_generator = generator.Dataset(train_imgs, train_msks, true_thresh, batch_size, True).get_generator()\n",
        "valid_generator = generator.Dataset(valid_imgs, valid_msks, true_thresh, batch_size, False).get_generator()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MGWB3XO6Jlkd",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "for i,j in valid_generator:\n",
        "  break"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_tJT_7v1JpXi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "evaluate.show_images(i[0:3,:,:,0:3], 'Train Images', (9,3.5))\n",
        "evaluate.show_images(i[0:3,:,:,3:6], 'Train Images', (9,3.5))\n",
        "evaluate.show_images(j[0:3,:,:,0], 'Train Masks', (9,3.5))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0XAsOSNtqR3e",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model = sprunet.get_model((IMG_DIM, IMG_DIM, 6), drop=0.5)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LWdZwuznu-4d",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "base_learning_rate = 0.001\n",
        "model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),\n",
        "              loss=[losses.bce_dice_loss],\n",
        "              metrics=[metrics.iou])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0q1P4qwOu7S0",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "if not os.path.exists('weights'):\n",
        "  os.mkdir('weights')\n",
        "\n",
        "cpk_path = 'weights/weights_{epoch:03d}-{val_iou:.4f}.h5'\n",
        "\n",
        "checkpoint = tf.keras.callbacks.ModelCheckpoint(\n",
        "  filepath=cpk_path,\n",
        "  monitor='val_iou',\n",
        "  mode='max',\n",
        "  verbose=1,\n",
        "  save_best_only=True,\n",
        "  save_weights_only=True,\n",
        "  save_freq='epoch'\n",
        "  )\n",
        "\n",
        "reducelr = tf.keras.callbacks.ReduceLROnPlateau(\n",
        "  monitor='val_loss',\n",
        "  factor=0.1,\n",
        "  patience=5,\n",
        "  verbose=1)\n",
        "\n",
        "earlystop = tf.keras.callbacks.EarlyStopping(\n",
        "  monitor='val_loss',  \n",
        "  patience=10, \n",
        "  verbose=1, \n",
        "  restore_best_weights=True\n",
        ")\n",
        "\n",
        "callbacks = [checkpoint, reducelr, earlystop]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VBAHlHLZu4ES",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "start_epoch = 0\n",
        "final_epoch = 50\n",
        "train_steps = math.ceil(len(train_imgs)/batch_size)*50\n",
        "valid_steps = math.ceil(len(valid_imgs)/batch_size)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iVQkbmg5vdgk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "history = model.fit(\n",
        "  x=train_generator,\n",
        "  steps_per_epoch=train_steps,\n",
        "  initial_epoch=start_epoch,\n",
        "  epochs=final_epoch,\n",
        "  validation_data=valid_generator,\n",
        "  validation_steps = valid_steps,\n",
        "  callbacks=callbacks,\n",
        "  verbose=1\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_GzCjuJ5faqg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import pickle\n",
        "with open('history', 'wb') as file_pi:\n",
        "  pickle.dump(history.history, file_pi)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DWXQy-1Dw43T",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "acc = history.history['iou']\n",
        "val_acc = history.history['val_iou']\n",
        "\n",
        "loss = history.history['loss']\n",
        "val_loss = history.history['val_loss']"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "A6E9zkPNy6Hv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "evaluate.plot_training_curves(loss, val_loss, acc, val_acc, (16,5))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fhs34zzjzATt",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "x_true = valid_imgs/255\n",
        "\n",
        "y_true = np.expand_dims(valid_msks, -1)\n",
        "y_true = (y_true>true_thresh).astype(np.uint8)\n",
        "\n",
        "y_pred = model.predict(x_true)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iGiG6tUjy8xN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "pred_thresh = 0.9"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GFIyRKI8zIbf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "score = evaluate.calc_score(y_true, y_pred, thresh=pred_thresh)\n",
        "evaluate.print_score(score)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GZocDr1bzPDD",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "evaluate.show_images(imgs[16:19,:,:,0:3], 'Images', (9,3.5))\n",
        "evaluate.show_images(y_true[0:3,:,:,0], 'Ground Truth Masks', (9,3.5))\n",
        "evaluate.show_images(y_pred[0:3,:,:,0], 'Predicted Masks', (9,3.5))"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}