{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BasicNLPModel(nn.Module):\n",
    "    \n",
    "    def __init__(self, words_lookup, args):\n",
    "        super(BasicNLPModel, self).__init__()\n",
    "\n",
    "        self.fine_tuning = args.fine_tuning\n",
    "        self.vocab_size, self.embedding_dim = words_lookup.shape        \n",
    "        \n",
    "        # create a embedding layer\n",
    "        self.embedding_layer = self._create_embedding_layer(words_lookup)\n",
    "        \n",
    "        # create bias for calculating similarity\n",
    "        self.bias_layer = self._create_bias_layer()\n",
    "                \n",
    "    def _create_embedding_layer(self, words_lookup):\n",
    "        embedding_layer = nn.Embedding(self.vocab_size, self.embedding_dim)\n",
    "        embedding_layer.weight.data = torch.from_numpy(words_lookup)\n",
    "        embedding_layer.weight.requires_grad = self.fine_tuning        \n",
    "        \n",
    "        return embedding_layer \n",
    "        \n",
    "    def _create_bias_layer(self):    \n",
    "        bias_layer = nn.Embedding(self.vocab_size, 1)\n",
    "        bias_layer.weight.requires_grad = True\n",
    "        \n",
    "        return bias_layer               \n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
