{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-24T18:36:29.302687Z",
     "start_time": "2023-05-24T18:36:28.996655Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "import random\n",
    "from IPython.display import clear_output\n",
    "from copy import deepcopy\n",
    "from scipy.sparse import diags\n",
    "\n",
    "from utils import get_data, ndcg, recall\n",
    "\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-24T18:36:29.306745Z",
     "start_time": "2023-05-24T18:36:29.304459Z"
    }
   },
   "outputs": [],
   "source": [
    "seed = 1337\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-24T18:36:30.822535Z",
     "start_time": "2023-05-24T18:36:29.308071Z"
    }
   },
   "outputs": [],
   "source": [
    "data = get_data('ml20m/')\n",
    "train_data, valid_in_data, valid_out_data, test_in_data, test_out_data = data\n",
    "n_users, n_items = train_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-24T18:36:30.882002Z",
     "start_time": "2023-05-24T18:36:30.823852Z"
    }
   },
   "outputs": [],
   "source": [
    "X = train_data\n",
    "b = np.array(X.mean(0)).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-24T18:36:30.888598Z",
     "start_time": "2023-05-24T18:36:30.883234Z"
    }
   },
   "outputs": [],
   "source": [
    "def ImplicitSLIM(W, X, λ, α, thr):\n",
    "    A = W.copy()\n",
    "    D = 1 / (np.array(X.sum(0)) + λ)\n",
    "\n",
    "    ind = (np.array(X.sum(axis=0)) < thr).squeeze()\n",
    "    A[:, ind.nonzero()[0]] = 0\n",
    "    \n",
    "    AXt = A @ X.T\n",
    "    \n",
    "    AinvC = λ**2 * A * D * D \\\n",
    "          +    λ * A * D * D           @ X.T @ X \\\n",
    "          +    λ * AXt @ X * D * D \\\n",
    "          +        AXt @ X * D * D @ X.T @ X\n",
    "    \n",
    "    AC = AinvC - AinvC @ A.T @ np.linalg.inv(np.eye(A.shape[0]) / α + AinvC @ A.T) @ AinvC\n",
    "    \n",
    "    return α * W @ A.T @ AC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-24T18:36:30.899942Z",
     "start_time": "2023-05-24T18:36:30.889751Z"
    },
    "run_control": {
     "marked": false
    }
   },
   "outputs": [],
   "source": [
    "def update_P(X, Q, model_kwargs):\n",
    "    r_p = model_kwargs['r_p']\n",
    "    L = model_kwargs['L']\n",
    "    return np.linalg.inv(r_p * np.eye(L) + Q @ Q.T) @ (Q @ X.T - Q @ b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-24T18:36:30.915909Z",
     "start_time": "2023-05-24T18:36:30.902892Z"
    }
   },
   "outputs": [],
   "source": [
    "def update_Q(X, P, Qs, model_kwargs):\n",
    "    r_q = model_kwargs['r_q']\n",
    "    s_q = model_kwargs['s_q']\n",
    "    L = model_kwargs['L']\n",
    "    return np.linalg.inv((r_q + s_q) * np.eye(L) + P @ P.T) @ (P @ X - P.sum(1, keepdims=True) @ b.T + s_q * Qs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-24T18:36:30.926462Z",
     "start_time": "2023-05-24T18:36:30.917924Z"
    }
   },
   "outputs": [],
   "source": [
    "def predict(X, Q, model_kwargs):\n",
    "    P = update_P(X, Q, model_kwargs)\n",
    "    return P.T @ Q + b.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-24T18:36:30.937851Z",
     "start_time": "2023-05-24T18:36:30.927854Z"
    }
   },
   "outputs": [],
   "source": [
    "def generate(batch_size, data_in, data_out=None, shuffle=False, samples_perc_per_epoch=1):\n",
    "    assert 0 < samples_perc_per_epoch <= 1\n",
    "    \n",
    "    total_samples = data_in.shape[0]\n",
    "    samples_per_epoch = int(total_samples * samples_perc_per_epoch)\n",
    "    \n",
    "    if shuffle:\n",
    "        idxlist = np.arange(total_samples)\n",
    "        np.random.shuffle(idxlist)\n",
    "        idxlist = idxlist[:samples_per_epoch]\n",
    "    else:\n",
    "        idxlist = np.arange(samples_per_epoch)\n",
    "    \n",
    "    for st_idx in range(0, samples_per_epoch, batch_size):\n",
    "        end_idx = min(st_idx + batch_size, samples_per_epoch)\n",
    "        idx = idxlist[st_idx:end_idx]\n",
    "\n",
    "        yield Batch(idx, data_in, data_out)\n",
    "\n",
    "\n",
    "class Batch:\n",
    "    def __init__(self, idx, data_in, data_out=None):\n",
    "        self._idx = idx\n",
    "        self._data_in = data_in\n",
    "        self._data_out = data_out\n",
    "    \n",
    "    def get_idx(self):\n",
    "        return self._idx\n",
    "        \n",
    "    def get_ratings(self, is_out=False):\n",
    "        data = self._data_out if is_out else self._data_in\n",
    "        return data[self._idx]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-24T18:36:30.954853Z",
     "start_time": "2023-05-24T18:36:30.939199Z"
    }
   },
   "outputs": [],
   "source": [
    "test_metrics = [{'metric': ndcg, 'k': 100}, {'metric': recall, 'k': 20}, {'metric': recall, 'k': 50}]\n",
    "\n",
    "def evaluate(Q, data_in, data_out, metrics, model_kwargs, samples_perc_per_epoch=1, batch_size=2000):\n",
    "    metrics = deepcopy(metrics)\n",
    "    \n",
    "    for m in metrics:\n",
    "        m['score'] = []\n",
    "    \n",
    "    for batch in generate(batch_size=batch_size,\n",
    "                          data_in=data_in,\n",
    "                          data_out=data_out,\n",
    "                          samples_perc_per_epoch=samples_perc_per_epoch\n",
    "                         ):\n",
    "        \n",
    "        ratings_in = batch.get_ratings()\n",
    "        ratings_out = batch.get_ratings(is_out=True)\n",
    "    \n",
    "        ratings_pred = predict(ratings_in, Q, model_kwargs)\n",
    "        \n",
    "        if not (data_in is data_out):\n",
    "            ratings_pred[batch.get_ratings().nonzero()] = -np.inf\n",
    "            \n",
    "        for m in metrics:\n",
    "            m['score'].append(m['metric'](ratings_pred, ratings_out, k=m['k']))\n",
    "\n",
    "    for m in metrics:\n",
    "        m['score'] = np.concatenate(m['score']).mean()\n",
    "        \n",
    "    return [x['score'] for x in metrics]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-24T18:36:30.971784Z",
     "start_time": "2023-05-24T18:36:30.956253Z"
    },
    "run_control": {
     "marked": true
    }
   },
   "outputs": [],
   "source": [
    "def run(**model_kwargs):\n",
    "    print(model_kwargs)\n",
    "\n",
    "    L = model_kwargs['L']\n",
    "\n",
    "    valid_scores = []\n",
    "    best_score = -np.inf\n",
    "\n",
    "    Q = np.random.randn(L, n_items)\n",
    "    \n",
    "    # learning\n",
    "    for epoch in range(5):\n",
    "\n",
    "        if epoch == 0:\n",
    "            Q = ImplicitSLIM(Q, X, model_kwargs['item_λ'], model_kwargs['item_α'], model_kwargs['item_thr'])\n",
    "\n",
    "        P = update_P(X, Q, model_kwargs)\n",
    "        \n",
    "        if epoch == 0:\n",
    "            Qs = Q.copy()\n",
    "        else:\n",
    "            Qs = ImplicitSLIM(Q, X, model_kwargs['item_λ'], model_kwargs['item_α'], model_kwargs['item_thr'])\n",
    "            \n",
    "        Q = update_Q(X, P, Qs, model_kwargs)\n",
    "\n",
    "        score = evaluate(Q, valid_in_data, valid_out_data, [{'metric': ndcg, 'k': 100}], model_kwargs, 1)[0]\n",
    "\n",
    "        print(score)\n",
    "\n",
    "        if score > best_score:\n",
    "            best_score = score\n",
    "            Q_best = Q.copy()\n",
    "        else:\n",
    "            break\n",
    "\n",
    "    final_scores = evaluate(Q_best, test_in_data, test_out_data, test_metrics, model_kwargs)\n",
    "    for metric, score in zip(test_metrics, final_scores):\n",
    "        print(f\"{metric['metric'].__name__}@{metric['k']}:\\t{score:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-24T18:43:32.555162Z",
     "start_time": "2023-05-24T18:36:30.973533Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'L': 1000, 'item_thr': 0, 'item_α': 0.12662155620150833, 'item_λ': 744.966655411396, 'r_p': 999.9999999999998, 'r_q': 1.0, 's_q': 102.1953781322435}\n",
      "0.4153677389627772\n",
      "0.42421278417158564\n",
      "0.4257338309941782\n",
      "0.4258351972193235\n",
      "0.426129678321776\n",
      "ndcg@100:\t0.4199\n",
      "recall@20:\t0.3908\n",
      "recall@50:\t0.5199\n"
     ]
    }
   ],
   "source": [
    "model_kwargs = {'L': 1000,\n",
    "                'item_thr': 0,\n",
    "                'item_α': 0.12662155620150833,\n",
    "                'item_λ': 744.966655411396,\n",
    "                'r_p': 999.9999999999998,\n",
    "                'r_q': 1.0,\n",
    "                's_q': 102.1953781322435}\n",
    "\n",
    "run(**model_kwargs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
