{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "import numpy as np\n",
    "import math\n",
    "from torch.nn import init\n",
    "from torch.nn.utils import rnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LockedDropout(nn.Module):\n",
    "    def __init__(self, dropout):\n",
    "        super().__init__()\n",
    "        self.dropout = dropout\n",
    "\n",
    "    def forward(self, x):\n",
    "        dropout = self.dropout\n",
    "        if not self.training:\n",
    "            return x\n",
    "        m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - dropout)\n",
    "        mask = Variable(m.div_(1 - dropout), requires_grad=False)\n",
    "        mask = mask.expand_as(x)\n",
    "        return mask * x\n",
    "\n",
    "class BiAttention(nn.Module):\n",
    "    def __init__(self, input_size, dropout):\n",
    "        super().__init__()\n",
    "        self.dropout = LockedDropout(dropout)\n",
    "        self.input_linear = nn.Linear(input_size, 1, bias=False)\n",
    "        self.memory_linear = nn.Linear(input_size, 1, bias=False)\n",
    "\n",
    "        self.dot_scale = nn.Parameter(torch.Tensor(input_size).uniform_(1.0 / (input_size ** 0.5)))\n",
    "\n",
    "    def forward(self, input, memory, mask):\n",
    "        bsz, input_len, memory_len = input.size(0), input.size(1), memory.size(1)\n",
    "\n",
    "        input = self.dropout(input)\n",
    "        memory = self.dropout(memory)\n",
    "\n",
    "        input_dot = self.input_linear(input)\n",
    "        memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len)\n",
    "        cross_dot = torch.bmm(input * self.dot_scale, memory.permute(0, 2, 1).contiguous())\n",
    "        att = input_dot + memory_dot + cross_dot\n",
    "        att = att - 1e30 * (1 - mask[:,None])\n",
    "\n",
    "        weight_one = F.softmax(att, dim=-1)\n",
    "        output_one = torch.bmm(weight_one, memory)\n",
    "        weight_two = F.softmax(att.max(dim=-1)[0], dim=-1).view(bsz, 1, input_len)\n",
    "        output_two = torch.bmm(weight_two, input)\n",
    "\n",
    "        return torch.cat([input, output_one, input*output_one, output_two*output_one], dim=-1)\n",
    "    \n",
    "class BiAttentionSeq2Span(nn.Module):\n",
    "    def __init__(self, input_size, dropout):\n",
    "        super().__init__()\n",
    "        self.dropout = LockedDropout(dropout)\n",
    "        self.input_linear = nn.Linear(input_size, 1, bias=False)\n",
    "        self.memory_linear = nn.Linear(input_size, 1, bias=False)\n",
    "\n",
    "        self.dot_scale = nn.Parameter(torch.Tensor(input_size).uniform_(1.0 / (input_size ** 0.5)))\n",
    "\n",
    "    def forward(self, input, memory):\n",
    "        '''\n",
    "        memery always has shape (batch, 2, hidden_dim)\n",
    "        '''\n",
    "        bsz, input_len, memory_len = input.size(0), input.size(1), memory.size(1)\n",
    "\n",
    "#         print(input.size())\n",
    "#         print(memory.size())\n",
    "        input = self.dropout(input)\n",
    "        memory = self.dropout(memory)\n",
    "\n",
    "        input_dot = self.input_linear(input)\n",
    "        memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len)\n",
    "        cross_dot = torch.bmm(input * self.dot_scale, memory.permute(0, 2, 1).contiguous())\n",
    "        att = input_dot + memory_dot + cross_dot\n",
    "\n",
    "        weight_one = F.softmax(att, dim=-1)\n",
    "        output_one = torch.bmm(weight_one, memory)\n",
    "        weight_two = F.softmax(att.max(dim=-1)[0], dim=-1).view(bsz, 1, input_len)\n",
    "        output_two = torch.bmm(weight_two, input)\n",
    "\n",
    "        return torch.cat([input, output_one, input*output_one, output_two*output_one], dim=-1)\n",
    "\n",
    "class GateLayer(nn.Module):\n",
    "    def __init__(self, d_input, d_output):\n",
    "        super(GateLayer, self).__init__()\n",
    "        self.linear = nn.Linear(d_input, d_output)\n",
    "        self.gate = nn.Linear(d_input, d_output)\n",
    "        self.sigmoid = nn.Sigmoid()\n",
    "\n",
    "    def forward(self, input):\n",
    "        return self.linear(input) * self.sigmoid(self.gate(input))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncoderRNN(nn.Module):\n",
    "    def __init__(self, input_size, num_units, nlayers, concat, bidir, dropout, return_last):\n",
    "        super().__init__()\n",
    "        self.rnns = []\n",
    "        for i in range(nlayers):\n",
    "            if i == 0:\n",
    "                input_size_ = input_size\n",
    "                output_size_ = num_units\n",
    "            else:\n",
    "                input_size_ = num_units if not bidir else num_units * 2\n",
    "                output_size_ = num_units\n",
    "            self.rnns.append(nn.GRU(input_size_, output_size_, 1, bidirectional=bidir, batch_first=True))\n",
    "        self.rnns = nn.ModuleList(self.rnns)\n",
    "        self.init_hidden = nn.ParameterList([nn.Parameter(torch.Tensor(2 if bidir else 1, 1, num_units).zero_()) for _ in range(nlayers)])\n",
    "        self.dropout = LockedDropout(dropout)\n",
    "        self.concat = concat\n",
    "        self.nlayers = nlayers\n",
    "        self.return_last = return_last\n",
    "\n",
    "        # self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        for rnn in self.rnns:\n",
    "            for name, p in rnn.named_parameters():\n",
    "                if 'weight' in name:\n",
    "                    p.data.normal_(std=0.1)\n",
    "                else:\n",
    "                    p.data.zero_()\n",
    "\n",
    "    def get_init(self, bsz, i):\n",
    "        return self.init_hidden[i].expand(-1, bsz, -1).contiguous()\n",
    "\n",
    "    def forward(self, input, input_lengths=None):\n",
    "        bsz, slen = input.size(0), input.size(1)\n",
    "\n",
    "        outputs = []\n",
    "        if input_lengths is not None:\n",
    "            lens = input_lengths.data.cpu().numpy()\n",
    "            sort_idx = np.argsort(-lens)\n",
    "            idx_dict = {sort_idx[i_]: i_ for i_ in range(lens.shape[0])}\n",
    "            revert_idx = np.array([idx_dict[i_] for i_ in range(lens.shape[0])])\n",
    "            input = input[sort_idx, :]\n",
    "        output = input\n",
    "            \n",
    "        for i in range(self.nlayers):\n",
    "            hidden = self.get_init(bsz, i)\n",
    "            output = self.dropout(output)\n",
    "#             print(output.size())\n",
    "            if input_lengths is not None:\n",
    "                output = rnn.pack_padded_sequence(output, lens[sort_idx], batch_first=True)\n",
    "            output, hidden = self.rnns[i](output, hidden)\n",
    "#             print(output.size())\n",
    "            if input_lengths is not None:\n",
    "                output, _ = rnn.pad_packed_sequence(output, batch_first=True)\n",
    "                if output.size(1) < slen: # used for parallel\n",
    "                    padding = Variable(output.data.new(1, 1, 1).zero_())\n",
    "                    output = torch.cat([output, padding.expand(output.size(0), slen-output.size(1), output.size(2))], dim=1)\n",
    "            if self.return_last:\n",
    "                outputs.append(hidden.permute(1, 0, 2).contiguous().view(bsz, -1))\n",
    "            else:\n",
    "                outputs.append(output)\n",
    "        \n",
    "        if input_lengths is not None: \n",
    "            if self.concat:\n",
    "                return torch.cat(outputs, dim=2)[revert_idx,:]\n",
    "            return outputs[-1][revert_idx,:]\n",
    "        else:\n",
    "            if self.concat:\n",
    "                return torch.cat(outputs, dim=2)\n",
    "            return outputs[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HotpotReaderParaOnly(nn.Module):\n",
    "    def __init__(self, config, word_mat):\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "        self.word_dim = config.glove_dim\n",
    "        self.word_emb = nn.Embedding(len(word_mat), len(word_mat[0]), padding_idx=0)\n",
    "        self.word_emb.weight.data.copy_(torch.from_numpy(word_mat))\n",
    "        self.word_emb.weight.requires_grad = False\n",
    "\n",
    "        self.hidden = config.hidden\n",
    "\n",
    "        self.rnn = EncoderRNN(self.word_dim, config.hidden, 1, True, True, 1-config.keep_prob, False)\n",
    "\n",
    "        self.qc_att = BiAttention(config.hidden*2, 1-config.keep_prob)\n",
    "        self.linear_1 = nn.Sequential(\n",
    "                nn.Linear(config.hidden*8, config.hidden),\n",
    "                nn.ReLU()\n",
    "            )\n",
    "\n",
    "        self.rnn_2 = EncoderRNN(config.hidden, config.hidden, 1, False, True, 1-config.keep_prob, False)\n",
    "        self.self_att = BiAttention(config.hidden*2, 1-config.keep_prob)\n",
    "        self.linear_2 = nn.Sequential(\n",
    "                nn.Linear(config.hidden*8, config.hidden),\n",
    "                nn.ReLU()\n",
    "            )\n",
    "\n",
    "        self.cache_S = 0\n",
    "\n",
    "\n",
    "    def forward(self, ques_idxs, context_idxs, context_lens):\n",
    "        para_size, ques_size, bsz = context_idxs.size(1), ques_idxs.size(1), context_idxs.size(0)\n",
    "\n",
    "        context_mask = (context_idxs > 0).float()\n",
    "        ques_mask = (ques_idxs > 0).float()\n",
    "\n",
    "        context_word = self.word_emb(context_idxs)\n",
    "        ques_word = self.word_emb(ques_idxs)\n",
    "\n",
    "        context_output = self.rnn(context_word, context_lens)\n",
    "        ques_output = self.rnn(ques_word)\n",
    "\n",
    "        output = self.qc_att(context_output, ques_output, ques_mask)\n",
    "        output = self.linear_1(output)\n",
    "\n",
    "        output_t = self.rnn_2(output, context_lens)\n",
    "        output_t = self.self_att(output_t, output_t, context_mask)\n",
    "        output_t = self.linear_2(output_t)\n",
    "\n",
    "        output = output + output_t\n",
    "\n",
    "        output_start = output\n",
    "\n",
    "        output_start = self.rnn_start(output_start, context_lens)\n",
    "        logit1 = self.linear_start(output_start).squeeze(2) - 1e30 * (1 - context_mask)\n",
    "        output_end = torch.cat([output, output_start], dim=2)\n",
    "        output_end = self.rnn_end(output_end, context_lens)\n",
    "        logit2 = self.linear_end(output_end).squeeze(2) - 1e30 * (1 - context_mask)\n",
    "\n",
    "        return logit1, logit2"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
