{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import sys\n",
    "sys.path.append(\"/home2/gridsan/dzhao/DLNN/updated/spt/spotlight\")\n",
    "from layers import ScaledEmbedding, ZeroEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class BilinearNet(nn.Module):\n",
    "    \"\"\"\n",
    "    Bilinear factorization representation.\n",
    "\n",
    "    Encodes both users and items as an embedding layer; the score\n",
    "    for a user-item pair is given by the dot product of the item\n",
    "    and user latent vectors.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "\n",
    "    num_users: int\n",
    "        Number of users in the model.\n",
    "    num_items: int\n",
    "        Number of items in the model.\n",
    "    embedding_dim: int, optional\n",
    "        Dimensionality of the latent representations.\n",
    "    user_embedding_layer: an embedding layer, optional\n",
    "        If supplied, will be used as the user embedding layer\n",
    "        of the network.\n",
    "    item_embedding_layer: an embedding layer, optional\n",
    "        If supplied, will be used as the item embedding layer\n",
    "        of the network.\n",
    "    sparse: boolean, optional\n",
    "        Use sparse gradients.\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, num_users, num_items, embedding_dim=32,\n",
    "                 user_embedding_layer=None, item_embedding_layer=None, sparse=False):\n",
    "\n",
    "        super(BilinearNet, self).__init__()\n",
    "\n",
    "        self.embedding_dim = embedding_dim\n",
    "\n",
    "        if user_embedding_layer is not None:\n",
    "            self.user_embeddings = user_embedding_layer\n",
    "        else:\n",
    "            self.user_embeddings = ScaledEmbedding(num_users, embedding_dim,\n",
    "                                                   sparse=sparse)\n",
    "\n",
    "        if item_embedding_layer is not None:\n",
    "            self.item_embeddings = item_embedding_layer\n",
    "        else:\n",
    "            self.item_embeddings = ScaledEmbedding(num_items, embedding_dim,\n",
    "                                                   sparse=sparse)\n",
    "\n",
    "        self.user_biases = ZeroEmbedding(num_users, 1, sparse=sparse)\n",
    "        self.item_biases = ZeroEmbedding(num_items, 1, sparse=sparse)\n",
    "\n",
    "    def forward(self, user_ids, item_ids):\n",
    "        \"\"\"\n",
    "        Compute the forward pass of the representation.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "\n",
    "        user_ids: tensor\n",
    "            Tensor of user indices.\n",
    "        item_ids: tensor\n",
    "            Tensor of item indices.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "\n",
    "        predictions: tensor\n",
    "            Tensor of predictions.\n",
    "        \"\"\"\n",
    "\n",
    "        user_embedding = self.user_embeddings(user_ids)\n",
    "        item_embedding = self.item_embeddings(item_ids)\n",
    "\n",
    "        user_embedding = user_embedding.squeeze()\n",
    "        item_embedding = item_embedding.squeeze()\n",
    "\n",
    "        user_bias = self.user_biases(user_ids).squeeze()\n",
    "        item_bias = self.item_biases(item_ids).squeeze()\n",
    "\n",
    "        dot = (user_embedding * item_embedding).sum(1)\n",
    "\n",
    "        return dot + user_bias + item_bias\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
