{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from transformers import BertModel, BertTokenizer\n",
    "\n",
    "import transformers\n",
    "print('transformers version:', transformers.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# class BERTRanker(nn.Module):\n",
    "#     def __init__(self, max_length, num_class, pretrain_path, blank_padding=True):\n",
    "#         \"\"\"\n",
    "#         Args:\n",
    "#             max_length: max length of sentence\n",
    "#             pretrain_path: path of pretrain model\n",
    "#         \"\"\"\n",
    "#         super().__init__()\n",
    "#         self.max_length = max_length\n",
    "#         self.blank_padding = blank_padding\n",
    "#         self.hidden_size = 768\n",
    "# #         self.mask_entity = mask_entity\n",
    "#         logging.info('Loading BERT pre-trained checkpoint.')\n",
    "#         self.bert = BertModel.from_pretrained(pretrain_path)\n",
    "        \n",
    "#         self.fc = nn.Linear(self.hidden_size, num_class)\n",
    "\n",
    "#     def forward(self, token, att_mask):\n",
    "#         \"\"\"\n",
    "#         Args:\n",
    "#             token: (B, L), index of tokens\n",
    "#             att_mask: (B, L), attention mask (1 for contents and 0 for padding)\n",
    "#         Return:\n",
    "#             x -- (B, H), representations for sentences\n",
    "#             return (B, 1) scores\n",
    "#         \"\"\"\n",
    "#         _, x = self.bert(token, attention_mask=att_mask)\n",
    "#         return self.fc(x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# class SiameseBertRanking(nn.Module):\n",
    "#     def __init__(self, max_length, num_class, pretrain_path, blank_padding=True):\n",
    "#         super().__init__()\n",
    "        \n",
    "#         self.max_length = max_length\n",
    "#         self.blank_padding = blank_padding\n",
    "#         self.hidden_size = 768\n",
    "\n",
    "#         logging.info('Loading BERT pre-trained checkpoint.')\n",
    "#         self.bert = BertModel.from_pretrained(pretrain_path)\n",
    "        \n",
    "#         self.num_classes = num_class\n",
    "# #         self.loss_func = nn.CrossEntropyLoss()\n",
    "        \n",
    "        \n",
    "#     def forward(self, q, p, q_mask, p_mask):\n",
    "#         \"\"\"\n",
    "#         Args:\n",
    "#             token: (B, L), index of tokens\n",
    "#             att_mask: (B, L), attention mask (1 for contents and 0 for padding)\n",
    "#         Return:\n",
    "#             return (B, 1) scores\n",
    "#         \"\"\"\n",
    "#         _, q_hiddens = self.bert(q.unsqueeze(0), attention_mask=q_mask.unsqueeze(0)) # (1, H)\n",
    "#         q_hiddens = q_hiddens.squeeze(0) # (H,)\n",
    "        \n",
    "#         _, p_hiddens = self.bert(p, attention_mask=p_mask) # (B, H)\n",
    "        \n",
    "# #         print(p_hiddens.size())\n",
    "# #         print(q_hiddens.size())\n",
    "\n",
    "#         return F.cosine_similarity(p_hiddens, q_hiddens.unsqueeze(0), dim=1).squeeze(-1) # (B,)\n",
    "        \n",
    "# #         return torch.mm(p_hiddens, q_hiddens.unsqueeze(-1)).squeeze(-1) # (B,)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SiameseBertClassification(nn.Module):\n",
    "    def __init__(self, max_length, num_class, pretrain_path, blank_padding=True):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.max_length = max_length\n",
    "        self.blank_padding = blank_padding\n",
    "        self.hidden_size = 768\n",
    "\n",
    "        logging.info('Loading BERT pre-trained checkpoint.')\n",
    "        self.bert = BertModel.from_pretrained(pretrain_path)\n",
    "        \n",
    "        self.num_classes = num_class\n",
    "        self.output_layer = nn.Linear(self.hidden_size*4, self.num_classes)\n",
    "        self.loss = nn.CrossEntropyLoss()\n",
    "        \n",
    "        \n",
    "    def forward(self, q, r, q_mask, r_mask):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            q: (B, L), index of tokens\n",
    "            r: (B, K, L'), index of tokens\n",
    "        Return:\n",
    "            return (B, L) scores\n",
    "        \"\"\"\n",
    "        _, q_hiddens = self.bert(q, attention_mask=q_mask) # (B, H)\n",
    "        \n",
    "        r_flat_ = r.view(r.size(0)*r.size(1), -1) #(batch_size*num_r, r_length)\n",
    "        r_mask_flat_ = r_mask.view(r_mask.size(0)*r_mask.size(1), -1) #(batch_size*num_r, r_length)\n",
    "        \n",
    "        _, r_hiddens = self.bert(r_flat_, attention_mask=r_mask_flat_) # (BK, H)\n",
    "        \n",
    "        q_hiddens_expand_ = q_hiddens.unsqueeze(1).expand(q_mask.size(0), r_mask.size(1), -1)\n",
    "        q_hiddens_expand_ = q_hiddens_expand_.contiguous().view(q_mask.size(0)*r_mask.size(1), -1) # (BK, H)\n",
    "\n",
    "        matching_state = torch.cat([q_hiddens_expand_, r_hiddens,\n",
    "                                    q_hiddens_expand_ - r_hiddens,\n",
    "                                    q_hiddens_expand_ * r_hiddens], dim=1) #(BK, 4H)\n",
    "        \n",
    "        # (batch_size,)                                                                                                 \n",
    "        predict = self.output_layer(matching_state).view(r_mask.size(0), -1) #(B, K)\n",
    "        \n",
    "        return predict\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SiameseBertSASRanking(nn.Module):\n",
    "    def __init__(self, max_length, num_class, pretrain_path, blank_padding=True):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.max_length = max_length\n",
    "        self.blank_padding = blank_padding\n",
    "        self.hidden_size = 768\n",
    "\n",
    "        logging.info('Loading BERT pre-trained checkpoint.')\n",
    "        self.bert = BertModel.from_pretrained(pretrain_path)\n",
    "        \n",
    "        self.num_classes = num_class\n",
    "        self.output_layer = nn.Linear(self.hidden_size*4, self.num_classes)\n",
    "        self.loss = nn.CrossEntropyLoss()\n",
    "        \n",
    "        \n",
    "    def forward(self, i, o, i_mask, o_mask):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            i: (B, L), index of tokens\n",
    "            o: (B, L'), index of tokens\n",
    "        Return:\n",
    "            return (B, B) scores\n",
    "        \"\"\"\n",
    "        _, i_hiddens = self.bert(i, attention_mask=i_mask) # (B, H)\n",
    "        _, o_hiddens = self.bert(o, attention_mask=o_mask) # (B, H)\n",
    "        \n",
    "        i_hiddens_expand_ = i_hiddens.unsqueeze(1).expand(i_mask.size(0), o_mask.size(0), -1) # (B, B, H)\n",
    "        o_hiddens_expand_ = o_hiddens.unsqueeze(0).expand(i_mask.size(0), o_mask.size(0), -1) # (B, B, H)\n",
    "        i_hiddens_expand_ = i_hiddens_expand_.contiguous().view(i_mask.size(0)*o_mask.size(0), -1) # (BB, H)\n",
    "        o_hiddens_expand_ = o_hiddens_expand_.contiguous().view(i_mask.size(0)*o_mask.size(0), -1) # (BB, H)\n",
    "\n",
    "        matching_state = torch.cat([i_hiddens_expand_, o_hiddens_expand_,\n",
    "                                    i_hiddens_expand_ - o_hiddens_expand_,\n",
    "                                    i_hiddens_expand_ * o_hiddens_expand_], dim=1) #(BB, 4H)\n",
    "        \n",
    "        predict = self.output_layer(matching_state).view(i_mask.size(0), -1) #(B, B)\n",
    "        \n",
    "#         matching_state = torch.cat([i_hiddens_expand_, o_hiddens_expand_,\n",
    "#                                     i_hiddens_expand_ - o_hiddens_expand_,\n",
    "#                                     i_hiddens_expand_ * o_hiddens_expand_], dim=2) #(B, B, 4H)\n",
    "        \n",
    "#         predict = self.output_layer(matching_state) #(B, B)\n",
    "        \n",
    "        return predict\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConcatBertRanking(nn.Module):\n",
    "    def __init__(self, max_length, num_class, pretrain_path, blank_padding=True):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.max_length = max_length\n",
    "        self.blank_padding = blank_padding\n",
    "        self.hidden_size = 768\n",
    "\n",
    "        logging.info('Loading BERT pre-trained checkpoint.')\n",
    "        self.bert = BertModel.from_pretrained(pretrain_path)\n",
    "        \n",
    "        self.num_classes = num_class\n",
    "        self.output_layer = nn.Linear(self.hidden_size, self.num_classes)\n",
    "        self.loss = nn.CrossEntropyLoss()\n",
    "        \n",
    "        \n",
    "    def forward(self, seq, mask):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            i: (B, L), index of tokens\n",
    "        Return:\n",
    "            return (B, B) scores\n",
    "        \"\"\"\n",
    "        _, hiddens = self.bert(seq, attention_mask=mask) # (B, H)\n",
    "        \n",
    "        predict = self.output_layer(hiddens) #(B, 1)\n",
    "        \n",
    "        return predict\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
