{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import gzip\n",
    "import dill\n",
    "from scipy.special import expit\n",
    "from scipy.stats import gamma, poisson, bernoulli\n",
    "import tensorflow as tf\n",
    "import tensorflow_probability as tfp\n",
    "import utils as ut\n",
    "import scipy.sparse as sparse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "datadir = os.path.join('..', 'dat', 'lastfm-semisynthetic-testing')\n",
    "A = sparse.load_npz(os.path.join(datadir, 'adj.gz.npz'))\n",
    "Beta = np.loadtxt(os.path.join(datadir, 'influence.gz'))\n",
    "Gamma = np.loadtxt(os.path.join(datadir, 'item_embed.gz'))\n",
    "Y_past = sparse.load_npz(os.path.join(datadir, 'past_obs.gz.npz'))\n",
    "Y = sparse.load_npz(os.path.join(datadir, 'obs.gz.npz'))\n",
    "Z = np.loadtxt(os.path.join(datadir, 'user_embed.gz'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "N,M = Y_past.nonzero()\n",
    "indices = np.vstack([N,M]).T\n",
    "tf_indices = tf.convert_to_tensor(indices, dtype=tf.int64)\n",
    "tf_vals = tf.ones(indices.shape[0])\n",
    "dense_shape = tf.convert_to_tensor(np.array(Y_past.shape), dtype=tf.int64)\n",
    "\n",
    "Y_past_sp_tensor = tf.SparseTensor(tf_indices, tf_vals, dense_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = A.toarray()\n",
    "Y = Y.toarray()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(os.path.join(datadir, 'uids'), 'rb') as f, open(os.path.join(datadir, 'itemids'), 'rb') as g:\n",
    "    uids = dill.load(f)\n",
    "    itemids = dill.load(g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No. users, items and K: 5032 1482 50\n"
     ]
    }
   ],
   "source": [
    "num_users = Y.shape[0]\n",
    "num_items = Y.shape[1]\n",
    "K = Z.shape[1]\n",
    "print(\"No. users, items and K:\", num_users, num_items, K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sparsity in A: 0.010202075555870094\n",
      "Sparsity in Obs: 0.027492066965751176\n",
      "Sparsity in Z: 3088.8263812555083\n",
      "Sparsity in Gamma: 3732.7406718245147\n",
      "Sparsity in Past Obs: 0.0002027509767447848\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/site-packages/numpy/matrixlib/defmatrix.py:68: PendingDeprecationWarning: the matrix subclass is not the recommended way to represent matrices or deal with linear algebra (see https://docs.scipy.org/doc/numpy/user/numpy-for-matlab-users.html). Please adjust your code to use regular ndarray.\n",
      "  return matrix(data, dtype=dtype, copy=False)\n",
      "/anaconda3/lib/python3.6/site-packages/numpy/matrixlib/defmatrix.py:68: PendingDeprecationWarning: the matrix subclass is not the recommended way to represent matrices or deal with linear algebra (see https://docs.scipy.org/doc/numpy/user/numpy-for-matlab-users.html). Please adjust your code to use regular ndarray.\n",
      "  return matrix(data, dtype=dtype, copy=False)\n"
     ]
    }
   ],
   "source": [
    "print(\"Sparsity in A:\", A.sum()/(num_users**2))\n",
    "\n",
    "B = Y.copy()\n",
    "B[B>1]=1\n",
    "print(\"Sparsity in Obs:\", B.sum()/(num_users*num_items))\n",
    "\n",
    "print(\"Sparsity in Z:\", Z.sum())\n",
    "\n",
    "print(\"Sparsity in Gamma:\", Gamma.sum())\n",
    "\n",
    "C = Y_past.copy()\n",
    "C[C>1]=1\n",
    "print(\"Sparsity in Past Obs:\", C.sum()/(num_users*num_items))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "module 'tensorflow._api.v1.sparse' has no attribute 'sparse_dense_matmul'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-37-b4ffe1a25915>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0mM\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbeta_hat\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mA_obs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mrate_inf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msparse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msparse_dense_matmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mY_past_sp_tensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mM\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     19\u001b[0m \u001b[0mrate_pref\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mZ_obs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgamma_hat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtranspose_b\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[0mrates\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrate_inf\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mrate_pref\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: module 'tensorflow._api.v1.sparse' has no attribute 'sparse_dense_matmul'"
     ]
    }
   ],
   "source": [
    "friend_inf_embed= tf.Variable(tf.random_normal([num_users]))\n",
    "item_embed = tf.keras.layers.Embedding(input_dim=num_items, output_dim=K)\n",
    "\n",
    "Y_obs = tf.placeholder(np.float32, name=\"Y_obs\")\n",
    "A_obs = tf.placeholder(np.float32, name=\"A_obs\")\n",
    "Z_obs = tf.placeholder(np.float32, name=\"Z_obs\")\n",
    "\n",
    "# Y_past_obs = tf.constant(Y_past.astyp/e(np.float32))\n",
    "\n",
    "all_items = tf.range(num_items)\n",
    "\n",
    "gamma_hat = tf.nn.softplus(item_embed(all_items))\n",
    "beta_hat = tf.nn.softplus(friend_inf_embed)\n",
    "\n",
    "# rate_inf = tf.matmul((beta_hat * A_obs), Y_past_sp_tensor)\n",
    "\n",
    "M = beta_hat * A_obs\n",
    "rate_inf = tf.sparse.sparse_matmul(Y_past_sp_tensor, M)\n",
    "rate_pref = tf.matmul(Z_obs, gamma_hat, transpose_b=True)\n",
    "rates = rate_inf + rate_pref\n",
    "                                       \n",
    "log_likelihood = tf.reduce_mean(tf.reduce_sum(\n",
    "    tfp.distributions.Poisson(rate=rates).log_prob(Y_obs), axis=1))\n",
    "                                       \n",
    "gamma_log_prior = tf.reduce_sum(\n",
    "    tfp.distributions.Gamma(concentration=1., rate=1.).log_prob(\n",
    "        gamma_hat))\n",
    "\n",
    "beta_log_prior = tf.reduce_sum(tfp.distributions.Gamma(concentration=1., rate=1.).log_prob(\n",
    "        beta_hat))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = tf.placeholder(tf.float32, shape=[])\n",
    "loss = -(log_likelihood+beta_log_prior)\n",
    "optimizer = tf.train.AdamOptimizer(learning_rate)\n",
    "train_op = optimizer.minimize(loss)\n",
    "init = tf.global_variables_initializer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "Z[Z==0]=1e-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Cost: 4366.761\n",
      "****************************************\n",
      "\n",
      "Cost: 520.11426\n",
      "****************************************\n",
      "\n",
      "Cost: 370.4594\n",
      "****************************************\n",
      "\n",
      "Cost: 213.14383\n",
      "****************************************\n",
      "\n",
      "Cost: 203.97012\n",
      "****************************************\n",
      "\n",
      "Cost: 138.18399\n",
      "****************************************\n",
      "\n",
      "Cost: 145.47523\n",
      "****************************************\n",
      "\n",
      "Cost: 115.21755\n",
      "****************************************\n",
      "\n",
      "Cost: 163.10661\n",
      "****************************************\n",
      "\n",
      "Cost: 66.6416\n",
      "****************************************\n"
     ]
    }
   ],
   "source": [
    "steps = 10000\n",
    "batch_size = 10\n",
    "users = np.arange(num_users)\n",
    "np.random.shuffle(users)\n",
    "with tf.Session() as sess:\n",
    "    sess.run(init)\n",
    "    for it in range(steps):\n",
    "        i = it%(num_users//batch_size)\n",
    "        start_index = i*batch_size\n",
    "        end_index = (i+1)*batch_size\n",
    "        batch_users = users[start_index:end_index]\n",
    "        Z_batch = Z[batch_users, :].astype(np.float32)\n",
    "        A_batch = A[batch_users, :].astype(np.float32)\n",
    "        Y_batch = Y[batch_users, :].astype(np.float32)\n",
    "        feed = {Y_obs:Y_batch, A_obs:A_batch, Z_obs:Z_batch, learning_rate:0.005}\n",
    "        sess.run(train_op, feed_dict=feed)\n",
    "        if it%1000==0:\n",
    "            print(\"\\nCost:\", sess.run(loss, feed_dict=feed))\n",
    "            print(\"*\"*40)\n",
    "    Beta_p = beta_hat.eval()\n",
    "    Gamma_p = gamma_hat.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_set_overlap(Beta_p, Beta, k=100):\n",
    "    top = np.argsort(Beta)[-k:]\n",
    "    top_p = np.argsort(Beta_p)[-k:]\n",
    "    return np.intersect1d(top, top_p).shape[0]/np.union1d(top, top_p).shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_set_overlap(Beta_p, Beta, k=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2.0990553e-04, 3.8056285e-04, 1.3064484e-04, ..., 9.3694114e-05,\n",
       "       1.4268333e-04, 1.2099011e-04], dtype=float32)"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Beta_p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([5.73879603e-01, 1.92421576e-05, 1.91097819e-17, ...,\n",
       "       9.02254658e-06, 2.44095425e-07, 3.90924265e-01])"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Beta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
