{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9d0f72e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 196  242    3]\n",
      " [ 186  302    3]\n",
      " [  22  377    1]\n",
      " ...\n",
      " [ 276 1090    1]\n",
      " [  13  225    2]\n",
      " [  12  203    3]]\n",
      "943 1682\n",
      "Epoch: 1 MSE: 75.31257\n",
      "Epoch: 2 MSE: 70.56687\n",
      "Epoch: 3 MSE: 66.17172\n",
      "Epoch: 4 MSE: 62.079414\n",
      "Epoch: 5 MSE: 58.261173\n",
      "Epoch: 6 MSE: 54.69375\n",
      "Epoch: 7 MSE: 51.35691\n",
      "Epoch: 8 MSE: 48.2329\n",
      "Epoch: 9 MSE: 45.305805\n",
      "Epoch: 10 MSE: 42.561325\n",
      "Epoch: 11 MSE: 39.986404\n",
      "Epoch: 12 MSE: 37.569336\n",
      "Epoch: 13 MSE: 35.299225\n",
      "Epoch: 14 MSE: 33.166294\n",
      "Epoch: 15 MSE: 31.161512\n",
      "Epoch: 16 MSE: 29.276457\n",
      "Epoch: 17 MSE: 27.50346\n",
      "Epoch: 18 MSE: 25.835398\n",
      "Epoch: 19 MSE: 24.265741\n",
      "Epoch: 20 MSE: 22.788332\n",
      "Epoch: 21 MSE: 21.397545\n",
      "Epoch: 22 MSE: 20.088099\n",
      "Epoch: 23 MSE: 18.855072\n",
      "Epoch: 24 MSE: 17.693933\n",
      "Epoch: 25 MSE: 16.600437\n",
      "Epoch: 26 MSE: 15.570606\n",
      "Epoch: 27 MSE: 14.60075\n",
      "Epoch: 28 MSE: 13.687428\n",
      "Epoch: 29 MSE: 12.82742\n",
      "Epoch: 30 MSE: 12.017735\n",
      "Epoch: 31 MSE: 11.25555\n",
      "Epoch: 32 MSE: 10.538267\n",
      "Epoch: 33 MSE: 9.863429\n",
      "Epoch: 34 MSE: 9.228738\n",
      "Epoch: 35 MSE: 8.632039\n",
      "Epoch: 36 MSE: 8.0713215\n",
      "Epoch: 37 MSE: 7.5447063\n",
      "Epoch: 38 MSE: 7.0504\n",
      "Epoch: 39 MSE: 6.5867267\n",
      "Epoch: 40 MSE: 6.1520987\n",
      "Epoch: 41 MSE: 5.7450223\n",
      "Epoch: 42 MSE: 5.3640795\n",
      "Epoch: 43 MSE: 5.0079303\n",
      "Epoch: 44 MSE: 4.675288\n",
      "Epoch: 45 MSE: 4.3649373\n",
      "Epoch: 46 MSE: 4.0757194\n",
      "Epoch: 47 MSE: 3.8065145\n",
      "Epoch: 48 MSE: 3.5562704\n",
      "Epoch: 49 MSE: 3.3239539\n",
      "Epoch: 50 MSE: 3.108599\n",
      "Epoch: 51 MSE: 2.9092572\n",
      "Epoch: 52 MSE: 2.725029\n",
      "Epoch: 53 MSE: 2.5550456\n",
      "Epoch: 54 MSE: 2.3984683\n",
      "Epoch: 55 MSE: 2.2544947\n",
      "Epoch: 56 MSE: 2.1223462\n",
      "Epoch: 57 MSE: 2.001281\n",
      "Epoch: 58 MSE: 1.890581\n",
      "Epoch: 59 MSE: 1.789563\n",
      "Epoch: 60 MSE: 1.6975617\n",
      "Epoch: 61 MSE: 1.613948\n",
      "Epoch: 62 MSE: 1.5381196\n",
      "Epoch: 63 MSE: 1.4694989\n",
      "Epoch: 64 MSE: 1.4075358\n",
      "Epoch: 65 MSE: 1.35171\n",
      "Epoch: 66 MSE: 1.3015281\n",
      "Epoch: 67 MSE: 1.2565203\n",
      "Epoch: 68 MSE: 1.2162473\n",
      "Epoch: 69 MSE: 1.1802949\n",
      "Epoch: 70 MSE: 1.1482729\n",
      "Epoch: 71 MSE: 1.1198208\n",
      "Epoch: 72 MSE: 1.0945985\n",
      "Epoch: 73 MSE: 1.0722947\n",
      "Epoch: 74 MSE: 1.0526154\n",
      "Epoch: 75 MSE: 1.0352963\n",
      "Epoch: 76 MSE: 1.0200887\n",
      "Epoch: 77 MSE: 1.0067686\n",
      "Epoch: 78 MSE: 0.99512833\n",
      "Epoch: 79 MSE: 0.9849817\n",
      "Epoch: 80 MSE: 0.97616005\n",
      "Epoch: 81 MSE: 0.9685082\n",
      "Epoch: 82 MSE: 0.9618875\n",
      "Epoch: 83 MSE: 0.9561752\n",
      "Epoch: 84 MSE: 0.9512585\n",
      "Epoch: 85 MSE: 0.94704044\n",
      "Epoch: 86 MSE: 0.9434311\n",
      "Epoch: 87 MSE: 0.9403518\n",
      "Epoch: 88 MSE: 0.93773377\n",
      "Epoch: 89 MSE: 0.9355164\n",
      "Epoch: 90 MSE: 0.9336453\n",
      "Epoch: 91 MSE: 0.93207365\n",
      "Epoch: 92 MSE: 0.9307594\n",
      "Epoch: 93 MSE: 0.9296676\n",
      "Epoch: 94 MSE: 0.928766\n",
      "Epoch: 95 MSE: 0.9280289\n",
      "Epoch: 96 MSE: 0.9274301\n",
      "Epoch: 97 MSE: 0.9269504\n",
      "Epoch: 98 MSE: 0.92657304\n",
      "Epoch: 99 MSE: 0.9262836\n",
      "Epoch: 100 MSE: 0.9260673\n",
      "Epoch: 101 MSE: 0.92591417\n",
      "Epoch: 102 MSE: 0.9258138\n",
      "Epoch: 103 MSE: 0.9257588\n",
      "Epoch: 104 MSE: 0.9257417\n",
      "Epoch: 105 MSE: 0.92575705\n",
      "Epoch: 106 MSE: 0.9257995\n",
      "Epoch: 107 MSE: 0.9258643\n",
      "Epoch: 108 MSE: 0.9259488\n",
      "Epoch: 109 MSE: 0.9260479\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Using standard matrix factorization to generate the complete rating matrix.\n",
    "\"\"\"\n",
    "import pickle\n",
    "import numpy as np\n",
    "from MF_complete import MF\n",
    "import tensorflow as tf\n",
    "\n",
    "matrix = np.loadtxt(\"data/u.data\", dtype=int)[:, :-1]\n",
    "print(matrix)\n",
    "user = matrix[:, 0] - 1\n",
    "item = matrix[:, 1] - 1\n",
    "rating = matrix[:, 2]\n",
    "user_num = np.max(user)+1\n",
    "item_num = np.max(item)+1\n",
    "print(user_num, item_num)\n",
    "total_num = user.shape[0]\n",
    "user_train, item_train, rating_train = user[:int(total_num*0.9)], item[:int(total_num*0.9)], rating[:int(total_num*0.9)]\n",
    "user_test, item_test, rating_test = user[int(total_num*0.9):], item[int(total_num*0.9):], rating[int(total_num*0.9):]\n",
    "train_num = user_train.shape[0]\n",
    "\n",
    "batch_size = 1024\n",
    "l2_reg_lambda = 1e-3    # Validated by grid-search\n",
    "mf = MF(user_num=user_num, item_num=item_num, embedding_size=64, l2_reg_lambda=l2_reg_lambda)\n",
    "train_op = tf.train.AdamOptimizer().minimize(mf.loss)\n",
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "test_dict = {\n",
    "    mf.user_id: user_test,\n",
    "    mf.item_id: item_test,\n",
    "    mf.y: rating_test\n",
    "}\n",
    "early_stop = 1\n",
    "best_mse = 100\n",
    "epoch = 0\n",
    "while early_stop < 5:\n",
    "    epoch += 1\n",
    "    n_batch = train_num // batch_size\n",
    "    for batch in range(n_batch):\n",
    "        feed_dict = {\n",
    "            mf.user_id: user_train[batch * batch_size:(batch + 1) * batch_size],\n",
    "            mf.item_id: item_train[batch * batch_size:(batch + 1) * batch_size],\n",
    "            mf.y: rating_train[batch * batch_size:(batch + 1) * batch_size]\n",
    "        }\n",
    "        sess.run(train_op, feed_dict)\n",
    "    prediction, mse = sess.run([mf.prediction, mf.mse], test_dict)\n",
    "    if mse < best_mse:\n",
    "        best_mse = mse\n",
    "        early_stop = 0\n",
    "    else:\n",
    "        early_stop += 1\n",
    "    print(\"Epoch:\", epoch, \"MSE:\", mse)\n",
    "\n",
    "all_matrix = np.array([[x0, y0] for x0 in np.arange(user_num) for y0 in np.arange(item_num)])\n",
    "user_all = all_matrix[:, 0]\n",
    "item_all = all_matrix[:, 1]\n",
    "rating_all = np.zeros(user_all.shape)\n",
    "feed_dict = {\n",
    "    mf.user_id: user_all,\n",
    "    mf.item_id: item_all,\n",
    "    mf.y: rating_all\n",
    "}\n",
    "prediction = sess.run(mf.prediction, feed_dict)\n",
    "\n",
    "file = open(\"data/predicted_matrix\", \"wb\")\n",
    "pickle.dump(prediction, file)\n",
    "pickle.dump(user_num, file)\n",
    "pickle.dump(item_num, file)\n",
    "file.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tf1-5-0]",
   "language": "python",
   "name": "conda-env-tf1-5-0-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
