{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from scipy.special import expit\n",
    "from scipy.stats import gamma, poisson, bernoulli\n",
    "import tensorflow as tf\n",
    "import tensorflow_probability as tfp\n",
    "import utils as ut\n",
    "import scipy.sparse as sparse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_set_overlap(Beta_p, Beta, k=50):\n",
    "    top = np.argsort(Beta)[-k:]\n",
    "    top_p = np.argsort(Beta_p)[-k:]\n",
    "    return np.intersect1d(top, top_p).shape[0]/np.union1d(top, top_p).shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def report_sparsity(A, Y, Z, Gamma, Y_past):\n",
    "    print(\"Sparsity in A:\", A.sum()/(num_users**2))\n",
    "\n",
    "    B = Y.copy()\n",
    "    B[B>1]=1\n",
    "    print(\"Sparsity in Obs:\", B.sum()/(num_users*num_items))\n",
    "    \n",
    "    C = Y_past.copy()\n",
    "    C[C>1]=1\n",
    "    print(\"Sparsity in Past Obs:\", C.sum()/(num_users*num_items))\n",
    "\n",
    "    print(\"Sparsity in Z:\", Z.sum())\n",
    "\n",
    "    print(\"Sparsity in Gamma:\", Gamma.sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Some useful parameters:\n",
    "##### Gamma shape: 0.1\n",
    "##### Gamma scale: 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.049589"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_users = 1000\n",
    "num_items = 10000\n",
    "K = 20\n",
    "\n",
    "mean = 0.05\n",
    "scale = 1./10.\n",
    "shape = mean/scale\n",
    "\n",
    "Z = gamma.rvs(shape, scale=scale, size=(num_users, K))\n",
    "A = poisson.rvs(np.dot(Z, Z.transpose()))\n",
    "A[A>1]=1\n",
    "A.sum()/(num_users**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1000, 1000)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "indegrees = A.sum(axis=1)\n",
    "indegrees.mean(), indegrees.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(49.589, (1000,))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outdegrees = A.sum(axis=0)\n",
    "outdegrees.mean(), outdegrees.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1000, 10000)\n"
     ]
    }
   ],
   "source": [
    "Gamma = gamma.rvs(shape, scale=scale, size=(num_items, K))\n",
    "Y_past = poisson.rvs(np.dot(Z, Gamma.transpose()))\n",
    "print(Y_past.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "33.442033515497386"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(row_idx, col_idx) = np.nonzero(A)\n",
    "overlaps = []\n",
    "for idx in range(row_idx.shape[0]):\n",
    "    i = row_idx[idx]\n",
    "    j = col_idx[idx]\n",
    "    item_overlap = np.intersect1d(np.nonzero(Y_past[i,:])[0], np.nonzero(Y_past[j,:])[0])\n",
    "    overlaps.append(item_overlap.shape[0])\n",
    "    break\n",
    "np.mean(overlaps)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1000,)\n",
      "0.010797623932200209\n",
      "10.797623932200208\n"
     ]
    }
   ],
   "source": [
    "alpha = 0.01\n",
    "mean = alpha*Z.sum(axis=1)\n",
    "scale = 1./10.\n",
    "shape = mean/scale\n",
    "# mean = 0.01\n",
    "#scale = mean/alpha*Z.sum(axis=1)\n",
    "\n",
    "Beta = gamma.rvs(shape, scale=scale)\n",
    "\n",
    "print(Beta.shape)\n",
    "print(Beta.mean())\n",
    "print(Beta.sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1000, 10000)\n",
      "(1000, 1000)\n"
     ]
    }
   ],
   "source": [
    "M = Beta*A\n",
    "I = np.dot(M,Y_past)\n",
    "print(I.shape)\n",
    "print(M.shape)\n",
    "\n",
    "# Gamma_n = gamma.rvs(shape, scale=scale, size=(num_items, K))\n",
    "P = np.dot(Z, Gamma.transpose())\n",
    "rate = I + P\n",
    "Y = poisson.rvs(rate)\n",
    "# sparse_Y = sparse.csr_matrix(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sparsity in A: 0.048762\n",
      "Sparsity in Obs: 0.0771662\n",
      "Sparsity in Past Obs: 0.048635\n",
      "Sparsity in Z: 1005.5255286165989\n",
      "Sparsity in Gamma: 10027.933041288756\n"
     ]
    }
   ],
   "source": [
    "report_sparsity(A, Y, Z, Gamma, Y_past)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "friend_inf_embed= tf.Variable(tf.random_normal([num_users]))\n",
    "item_embed = tf.keras.layers.Embedding(input_dim=num_items, output_dim=K)\n",
    "\n",
    "Y_obs = tf.placeholder(np.float32, name=\"Y_obs\")\n",
    "A_obs = tf.placeholder(np.float32, name=\"A_obs\")\n",
    "Z_obs = tf.placeholder(np.float32, name=\"Z_obs\")\n",
    "\n",
    "Y_past_obs = tf.constant(Y_past.astype(np.float32))\n",
    "\n",
    "all_items = tf.range(num_items)\n",
    "\n",
    "gamma_hat = tf.nn.softplus(item_embed(all_items))\n",
    "beta_hat = tf.nn.softplus(friend_inf_embed)\n",
    "\n",
    "rate_inf = tf.matmul((beta_hat * A_obs), Y_past_obs)\n",
    "rate_pref = tf.matmul(Z_obs, gamma_hat, transpose_b=True)\n",
    "rates = rate_inf + rate_pref\n",
    "                                       \n",
    "log_likelihood = tf.reduce_mean(tf.reduce_sum(\n",
    "    tfp.distributions.Poisson(rate=rates).log_prob(Y_obs), axis=1))\n",
    "                                       \n",
    "# log_likelihood = tf.reduce_mean(tf.reduce_sum((rates - Y_obs)**2, axis=-1))\n",
    "\n",
    "log_prior = tf.reduce_sum(\n",
    "    tfp.distributions.Gamma(concentration=1., rate=1.).log_prob(\n",
    "        gamma_hat))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = tf.placeholder(tf.float32, shape=[])\n",
    "loss = -(log_likelihood)\n",
    "optimizer = tf.train.AdamOptimizer(learning_rate)\n",
    "train_op = optimizer.minimize(loss)\n",
    "init = tf.global_variables_initializer()\n",
    "local_init = tf.local_variables_initializer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Cost: 28021.697\n",
      "****************************************\n",
      "\n",
      "Cost: 3001.5334\n",
      "****************************************\n",
      "\n",
      "Cost: 2730.7576\n",
      "****************************************\n",
      "\n",
      "Cost: 2678.9998\n",
      "****************************************\n",
      "\n",
      "Cost: 2659.5205\n",
      "****************************************\n",
      "\n",
      "Cost: 2650.2979\n",
      "****************************************\n",
      "\n",
      "Cost: 2645.3193\n",
      "****************************************\n",
      "\n",
      "Cost: 2642.4255\n",
      "****************************************\n",
      "\n",
      "Cost: 2640.676\n",
      "****************************************\n",
      "\n",
      "Cost: 2639.5938\n",
      "****************************************\n",
      "\n",
      "Cost: 2638.915\n",
      "****************************************\n",
      "\n",
      "Cost: 2638.4856\n",
      "****************************************\n",
      "\n",
      "Cost: 2638.2124\n",
      "****************************************\n",
      "\n",
      "Cost: 2638.0378\n",
      "****************************************\n",
      "\n",
      "Cost: 2637.926\n",
      "****************************************\n",
      "\n",
      "Cost: 2637.8538\n",
      "****************************************\n",
      "\n",
      "Cost: 2637.807\n",
      "****************************************\n",
      "\n",
      "Cost: 2637.776\n",
      "****************************************\n",
      "\n",
      "Cost: 2637.7559\n",
      "****************************************\n",
      "\n",
      "Cost: 2637.7424\n",
      "****************************************\n"
     ]
    }
   ],
   "source": [
    "steps = 20000\n",
    "batch_size = 50\n",
    "users = np.arange(num_users)\n",
    "np.random.shuffle(users)\n",
    "with tf.Session() as sess:\n",
    "    sess.run(init)\n",
    "    sess.run(local_init)\n",
    "    for it in range(steps):\n",
    "        i = it%(num_users//batch_size)\n",
    "        start_index = i*batch_size\n",
    "        end_index = (i+1)*batch_size\n",
    "        batch_users = users[start_index:end_index]\n",
    "        Z_batch = Z[batch_users, :].astype(np.float32)\n",
    "        A_batch = A[batch_users, :].astype(np.float32)\n",
    "        Y_batch = Y[batch_users, :].astype(np.float32)        \n",
    "        feed = {Y_obs:Y_batch, A_obs:A_batch, Z_obs:Z_batch, learning_rate:0.01}\n",
    "        sess.run(train_op, feed_dict=feed)\n",
    "        if it%1000==0:\n",
    "            print(\"\\nCost:\", sess.run(loss, feed_dict=feed))\n",
    "            print(\"*\"*40)\n",
    "    Beta_p = beta_hat.eval()\n",
    "    Gamma_p = gamma_hat.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9607843137254902"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_set_overlap(Beta_p, Beta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[945 146 402 722 899 198 231 268 483 721 260 912 547 157  35 693  12 905\n",
      "  43  49 548 604 491   3 922 926 812 303 123 276 759 979 947 149  10 944\n",
      "   0 842  41 794 153 633 725 291  38 590 413 244 662 685]\n"
     ]
    }
   ],
   "source": [
    "print(np.argsort(Beta)[-50:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[597 146 945 899 722 721 231 198 268 260 483 547 912 157  35 693 548  43\n",
      "  12 905  49 491 604 922   3 926 303 812 123 276 759 149 947  10 979 944\n",
      "   0  41 794 842 153 633 725 291  38 590 413 244 662 685]\n"
     ]
    }
   ],
   "source": [
    "print(np.argsort(Beta_p)[-50:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-3392665.671068514"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(poisson.logpmf(Y, I+P))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_input_pipeline(obs_matrix, batch_size):\n",
    "    num_users, num_items = obs_matrix.shape\n",
    "    dataset = tf.data.Dataset.from_tensor_slices(tf.range(num_users))\n",
    "    dataset = dataset.shuffle(num_users).repeat()\n",
    "    \n",
    "  # Returns a single document as a dense TensorFlow tensor. The dataset is\n",
    "  # stored as a sparse matrix outside of the graph.\n",
    "    def get_row_py_func(user_idx):\n",
    "        def get_row_python(user_idx_py):\n",
    "            return np.float32(np.array(obs_matrix[user_idx_py].todense())[0])\n",
    "        py_func = tf.py_func(get_row_python, [user_idx], tf.float32, stateful=False)\n",
    "        py_func.set_shape((num_items,))\n",
    "        return py_func\n",
    "    dataset = dataset.map(lambda user_idx: (get_row_py_func(user_idx), user_idx))\n",
    "    iterator = dataset.batch(batch_size).make_one_shot_iterator()\n",
    "    items, user = iterator.get_next()\n",
    "    return items, user, num_users, num_items "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
