{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## With tensorflow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_probability as tfp\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "np.random.seed(0)\n",
    "from scipy.special import expit\n",
    "from scipy.stats import gamma, poisson, bernoulli\n",
    "import scipy.sparse as sparse\n",
    "from sklearn.metrics import roc_auc_score, mean_squared_error as mse\n",
    "import utils as ut\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_performance(holdout_pairs,user_pref_embed,item_attr_embed):\n",
    "\n",
    "#     log_likelihood = []\n",
    "    y_tr = []\n",
    "    y_pred = []\n",
    "    for (u, i, val) in holdout_pairs:\n",
    "        user = user_pref_embed[u,:]\n",
    "        item = item_attr_embed[i, :]\n",
    "        rate = (user*item).sum()\n",
    "        y_tr.append(val)\n",
    "        y_pred.append(expit(rate))\n",
    "        \n",
    "    return roc_auc_score(y_tr, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num. heldout pairs: 1892\n",
      "1000\n"
     ]
    }
   ],
   "source": [
    "datadir = os.path.join('..','lastfm', 'simulated')\n",
    "user_latent_pref = np.loadtxt(os.path.join(datadir, 'user_latent_pref'))\n",
    "item_latent_pref = np.loadtxt(os.path.join(datadir, 'item_latent_pref'))\n",
    "ratings = pd.read_csv(os.path.join(datadir, 'rating.csv'), header=None)\n",
    "\n",
    "num_users = user_latent_pref.shape[0]\n",
    "num_items = item_latent_pref.shape[1]\n",
    "rating_mat = np.zeros((num_users, num_items))\n",
    "for r in ratings.values:\n",
    "    rating_mat[r[0]][r[1]] = 1\n",
    "\n",
    "holdout_pairs, rating_mat = ut.get_holdout_pairs(rating_mat)\n",
    "ratings = pd.DataFrame(rating_mat, dtype=np.float32)\n",
    "print(\"Num. heldout pairs:\", len(holdout_pairs))\n",
    "print(num_items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.26534\n"
     ]
    }
   ],
   "source": [
    "num_users = 1000\n",
    "num_items = 100\n",
    "rank = 5\n",
    "size_user = (num_users, rank)\n",
    "size_item = (num_items, rank)\n",
    "# bias = 1.0\n",
    "user_latent_pref = gamma.rvs(0.3, scale=1.0, size=size_user)\n",
    "item_latent_attr = gamma.rvs(0.3, scale=1.0, size=size_item)\n",
    "rates = np.dot(user_latent_pref,item_latent_attr.transpose())\n",
    "counts_mat = poisson.rvs(rates)\n",
    "counts_mat[counts_mat>1]=1\n",
    "print(counts_mat.sum()/(num_users*num_items))\n",
    "holdout_pairs, counts_mat = ut.get_holdout_pairs(counts_mat)\n",
    "ratings = pd.DataFrame(counts_mat, dtype=np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rank: 5\n"
     ]
    }
   ],
   "source": [
    "rank = user_latent_pref.shape[1]\n",
    "print(\"Rank:\", rank)\n",
    "row_embeddings = tf.keras.layers.Embedding(input_dim=num_users, \n",
    "                                       output_dim=rank)\n",
    "column_embeddings = tf.keras.layers.Embedding(input_dim=num_items, \n",
    "                                            output_dim=rank)\n",
    "\n",
    "# column_embeddings = tf.keras.layers.Embedding(input_dim=num_items, \n",
    "#                                             output_dim=rank,\n",
    "#                                             embeddings_initializer=tf.initializers.constant(\n",
    "#                                             gamma.rvs(1.0, scale=1.0, size=(num_items, rank))))\n",
    "ratings = tf.constant(ratings.values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# all_users = tf.range(size_user[0])\n",
    "# all_items = tf.range(size_item[0])\n",
    "# positive_attrs = tf.nn.softplus(column_embeddings(all_items))\n",
    "# positive_prefs = tf.nn.softplus(row_embeddings(all_users))\n",
    "# rates = tf.matmul(positive_prefs, positive_attrs, transpose_b=True)\n",
    "# attrs_log_prior = 1. / 1. * tf.reduce_sum(\n",
    "#     tfp.distributions.Gamma(concentration=1., \n",
    "#                             rate=1.).log_prob(positive_attrs))\n",
    "# prefs_log_prior = 1. / 1. * tf.reduce_sum(\n",
    "#     tfp.distributions.Gamma(concentration=1., rate=1.).log_prob(\n",
    "#         positive_prefs))\n",
    "# log_likelihood = tf.reduce_mean(tf.reduce_sum(\n",
    "#     tfp.distributions.Poisson(rate=rates).log_prob(counts),\n",
    "#     axis=1))\n",
    "# log_joint = log_likelihood + attrs_log_prior + prefs_log_prior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "user_embed = pd.DataFrame(user_latent_pref, dtype=np.float32)\n",
    "user_embed_fixed = tf.constant(user_embed.values)\n",
    "b = tf.Variable(tf.ones([1]))\n",
    "\n",
    "all_items = tf.range(num_items)\n",
    "positive_item_embeds = tf.nn.softplus(column_embeddings(all_items))\n",
    "rates = tf.matmul(user_embed_fixed, positive_item_embeds, transpose_b=True)# - b\n",
    "item_log_prior = tf.reduce_sum(\n",
    "    tfp.distributions.Gamma(concentration=1., rate=1.).log_prob(\n",
    "        positive_item_embeds))\n",
    "item_l2_penalty = tf.reduce_sum(positive_item_embeds ** 2)\n",
    "# log_likelihood = tf.reduce_sum(\n",
    "#     tfp.distributions.Bernoulli(logits=rates).log_prob(ratings))\n",
    "\n",
    "log_likelihood = tf.reduce_sum(\n",
    "    tfp.distributions.Poisson(rate=rates).log_prob(ratings))\n",
    "log_joint = log_likelihood"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# learning_rate = 0.05\n",
    "learning_rate = tf.placeholder(tf.float32, shape=[])\n",
    "loss = -log_joint\n",
    "optimizer = tf.train.GradientDescentOptimizer(learning_rate)\n",
    "train_op = optimizer.minimize(loss)\n",
    "init = tf.global_variables_initializer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Cost: 71440.22\n",
      "****************************************\n",
      "\n",
      "Cost: 67021.6\n",
      "****************************************\n",
      "\n",
      "Cost: 66986.53\n",
      "****************************************\n",
      "\n",
      "Cost: 66834.9\n",
      "****************************************\n",
      "\n",
      "Cost: 66829.5\n",
      "****************************************\n",
      "\n",
      "Cost: 66827.16\n",
      "****************************************\n",
      "\n",
      "Cost: 66826.516\n",
      "****************************************\n",
      "\n",
      "Cost: 66826.03\n",
      "****************************************\n",
      "\n",
      "Cost: 66825.445\n",
      "****************************************\n",
      "\n",
      "Cost: 66824.68\n",
      "****************************************\n",
      "\n",
      "Cost: 66823.516\n",
      "****************************************\n",
      "\n",
      "Cost: 66821.625\n",
      "****************************************\n",
      "\n",
      "Cost: 66818.05\n",
      "****************************************\n",
      "\n",
      "Cost: 66810.83\n",
      "****************************************\n",
      "\n",
      "Cost: 66800.28\n",
      "****************************************\n",
      "\n",
      "Cost: 66794.36\n",
      "****************************************\n",
      "\n",
      "Cost: 66792.766\n",
      "****************************************\n",
      "\n",
      "Cost: 66792.33\n",
      "****************************************\n",
      "\n",
      "Cost: 66792.16\n",
      "****************************************\n",
      "\n",
      "Cost: 66792.0\n",
      "****************************************\n",
      "\n",
      "Cost: 66791.86\n",
      "****************************************\n",
      "\n",
      "Cost: 66791.75\n",
      "****************************************\n",
      "\n",
      "Cost: 66791.61\n",
      "****************************************\n",
      "\n",
      "Cost: 66791.47\n",
      "****************************************\n",
      "\n",
      "Cost: 66791.33\n",
      "****************************************\n",
      "\n",
      "Cost: 66791.19\n",
      "****************************************\n",
      "\n",
      "Cost: 66791.03\n",
      "****************************************\n",
      "\n",
      "Cost: 66790.86\n",
      "****************************************\n",
      "\n",
      "Cost: 66790.7\n",
      "****************************************\n",
      "\n",
      "Cost: 66790.53\n",
      "****************************************\n",
      "\n",
      "Cost: 66790.36\n",
      "****************************************\n",
      "\n",
      "Cost: 66790.18\n",
      "****************************************\n",
      "\n",
      "Cost: 66789.99\n",
      "****************************************\n",
      "\n",
      "Cost: 66789.8\n",
      "****************************************\n",
      "\n",
      "Cost: 66789.61\n",
      "****************************************\n",
      "\n",
      "Cost: 66789.41\n",
      "****************************************\n",
      "\n",
      "Cost: 66789.19\n",
      "****************************************\n",
      "\n",
      "Cost: 66788.984\n",
      "****************************************\n",
      "\n",
      "Cost: 66788.78\n",
      "****************************************\n",
      "\n",
      "Cost: 66788.55\n",
      "****************************************\n",
      "\n",
      "Cost: 66788.33\n",
      "****************************************\n",
      "\n",
      "Cost: 66788.1\n",
      "****************************************\n",
      "\n",
      "Cost: 66787.89\n",
      "****************************************\n",
      "\n",
      "Cost: 66787.65\n",
      "****************************************\n",
      "\n",
      "Cost: 66787.43\n",
      "****************************************\n",
      "\n",
      "Cost: 66787.2\n",
      "****************************************\n",
      "\n",
      "Cost: 66786.97\n",
      "****************************************\n",
      "\n",
      "Cost: 66786.73\n",
      "****************************************\n",
      "\n",
      "Cost: 66786.47\n",
      "****************************************\n",
      "\n",
      "Cost: 66786.2\n",
      "****************************************\n",
      "\n",
      "Cost: 66785.94\n",
      "****************************************\n",
      "\n",
      "Cost: 66785.664\n",
      "****************************************\n",
      "\n",
      "Cost: 66785.375\n",
      "****************************************\n",
      "\n",
      "Cost: 66785.05\n",
      "****************************************\n",
      "\n",
      "Cost: 66784.7\n",
      "****************************************\n",
      "\n",
      "Cost: 66784.305\n",
      "****************************************\n",
      "\n",
      "Cost: 66783.89\n",
      "****************************************\n",
      "\n",
      "Cost: 66783.41\n",
      "****************************************\n",
      "\n",
      "Cost: 66782.875\n",
      "****************************************\n",
      "\n",
      "Cost: 66782.234\n",
      "****************************************\n",
      "\n",
      "Cost: 66781.51\n",
      "****************************************\n",
      "\n",
      "Cost: 66780.62\n",
      "****************************************\n",
      "\n",
      "Cost: 66779.555\n",
      "****************************************\n",
      "\n",
      "Cost: 66778.24\n",
      "****************************************\n",
      "\n",
      "Cost: 66776.58\n",
      "****************************************\n",
      "\n",
      "Cost: 66774.44\n",
      "****************************************\n",
      "\n",
      "Cost: 66771.61\n",
      "****************************************\n",
      "\n",
      "Cost: 66767.79\n",
      "****************************************\n",
      "\n",
      "Cost: 66762.516\n",
      "****************************************\n",
      "\n",
      "Cost: 66755.28\n",
      "****************************************\n",
      "\n",
      "Cost: 66745.68\n",
      "****************************************\n",
      "\n",
      "Cost: 66734.08\n",
      "****************************************\n",
      "\n",
      "Cost: 66722.06\n",
      "****************************************\n",
      "\n",
      "Cost: 66711.57\n",
      "****************************************\n",
      "\n",
      "Cost: 66703.59\n",
      "****************************************\n",
      "\n",
      "Cost: 66697.99\n",
      "****************************************\n",
      "\n",
      "Cost: 66694.23\n",
      "****************************************\n",
      "\n",
      "Cost: 66691.7\n",
      "****************************************\n",
      "\n",
      "Cost: 66690.02\n",
      "****************************************\n",
      "\n",
      "Cost: 66688.875\n",
      "****************************************\n",
      "\n",
      "Cost: 66688.1\n",
      "****************************************\n",
      "\n",
      "Cost: 66687.55\n",
      "****************************************\n",
      "\n",
      "Cost: 66687.16\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.875\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.69\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.55\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.44\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.35\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.305\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.26\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.23\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.19\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.164\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.13\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.125\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.11\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.09\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.09\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.07\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.06\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.06\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.04\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.03\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.02\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.0\n",
      "****************************************\n",
      "\n",
      "Cost: 66686.01\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.99\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.97\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.97\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.96\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.96\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.94\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.92\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.91\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.91\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.9\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.89\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.875\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.87\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.86\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.85\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.84\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.84\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.82\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.805\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.8\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.8\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.79\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.78\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.766\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.76\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.75\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.734\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.72\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.72\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.72\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.7\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.69\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.68\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.67\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.66\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.65\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.64\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.625\n",
      "****************************************\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Cost: 66685.62\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.6\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.59\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.586\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.58\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.57\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.56\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.55\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.53\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.516\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.51\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.5\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.484\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.484\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.46\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.45\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.445\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.44\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.42\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.41\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.41\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.39\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.375\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.36\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.35\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.336\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.33\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.31\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.305\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.3\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.28\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.266\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.25\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.234\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.22\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.2\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.195\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.19\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.17\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.164\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.16\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.14\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.125\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.11\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.09\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.08\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.06\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.05\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.03\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.02\n",
      "****************************************\n",
      "\n",
      "Cost: 66685.01\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.984\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.98\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.97\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.95\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.93\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.92\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.91\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.89\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.875\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.86\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.84\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.83\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.81\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.8\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.78\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.766\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.75\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.734\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.72\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.695\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.68\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.664\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.64\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.625\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.61\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.59\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.56\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.55\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.53\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.51\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.484\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.47\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.45\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.44\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.41\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.39\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.375\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.35\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.33\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.31\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.28\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.266\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.24\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.22\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.195\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.17\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.15\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.125\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.09\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.08\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.05\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.02\n",
      "****************************************\n",
      "\n",
      "Cost: 66684.0\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.984\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.95\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.92\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.89\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.875\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.84\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.81\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.8\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.766\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.734\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.7\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.67\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.64\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.61\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.58\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.55\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.516\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.484\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.45\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.42\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.39\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.36\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.33\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.28\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.25\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.22\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.17\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.14\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.1\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.06\n",
      "****************************************\n",
      "\n",
      "Cost: 66683.02\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.984\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.95\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.914\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.875\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.83\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.78\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.74\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.695\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.66\n",
      "****************************************\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Cost: 66682.61\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.56\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.51\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.45\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.41\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.36\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.305\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.26\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.2\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.15\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.08\n",
      "****************************************\n",
      "\n",
      "Cost: 66682.04\n",
      "****************************************\n",
      "Log likelihood of held-out data: 0.5888196402367644\n",
      "Log likelihood under random embeddings: 0.4916200887042258\n"
     ]
    }
   ],
   "source": [
    "steps = 300000\n",
    "with tf.Session() as sess:\n",
    "    sess.run(init)\n",
    "    for i in range(steps):\n",
    "        sess.run(train_op, feed_dict={learning_rate:1.0/(i+1)})\n",
    "        if i%1000==0:\n",
    "            print(\"\\nCost:\", sess.run(loss))\n",
    "            print(\"*\"*40)\n",
    "    item_embed = positive_item_embeds.eval()\n",
    "    score = evaluate_performance(holdout_pairs, user_latent_pref, item_embed)\n",
    "    print(\"Log likelihood of held-out data:\", score)\n",
    "    \n",
    "    random_beta = gamma.rvs(0.5, scale=1.0, size=(num_users, rank))\n",
    "    random_theta = gamma.rvs(0.5, scale=1.0, size=(num_items, rank))\n",
    "    print(\"Log likelihood under random embeddings:\", evaluate_performance(holdout_pairs, random_beta, random_theta))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Log likelihood under random embeddings: 0.5024772880908807\n"
     ]
    }
   ],
   "source": [
    "random_theta = gamma.rvs(0.1, scale=1.0, size=(num_items, rank))\n",
    "print(\"Log likelihood under random embeddings:\", evaluate_performance(holdout_pairs, user_latent_pref, random_theta))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Truth: [17 46  1 73 25 37 98 79  0 57]\n",
      "Predicted: [97 73 17 46 98 25 37 79  0 57]\n",
      "****************************************\n",
      "Truth: [88 47 14  6 34 92 91 13 31 74]\n",
      "Predicted: [61  1 81 18 14 16 74 92 91 31]\n",
      "****************************************\n"
     ]
    }
   ],
   "source": [
    "true_item_embed = item_latent_attr\n",
    "for i in range(2):\n",
    "    true_popular_items = np.argsort(true_item_embed[:,i])[-10:]\n",
    "    pred_popular_items = np.argsort(item_embed[:,i])[-10:]\n",
    "#     print(\"Truth:\", true_item_embed[true_popular_items, i])\n",
    "#     print(\"Predicted:\", item_embed[true_popular_items, i])\n",
    "    print(\"Truth:\", true_popular_items)\n",
    "    print(\"Predicted:\", pred_popular_items)\n",
    "    print(\"*\"*40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
