{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da4b81ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "from MF import MF\n",
    "import tensorflow as tf\n",
    "import pandas as pd\n",
    "\n",
    "matrix = pd.read_table(\"data/ratings.dat\",sep = '::',header=None,engine='python') \n",
    "matrix.columns = (['user_id','item_id',\"rating\",\"timestamp\"])\n",
    "\n",
    "user = np.array((matrix['user_id']) - 1)\n",
    "item = np.array((matrix['item_id']) - 1)\n",
    "rating = np.array(matrix['rating'])\n",
    "user_num = int(np.max(user)+1)\n",
    "item_num = int(np.max(item)+1)\n",
    "print(user_num, item_num)\n",
    "total_num = user.shape[0]\n",
    "user_train, item_train, rating_train = user[:int(total_num*0.9)], item[:int(total_num*0.9)], rating[:int(total_num*0.9)]\n",
    "user_test, item_test, rating_test = user[int(total_num*0.9):], item[int(total_num*0.9):], rating[int(total_num*0.9):]\n",
    "train_num = user_train.shape[0]\n",
    "\n",
    "batch_size = 16384\n",
    "l2_reg_lambda = 1e-4    # Validated by grid-search\n",
    "l2_reg_lambda = 1e-3    # Validated by grid-search\n",
    "mf = MF(user_num=user_num, item_num=item_num, embedding_size=64, l2_reg_lambda=l2_reg_lambda)\n",
    "train_op = tf.train.AdamOptimizer().minimize(mf.loss)\n",
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "test_dict = {\n",
    "    mf.user_id: user_test,\n",
    "    mf.item_id: item_test,\n",
    "    mf.y: rating_test\n",
    "}\n",
    "early_stop = 1\n",
    "best_mse = 100\n",
    "epoch = 0\n",
    "while early_stop < 5:\n",
    "    epoch += 1\n",
    "    n_batch = train_num // batch_size\n",
    "    for batch in range(n_batch):\n",
    "        feed_dict = {\n",
    "            mf.user_id: user_train[batch * batch_size:(batch + 1) * batch_size],\n",
    "            mf.item_id: item_train[batch * batch_size:(batch + 1) * batch_size],\n",
    "            mf.y: rating_train[batch * batch_size:(batch + 1) * batch_size]\n",
    "        }\n",
    "        sess.run(train_op, feed_dict)\n",
    "    prediction, mse = sess.run([mf.prediction, mf.mse], test_dict)\n",
    "    if mse < best_mse:\n",
    "        best_mse = mse\n",
    "        early_stop = 0\n",
    "    else:\n",
    "        early_stop += 1\n",
    "    print(\"Epoch:\", epoch, \"MSE:\", mse)\n",
    "\n",
    "all_matrix = np.array([[x0, y0] for x0 in np.arange(user_num) for y0 in np.arange(item_num)])\n",
    "user_all = all_matrix[:, 0]\n",
    "item_all = all_matrix[:, 1]\n",
    "rating_all = np.zeros(user_all.shape)\n",
    "feed_dict = {\n",
    "    mf.user_id: user_all,\n",
    "    mf.item_id: item_all,\n",
    "    mf.y: rating_all\n",
    "}\n",
    "prediction = sess.run(mf.prediction, feed_dict)\n",
    "\n",
    "file = open(\"data/predicted_matrix\", \"wb\")\n",
    "pickle.dump(prediction, file)\n",
    "pickle.dump(user_num, file)\n",
    "pickle.dump(item_num, file)\n",
    "file.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tf1-5-0]",
   "language": "python",
   "name": "conda-env-tf1-5-0-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
