{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from MF import MF\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch\n",
    "matrix = np.loadtxt(\"./data/u.data\", dtype=int)[:, :-1]\n",
    "user = matrix[:, 0] - 1\n",
    "item = matrix[:, 1] - 1\n",
    "rating = matrix[:, 2]\n",
    "user_num = np.max(user)+1\n",
    "item_num = np.max(item)+1\n",
    "print(user_num, item_num)\n",
    "total_num = user.shape[0]\n",
    "print(total_num)\n",
    "\n",
    "user_train, item_train, rating_train = user[:int(total_num*0.9)], item[:int(total_num*0.9)], rating[:int(total_num*0.9)]\n",
    "user_test, item_test, rating_test = user[int(total_num*0.9):], item[int(total_num*0.9):], rating[int(total_num*0.9):]\n",
    "train_num = user_train.shape[0]\n",
    "\n",
    "batch_size = 1024\n",
    "gamma = 1e-4    # Validated by grid-search\n",
    "\n",
    "mf = MF(num_users=user_num, num_items=item_num, embedding_size=64)\n",
    "mf.fit(user_train, item_train, rating_train, user_test, item_test, rating_test, batch_size = batch_size, lamb=1e-3, gamma = gamma)\n",
    "all_matrix = np.array([[x0, y0] for x0 in np.arange(user_num) for y0 in np.arange(item_num)]) \n",
    "user_all = all_matrix[:, 0] #\n",
    "item_all = all_matrix[:, 1] # \n",
    "\n",
    "rating_all = np.zeros(user_all.shape)\n",
    "prediction = mf.predict(user_all, item_all)\n",
    "print(prediction[:50])\n",
    "print(prediction.shape)\n",
    "file = open(\"data/predicted_matrix\", \"wb\")\n",
    "pickle.dump(prediction, file)\n",
    "pickle.dump(user_num, file)\n",
    "pickle.dump(item_num, file)\n",
    "file.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch-gpu]",
   "language": "python",
   "name": "conda-env-pytorch-gpu-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
