{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "from .encoder import Encoder, ClassificationEncoder, MLPClassificationEncoder\n",
    "# from utils.utils import l21_norm, continuity_loss_func, corner_detection, fused_sparsity_loss_batch\n",
    "from .basic_nlp_models import BasicNLPModel\n",
    "import numpy as np\n",
    "import copy\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncoderRNN(nn.Module):\n",
    "    def __init__(self, input_size, num_units, nlayers, concat, bidir, dropout, return_last):\n",
    "        super().__init__()\n",
    "        self.rnns = []\n",
    "        for i in range(nlayers):\n",
    "            if i == 0:\n",
    "                input_size_ = input_size\n",
    "                output_size_ = num_units\n",
    "            else:\n",
    "                input_size_ = num_units if not bidir else num_units * 2\n",
    "                output_size_ = num_units\n",
    "            self.rnns.append(nn.GRU(input_size_, output_size_, 1, bidirectional=bidir, batch_first=True))\n",
    "        self.rnns = nn.ModuleList(self.rnns)\n",
    "        self.init_hidden = nn.ParameterList([nn.Parameter(torch.Tensor(2 if bidir else 1, 1, num_units).zero_()) for _ in range(nlayers)])\n",
    "        self.dropout = LockedDropout(dropout)\n",
    "        self.concat = concat\n",
    "        self.nlayers = nlayers\n",
    "        self.return_last = return_last\n",
    "\n",
    "        # self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        for rnn in self.rnns:\n",
    "            for name, p in rnn.named_parameters():\n",
    "                if 'weight' in name:\n",
    "                    p.data.normal_(std=0.1)\n",
    "                else:\n",
    "                    p.data.zero_()\n",
    "\n",
    "    def get_init(self, bsz, i):\n",
    "        return self.init_hidden[i].expand(-1, bsz, -1).contiguous()\n",
    "\n",
    "    def forward(self, input, input_lengths=None):\n",
    "        bsz, slen = input.size(0), input.size(1)\n",
    "\n",
    "        outputs = []\n",
    "        if input_lengths is not None:\n",
    "            lens = input_lengths.data.cpu().numpy()\n",
    "            sort_idx = np.argsort(-lens)\n",
    "            idx_dict = {sort_idx[i_]: i_ for i_ in range(lens.shape[0])}\n",
    "            revert_idx = np.array([idx_dict[i_] for i_ in range(lens.shape[0])])\n",
    "            input = input[sort_idx, :]\n",
    "        output = input\n",
    "            \n",
    "        for i in range(self.nlayers):\n",
    "            hidden = self.get_init(bsz, i)\n",
    "            output = self.dropout(output)\n",
    "#             print(output.size())\n",
    "            if input_lengths is not None:\n",
    "                output = rnn.pack_padded_sequence(output, lens[sort_idx], batch_first=True)\n",
    "            output, hidden = self.rnns[i](output, hidden)\n",
    "#             print(output.size())\n",
    "            if input_lengths is not None:\n",
    "                output, _ = rnn.pad_packed_sequence(output, batch_first=True)\n",
    "                if output.size(1) < slen: # used for parallel\n",
    "                    padding = Variable(output.data.new(1, 1, 1).zero_())\n",
    "                    output = torch.cat([output, padding.expand(output.size(0), slen-output.size(1), output.size(2))], dim=1)\n",
    "            if self.return_last:\n",
    "                outputs.append(hidden.permute(1, 0, 2).contiguous().view(bsz, -1))\n",
    "            else:\n",
    "                outputs.append(output)\n",
    "        \n",
    "        if input_lengths is not None: \n",
    "            if self.concat:\n",
    "                return torch.cat(outputs, dim=2)[revert_idx,:]\n",
    "            return outputs[-1][revert_idx,:]\n",
    "        else:\n",
    "            if self.concat:\n",
    "                return torch.cat(outputs, dim=2)\n",
    "            return outputs[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RelationalMatchLSTMEncoder(nn.Module):\n",
    "    def __init__(self, args):\n",
    "        super(RelationalMatchLSTMEncoder, self).__init__()\n",
    "        # create an encoder layer\n",
    "        encoder_args = copy.deepcopy(args)\n",
    "        encoder_args.model_type = 'RNN'\n",
    "        encoder_args.input_dim = args.embedding_dim\n",
    "        encoder_args.layer_num = args.layer_num\n",
    "        \n",
    "        self.encoder = Encoder(encoder_args) \n",
    "\n",
    "        # create an output layer\n",
    "        match_lstm_args = copy.deepcopy(args)\n",
    "        match_lstm_args.model_type = 'RNN'\n",
    "        match_lstm_args.layer_num = 1\n",
    "        match_lstm_args.embedding_dim = match_lstm_args.hidden_dim*4\n",
    "        match_lstm_args.hidden_dim = args.mlstm_hidden_dim\n",
    "        match_lstm_args.num_classes = 1\n",
    "        \n",
    "        self.match_lstm = ClassificationEncoder(match_lstm_args)\n",
    "\n",
    "        \n",
    "    def _calculate_similarity_matrix(self, q_hiddens, p_hiddens):\n",
    "        \"\"\"\n",
    "        Inputs: \n",
    "            q_hiddens -- (batch_size, hidden_dim, sequence_length_q)                                                    \n",
    "            rw_hiddens -- (batch_size, hidden_dim, sequence_length_p)    \n",
    "        \"\"\"\n",
    "       \n",
    "        #(batch_size, q_length, p_length)\n",
    "        similarity_matrix = torch.bmm(q_hiddens.transpose(1, 2), p_hiddens)\n",
    "        \n",
    "        return similarity_matrix\n",
    "    \n",
    "    \n",
    "    def _get_r_matching_representations(self, q_hiddens, p_hiddens, similarity_matrix, mask_matrix=None):\n",
    "        \"\"\"                                                                                                             \n",
    "        This function takes the sequences of hidden states and the word-by-word attention matrix,                       \n",
    "        and returns                                                                                                     \n",
    "                                                                                                                        \n",
    "        Inputs:                                                                                                         \n",
    "            q_hiddens -- (batch_size, hidden_dim, sequence_length_q)                                                    \n",
    "            p_hiddens -- (batch_size, hidden_dim, sequence_length_p)                                                    \n",
    "            similarity_matrix -- (batch_size, sequence_length_q, sequence_length_p)     \n",
    "            z_matrix -- (batch_size, q_length, p_length)\n",
    "                                                                                                                        \n",
    "        Outputs:                                                                                                        \n",
    "            q_matching_states -- (batch_size, sequence_length_q， hidden_dim * 4)                                       \n",
    "        \"\"\"\n",
    "#         attention_softmax = F.softmax(similarity_matrix, dim=2) * z_matrix\n",
    "        neg_inf = -1.0e6\n",
    "        if mask_matrix is not None:\n",
    "            attention_softmax = F.softmax(similarity_matrix + (1 * mask_matrix) * neg_inf, dim=2)\n",
    "        else:\n",
    "            attention_softmax = F.softmax(similarity_matrix, dim=2)\n",
    "        \n",
    "        # shape: (batch_size, sequence_length_q, hidden_dim)                                                            \n",
    "        q_hiddens_tilda = torch.bmm(attention_softmax, p_hiddens.transpose(1, 2))\n",
    "        q_hiddens_ = q_hiddens.transpose(1, 2)\n",
    "\n",
    "        # shape: (batch_size, sequence_length_q, hidden_dim * 4)                                                        \n",
    "        q_matching_states = torch.cat([q_hiddens_, q_hiddens_tilda,\n",
    "                                      q_hiddens_ - q_hiddens_tilda,\n",
    "                                      q_hiddens_ * q_hiddens_tilda], dim=2)\n",
    "        return q_matching_states, attention_softmax\n",
    "        \n",
    "    def forward(self, q_embeddings, r_embeddings, z_q, z_rw, q_mask, r_mask):\n",
    "        \n",
    "        q_len = torch.sum(q_mask, dim=1).cpu().data.numpy()\n",
    "        \n",
    "        sort_idx = np.argsort(-q_len)\n",
    "        idx_dict = {sort_idx[i_]: i_ for i_ in range(q_len.shape[0])}\n",
    "        revert_idx = np.array([idx_dict[i_] for i_ in range(q_len.shape[0])])\n",
    "        \n",
    "        q_hiddens_sort_ = self.encoder(q_embeddings[sort_idx,:,:], z_q[sort_idx,:], q_mask[sort_idx,:]) #(batch_size, hidden_dim, q_length)\n",
    "        q_hiddens = q_hiddens_sort_[revert_idx, :, :].contiguous()\n",
    "#         q_hiddens = self.encoder(q_embeddings, z_q, q_mask)\n",
    "        \n",
    "#         print(q_hiddens.size())\n",
    "\n",
    "        r_mask_flat_ = r_mask.view(r_mask.size(0)*r_mask.size(1), -1) #(batch_size*num_r, rw_length)\n",
    "        r_embeddings_flat_ = r_embeddings.view(r_mask.size(0)*r_mask.size(1), r_mask.size(2), -1)\n",
    "        \n",
    "        r_len = torch.sum(r_mask_flat_, dim=1).cpu().data.numpy()\n",
    "        sort_idx = np.argsort(-r_len)\n",
    "        idx_dict = {sort_idx[i_]: i_ for i_ in range(r_len.shape[0])}\n",
    "        revert_idx = np.array([idx_dict[i_] for i_ in range(r_len.shape[0])])\n",
    "        \n",
    "        r_hiddens_sort_ = self.encoder(r_embeddings_flat_[sort_idx,:,:], \n",
    "                                   z_rw.view(z_rw.size(0)*z_rw.size(1), -1)[sort_idx,:],\n",
    "                                   r_mask_flat_[sort_idx,:]) #(batch*num_r, hidden_dim, rw_length)\n",
    "        \n",
    "        r_hiddens_ = r_hiddens_sort_[revert_idx,:,:].contiguous()\n",
    "        \n",
    "#         print(r_hiddens_.size())\n",
    "        \n",
    "#         rw_hiddens = rw_hiddens_.view(rw.size(0), rw.size(1), -1, rw.size(2)) #(batch, num_r, hidden_dim, rw_length)\n",
    "        \n",
    "        # expand to (batch*num_r, hidden_dim, q_length) and (batch*num_r, q_length)\n",
    "#         q_hiddens_expand_ = q_hiddens.unsqueeze(1).expand(q_mask.size(0), r_mask.size(1), -1, q_mask.size(1))\n",
    "#         q_hiddens_expand_ = q_hiddens_expand_.contiguous().view(q_mask.size(0)*r_mask.size(1), -1, q_mask.size(1))\n",
    "#         q_mask_expand_ = q_mask.unsqueeze(1).expand(q_mask.size(0), r_mask.size(1), q_mask.size(1))\n",
    "#         q_mask_expand_ = q_mask_expand_.contiguous().view(-1, q_mask.size(1))\n",
    "        \n",
    "        q_hiddens_expand_ = q_hiddens.unsqueeze(1).expand(q_mask.size(0), r_mask.size(1), -1, q_hiddens.size(2))\n",
    "        q_hiddens_expand_ = q_hiddens_expand_.contiguous().view(q_mask.size(0)*r_mask.size(1), -1, q_hiddens.size(2))\n",
    "        q_mask_expand_ = torch.ones(q_mask.size(0) * r_mask.size(1), q_hiddens.size(2)).cuda()\n",
    "\n",
    "    \n",
    "        # generate word-by-word similarity between q and p, (batch_size*num_r, rw_length, q_length)            \n",
    "        similarity_matrix = self._calculate_similarity_matrix(r_hiddens_, q_hiddens_expand_)\n",
    "        mask_matrix = torch.bmm(r_mask_flat_.unsqueeze(2), q_mask_expand_.unsqueeze(1))\n",
    "        \n",
    "        # (batch_size*num_r, rw_length， hidden_dim * 4)                                                              \n",
    "        r_matching_states, _ = self._get_r_matching_representations(r_hiddens_, q_hiddens_expand_,\n",
    "                                                                 similarity_matrix, mask_matrix)\n",
    "        # (batch_size*num_r,)\n",
    "#         predict = self.match_lstm(r_matching_states, \n",
    "#                                   z_rw.view(z_rw.size(0)*z_rw.size(1), -1),\n",
    "#                                   r_mask_flat_)\n",
    "#         predict = predict.view(r_mask.size(0), r_mask.size(1)) # (batch, num_r)\n",
    "\n",
    "        predict_sort_ = self.match_lstm(r_matching_states[sort_idx,:,:], \n",
    "                                  z_rw.view(z_rw.size(0)*z_rw.size(1), -1)[sort_idx,:],\n",
    "                                  r_mask_flat_[sort_idx,:])\n",
    "        predict = predict_sort_[revert_idx,:].contiguous().view(r_mask.size(0), r_mask.size(1)) # (batch, num_r)\n",
    "        \n",
    "        return predict\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MatchLSTMEncoder(nn.Module):\n",
    "    def __init__(self, args):\n",
    "        super(RelationalMatchLSTMEncoder, self).__init__()\n",
    "        # create an encoder layer\n",
    "        encoder_args = copy.deepcopy(args)\n",
    "        encoder_args.model_type = 'RNN'\n",
    "        encoder_args.input_dim = args.embedding_dim\n",
    "        encoder_args.layer_num = args.layer_num\n",
    "        \n",
    "        self.encoder = Encoder(encoder_args) \n",
    "\n",
    "        # create an output layer\n",
    "        match_lstm_args = copy.deepcopy(args)\n",
    "        match_lstm_args.model_type = 'RNN'\n",
    "        match_lstm_args.layer_num = 1\n",
    "        match_lstm_args.embedding_dim = match_lstm_args.hidden_dim*4\n",
    "        match_lstm_args.hidden_dim = args.mlstm_hidden_dim\n",
    "        match_lstm_args.num_classes = 1\n",
    "        \n",
    "        self.match_lstm = ClassificationEncoder(match_lstm_args)\n",
    "\n",
    "        \n",
    "    def _calculate_similarity_matrix(self, q_hiddens, p_hiddens):\n",
    "        \"\"\"\n",
    "        Inputs: \n",
    "            q_hiddens -- (batch_size, hidden_dim, sequence_length_q)                                                    \n",
    "            rw_hiddens -- (batch_size, hidden_dim, sequence_length_p)    \n",
    "        \"\"\"\n",
    "       \n",
    "        #(batch_size, q_length, p_length)\n",
    "        similarity_matrix = torch.bmm(q_hiddens.transpose(1, 2), p_hiddens)\n",
    "        \n",
    "        return similarity_matrix\n",
    "    \n",
    "    \n",
    "    def _get_matching_representations(self, q_hiddens, p_hiddens, similarity_matrix, mask_matrix=None):\n",
    "        \"\"\"                                                                                                             \n",
    "        This function takes the sequences of hidden states and the word-by-word attention matrix,                       \n",
    "        and returns                                                                                                     \n",
    "                                                                                                                        \n",
    "        Inputs:                                                                                                         \n",
    "            q_hiddens -- (batch_size, hidden_dim, sequence_length_q)                                                    \n",
    "            p_hiddens -- (batch_size, hidden_dim, sequence_length_p)                                                    \n",
    "            similarity_matrix -- (batch_size, sequence_length_q, sequence_length_p)     \n",
    "            z_matrix -- (batch_size, q_length, p_length)\n",
    "                                                                                                                        \n",
    "        Outputs:                                                                                                        \n",
    "            q_matching_states -- (batch_size, sequence_length_q， hidden_dim * 4)                                       \n",
    "        \"\"\"\n",
    "#         attention_softmax = F.softmax(similarity_matrix, dim=2) * z_matrix\n",
    "        neg_inf = -1.0e6\n",
    "        if mask_matrix is not None:\n",
    "            attention_softmax = F.softmax(similarity_matrix + (1 * mask_matrix) * neg_inf, dim=2)\n",
    "        else:\n",
    "            attention_softmax = F.softmax(similarity_matrix, dim=2)\n",
    "        \n",
    "        # shape: (batch_size, sequence_length_q, hidden_dim)                                                            \n",
    "        q_hiddens_tilda = torch.bmm(attention_softmax, p_hiddens.transpose(1, 2))\n",
    "        q_hiddens_ = q_hiddens.transpose(1, 2)\n",
    "\n",
    "        # shape: (batch_size, sequence_length_q, hidden_dim * 4)                                                        \n",
    "        q_matching_states = torch.cat([q_hiddens_, q_hiddens_tilda,\n",
    "                                      q_hiddens_ - q_hiddens_tilda,\n",
    "                                      q_hiddens_ * q_hiddens_tilda], dim=2)\n",
    "        return q_matching_states, attention_softmax\n",
    "        \n",
    "    def forward(self, q_embeddings, r_embeddings, z_q, z_rw, q_mask, r_mask):\n",
    "        \n",
    "        q_len = torch.sum(q_mask, dim=1).cpu().data.numpy()\n",
    "        \n",
    "        q_sort_idx = np.argsort(-q_len)\n",
    "        q_idx_dict = {q_sort_idx[i_]: i_ for i_ in range(q_len.shape[0])}\n",
    "        q_revert_idx = np.array([q_idx_dict[i_] for i_ in range(q_len.shape[0])])\n",
    "        \n",
    "        q_hiddens_sort_ = self.encoder(q_embeddings[q_sort_idx,:,:], z_q[q_sort_idx,:], \n",
    "                                       q_mask[q_sort_idx,:]) #(batch_size, hidden_dim, q_length)\n",
    "        q_hiddens = q_hiddens_sort_[q_revert_idx, :, :].contiguous()\n",
    "#         q_hiddens = self.encoder(q_embeddings, z_q, q_mask)\n",
    "        \n",
    "#         print(q_hiddens.size())\n",
    "\n",
    "        r_mask_flat_ = r_mask.view(r_mask.size(0)*r_mask.size(1), -1) #(B1*B2, L2)\n",
    "        r_embeddings_flat_ = r_embeddings.view(r_mask.size(0)*r_mask.size(1), r_mask.size(2), -1)\n",
    "        \n",
    "        r_len = torch.sum(r_mask_flat_, dim=1).cpu().data.numpy()\n",
    "        sort_idx = np.argsort(-r_len)\n",
    "        idx_dict = {sort_idx[i_]: i_ for i_ in range(r_len.shape[0])}\n",
    "        revert_idx = np.array([idx_dict[i_] for i_ in range(r_len.shape[0])])\n",
    "        \n",
    "        r_hiddens_sort_ = self.encoder(r_embeddings_flat_[sort_idx,:,:], \n",
    "                                   z_rw.view(z_rw.size(0)*z_rw.size(1), -1)[sort_idx,:],\n",
    "                                   r_mask_flat_[sort_idx,:]) #(B1*B2, H, L2)\n",
    "        \n",
    "        r_hiddens_ = r_hiddens_sort_[revert_idx,:,:].contiguous()\n",
    "        \n",
    "        q_hiddens_expand_ = q_hiddens.unsqueeze(1).expand(q_mask.size(0), r_mask.size(1), -1, q_hiddens.size(2))\n",
    "        q_hiddens_expand_ = q_hiddens_expand_.contiguous().view(q_mask.size(0)*r_mask.size(1), -1, q_hiddens.size(2))\n",
    "        q_mask_expand_ = torch.ones(q_mask.size(0) * r_mask.size(1), q_hiddens.size(2)).cuda() #(B1*B2, L1)\n",
    "\n",
    "        qe_len = torch.sum(q_mask_expand_, dim=1).cpu().data.numpy()\n",
    "        qe_sort_idx = np.argsort(-qe_len)\n",
    "        qe_idx_dict = {qe_sort_idx[i_]: i_ for i_ in range(qe_len.shape[0])}\n",
    "        qe_revert_idx = np.array([qe_idx_dict[i_] for i_ in range(qe_len.shape[0])])\n",
    "    \n",
    "        # generate word-by-word similarity between q and p, (B1*B2, L1, L2)\n",
    "        similarity_matrix = self._calculate_similarity_matrix(q_hiddens_expand_, r_hiddens_)\n",
    "        mask_matrix = torch.bmm(q_mask_expand_.unsqueeze(2), r_mask_flat_.unsqueeze(1)) #(B1*B2, L1, L2)\n",
    "        \n",
    "        # (B1*B2, L1， hidden_dim * 4)                                                              \n",
    "        q_matching_states, _ = self._get_matching_representations(r_hiddens_, q_hiddens_expand_,\n",
    "                                                                 similarity_matrix, mask_matrix)\n",
    "\n",
    "        \n",
    "        \n",
    "        predict_sort_ = self.match_lstm(r_matching_states[qe_sort_idx,:,:], \n",
    "                                  z_rw.view(z_rw.size(0)*z_rw.size(1), -1)[qe_sort_idx,:],\n",
    "                                  q_mask_expand_[qe_sort_idx,:])\n",
    "        predict = predict_sort_[qe_revert_idx,:].contiguous().view(r_mask.size(0), r_mask.size(1)) # (B1, B2)\n",
    "        \n",
    "        return predict\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MatchLSTMRankingModel(BasicNLPModel):\n",
    "    def __init__(self, words_lookup, args):\n",
    "        super(MatchLSTMRankingModel, self).__init__(words_lookup, args)\n",
    "        \n",
    "        self.num_classes = args.num_classes\n",
    "        self.hidden_dim = args.hidden_dim\n",
    "        self.is_cuda = args.cuda\n",
    "        self.model_type = args.model_type\n",
    "\n",
    "        # create an encoder layer\n",
    "        self.match_lstm = RelationalMatchLSTMEncoder(args)\n",
    "#         self.match_lstm = MatchLSTMEncoder(args)\n",
    "        \n",
    "        # create a loss function \n",
    "        self.loss_func = nn.CrossEntropyLoss(reduce=False)\n",
    "        \n",
    "        \n",
    "    def _create_embed_layer(self, vocab_size, emb_dim):\n",
    "        embed_layer = nn.Embedding(vocab_size, emb_dim)\n",
    "        embed_layer.weight.requires_grad = True\n",
    "        return embed_layer\n",
    "        \n",
    "        \n",
    "    def forward(self, q, r, q_mask, r_mask):\n",
    "        \n",
    "        q_embeddings = self.embedding_layer(q) #(batch_size, q_length, embedding_dim)\n",
    "            \n",
    "        rw_embeddings = self.embedding_layer(r) #(batch_size, num_r, r_length, embedding_dim)         \n",
    "        \n",
    "        z_q = torch.ones_like(q).float().cuda() # (batch, q_length)\n",
    "        z_r = torch.ones_like(r).float().cuda() # (batch, num_r, p_length)\n",
    "        \n",
    "        predict = self.match_lstm(q_embeddings, rw_embeddings, z_q, z_r, q_mask, r_mask)\n",
    "        \n",
    "        return predict\n",
    "    \n",
    "    def loss(self, predict, label):\n",
    "        prediction_loss = self.loss_func(predict, label) # (batch_size, )\n",
    "        supervised_loss = torch.mean(prediction_loss)\n",
    "                \n",
    "        return supervised_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MatchLSTMForwardRankingModel(BasicNLPModel):\n",
    "    def __init__(self, words_lookup, args):\n",
    "        super(MatchLSTMForwardRankingModel, self).__init__(words_lookup, args)\n",
    "        \n",
    "        self.num_classes = args.num_classes\n",
    "        self.hidden_dim = args.hidden_dim\n",
    "        self.is_cuda = args.cuda\n",
    "        self.model_type = args.model_type\n",
    "\n",
    "        # create an encoder layer\n",
    "        self.match_lstm = RelationalMatchLSTMEncoder(args)\n",
    "#         self.match_lstm = MatchLSTMEncoder(args)\n",
    "        \n",
    "        # create a loss function \n",
    "        self.loss_func = nn.CrossEntropyLoss(reduce=False)\n",
    "        \n",
    "        \n",
    "    def _create_embed_layer(self, vocab_size, emb_dim):\n",
    "        embed_layer = nn.Embedding(vocab_size, emb_dim)\n",
    "        embed_layer.weight.requires_grad = True\n",
    "        return embed_layer\n",
    "        \n",
    "        \n",
    "    def forward(self, i, o, i_mask, o_mask):\n",
    "        \n",
    "        i_embeddings = self.embedding_layer(i) #(B1, L1, E)\n",
    "#         print(i_embeddings.size())\n",
    "            \n",
    "        o_embeddings = self.embedding_layer(o) #(B2, L2, E)\n",
    "        o_embeddings = o_embeddings.unsqueeze(0).expand(i.size(0), o.size(0),\n",
    "                                                       o.size(1), o_embeddings.size(2)).contiguous() # (B1, B2, L2, E)\n",
    "        \n",
    "        z_i = torch.ones_like(i).float().cuda() # (B1, L)\n",
    "        z_o = torch.ones_like(o).float().cuda() # (B2, L)\n",
    "        z_o = z_o.unsqueeze(0).expand(i.size(0), o.size(0), o.size(1)).contiguous() # (B1, B2, L2)\n",
    "        o_mask = o_mask.unsqueeze(0).expand(i.size(0), o.size(0), o.size(1)).contiguous() # (B1, B2, L2)\n",
    "        \n",
    "        predict = self.match_lstm(i_embeddings, o_embeddings, z_i, z_o, i_mask, o_mask)\n",
    "        \n",
    "        return predict\n",
    "    \n",
    "    def loss(self, predict, label):\n",
    "        prediction_loss = self.loss_func(predict, label) # (B,)\n",
    "        supervised_loss = torch.mean(prediction_loss)\n",
    "                \n",
    "        return supervised_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MatchLSTMForwardDiffRankingModel(BasicNLPModel):\n",
    "    def __init__(self, words_lookup, args):\n",
    "        super(MatchLSTMForwardDiffRankingModel, self).__init__(words_lookup, args)\n",
    "        \n",
    "        self.num_classes = args.num_classes\n",
    "        self.hidden_dim = args.hidden_dim\n",
    "        self.is_cuda = args.cuda\n",
    "        self.model_type = args.model_type\n",
    "        \n",
    "        self.state_linear = nn.Sequential(\n",
    "                nn.Linear(config.hidden, config.hidden),\n",
    "                nn.ReLU()\n",
    "            )\n",
    "        \n",
    "        self.next_state_linear = nn.Sequential(\n",
    "                nn.Linear(config.hidden, config.hidden),\n",
    "                nn.ReLU()\n",
    "            )\n",
    "\n",
    "        # create an encoder layer\n",
    "#         self.match_lstm = RelationalMatchLSTMEncoder(args)\n",
    "        self.match_lstm = MatchLSTMEncoder(args)\n",
    "        \n",
    "        # create a loss function \n",
    "        self.loss_func = nn.CrossEntropyLoss(reduce=False)\n",
    "        \n",
    "        \n",
    "    def _create_embed_layer(self, vocab_size, emb_dim):\n",
    "        embed_layer = nn.Embedding(vocab_size, emb_dim)\n",
    "        embed_layer.weight.requires_grad = True\n",
    "        return embed_layer\n",
    "        \n",
    "        \n",
    "    def forward(self, i, o, i_mask, o_mask):\n",
    "        \n",
    "        i_embeddings = self.embedding_layer(i) #(B1, L1, E)\n",
    "#         print(i_embeddings.size())\n",
    "            \n",
    "        o_embeddings = self.embedding_layer(o) #(B2, L2, E)\n",
    "        o_embeddings = o_embeddings.unsqueeze(0).expand(i.size(0), o.size(0),\n",
    "                                                       o.size(1), o_embeddings.size(2)).contiguous() # (B1, B2, L2, E)\n",
    "        \n",
    "        z_i = torch.ones_like(i).float().cuda() # (B1, L)\n",
    "        z_o = torch.ones_like(o).float().cuda() # (B2, L)\n",
    "        z_o = z_o.unsqueeze(0).expand(i.size(0), o.size(0), o.size(1)).contiguous() # (B1, B2, L2)\n",
    "        o_mask = o_mask.unsqueeze(0).expand(i.size(0), o.size(0), o.size(1)).contiguous() # (B1, B2, L2)\n",
    "        \n",
    "        predict = self.match_lstm(i_embeddings, o_embeddings, z_i, z_o, i_mask, o_mask)\n",
    "        \n",
    "        return predict\n",
    "    \n",
    "    def loss(self, predict, label):\n",
    "        prediction_loss = self.loss_func(predict, label) # (B,)\n",
    "        supervised_loss = torch.mean(prediction_loss)\n",
    "                \n",
    "        return supervised_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DifferentiateMatchLSTMEncoder(RelationalMatchLSTMEncoder):\n",
    "    def __init__(self, args):\n",
    "        super(DifferentiateMatchLSTMEncoder, self).__init__(args)\n",
    "        # create an diff layer\n",
    "        match_lstm_args = copy.deepcopy(args)\n",
    "        match_lstm_args.model_type = 'RNN'\n",
    "        match_lstm_args.layer_num = 1\n",
    "        match_lstm_args.embedding_dim = match_lstm_args.hidden_dim*4\n",
    "        \n",
    "        self.diff_match_lstm = Encoder(match_lstm_args)\n",
    "        \n",
    "#         self.linear_proj_qp = nn.Sequential(\n",
    "#             nn.Linear(args.hidden_dim*4, args.hidden_dim),\n",
    "#             nn.ReLU()\n",
    "#         )\n",
    "        \n",
    "    def forward(self, q_embeddings, p_embeddings, r_embeddings, z_q, z_p, z_rw, q_mask, p_mask, r_mask):\n",
    "        # q encoding\n",
    "        q_len = torch.sum(q_mask, dim=1).cpu().data.numpy()\n",
    "        sort_idx = np.argsort(-q_len)\n",
    "        idx_dict = {sort_idx[i_]: i_ for i_ in range(q_len.shape[0])}\n",
    "        revert_idx = np.array([idx_dict[i_] for i_ in range(q_len.shape[0])])\n",
    "        \n",
    "        q_hiddens_sort_ = self.encoder(q_embeddings[sort_idx,:,:], z_q[sort_idx,:], q_mask[sort_idx,:]) #(batch_size, hidden_dim, q_length)\n",
    "        q_hiddens = q_hiddens_sort_[revert_idx, :, :].contiguous()\n",
    "        \n",
    "        # p encoding\n",
    "        p_len = torch.sum(p_mask, dim=1).cpu().data.numpy()\n",
    "        p_sort_idx = np.argsort(-p_len)\n",
    "        p_idx_dict = {p_sort_idx[i_]: i_ for i_ in range(p_len.shape[0])}\n",
    "        p_revert_idx = np.array([p_idx_dict[i_] for i_ in range(p_len.shape[0])])\n",
    "        \n",
    "        p_hiddens_sort_ = self.encoder(p_embeddings[p_sort_idx,:,:], z_p[p_sort_idx,:], p_mask[p_sort_idx,:])\n",
    "        p_hiddens = p_hiddens_sort_[p_revert_idx, :, :].contiguous()\n",
    "\n",
    "        # r encoding\n",
    "        r_mask_flat_ = r_mask.view(r_mask.size(0)*r_mask.size(1), -1) #(batch_size*num_r, rw_length)\n",
    "        r_embeddings_flat_ = r_embeddings.view(r_mask.size(0)*r_mask.size(1), r_mask.size(2), -1)\n",
    "        \n",
    "        r_len = torch.sum(r_mask_flat_, dim=1).cpu().data.numpy()\n",
    "        sort_idx = np.argsort(-r_len)\n",
    "        idx_dict = {sort_idx[i_]: i_ for i_ in range(r_len.shape[0])}\n",
    "        revert_idx = np.array([idx_dict[i_] for i_ in range(r_len.shape[0])])\n",
    "        \n",
    "        r_hiddens_sort_ = self.encoder(r_embeddings_flat_[sort_idx,:,:], \n",
    "                                   z_rw.view(z_rw.size(0)*z_rw.size(1), -1)[sort_idx,:],\n",
    "                                   r_mask_flat_[sort_idx,:]) #(batch*num_r, hidden_dim, rw_length)\n",
    "        \n",
    "        r_hiddens_ = r_hiddens_sort_[revert_idx,:,:].contiguous()\n",
    "        \n",
    "        # generate word-by-word similarity between q and p, (batch_size, p_length, q_length)\n",
    "        pq_similarity_matrix = self._calculate_similarity_matrix(p_hiddens, q_hiddens)\n",
    "        pq_mask_matrix = torch.bmm(p_mask.unsqueeze(2), q_mask.unsqueeze(1))\n",
    "        \n",
    "        # (batch_size, p_length, hidden_dim * 4)                                                              \n",
    "        p_matching_states, _ = self._get_r_matching_representations(p_hiddens, q_hiddens,\n",
    "                                                                 pq_similarity_matrix, pq_mask_matrix)\n",
    "        \n",
    "#         p_matching_states = self.linear_proj_qp(p_matching_states).transpose(1, 2).contiguous() # (batch_size, hidden_dim, p_length)\n",
    "#         p_matching_states += p_hiddens\n",
    "        p_matching_states_sort_ = self.diff_match_lstm(p_matching_states[p_sort_idx,:,:],\n",
    "                                                      z_p[p_sort_idx,:], p_mask[p_sort_idx,:]) \n",
    "        p_matching_states = p_matching_states_sort_[p_revert_idx, :, :].contiguous() # (batch_size, hidden_dim, p_length)\n",
    "        p_matching_states += p_hiddens\n",
    "        \n",
    "        # expand to (batch*num_r, hidden_dim, p_length) and (batch*num_r, p_length)\n",
    "        p_matching_expand_ = p_matching_states.unsqueeze(1).expand(p_mask.size(0), r_mask.size(1), -1, p_mask.size(1))\n",
    "        p_matching_expand_ = p_matching_expand_.contiguous().view(p_mask.size(0)*r_mask.size(1), -1, p_mask.size(1))\n",
    "        p_mask_expand_ = p_mask.unsqueeze(1).expand(p_mask.size(0), r_mask.size(1), p_mask.size(1))\n",
    "        p_mask_expand_ = p_mask_expand_.contiguous().view(-1, p_mask.size(1))\n",
    "\n",
    "    \n",
    "        # generate word-by-word similarity between q and p, (batch_size*num_r, rw_length, q_length)            \n",
    "        similarity_matrix = self._calculate_similarity_matrix(r_hiddens_, p_matching_expand_)\n",
    "        mask_matrix = torch.bmm(r_mask_flat_.unsqueeze(2), p_mask_expand_.unsqueeze(1))\n",
    "        \n",
    "        # (batch_size*num_r, rw_length， hidden_dim * 4)                                                              \n",
    "        r_matching_states, _ = self._get_r_matching_representations(r_hiddens_, p_matching_expand_,\n",
    "                                                                 similarity_matrix, mask_matrix)\n",
    "\n",
    "        predict_sort_ = self.match_lstm(r_matching_states[sort_idx,:,:], \n",
    "                                  z_rw.view(z_rw.size(0)*z_rw.size(1), -1)[sort_idx,:],\n",
    "                                  r_mask_flat_[sort_idx,:])\n",
    "        predict = predict_sort_[revert_idx,:].contiguous().view(r_mask.size(0), r_mask.size(1)) # (batch, num_r)\n",
    "        \n",
    "        return predict\n",
    "        \n",
    "class DifferentiateMatchLSTMRankingModel(BasicNLPModel):\n",
    "    def __init__(self, words_lookup, args):\n",
    "        super(DifferentiateMatchLSTMRankingModel, self).__init__(words_lookup, args)\n",
    "        \n",
    "        self.num_classes = args.num_classes\n",
    "        self.hidden_dim = args.hidden_dim\n",
    "        self.is_cuda = args.cuda\n",
    "        self.model_type = args.model_type\n",
    "\n",
    "        # create an encoder layer\n",
    "        self.match_lstm = DifferentiateMatchLSTMEncoder(args)\n",
    "        \n",
    "        # create a loss function \n",
    "        self.loss_func = nn.CrossEntropyLoss(reduce=False)\n",
    "        \n",
    "        \n",
    "    def _create_embed_layer(self, vocab_size, emb_dim):\n",
    "        embed_layer = nn.Embedding(vocab_size, emb_dim)\n",
    "        embed_layer.weight.requires_grad = True\n",
    "        return embed_layer\n",
    "        \n",
    "        \n",
    "    def forward(self, q, p, r, q_mask, p_mask, r_mask):\n",
    "        \n",
    "        q_embeddings = self.embedding_layer(q) #(batch_size, q_length, embedding_dim)\n",
    "        p_embeddings = self.embedding_layer(p) #(batch_size, p_length, embedding_dim)\n",
    "        rw_embeddings = self.embedding_layer(r) #(batch_size, num_r, r_length, embedding_dim)         \n",
    "        \n",
    "        z_q = torch.ones_like(q).float().cuda() # (batch, q_length)\n",
    "        z_p = torch.ones_like(p).float().cuda() # (batch, p_length)\n",
    "        z_r = torch.ones_like(r).float().cuda() # (batch, num_r, r_length)\n",
    "        \n",
    "        predict = self.match_lstm(q_embeddings, p_embeddings, rw_embeddings, \n",
    "                                  z_q, z_p, z_r, q_mask, p_mask, r_mask)\n",
    "        \n",
    "        return predict\n",
    "    \n",
    "    def loss(self, predict, label):\n",
    "        prediction_loss = self.loss_func(predict, label) # (batch_size, )\n",
    "        supervised_loss = torch.mean(prediction_loss)\n",
    "                \n",
    "        return supervised_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CoMatchLSTMEncoder(RelationalMatchLSTMEncoder):\n",
    "    def __init__(self, args):\n",
    "        super(CoMatchLSTMEncoder, self).__init__(args)\n",
    "        # create an diff layer\n",
    "        \n",
    "#         self.linear_proj_rq = nn.Sequential(\n",
    "#                 nn.Linear(args.hidden_dim*4, args.hidden_dim),\n",
    "#                 nn.ReLU()\n",
    "#             )\n",
    "#         self.linear_proj_rp = nn.Sequential(\n",
    "#                 nn.Linear(args.hidden_dim*4, args.hidden_dim),\n",
    "#                 nn.ReLU()\n",
    "#             )\n",
    "        \n",
    "        match_lstm_args = copy.deepcopy(args)\n",
    "        match_lstm_args.model_type = 'RNN'\n",
    "        match_lstm_args.layer_num = 1\n",
    "        match_lstm_args.embedding_dim = match_lstm_args.hidden_dim * 4\n",
    "        match_lstm_args.hidden_dim = args.mlstm_hidden_dim\n",
    "        match_lstm_args.num_classes = 1\n",
    "        self.match_lstm = ClassificationEncoder(match_lstm_args)\n",
    "        \n",
    "    def forward(self, q_embeddings, p_embeddings, r_embeddings, z_q, z_p, z_rw, q_mask, p_mask, r_mask):\n",
    "        # q encoding\n",
    "        q_len = torch.sum(q_mask, dim=1).cpu().data.numpy()\n",
    "        sort_idx = np.argsort(-q_len)\n",
    "        idx_dict = {sort_idx[i_]: i_ for i_ in range(q_len.shape[0])}\n",
    "        revert_idx = np.array([idx_dict[i_] for i_ in range(q_len.shape[0])])\n",
    "        \n",
    "        q_hiddens_sort_ = self.encoder(q_embeddings[sort_idx,:,:], z_q[sort_idx,:], q_mask[sort_idx,:]) #(batch_size, hidden_dim, q_length)\n",
    "        q_hiddens = q_hiddens_sort_[revert_idx, :, :].contiguous()\n",
    "        \n",
    "        # p encoding\n",
    "        p_len = torch.sum(p_mask, dim=1).cpu().data.numpy()\n",
    "        p_sort_idx = np.argsort(-p_len)\n",
    "        p_idx_dict = {p_sort_idx[i_]: i_ for i_ in range(p_len.shape[0])}\n",
    "        p_revert_idx = np.array([p_idx_dict[i_] for i_ in range(p_len.shape[0])])\n",
    "        \n",
    "        p_hiddens_sort_ = self.encoder(p_embeddings[p_sort_idx,:,:], z_p[p_sort_idx,:], p_mask[p_sort_idx,:])\n",
    "        p_hiddens = p_hiddens_sort_[p_revert_idx, :, :].contiguous()\n",
    "\n",
    "        # r encoding\n",
    "        r_mask_flat_ = r_mask.view(r_mask.size(0)*r_mask.size(1), -1) #(batch_size*num_r, rw_length)\n",
    "        r_embeddings_flat_ = r_embeddings.view(r_mask.size(0)*r_mask.size(1), r_mask.size(2), -1)\n",
    "        \n",
    "        r_len = torch.sum(r_mask_flat_, dim=1).cpu().data.numpy()\n",
    "        sort_idx = np.argsort(-r_len)\n",
    "        idx_dict = {sort_idx[i_]: i_ for i_ in range(r_len.shape[0])}\n",
    "        revert_idx = np.array([idx_dict[i_] for i_ in range(r_len.shape[0])])\n",
    "        \n",
    "        r_hiddens_sort_ = self.encoder(r_embeddings_flat_[sort_idx,:,:], \n",
    "                                   z_rw.view(z_rw.size(0)*z_rw.size(1), -1)[sort_idx,:],\n",
    "                                   r_mask_flat_[sort_idx,:]) #(batch*num_r, hidden_dim, rw_length)\n",
    "        \n",
    "        r_hiddens_ = r_hiddens_sort_[revert_idx,:,:].contiguous()\n",
    "        \n",
    "        # R-Q attention\n",
    "        q_hiddens_expand_ = q_hiddens.unsqueeze(1).expand(q_mask.size(0), r_mask.size(1), -1, q_mask.size(1))\n",
    "        q_hiddens_expand_ = q_hiddens_expand_.contiguous().view(q_mask.size(0)*r_mask.size(1), -1, q_mask.size(1))\n",
    "        q_mask_expand_ = q_mask.unsqueeze(1).expand(q_mask.size(0), r_mask.size(1), q_mask.size(1))\n",
    "        q_mask_expand_ = q_mask_expand_.contiguous().view(-1, q_mask.size(1))\n",
    "        \n",
    "        rq_similarity_matrix = self._calculate_similarity_matrix(r_hiddens_, q_hiddens_expand_)\n",
    "        rq_mask_matrix = torch.bmm(r_mask_flat_.unsqueeze(2), q_mask_expand_.unsqueeze(1))\n",
    "        \n",
    "        # (batch_size*num_r, rw_length， hidden_dim * 4)                                                              \n",
    "        rq_matching_states, _ = self._get_r_matching_representations(r_hiddens_, q_hiddens_expand_,\n",
    "                                                                 rq_similarity_matrix, rq_mask_matrix)\n",
    "#         rq_matching_states = self.linear_proj_rq(rq_matching_states)\n",
    "        \n",
    "        # R-P attention\n",
    "        # expand to (batch*num_r, hidden_dim, p_length) and (batch*num_r, p_length)\n",
    "        p_hiddens_expand_ = p_hiddens.unsqueeze(1).expand(p_mask.size(0), r_mask.size(1), -1, p_mask.size(1))\n",
    "        p_hiddens_expand_ = p_hiddens_expand_.contiguous().view(p_mask.size(0)*r_mask.size(1), -1, p_mask.size(1))\n",
    "        p_mask_expand_ = p_mask.unsqueeze(1).expand(p_mask.size(0), r_mask.size(1), p_mask.size(1))\n",
    "        p_mask_expand_ = p_mask_expand_.contiguous().view(-1, p_mask.size(1))\n",
    "\n",
    "        rp_similarity_matrix = self._calculate_similarity_matrix(r_hiddens_, p_hiddens_expand_)\n",
    "        rp_mask_matrix = torch.bmm(r_mask_flat_.unsqueeze(2), p_mask_expand_.unsqueeze(1))\n",
    "        \n",
    "        # (batch_size*num_r, p_length, hidden_dim * 4)\n",
    "        rp_matching_states, _ = self._get_r_matching_representations(r_hiddens_, p_hiddens_expand_,\n",
    "                                                                 rp_similarity_matrix, rp_mask_matrix)\n",
    "#         rp_matching_states = self.linear_proj_rp(rp_matching_states)\n",
    "\n",
    "#         r_matching_states = torch.cat([rq_matching_states, rp_matching_states], dim=2)\n",
    "        r_matching_states = rp_matching_states\n",
    "        predict_sort_ = self.match_lstm(r_matching_states[sort_idx,:,:], \n",
    "                                  z_rw.view(z_rw.size(0)*z_rw.size(1), -1)[sort_idx,:],\n",
    "                                  r_mask_flat_[sort_idx,:])\n",
    "        predict = predict_sort_[revert_idx,:].contiguous().view(r_mask.size(0), r_mask.size(1)) # (batch, num_r)\n",
    "        \n",
    "        return predict\n",
    "        \n",
    "class CoMatchLSTMRankingModel(BasicNLPModel):\n",
    "    def __init__(self, words_lookup, args):\n",
    "        super(CoMatchLSTMRankingModel, self).__init__(words_lookup, args)\n",
    "        \n",
    "        self.num_classes = args.num_classes\n",
    "        self.hidden_dim = args.hidden_dim\n",
    "        self.is_cuda = args.cuda\n",
    "        self.model_type = args.model_type\n",
    "\n",
    "        # create an encoder layer\n",
    "        self.match_lstm = CoMatchLSTMEncoder(args)\n",
    "        \n",
    "        # create a loss function \n",
    "        self.loss_func = nn.CrossEntropyLoss(reduce=False)\n",
    "        \n",
    "        \n",
    "    def _create_embed_layer(self, vocab_size, emb_dim):\n",
    "        embed_layer = nn.Embedding(vocab_size, emb_dim)\n",
    "        embed_layer.weight.requires_grad = True\n",
    "        return embed_layer\n",
    "        \n",
    "        \n",
    "    def forward(self, q, p, r, q_mask, p_mask, r_mask):\n",
    "        \n",
    "        q_embeddings = self.embedding_layer(q) #(batch_size, q_length, embedding_dim)\n",
    "        p_embeddings = self.embedding_layer(p) #(batch_size, p_length, embedding_dim)\n",
    "        rw_embeddings = self.embedding_layer(r) #(batch_size, num_r, r_length, embedding_dim)         \n",
    "        \n",
    "        z_q = torch.ones_like(q).float().cuda() # (batch, q_length)\n",
    "        z_p = torch.ones_like(p).float().cuda() # (batch, p_length)\n",
    "        z_r = torch.ones_like(r).float().cuda() # (batch, num_r, r_length)\n",
    "        \n",
    "        predict = self.match_lstm(q_embeddings, p_embeddings, rw_embeddings, \n",
    "                                  z_q, z_p, z_r, q_mask, p_mask, r_mask)\n",
    "        \n",
    "        return predict\n",
    "    \n",
    "    def loss(self, predict, label):\n",
    "        prediction_loss = self.loss_func(predict, label) # (batch_size, )\n",
    "        supervised_loss = torch.mean(prediction_loss)\n",
    "                \n",
    "        return supervised_loss"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
