{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "from sklearn.linear_model import SGDClassifier\n",
    "from sklearn.metrics import roc_auc_score, f1_score\n",
    "import utils as ut\n",
    "from importlib import reload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num. of preference latent dimensions 10\n"
     ]
    }
   ],
   "source": [
    "simulation_type = 'simulated-det-influence'\n",
    "\n",
    "datadir = os.path.join('..', 'lastfm', simulation_type)\n",
    "\n",
    "if not simulation_type == 'simulated-bernoulli-influence-only':\n",
    "    user_latent_pref = np.loadtxt(os.path.join(datadir, 'user_latent_pref'))\n",
    "    item_latent_pref = np.loadtxt(os.path.join(datadir, 'item_latent_pref'))\n",
    "    \n",
    "    print(\"Num. of preference latent dimensions\", user_latent_pref.shape[1])\n",
    "    \n",
    "friend_latent_influence = np.loadtxt(os.path.join(datadir, 'friend_latent_influence'))\n",
    "item_latent_influence = np.loadtxt(os.path.join(datadir, 'item_latent_influence'))\n",
    "\n",
    "print(\"Num. of preference latent dimensions\", friend_latent_influence.shape[1])\n",
    "\n",
    "ratings = pd.read_csv(os.path.join(datadir, 'rating.csv'), header=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num. users: 1892 \n",
      "Num. items: 10000\n",
      "Num. heldout pairs 1892\n"
     ]
    }
   ],
   "source": [
    "num_users = friend_latent_influence.shape[0]\n",
    "num_items = item_latent_influence.shape[1]\n",
    "influence_dim = friend_latent_influence.shape[1]\n",
    "\n",
    "print(\"Num. users: %d \\nNum. items: %d\"% (num_users,num_items))\n",
    "\n",
    "rating_mat = np.zeros((num_users, num_items))\n",
    "for r in ratings.values:\n",
    "    rating_mat[r[0]][r[1]] = 1\n",
    "    \n",
    "holdout_pairs, rating_mat = ut.get_holdout_pairs(rating_mat)\n",
    "print(\"Num. heldout pairs\", len(holdout_pairs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "friends = pd.read_csv(os.path.join('..', 'lastfm', 'trueratings', 'friends.txt'), header=None, sep='\\t')\n",
    "adj = ut.to_mat(friends, num_users, num_users)\n",
    "top_users = [i[0]-1 for i in ut.get_n_most_popular_users(friends, n=100)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fitting both the user preference and influence params\n",
    "- If simulating implicit ratings from influence only, do not run this model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if not simulation_type == 'simulated-bernoulli-influence-only':\n",
    "\n",
    "    batch_size = 20\n",
    "    num_batches = num_users//batch_size\n",
    "    predictor = SGDClassifier(loss='log', alpha=1.0)\n",
    "    it = 0\n",
    "    while it < num_batches:\n",
    "        for i in range(num_batches):\n",
    "            it+=1\n",
    "            print(\"Working on batch\", i)\n",
    "\n",
    "            start = i*(batch_size)\n",
    "            end = (i+1)*batch_size\n",
    "            b_user_pref = user_latent_pref[start:end, :]\n",
    "            b_rating_mat = rating_mat[start:end, :]\n",
    "            b_adj = adj[start:end, :]\n",
    "\n",
    "            labels = b_rating_mat.flatten()\n",
    "            features = ut.get_user_pref_features(b_user_pref, batch_size, num_items)\n",
    "            inf_features = ut.generate_social_features(item_latent_influence.transpose(), b_adj)    \n",
    "            features = np.hstack([features, inf_features])\n",
    "            predictor.partial_fit(features, labels, classes=[0,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Fit intercept term:\", predictor.intercept_)\n",
    "pref_features_len = num_items * pref_dim\n",
    "learned_influence = ut.get_learned_embeddings(predictor, influence_dim, num_users, offset=pref_features_len)\n",
    "learned_item_embed = ut.get_learned_embeddings(predictor, pref_dim, num_items, end=pref_features_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5137503927744123"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from scipy.special import expit\n",
    "\n",
    "y_true = []\n",
    "y_pred = []\n",
    "for (u, i, val) in holdout_pairs:\n",
    "    y_true.append(val)\n",
    "    user = user_latent_pref[u,:]\n",
    "    item = learned_item_embed[i, :]\n",
    "    i_inf = item_latent_influence[:,i]\n",
    "    u_adj = adj[u,:]\n",
    "    inf = np.dot(u_adj, learned_influence)\n",
    "    inf = np.dot(inf, i_inf)\n",
    "    pred = (user*item).sum() + inf.sum()\n",
    "    y_pred.append(expit(pred))\n",
    "\n",
    "roc_auc_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fitting only the influence params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Considering only friendships to most popular users; adj. shape: (1892, 100)\n",
      "Working on batch 0\n",
      "Working on batch 1\n",
      "Working on batch 2\n",
      "Working on batch 3\n",
      "Working on batch 4\n",
      "Working on batch 5\n",
      "Working on batch 6\n",
      "Working on batch 7\n",
      "Working on batch 8\n",
      "Working on batch 9\n",
      "Working on batch 10\n",
      "Working on batch 11\n",
      "Working on batch 12\n",
      "Working on batch 13\n",
      "Working on batch 14\n",
      "Working on batch 15\n",
      "Working on batch 16\n",
      "Working on batch 17\n",
      "Working on batch 18\n",
      "Working on batch 19\n",
      "Working on batch 20\n",
      "Working on batch 21\n",
      "Working on batch 22\n",
      "Working on batch 23\n",
      "Working on batch 24\n",
      "Working on batch 25\n",
      "Working on batch 26\n",
      "Working on batch 27\n",
      "Working on batch 28\n",
      "Working on batch 29\n",
      "Working on batch 30\n",
      "Working on batch 31\n",
      "Working on batch 32\n",
      "Working on batch 33\n",
      "Working on batch 34\n",
      "Working on batch 35\n",
      "Working on batch 36\n",
      "Working on batch 37\n",
      "Working on batch 38\n",
      "Working on batch 39\n",
      "Working on batch 40\n",
      "Working on batch 41\n",
      "Working on batch 42\n",
      "Working on batch 43\n",
      "Working on batch 44\n",
      "Working on batch 45\n",
      "Working on batch 46\n",
      "Working on batch 47\n",
      "Working on batch 48\n",
      "Working on batch 49\n",
      "Working on batch 50\n",
      "Working on batch 51\n",
      "Working on batch 52\n",
      "Working on batch 53\n",
      "Working on batch 54\n",
      "Working on batch 55\n",
      "Working on batch 56\n",
      "Working on batch 57\n",
      "Working on batch 58\n",
      "Working on batch 59\n",
      "Working on batch 60\n",
      "Working on batch 61\n",
      "Working on batch 62\n",
      "Working on batch 63\n",
      "Working on batch 64\n",
      "Working on batch 65\n",
      "Working on batch 66\n",
      "Working on batch 67\n",
      "Working on batch 68\n",
      "Working on batch 69\n",
      "Working on batch 70\n",
      "Working on batch 71\n",
      "Working on batch 72\n",
      "Working on batch 73\n",
      "Working on batch 74\n",
      "Working on batch 75\n",
      "Working on batch 76\n",
      "Working on batch 77\n",
      "Working on batch 78\n",
      "Working on batch 79\n",
      "Working on batch 80\n",
      "Working on batch 81\n",
      "Working on batch 82\n",
      "Working on batch 83\n",
      "Working on batch 84\n",
      "Working on batch 85\n",
      "Working on batch 86\n",
      "Working on batch 87\n",
      "Working on batch 88\n",
      "Working on batch 89\n",
      "Working on batch 90\n",
      "Working on batch 91\n",
      "Working on batch 92\n",
      "Working on batch 93\n"
     ]
    }
   ],
   "source": [
    "batch_size = 20\n",
    "num_batches = num_users//batch_size\n",
    "predictor = SGDClassifier(loss='log', tol=1e-08)\n",
    "users = np.arange(num_users)\n",
    "np.random.shuffle(users)\n",
    "# adj_top_users = adj[:,top_users]\n",
    "\n",
    "print(\"Considering only friendships to most popular users; adj. shape:\", adj_top_users.shape)\n",
    "\n",
    "for i in range(num_batches):\n",
    "    print(\"Working on batch\", i)\n",
    "\n",
    "    start = i*(batch_size)\n",
    "    end = (i+1)*batch_size\n",
    "    rand_users = users[start:end]\n",
    "\n",
    "    b_rating_mat = rating_mat[rand_users, :]\n",
    "    b_adj = adj[rand_users, :]\n",
    "\n",
    "    labels = b_rating_mat.flatten()\n",
    "    features = ut.get_influence_features(item_latent_influence.transpose(), b_adj)    \n",
    "    predictor.partial_fit(features, labels, classes=[0,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Intercept term: [-0.11119023]\n",
      "Learned influence shape: (100, 10)\n"
     ]
    }
   ],
   "source": [
    "print(\"Intercept term:\", predictor.intercept_)\n",
    "learned_influence = ut.get_learned_embeddings(predictor, influence_dim, num_top_users)\n",
    "print(\"Learned influence shape:\", learned_influence.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-0.13180793 -0.01246609  0.01058611  0.02281041 -0.01508805 -0.00870629\n",
      " -0.00799467  0.02113989 -0.01771266  0.13638686]\n",
      "[2.23775305e-01 4.70506288e-02 1.02304612e-01 1.13086303e-06\n",
      " 3.11108465e-04 2.49990288e-14 1.16772963e-02 2.59877616e+00\n",
      " 1.07077898e-03 6.30186238e-06]\n",
      "[-0.12857417 -0.04040205  0.00752409  0.02628762 -0.00779636 -0.0056934\n",
      " -0.01332569  0.01593751 -0.00915753  0.12736557]\n",
      "[2.10308046e-02 1.27807349e-01 5.51912030e-05 1.68833581e-03\n",
      " 7.79423980e-05 1.25260717e+00 8.13101554e-02 7.03738347e-04\n",
      " 6.06191749e-01 4.52940274e-07]\n",
      "[-0.12210894 -0.00977519  0.00324274  0.02341844 -0.01598761 -0.01230258\n",
      " -0.00897275  0.01649319 -0.01910393  0.12078078]\n",
      "[4.70135307e-02 3.30757020e-09 3.72226096e-01 1.99260184e-07\n",
      " 2.03190002e-08 1.68420110e-06 8.12308184e-03 3.91292664e-11\n",
      " 1.87579370e-03 1.27831565e-06]\n",
      "[-0.14716075 -0.0192761   0.00451736  0.02276491 -0.02631949 -0.00156641\n",
      " -0.01284016  0.00918411 -0.01416944  0.13174733]\n",
      "[4.66294992e-06 8.07008699e-13 1.26826225e-01 8.34231342e-01\n",
      " 2.11576722e-08 4.53278487e-02 1.40665234e-05 1.63507029e-08\n",
      " 1.10694285e-04 1.28168849e-04]\n",
      "[-0.1298498  -0.02456099  0.00831528  0.01922914 -0.01763654 -0.00566751\n",
      " -0.00804436  0.01858784 -0.01448849  0.1276616 ]\n",
      "[5.22110773e-04 4.63624355e-05 5.93904987e-02 4.32048743e-04\n",
      " 4.09956749e-03 1.47893689e-09 4.11611936e-07 1.55899506e-07\n",
      " 1.03909051e+00 2.46607101e-02]\n",
      "[-0.16922    -0.02484984  0.00625122  0.01957331 -0.01540813  0.00050549\n",
      " -0.00642675  0.01027429 -0.01402426  0.12107809]\n",
      "[1.73894880e-20 9.31353066e-03 5.68888287e-06 1.39071124e-06\n",
      " 8.60247123e-07 1.90163974e-13 9.81553573e-06 1.81654820e-07\n",
      " 1.20988969e-06 5.56666927e-03]\n",
      "[-0.13196926 -0.02354825  0.00136408  0.01894568 -0.02766881 -0.00968222\n",
      " -0.00098563  0.00530458 -0.01584791  0.13312838]\n",
      "[8.29328184e-05 6.46773793e-12 1.75846464e-02 1.42317251e-04\n",
      " 1.54821003e-02 3.82099273e-03 1.51487789e-05 9.87280255e-05\n",
      " 3.80061625e-05 1.63505698e-06]\n",
      "[-0.1232999  -0.01751289  0.00402535  0.01348414 -0.02597337 -0.01628302\n",
      " -0.01273872  0.01945231 -0.00594317  0.1187169 ]\n",
      "[5.09797118e-02 7.84149937e-17 9.71710653e-04 7.32330201e-02\n",
      " 3.15018962e-01 3.05715694e-03 7.21602733e-05 9.06957898e-08\n",
      " 6.76325761e-02 4.77753547e-06]\n",
      "[-0.15842269 -0.01667575 -0.010627    0.01518014 -0.01930542 -0.01755917\n",
      " -0.01152408  0.01384417 -0.01725676  0.13904306]\n",
      "[4.82523287e-05 3.57840650e-04 5.94579100e-03 3.08251848e-20\n",
      " 3.69249019e-06 3.21702118e-03 2.68910810e-25 5.52673490e-13\n",
      " 2.65250472e-06 5.75577698e-03]\n",
      "[-0.1328404  -0.00549413  0.00350655  0.01046416 -0.02441032 -0.00053402\n",
      "  0.00057945  0.01419572 -0.00065649  0.12562894]\n",
      "[1.75391183e-01 2.34465929e-09 4.56367141e-01 3.25663072e-03\n",
      " 3.61381786e-06 1.49969592e-04 1.67671578e-17 6.81855076e-01\n",
      " 1.26783963e-01 1.49436291e-06]\n"
     ]
    }
   ],
   "source": [
    "from scipy.stats import spearmanr, pearsonr\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "for i in range(influence_dim):\n",
    "    learned = learned_influence[:,i]\n",
    "    true = friend_latent_influence[:,i]\n",
    "    true = true[top_users]\n",
    "    print(learned[:10])\n",
    "    print(true[:10])\n",
    "#     print(mean_squared_error(true, learned))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5651717304855796"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from scipy.special import expit\n",
    "\n",
    "y_true = []\n",
    "y_pred = []\n",
    "for (u, i, val) in holdout_pairs:\n",
    "    y_true.append(val)\n",
    "    i_inf = item_latent_influence[:,i]\n",
    "    u_adj = adj[u,:]\n",
    "    u_adj = u_adj[top_users]\n",
    "    inf = np.dot(u_adj, learned_influence)\n",
    "    inf = np.dot(inf, i_inf)\n",
    "    pred = inf.sum() + predictor.intercept_\n",
    "    y_pred.append(expit(pred))\n",
    "\n",
    "roc_auc_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
