{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "\n",
    "import os, sys\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.layers import Dense, BatchNormalization, Activation\n",
    "\n",
    "from datasets.data_processing import read_tfrecords_nosaic_mnist, decode_nosaic_mnist, sequential_slice, binarize_labels_nosaic_mnist, normalize_images_nosaic_mnist\n",
    "from models.backbones_lstm import LSTMModel\n",
    "from models.losses_v2 import get_loss_lstm\n",
    "from utils.misc import load_yaml, set_gpu_devices, fix_random_seed\n",
    "from utils.performance_metrics_stylish import multiplet_sequential_confmx, binary_llr_sequential_confmx, binary_truncated_sprt, confmx_to_metrics, seqconfmx_to_list_metrics, list_metrics_to_list_bac, dict_confmx_to_dict_metrics, run_truncated_sprt, calc_binary_llrs\n",
    "\n",
    "# GPU settings\n",
    "set_gpu_devices(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMModel(tf.keras.Model):\n",
    "    \"\"\"LSTM model with TF2.0.0 implementation.\n",
    "    Remark:\n",
    "        If you are to use the N-th-order SPRT, \n",
    "        inputs argument in __call__ must have shape (batch, N, feature dimension),\n",
    "        not (batch, duration (e.g.=20 for nosaic MNIST), feature dimension),\n",
    "        because the memory state is defined to be deleted after the inputs.shape[1] time step.\n",
    "        The reshape (batch, duration, feat dim) -> (batch*(duration-N+1), N, feat dim) can\n",
    "        be performed with datasets.data_processing.sequential_slice_nosaic_mnist().\n",
    "    \"\"\"\n",
    "    def __init__(self, nb_cls, width_lstm, dropout=0., activation=\"tanh\"):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            nb_cls: An int. The dimension of the output logit vectors.\n",
    "            width_lstm: An int. The width of LSTM hidden fc layer.\n",
    "            dropout: An float in [0, 1]. Dropout rate, not keep_prob.\n",
    "            activation: A string. For activation argument in tf.keras.layers.LSTM.\n",
    "                Note that\n",
    "                recurrent_activation (i.e., input, output, and forget gate activation) \n",
    "                in tf.keras.layers.LSTM is \"sigmoid\" by default and is fixed.\n",
    "        \"\"\"\n",
    "        super(LSTMModel, self).__init__(name=\"PeepholeLSTM_\")\n",
    "\n",
    "        # Parameters\n",
    "        self.nb_cls = nb_cls\n",
    "        self.width_lstm = width_lstm\n",
    "        self.dropout = dropout\n",
    "        self.activation = activation\n",
    "        \n",
    "        # Feature extraction fully-connected layer\n",
    "        self.fc_featext = Dense(self.width_lstm, activation=Activation(self.activation), use_bias=True)\n",
    "        self.bn_featext = BatchNormalization()\n",
    "        self.activation_featext = Activation(self.activation)\n",
    "\n",
    "        # LSTM cell\n",
    "        self.lstm_cell = tf.keras.experimental.PeepholeLSTMCell(\n",
    "            units=self.width_lstm,\n",
    "            activation=Activation(self.activation),\n",
    "            unit_forget_bias=True,\n",
    "            dropout=self.dropout,\n",
    "            recurrent_dropout=self.dropout)\n",
    "\n",
    "        # RNN\n",
    "        self.rnn = tf.keras.layers.RNN(\n",
    "            self.lstm_cell,\n",
    "            return_sequences=True,\n",
    "            return_state=True)\n",
    "\n",
    "        # Logit generation fully-connected layer\n",
    "        self.bn_logit = BatchNormalization()\n",
    "        self.activation_logit = Activation(self.activation)\n",
    "        self.fc_logit = Dense(nb_cls, activation=None, use_bias=False)\n",
    "    \n",
    "    def fc_bn_act_featext(self, x, training):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            x: A Tensor. Input feature with shape=(batch*duration, 784) for nosaic MNIST.\n",
    "        Return:\n",
    "            x: A Tensor. Logit with shape=(batch*duration, self.width_lstm)        \n",
    "        \"\"\"\n",
    "        x = self.fc_featext(x)\n",
    "        x = self.bn_featext(x, training=training)\n",
    "        x = self.activation_featext(x)\n",
    "        return x\n",
    "        \n",
    "    def bn_act_fc_logit(self, x, training):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            x: A Tensor. Output of LSTM with shape=(batch*duration, self.width_lstm).\n",
    "        Return:\n",
    "            x: A Tensor. Logit with shape=(batch*duration, self.nb_cls)        \n",
    "        \"\"\"\n",
    "        x = self.bn_logit(x, training=training)\n",
    "        x = self.activation_logit(x)\n",
    "        x = self.fc_logit(x)\n",
    "        return x\n",
    "\n",
    "    def call(self, inputs, training):\n",
    "        \"\"\"Calc logits.\n",
    "        Args:\n",
    "            inputs: A Tensor with shape=(batch, duration, feature dimension). E.g. (128, 20, 784) for nosaic MNIST.\n",
    "            training: A boolean. Training flag used in BatchNormalization and dropout.\n",
    "        Returns:\n",
    "            outputs: A Tensor with shape=(batch, duration, nb_cls).\n",
    "        \"\"\"\n",
    "        # Parameters\n",
    "        inputs_shape = inputs.shape \n",
    "        duration = inputs_shape[1] # 20 by default for nosaic MNIST\n",
    "\n",
    "        # Feature extraction\n",
    "        inputs_featext = tf.reshape(inputs, (-1,784)) # (B, T, 784) -> (BT, 784)\n",
    "        inputs_featext = self.fc_bn_act_featext(inputs_featext, training=training) # (BT, 784) -> (BT, self.width.lstm)\n",
    "        inputs_featext = tf.reshape(inputs_featext, (-1, duration, self.width_lstm)) # (BT, self.width_lstm) -> (B, T, self.width_lstm)\n",
    "\n",
    "        # Feedforward\n",
    "        outputs, _, _ = self.rnn(inputs_featext, training=training)\n",
    "\n",
    "        # Make logits\n",
    "        outputs = tf.reshape(outputs, (-1, self.width_lstm))\n",
    "        outputs = self.bn_act_fc_logit(outputs, training=training)\n",
    "        outputs = tf.reshape(outputs, (-1, duration, self.nb_cls)) # (B, T, nb_cls)\n",
    "\n",
    "        return outputs # A Tensor with shape=(batch, duration, nb_cls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1 \n",
    "duration = 20 \n",
    "feat_dim = 784\n",
    "inputs = np.float32(np.random.rand(batch_size, duration, feat_dim))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LSTMModel(nb_cls=2, width_lstm=256, dropout=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model(inputs, training=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=53105, shape=(1, 20, 2), dtype=float32, numpy=\n",
       "array([[[-2.1808226e-05, -6.4713859e-05],\n",
       "        [-8.7819194e-07, -5.7905949e-05],\n",
       "        [ 1.0402497e-05, -4.7697973e-05],\n",
       "        [ 1.5821472e-05, -3.6216345e-05],\n",
       "        [ 1.7641585e-05, -2.4853445e-05],\n",
       "        [ 1.7236704e-05, -1.4429903e-05],\n",
       "        [ 1.5462310e-05, -5.3561926e-06],\n",
       "        [ 1.2868548e-05,  2.2318188e-06],\n",
       "        [ 9.8215369e-06,  8.3692930e-06],\n",
       "        [ 6.5719855e-06,  1.3187805e-05],\n",
       "        [ 3.2944104e-06,  1.6864358e-05],\n",
       "        [ 1.0993799e-07,  1.9588446e-05],\n",
       "        [-2.8995350e-06,  2.1542070e-05],\n",
       "        [-5.6819445e-06,  2.2888966e-05],\n",
       "        [-8.2079923e-06,  2.3769831e-05],\n",
       "        [-1.0465878e-05,  2.4301398e-05],\n",
       "        [-1.2456623e-05,  2.4577792e-05],\n",
       "        [-1.4190584e-05,  2.4673149e-05],\n",
       "        [-1.5684229e-05,  2.4644336e-05],\n",
       "        [-1.6957800e-05,  2.4534304e-05]]], dtype=float32)>"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
