{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9d0f72e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/haoxuanli/opt/anaconda3/envs/mrdr/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:523: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/Users/haoxuanli/opt/anaconda3/envs/mrdr/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:524: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/Users/haoxuanli/opt/anaconda3/envs/mrdr/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/Users/haoxuanli/opt/anaconda3/envs/mrdr/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/Users/haoxuanli/opt/anaconda3/envs/mrdr/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/Users/haoxuanli/opt/anaconda3/envs/mrdr/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:532: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 196  242    3]\n",
      " [ 186  302    3]\n",
      " [  22  377    1]\n",
      " ...\n",
      " [ 276 1090    1]\n",
      " [  13  225    2]\n",
      " [  12  203    3]]\n",
      "943 1682\n",
      "Epoch: 1 MSE: 60.049713\n",
      "Epoch: 2 MSE: 56.029274\n",
      "Epoch: 3 MSE: 52.343998\n",
      "Epoch: 4 MSE: 48.9447\n",
      "Epoch: 5 MSE: 45.800495\n",
      "Epoch: 6 MSE: 42.886383\n",
      "Epoch: 7 MSE: 40.18101\n",
      "Epoch: 8 MSE: 37.665348\n",
      "Epoch: 9 MSE: 35.322826\n",
      "Epoch: 10 MSE: 33.138523\n",
      "Epoch: 11 MSE: 31.09905\n",
      "Epoch: 12 MSE: 29.1925\n",
      "Epoch: 13 MSE: 27.408207\n",
      "Epoch: 14 MSE: 25.736462\n",
      "Epoch: 15 MSE: 24.168633\n",
      "Epoch: 16 MSE: 22.696953\n",
      "Epoch: 17 MSE: 21.31439\n",
      "Epoch: 18 MSE: 20.014584\n",
      "Epoch: 19 MSE: 18.791842\n",
      "Epoch: 20 MSE: 17.640942\n",
      "Epoch: 21 MSE: 16.557215\n",
      "Epoch: 22 MSE: 15.536382\n",
      "Epoch: 23 MSE: 14.574537\n",
      "Epoch: 24 MSE: 13.668177\n",
      "Epoch: 25 MSE: 12.8139925\n",
      "Epoch: 26 MSE: 12.009042\n",
      "Epoch: 27 MSE: 11.250558\n",
      "Epoch: 28 MSE: 10.535996\n",
      "Epoch: 29 MSE: 9.863016\n",
      "Epoch: 30 MSE: 9.229422\n",
      "Epoch: 31 MSE: 8.633182\n",
      "Epoch: 32 MSE: 8.072375\n",
      "Epoch: 33 MSE: 7.545225\n",
      "Epoch: 34 MSE: 7.0500236\n",
      "Epoch: 35 MSE: 6.5852146\n",
      "Epoch: 36 MSE: 6.149277\n",
      "Epoch: 37 MSE: 5.740788\n",
      "Epoch: 38 MSE: 5.3583884\n",
      "Epoch: 39 MSE: 5.0007915\n",
      "Epoch: 40 MSE: 4.6667585\n",
      "Epoch: 41 MSE: 4.3551073\n",
      "Epoch: 42 MSE: 4.0647025\n",
      "Epoch: 43 MSE: 3.794461\n",
      "Epoch: 44 MSE: 3.5433269\n",
      "Epoch: 45 MSE: 3.310296\n",
      "Epoch: 46 MSE: 3.0943944\n",
      "Epoch: 47 MSE: 2.8946831\n",
      "Epoch: 48 MSE: 2.7102535\n",
      "Epoch: 49 MSE: 2.5402353\n",
      "Epoch: 50 MSE: 2.3837821\n",
      "Epoch: 51 MSE: 2.2400782\n",
      "Epoch: 52 MSE: 2.1083384\n",
      "Epoch: 53 MSE: 1.9878043\n",
      "Epoch: 54 MSE: 1.8777514\n",
      "Epoch: 55 MSE: 1.7774726\n",
      "Epoch: 56 MSE: 1.6862942\n",
      "Epoch: 57 MSE: 1.6035751\n",
      "Epoch: 58 MSE: 1.5286933\n",
      "Epoch: 59 MSE: 1.4610617\n",
      "Epoch: 60 MSE: 1.4001166\n",
      "Epoch: 61 MSE: 1.3453258\n",
      "Epoch: 62 MSE: 1.2961818\n",
      "Epoch: 63 MSE: 1.2522051\n",
      "Epoch: 64 MSE: 1.2129465\n",
      "Epoch: 65 MSE: 1.1779828\n",
      "Epoch: 66 MSE: 1.1469162\n",
      "Epoch: 67 MSE: 1.1193758\n",
      "Epoch: 68 MSE: 1.0950181\n",
      "Epoch: 69 MSE: 1.073524\n",
      "Epoch: 70 MSE: 1.0545996\n",
      "Epoch: 71 MSE: 1.0379713\n",
      "Epoch: 72 MSE: 1.0233918\n",
      "Epoch: 73 MSE: 1.0106353\n",
      "Epoch: 74 MSE: 0.99949336\n",
      "Epoch: 75 MSE: 0.98978025\n",
      "Epoch: 76 MSE: 0.98132616\n",
      "Epoch: 77 MSE: 0.9739805\n",
      "Epoch: 78 MSE: 0.9676072\n",
      "Epoch: 79 MSE: 0.96208495\n",
      "Epoch: 80 MSE: 0.95730585\n",
      "Epoch: 81 MSE: 0.95317423\n",
      "Epoch: 82 MSE: 0.9496064\n",
      "Epoch: 83 MSE: 0.9465291\n",
      "Epoch: 84 MSE: 0.94387615\n",
      "Epoch: 85 MSE: 0.94159305\n",
      "Epoch: 86 MSE: 0.93962693\n",
      "Epoch: 87 MSE: 0.9379373\n",
      "Epoch: 88 MSE: 0.93648714\n",
      "Epoch: 89 MSE: 0.9352416\n",
      "Epoch: 90 MSE: 0.9341758\n",
      "Epoch: 91 MSE: 0.93326396\n",
      "Epoch: 92 MSE: 0.9324861\n",
      "Epoch: 93 MSE: 0.93182343\n",
      "Epoch: 94 MSE: 0.93126076\n",
      "Epoch: 95 MSE: 0.93078536\n",
      "Epoch: 96 MSE: 0.9303852\n",
      "Epoch: 97 MSE: 0.9300502\n",
      "Epoch: 98 MSE: 0.92977273\n",
      "Epoch: 99 MSE: 0.92954373\n",
      "Epoch: 100 MSE: 0.9293579\n",
      "Epoch: 101 MSE: 0.929208\n",
      "Epoch: 102 MSE: 0.9290932\n",
      "Epoch: 103 MSE: 0.9290039\n",
      "Epoch: 104 MSE: 0.92894006\n",
      "Epoch: 105 MSE: 0.9288968\n",
      "Epoch: 106 MSE: 0.92887205\n",
      "Epoch: 107 MSE: 0.928863\n",
      "Epoch: 108 MSE: 0.9288674\n",
      "Epoch: 109 MSE: 0.9288838\n",
      "Epoch: 110 MSE: 0.92891055\n",
      "Epoch: 111 MSE: 0.9289453\n",
      "Epoch: 112 MSE: 0.9289875\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Using standard matrix factorization to generate the complete rating matrix.\n",
    "\"\"\"\n",
    "import pickle\n",
    "import numpy as np\n",
    "from MF import MF\n",
    "import tensorflow as tf\n",
    "\n",
    "matrix = np.loadtxt(\"data/u.data\", dtype=int)[:, :-1]\n",
    "print(matrix)\n",
    "user = matrix[:, 0] - 1\n",
    "item = matrix[:, 1] - 1\n",
    "rating = matrix[:, 2]\n",
    "user_num = np.max(user)+1\n",
    "item_num = np.max(item)+1\n",
    "print(user_num, item_num)\n",
    "total_num = user.shape[0]\n",
    "user_train, item_train, rating_train = user[:int(total_num*0.9)], item[:int(total_num*0.9)], rating[:int(total_num*0.9)]\n",
    "user_test, item_test, rating_test = user[int(total_num*0.9):], item[int(total_num*0.9):], rating[int(total_num*0.9):]\n",
    "train_num = user_train.shape[0]\n",
    "\n",
    "batch_size = 1024\n",
    "l2_reg_lambda = 1e-3    # Validated by grid-search\n",
    "mf = MF(user_num=user_num, item_num=item_num, embedding_size=64, l2_reg_lambda=l2_reg_lambda)\n",
    "train_op = tf.train.AdamOptimizer().minimize(mf.loss)\n",
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "test_dict = {\n",
    "    mf.user_id: user_test,\n",
    "    mf.item_id: item_test,\n",
    "    mf.y: rating_test\n",
    "}\n",
    "early_stop = 1\n",
    "best_mse = 100\n",
    "epoch = 0\n",
    "while early_stop < 5:\n",
    "    epoch += 1\n",
    "    n_batch = train_num // batch_size\n",
    "    for batch in range(n_batch):\n",
    "        feed_dict = {\n",
    "            mf.user_id: user_train[batch * batch_size:(batch + 1) * batch_size],\n",
    "            mf.item_id: item_train[batch * batch_size:(batch + 1) * batch_size],\n",
    "            mf.y: rating_train[batch * batch_size:(batch + 1) * batch_size]\n",
    "        }\n",
    "        sess.run(train_op, feed_dict)\n",
    "    prediction, mse = sess.run([mf.prediction, mf.mse], test_dict)\n",
    "    if mse < best_mse:\n",
    "        best_mse = mse\n",
    "        early_stop = 0\n",
    "    else:\n",
    "        early_stop += 1\n",
    "    print(\"Epoch:\", epoch, \"MSE:\", mse)\n",
    "\n",
    "all_matrix = np.array([[x0, y0] for x0 in np.arange(user_num) for y0 in np.arange(item_num)])\n",
    "user_all = all_matrix[:, 0]\n",
    "item_all = all_matrix[:, 1]\n",
    "rating_all = np.zeros(user_all.shape)\n",
    "feed_dict = {\n",
    "    mf.user_id: user_all,\n",
    "    mf.item_id: item_all,\n",
    "    mf.y: rating_all\n",
    "}\n",
    "prediction = sess.run(mf.prediction, feed_dict)\n",
    "\n",
    "file = open(\"data/predicted_matrix\", \"wb\")\n",
    "pickle.dump(prediction, file)\n",
    "pickle.dump(user_num, file)\n",
    "pickle.dump(item_num, file)\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52ec1901",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a4439b7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch-gpu]",
   "language": "python",
   "name": "conda-env-pytorch-gpu-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
