{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "\n",
    "import os, sys\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "from datasets.data_processing import read_tfrecords_nosaic_mnist, decode_nosaic_mnist, sequential_slice, binarize_labels_nosaic_mnist, normalize_images_nosaic_mnist\n",
    "from models.backbones_lstm import LSTMModel\n",
    "from models.losses_v2 import get_loss_lstm, threshold_loss\n",
    "from utils.misc import load_yaml, set_gpu_devices, fix_random_seed\n",
    "from utils.performance_metrics_stylish import confmx_to_metrics, dict_confmx_to_dict_metrics, calc_binary_llrs, binary_varthtruncated_sprt_with_llrs, binary_confmx_to_bac, binary_np_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# GPU settings\n",
    "set_gpu_devices(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameters\n",
    "name_subdir = \"0thO_testrun_stat\"\n",
    "name_trial = \"0thO_testrun_20200303_011718077\"\n",
    "order_sprt = 0 # make sure this is consistent with ckpt\n",
    "path_resume = \"/data/t-miyagawa/sprt/nosaic_mnist/ckptlogs/{}/{}\".format(name_subdir, name_trial)\n",
    "\n",
    "tfr_train = \"/data/t-miyagawa/nosaic_mnist/nosaic_mnist_train.tfrecords\"\n",
    "tfr_test = \"/data/t-miyagawa/nosaic_mnist/nosaic_mnist_test.tfrecords\"\n",
    "duration = 20\n",
    "batch_size = 200\n",
    "nb_cls = 2\n",
    "width_lstm = 1024 # make sure this is consistent with ckpt\n",
    "dropout = 0.\n",
    "activation = \"tanh\" # (maybe) make sure this is consistent with ckpt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Logit Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "##################################\n",
    "# Reed tfr and make \n",
    "parsed_image_dataset_train, parsed_image_dataset_valid, parsed_image_dataset_test = \\\n",
    "    read_tfrecords_nosaic_mnist(\n",
    "        record_file_train=tfr_train, \n",
    "        record_file_test=tfr_test, \n",
    "        batch_size=batch_size, \n",
    "        shuffle_buffer_size=2000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model\n",
    "##################################\n",
    "model = LSTMModel(\n",
    "    nb_cls=nb_cls, \n",
    "    width_lstm=width_lstm, \n",
    "    dropout=dropout, \n",
    "    activation=activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Restored from /data/t-miyagawa/sprt/nosaic_mnist/ckptlogs/0thO_testrun_stat/0thO_testrun_20200303_011718077/ckpt_step16380_mbac0.79547-4\n"
     ]
    }
   ],
   "source": [
    "# Restore parameters\n",
    "#################################\n",
    "#ckpt = tf.train.Checkpoint(step=global_step, optimizer=optimizer, net=model)\n",
    "ckpt = tf.train.Checkpoint(net=model)\n",
    "ckpt_manager_restore = tf.train.CheckpointManager(ckpt, path_resume, max_to_keep=3)\n",
    "ckpt.restore(ckpt_manager_restore.latest_checkpoint)\n",
    "print(\"Restored from {}\".format(ckpt_manager_restore.latest_checkpoint))  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"PeepholeLSTM_\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense (Dense)                multiple                  803840    \n",
      "_________________________________________________________________\n",
      "batch_normalization (BatchNo multiple                  4096      \n",
      "_________________________________________________________________\n",
      "activation_1 (Activation)    multiple                  0         \n",
      "_________________________________________________________________\n",
      "peephole_lstm_cell (Peephole multiple                  8395776   \n",
      "_________________________________________________________________\n",
      "rnn (RNN)                    multiple                  8395776   \n",
      "_________________________________________________________________\n",
      "batch_normalization_1 (Batch multiple                  4096      \n",
      "_________________________________________________________________\n",
      "activation_3 (Activation)    multiple                  0         \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              multiple                  2048      \n",
      "=================================================================\n",
      "Total params: 9,209,856\n",
      "Trainable params: 9,205,760\n",
      "Non-trainable params: 4,096\n",
      "_________________________________________________________________\n",
      "Evaluation Iter:  50Finish calc LLRs.\n"
     ]
    }
   ],
   "source": [
    "# Calc LLRs\n",
    "##################################\n",
    "# Calc logits\n",
    "for iter_b, feats in enumerate(parsed_image_dataset_valid): ############ train on validation dataset\n",
    "    cnt = iter_b + 1\n",
    "\n",
    "    # Decode features and binarize classification labels\n",
    "    x_batch, y_batch = decode_nosaic_mnist(feats) \n",
    "    x_batch = tf.reshape(x_batch, (-1, 20, 784))\n",
    "    x_batch = normalize_images_nosaic_mnist(x_batch)\n",
    "    y_batch = binarize_labels_nosaic_mnist(y_batch)\n",
    "    if iter_b == 0:\n",
    "        model.build(x_batch.shape)\n",
    "        model.summary()\n",
    "\n",
    "    # Calc loss, confmx, and mean hitting time \n",
    "    if iter_b == 0:\n",
    "        # Calc loss\n",
    "        _, logits_concat = get_loss_lstm(model, x_batch,y_batch, training=False, order_sprt=order_sprt)\n",
    "        logits_all = logits_concat\n",
    "        labels_all = y_batch\n",
    "\n",
    "    else:\n",
    "        losses_tmp, logits_concat = get_loss_lstm(model, x_batch, y_batch, training=False, order_sprt=order_sprt)\n",
    "        logits_all = tf.concat([logits_all, logits_concat], 0)\n",
    "        labels_all = tf.concat([labels_all, y_batch], 0)\n",
    "\n",
    "    # Verbose\n",
    "    if ((iter_b+1)%10 == 0) or (iter_b == 0):\n",
    "        sys.stdout.write(\"\\rEvaluation Iter: {:3d}\".format(iter_b+1))\n",
    "        sys.stdout.flush()\n",
    "        \n",
    "# Calc LLRs\n",
    "llrs_all = calc_binary_llrs(logits_all)\n",
    "print(\"Finish calc LLRs.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use llrs_all and labels_all from now on."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Threshld Adaptation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Examples of variable thresholds\n",
    "\n",
    "default\n",
    "\n",
    "linear\n",
    "\n",
    "quadratic\n",
    "\n",
    "adapted: optimizer??\n",
    "\n",
    "adapted vs adapted_scaled"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate variable thresholds\n",
    "duration = 20\n",
    "\n",
    "alpha = np.linspace(\n",
    "    start=1e-4,\n",
    "    stop=0.5,\n",
    "    num=duration,\n",
    "    endpoint=True)\n",
    "beta = alpha\n",
    "\n",
    "# Calc variable thresholds\n",
    "thresh = np.array([np.log(beta/(1-alpha)), np.log((1-beta)/alpha)])\n",
    "if not ( (np.prod(np.int32(thresh[1] >= thresh[0])) != 0) and (np.prod(np.int32(thresh[1] * thresh[0] <= 0)) != 0) ):\n",
    "    raise ValueError(\"thresh must be thresh[1] >= thresh[0] and thresh[1] * thresh[0] < 0. Now thresh = {}\".format(thresh))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Truncated SPRT\n",
    "confmx, mean_hittime, var_hittime, truncate_rate = \\\n",
    "    binary_varthtruncated_sprt_with_llrs(\n",
    "        llrs_all, \n",
    "        labels_all,\n",
    "        alpha,\n",
    "        beta)\n",
    "\n",
    "bac = binary_confmx_to_bac(confmx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2-1. Loss function 2 (set goal)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calc straight thresholds\n",
    "init_alpha_base = 1e-4\n",
    "\n",
    "alpha = np.array([init_alpha_base]*duration)\n",
    "beta = alpha\n",
    "thresh = np.array([np.log(beta/(1-alpha)), np.log((1-beta)/alpha)])\n",
    "if not ((np.prod(np.int32(thresh[1] >= thresh[0])) != 0) and (np.prod(np.int32(thresh[1] * thresh[0] <= 0)) != 0) ):\n",
    "    raise ValueError(\"thresh must be thresh[1] >= thresh[0] and thresh[1] * thresh[0] < 0. Now thresh = {}\".format(thresh))\n",
    "    \n",
    "# Truncated SPRT with straight thresholds (Calc mean_hittime_base and bac_base)\n",
    "confmx_base, mean_hittime_base, var_hittime_base, truncate_rate_base = \\\n",
    "    binary_varthtruncated_sprt_with_llrs(\n",
    "        llrs_all, \n",
    "        labels_all,\n",
    "        alpha,\n",
    "        beta)\n",
    "\n",
    "bac_base = binary_confmx_to_bac(confmx_base)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loss function 2 (set goal)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.14758193492889404 0.0 0.14758193492889404\n"
     ]
    }
   ],
   "source": [
    "duration = 20\n",
    "m_bac = 0.01\n",
    "m_mht = 1 # frame base\n",
    "\n",
    "loss = threshold_loss(bac, mean_hittime, bac_base, mean_hittime_base, duration=duration, m_bac=m_bac, m_mht=m_mht)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(0.32759497, shape=(), dtype=float32)\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'loss1' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-30-8cf31db3d2c5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmean_hittime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtruncate_rate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'loss1' is not defined"
     ]
    }
   ],
   "source": [
    "print(loss)\n",
    "print(loss1)\n",
    "print(loss2)\n",
    "print(mean_hittime)\n",
    "print(truncate_rate)\n",
    "print(bac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.64547527"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bac.numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Neyman-Pearson Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NP test\n",
    "thresh = 0.\n",
    "confmx = binary_np_test(\n",
    "        llrs_all, \n",
    "        labels_all,\n",
    "        length=10,\n",
    "        thresh=thresh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=47803, shape=(), dtype=float64, numpy=0.7602133787161678>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confmx_to_metrics(confmx)[\"BAC\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
