{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "# from models import CnnModel, RnnModel\n",
    "import numpy as np\n",
    "\n",
    "import math\n",
    "from torch.autograd import Variable  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CnnModel(nn.Module):\n",
    "    \n",
    "    def __init__(self, args):\n",
    "        \"\"\"\n",
    "        args.hidden_dim -- dimension of filters\n",
    "        args.embedding_dim -- dimension of word embeddings\n",
    "        args.kernel_size -- kernel size of the conv1d\n",
    "        args.layer_num -- number of CNN layers\n",
    "        \"\"\"\n",
    "        super(CnnModel, self).__init__()\n",
    "\n",
    "        self.args = args\n",
    "        if args.kernel_size % 2 == 0:\n",
    "            raise ValueError(\"args.kernel_size should be an odd number\")\n",
    "            \n",
    "        self.conv_layers = nn.Sequential()\n",
    "        for i in range(args.layer_num):\n",
    "            if i == 0:\n",
    "                input_dim = args.embedding_dim\n",
    "            else:\n",
    "                input_dim = args.hidden_dim\n",
    "            self.conv_layers.add_module('conv_layer{:d}'.format(i), nn.Conv1d(in_channels=input_dim, \n",
    "                                                  out_channels=args.hidden_dim, kernel_size=args.kernel_size,\n",
    "                                                                             padding=(args.kernel_size-1)/2))\n",
    "            self.conv_layers.add_module('relu{:d}'.format(i), nn.ReLU())\n",
    "        \n",
    "    def forward(self, embeddings):\n",
    "        \"\"\"\n",
    "        Given input embeddings in shape of (batch_size, sequence_length, embedding_dim) generate a \n",
    "        sentence embedding tensor (batch_size, sequence_length, hidden_dim)\n",
    "        Inputs:\n",
    "            embeddings -- sequence of word embeddings, (batch_size, sequence_length, embedding_dim)\n",
    "        Outputs:\n",
    "            hiddens -- sentence embedding tensor, (batch_size, hidden_dim, sequence_length)       \n",
    "        \"\"\"\n",
    "        embeddings_ = embeddings.transpose(1, 2) #(batch_size, embedding_dim, sequence_length)\n",
    "        hiddens = self.conv_layers(embeddings_)\n",
    "        return hiddens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RnnModel(nn.Module):\n",
    "\n",
    "    def __init__(self, args):\n",
    "        \"\"\"\n",
    "        args.hidden_dim -- dimension of filters\n",
    "        args.embedding_dim -- dimension of word embeddings\n",
    "        args.layer_num -- number of RNN layers   \n",
    "        args.cell_type -- type of RNN cells, GRU or LSTM\n",
    "        \"\"\"\n",
    "        super(RnnModel, self).__init__()\n",
    "        \n",
    "        self.args = args\n",
    " \n",
    "        if args.cell_type == 'GRU':\n",
    "            self.rnn_layer = nn.GRU(input_size=args.embedding_dim, \n",
    "                                    hidden_size=args.hidden_dim//2, \n",
    "                                    num_layers=args.layer_num, bidirectional=True)\n",
    "        elif args.cell_type == 'LSTM':\n",
    "            self.rnn_layer = nn.LSTM(input_size=args.embedding_dim, \n",
    "                                     hidden_size=args.hidden_dim//2, \n",
    "                                     num_layers=args.layer_num, bidirectional=True)\n",
    "    \n",
    "    def forward(self, embeddings, mask=None):\n",
    "        \"\"\"\n",
    "        Inputs:\n",
    "            embeddings -- sequence of word embeddings, (batch_size, sequence_length, embedding_dim)\n",
    "            mask -- a float tensor of masks, (batch_size, length)\n",
    "        Outputs:\n",
    "            hiddens -- sentence embedding tensor, (batch_size, hidden_dim, sequence_length)\n",
    "        \"\"\"\n",
    "        embeddings_ = embeddings.transpose(0, 1) #(sequence_length, batch_size, embedding_dim)\n",
    "        \n",
    "        if mask is not None:\n",
    "            seq_lengths = list(torch.sum(mask, dim=1).cpu().data.numpy())\n",
    "            seq_lengths = [int(x) for x in seq_lengths]\n",
    "#             seq_lengths = map(int, seq_lengths)\n",
    "            inputs_ = torch.nn.utils.rnn.pack_padded_sequence(embeddings_, seq_lengths)\n",
    "        else:\n",
    "            inputs_ = embeddings_\n",
    "        \n",
    "        hidden, _ = self.rnn_layer(inputs_) #(sequence_length, batch_size, hidden_dim (* 2 if bidirectional))\n",
    "        \n",
    "        if mask is not None:\n",
    "            hidden, _ = torch.nn.utils.rnn.pad_packed_sequence(hidden) #(length, batch_size, hidden_dim)\n",
    "        \n",
    "        return hidden.permute(1, 2, 0) #(batch_size, hidden_dim, sequence_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    \n",
    "    def __init__(self, args):\n",
    "        \"\"\"\n",
    "        Inputs:\n",
    "        args.model_type -- \"CNN\" or \"RNN\"\n",
    "        if use CNN:\n",
    "            args.hidden_dim -- dimension of filters\n",
    "            args.embedding_dim -- dimension of word embeddings\n",
    "            args.kernel_size -- kernel size of the conv1d\n",
    "            args.layer_num -- number of CNN layers\n",
    "        if use RNN:\n",
    "            args.hidden_dim -- dimension of filters\n",
    "            args.embedding_dim -- dimension of word embeddings\n",
    "            args.layer_num -- number of RNN layers   \n",
    "            args.cell_type -- type of RNN cells, \"GRU\" or \"LSTM\"\n",
    "        \"\"\"\n",
    "        super(Encoder, self).__init__()\n",
    "        \n",
    "        self.args = args\n",
    "        \n",
    "        if args.model_type == \"CNN\":\n",
    "            self.encoder_model = CnnModel(args)\n",
    "        elif args.model_type == \"RNN\":\n",
    "            self.encoder_model = RnnModel(args)\n",
    "                \n",
    "    def forward(self, x, z, mask=None):\n",
    "        \"\"\"\n",
    "        Given input x in shape of (batch_size, sequence_length) generate a \n",
    "        regression value of each input\n",
    "        Inputs:\n",
    "            x -- input sequence of word embeddings, (batch_size, sequence_length, embedding_dim)\n",
    "            z -- input rationale, ``binary'' mask, (batch_size, sequence_length)\n",
    "        Outputs:\n",
    "            output -- hidden values at all time step, shape: (batch_size, hidden_dim, sequence_length) \n",
    "        \"\"\"\n",
    "        masked_input = x * z.unsqueeze(-1) #(batch_size, sequence_length, embedding_dim)        \n",
    "        hiddens = self.encoder_model(masked_input, mask) #(batch_size, hidden_dim, sequence_length)        \n",
    "        return hiddens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class ClassificationEncoder(Encoder):\n",
    "    \n",
    "    def __init__(self, args):\n",
    "        super(ClassificationEncoder, self).__init__(args)\n",
    "        self.num_classes = args.num_classes\n",
    "        self.output_layer = nn.Linear(args.hidden_dim, self.num_classes)\n",
    "                \n",
    "    def forward(self, x, z, mask=None):\n",
    "        \"\"\"\n",
    "        Given input x in shape of (batch_size, sequence_length) generate a \n",
    "        regression value of each input\n",
    "        Inputs:\n",
    "            x -- input sequence of word embeddings, (batch_size, sequence_length, embedding_dim)\n",
    "            z -- input rationale, ``binary'' mask, (batch_size, sequence_length)\n",
    "        Outputs:\n",
    "            output -- output of the regression value, a vector of size batch_size\n",
    "        \"\"\"\n",
    "        hiddens = super(ClassificationEncoder, self).forward(x, z, mask)\n",
    "        \n",
    "        if mask is not None:\n",
    "            neg_inf = -1.0e6\n",
    "            hiddens = hiddens + (1 - mask.unsqueeze(1)) * neg_inf \n",
    "        \n",
    "        max_hidden = torch.max(hiddens, -1)[0] #(batch_size, hidden_dim)        \n",
    "        output = self.output_layer(max_hidden)\n",
    "        \n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class MLPClassificationEncoder(nn.Module):\n",
    "    def __init__(self, args, num_classes=2):\n",
    "        \"\"\"\n",
    "        Inputs:\n",
    "        model_type -- \"Linear\" or \"MLP\"\n",
    "        if use MLP:\n",
    "            hidden_dim -- dimension of filters\n",
    "        \"\"\"\n",
    "        super(MLPClassificationEncoder, self).__init__()\n",
    "        \n",
    "        self.args = args\n",
    "        self.num_classes = num_classes        \n",
    "\n",
    "        self.hidden_layer = nn.Sequential()\n",
    "        self.hidden_layer.add_module('linear', nn.Linear(args.embedding_dim, args.hidden_dim))\n",
    "        self.hidden_layer.add_module('relu', nn.ReLU())\n",
    "        self.output_layer = nn.Linear(args.hidden_dim, self.num_classes)\n",
    "\n",
    "    def forward(self, x, z, mask=None):\n",
    "        \"\"\"\n",
    "        Given input x in shape of (batch_size, sequence_length) generate a\n",
    "        list of embeddings\n",
    "        Inputs:\n",
    "            x -- input sequence of word embeddings, (batch_size, sequence_length, embedding_dim)\n",
    "        Outputs:\n",
    "            output -- hidden values at all time step, shape: (batch_size, hidden_dim, sequence_length)\n",
    "        \"\"\"\n",
    "        masked_input = x * z.unsqueeze(-1)\n",
    "        hiddens = self.hidden_layer(masked_input) #(batch_size, sequence_length, hidden)\n",
    "\n",
    "        if mask is not None:\n",
    "            neg_inf = -1.0e6\n",
    "            hiddens = hiddens + (1 - mask.unsqueeze(-1)) * neg_inf \n",
    "        \n",
    "        max_hidden = torch.max(hiddens, dim=1)[0] #(batch_size, hidden_dim)            \n",
    "        output = self.output_layer(max_hidden).squeeze(-1) #(batch_size,)\n",
    "        return output"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
