{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "Xo7qeXqAxVxp",
   "metadata": {
    "executionInfo": {
     "elapsed": 8551,
     "status": "ok",
     "timestamp": 1681721630998,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "Xo7qeXqAxVxp",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import numpy as np\n",
    "import torch\n",
    "import pdb\n",
    "from sklearn.metrics import roc_auc_score\n",
    "# np.random.seed(2024)\n",
    "# torch.manual_seed(2024)\n",
    "import pdb\n",
    "\n",
    "from dataset import load_data\n",
    "from matrix_factorization_coat import MF_KBIPS_Gau,MF_AKBIPS_Gau,MF_WKBIPS_Gau,MF_KBDR_Gau,MF_AKBDR_Gau,MF_WKBDR_Gau\n",
    "from matrix_factorization_coat import MF_KBIPS_Exp,MF_AKBIPS_Exp,MF_WKBIPS_Exp,MF_KBDR_Exp,MF_AKBDR_Exp,MF_WKBDR_Exp\n",
    "\n",
    "from utils import gini_index, ndcg_func, get_user_wise_ctr, rating_mat_to_sample, binarize, shuffle, minU,recall_func, precision_func\n",
    "mse_func = lambda x,y: np.mean((x-y)**2)\n",
    "acc_func = lambda x,y: np.sum(x == y) / len(x)\n",
    "\n",
    "dataset_name = \"coat\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "902db9a6",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1441,
     "status": "ok",
     "timestamp": 1681721635206,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "902db9a6",
    "outputId": "f5254160-9ad6-4c18-d5ac-92aa63d6700d",
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_mat, test_mat = load_data(\"coat\")        \n",
    "x_train, y_train = rating_mat_to_sample(train_mat)\n",
    "x_test, y_test = rating_mat_to_sample(test_mat)\n",
    "num_user = train_mat.shape[0]\n",
    "num_item = train_mat.shape[1]\n",
    "\n",
    "print(\"# user: {}, # item: {}\".format(num_user, num_item))\n",
    "# binarize\n",
    "y_train = binarize(y_train)\n",
    "y_test = binarize(y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b005c0b-047d-4526-b359-48aa3244897a",
   "metadata": {},
   "source": [
    "### KBIPS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ff218d0-26d2-4766-848d-205af4b22093",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "mf_kbips_gau = MF_KBIPS_Gau(num_user, num_item)\n",
    "mf_kbips_gau.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_kbips_gau.fit(x_train, y_train,  y_ips,\n",
    "    lr1=3e-2,\n",
    "    lamb1=1e-3,\n",
    "    lr2=1e-2,\n",
    "    lamb2=1e-4,\n",
    "    gamma = 1,\n",
    "    batch_size=128,\n",
    "    J = 3,\n",
    "    C = 1e-2,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_kbips_gau.predict(x_test)\n",
    "mse_mfkbips = mse_func(y_test, test_pred)\n",
    "auc_mfkbips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_kbips_gau, x_test, y_test)\n",
    "recall_res = recall_func(mf_kbips_gau, x_test, y_test)\n",
    "precision_res = precision_func(mf_kbips_gau, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-KBIPS-Gau]\" + \"***\"*5)\n",
    "print(\"[MF-KBIPS-Gau] test mse:\", mse_mfkbips)\n",
    "print(\"[MF-KBIPS-Gau] test auc:\", auc_mfkbips)\n",
    "print(\"[MF-KBIPS-Gau] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-KBIPS-Gau] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-KBIPS-Gau] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-KBIPS-Gau] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "\n",
    "print(\"***\"*5 + \"[MF-KBIPS-Gau]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f5f1b54",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_kbips_exp = MF_KBIPS_Exp(num_user, num_item)\n",
    "mf_kbips_exp.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_kbips_exp.fit(x_train, y_train,  y_ips,\n",
    "    lr1=3e-2,\n",
    "    lamb1=1e-3,\n",
    "    lr2=1e-2,\n",
    "    lamb2=1e-4,\n",
    "    gamma = 1,\n",
    "    batch_size=128,\n",
    "    J = 3,\n",
    "    C = 1e-2,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_kbips_exp.predict(x_test)\n",
    "mse_mfkbips = mse_func(y_test, test_pred)\n",
    "auc_mfkbips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_kbips_exp, x_test, y_test)\n",
    "recall_res = recall_func(mf_kbips_exp, x_test, y_test)\n",
    "precision_res = precision_func(mf_kbips_exp, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-KBIPS-Exp]\" + \"***\"*5)\n",
    "print(\"[MF-KBIPS-Exp] test mse:\", mse_mfkbips)\n",
    "print(\"[MF-KBIPS-Exp] test auc:\", auc_mfkbips)\n",
    "print(\"[MF-KBIPS-Exp] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-KBIPS-Exp] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-KBIPS-Exp] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-KBIPS-Exp] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-KBIPS-Exp]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22341dbb-b46d-4d24-80d8-e011f4619d89",
   "metadata": {},
   "source": [
    "### AKBIPS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df2284cb-842d-4339-b4e6-ef526eb30715",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "mf_akbips_Gau = MF_AKBIPS_Gau(num_user, num_item,embedding_k = 4)\n",
    "mf_akbips_Gau.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_akbips_Gau.fit(x_train, y_train,  y_ips,\n",
    "    lr1=3e-2,\n",
    "    lamb1=3e-3,\n",
    "    lr2=3e-2,\n",
    "    lamb2=1e-3,\n",
    "    lr3=3e-2,\n",
    "    lamb3=1e-3,\n",
    "    gamma = 10,\n",
    "    batch_size=128,\n",
    "    num_w_epo = 3,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_akbips_Gau.predict(x_test)\n",
    "mse_mfakbips = mse_func(y_test, test_pred)\n",
    "auc_mfakbips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_akbips_Gau, x_test, y_test)\n",
    "recall_res = recall_func(mf_akbips_Gau, x_test, y_test)\n",
    "precision_res = precision_func(mf_akbips_Gau, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-AKBIPS-Gau]\" + \"***\"*5)\n",
    "print(\"[MF-AKBIPS-Gau] test mse:\", mse_mfakbips)\n",
    "print(\"[MF-AKBIPS-Gau] test auc:\", auc_mfakbips)\n",
    "print(\"[MF-AKBIPS-Gau] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-AKBIPS-Gau] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-AKBIPS-Gau] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-AKBIPS-Gau] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-AKBIPS-Gau]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70644d62",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_akbips_exp = MF_AKBIPS_Exp(num_user, num_item,embedding_k = 4)\n",
    "mf_akbips_exp.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_akbips_exp.fit(x_train, y_train,  y_ips,\n",
    "    lr1=3e-2,\n",
    "    lamb1=3e-3,\n",
    "    lr2=3e-2,\n",
    "    lamb2=1e-3,\n",
    "    lr3=3e-2,\n",
    "    lamb3=1e-3,\n",
    "    gamma = 100,\n",
    "    batch_size=128,\n",
    "    num_w_epo = 3,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_akbips_exp.predict(x_test)\n",
    "mse_mfakbips = mse_func(y_test, test_pred)\n",
    "auc_mfakbips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_akbips_exp, x_test, y_test)\n",
    "recall_res = recall_func(mf_akbips_exp, x_test, y_test)\n",
    "precision_res = precision_func(mf_akbips_exp, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-AKBIPS-Exp]\" + \"***\"*5)\n",
    "print(\"[MF-AKBIPS-Exp] test mse:\", mse_mfakbips)\n",
    "print(\"[MF-AKBIPS-Exp] test auc:\", auc_mfakbips)\n",
    "print(\"[MF-AKBIPS-Exp] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-AKBIPS-Exp] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-AKBIPS-Exp] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-AKBIPS-Exp] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-AKBIPS-Exp]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05f6a82a-e3f0-4004-b899-ed3c6fa0ce2f",
   "metadata": {},
   "source": [
    "### WKBIPS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d29640ce-1e1f-44ce-a2f5-e89037e8f4af",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "mf_wkbips_gau = MF_WKBIPS_Gau(num_user, num_item)\n",
    "mf_wkbips_gau.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "mf_wkbips_gau.fit(x_train, y_train,  y_ips,\n",
    "    lr1=3e-2,\n",
    "    lamb1=3e-3,\n",
    "    lr2=3e-2,\n",
    "    lamb2=1e-4,\n",
    "    lr3=3e-2,\n",
    "    lamb3=1e-3,\n",
    "    gamma = 1,\n",
    "    batch_size=128,\n",
    "    num_w_epo = 3,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_wkbips_gau.predict(x_test)\n",
    "mse_mfwkbips = mse_func(y_test, test_pred)\n",
    "auc_mfwkbips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_wkbips_gau, x_test, y_test)\n",
    "recall_res = recall_func(mf_wkbips_gau, x_test, y_test)\n",
    "precision_res = precision_func(mf_wkbips_gau, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-WKBIPS-Gau]\" + \"***\"*5)\n",
    "print(\"[MF-WKBIPS-Gau] test mse:\", mse_mfwkbips)\n",
    "print(\"[MF-WKBIPS-Gau] test auc:\", auc_mfwkbips)\n",
    "print(\"[MF-WKBIPS-Gau] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-WKBIPS-Gau] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-WKBIPS-Gau] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-WKBIPS-Gau] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-WKBIPS-Gau]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3c941a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_wkbips_exp = MF_WKBIPS_Exp(num_user, num_item)\n",
    "mf_wkbips_exp.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "mf_wkbips_exp.fit(x_train, y_train,  y_ips,\n",
    "    lr1=3e-2,\n",
    "    lamb1=2e-3,\n",
    "    lr2=3e-2,\n",
    "    lamb2=1e-4,\n",
    "    lr3=3e-2,\n",
    "    lamb3=1e-3,\n",
    "    gamma = 1,\n",
    "    batch_size=128,\n",
    "    num_w_epo = 3,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_wkbips_exp.predict(x_test)\n",
    "mse_mfwkbips = mse_func(y_test, test_pred)\n",
    "auc_mfwkbips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_wkbips_exp, x_test, y_test)\n",
    "recall_res = recall_func(mf_wkbips_exp, x_test, y_test)\n",
    "precision_res = precision_func(mf_wkbips_exp, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-WKBIPS-Exp]\" + \"***\"*5)\n",
    "print(\"[MF-WKBIPS-Exp] test mse:\", mse_mfwkbips)\n",
    "print(\"[MF-WKBIPS-Exp] test auc:\", auc_mfwkbips)\n",
    "print(\"[MF-WKBIPS-Exp] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-WKBIPS-Exp] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-WKBIPS-Exp] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-WKBIPS-Exp] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-WKBIPS-Exp]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b95ef6a2-35bd-4880-92e8-91fa16a72cf8",
   "metadata": {},
   "source": [
    "### KBDR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4b3d1a7-bf74-467c-ad16-5bd10c7d4c5e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "mf_kbdr_gau = MF_KBDR_Gau(num_user, num_item)\n",
    "mf_kbdr_gau.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_kbdr_gau.fit(x_train, y_train, y_ips, \n",
    "    lr=1e-2,\n",
    "    lamb=3e-3,\n",
    "    lr2=1e-2,\n",
    "    lamb2=1e-5,\n",
    "    gamma=1,\n",
    "    batch_size=128,\n",
    "    J=3,\n",
    "    C=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_kbdr_gau.predict(x_test)\n",
    "mse_mfkbdr = mse_func(y_test, test_pred)\n",
    "auc_mfkbdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_kbdr_gau, x_test, y_test)\n",
    "recall_res = recall_func(mf_kbdr_gau, x_test, y_test)\n",
    "precision_res = precision_func(mf_kbdr_gau, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-KBDR-Gau]\" + \"***\"*5)\n",
    "print(\"[MF-KBDR-Gau] test mse:\", mse_mfkbdr)\n",
    "print(\"[MF-KBDR-Gau] test auc:\", auc_mfkbdr)\n",
    "print(\"[MF-KBDR-Gau] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-KBDR-Gau] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-KBDR-Gau] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-KBDR-Gau] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-KBDR-Gau]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9aa05473",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_kbdr_exp = MF_KBDR_Exp(num_user, num_item)\n",
    "mf_kbdr_exp.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_kbdr_exp.fit(x_train, y_train, y_ips, \n",
    "    lr=1e-2,\n",
    "    lamb=3e-3,\n",
    "    lr2=1e-2,\n",
    "    lamb2=1e-5,\n",
    "    gamma=1,\n",
    "    batch_size=128,\n",
    "    J=3,\n",
    "    C=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_kbdr_exp.predict(x_test)\n",
    "mse_mfkbdr = mse_func(y_test, test_pred)\n",
    "auc_mfkbdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_kbdr_exp, x_test, y_test)\n",
    "recall_res = recall_func(mf_kbdr_exp, x_test, y_test)\n",
    "precision_res = precision_func(mf_kbdr_exp, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-KBDR-Exp]\" + \"***\"*5)\n",
    "print(\"[MF-KBDR-Exp] test mse:\", mse_mfkbdr)\n",
    "print(\"[MF-KBDR-Exp] test auc:\", auc_mfkbdr)\n",
    "print(\"[MF-KBDR-Exp] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-KBDR-Exp] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-KBDR-Exp] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-KBDR-Exp] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-KBDR-Exp]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0dc13a5-74f1-4d54-8cc9-3c9ae078a819",
   "metadata": {},
   "source": [
    "### AKBDR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a51b036f-6612-4cd7-a274-ae0cea0b1688",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "mf_akbdr_gau = MF_AKBDR_Gau(num_user, num_item)\n",
    "mf_akbdr_gau.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_akbdr_gau.fit(x_train, y_train, y_ips, \n",
    "    lr=3e-2,\n",
    "    lamb=5e-3,\n",
    "    lr1=3e-2,\n",
    "    lamb1=3e-3,\n",
    "    lr2=3e-2,\n",
    "    lamb2=1e-3,\n",
    "    lr3=3e-2,\n",
    "    lamb3=1e-3,\n",
    "    gamma=10,\n",
    "    batch_size=128,\n",
    "    num_w_epo = 3,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "\n",
    "\n",
    "test_pred = mf_akbdr_gau.predict(x_test)\n",
    "mse_mfakbdr = mse_func(y_test, test_pred)\n",
    "auc_mfakbdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_akbdr_gau, x_test, y_test)\n",
    "recall_res = recall_func(mf_akbdr_gau, x_test, y_test)\n",
    "precision_res = precision_func(mf_akbdr_gau, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-AKBDR-Gau]\" + \"***\"*5)\n",
    "print(\"[MF-AKBDR-Gau] test mse:\", mse_mfakbdr)\n",
    "print(\"[MF-AKBDR-Gau] test auc:\", auc_mfakbdr)\n",
    "print(\"[MF-AKBDR-Gau] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-AKBDR-Gau] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-AKBDR-Gau] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-AKBDR-Gau] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-AKBDR-Gau]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf3758a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_akbdr_exp = MF_AKBDR_Exp(num_user, num_item)\n",
    "mf_akbdr_exp.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_akbdr_exp.fit(x_train, y_train, y_ips, \n",
    "    lr=3e-2,\n",
    "    lamb=5e-3,\n",
    "    lr1=3e-2,\n",
    "    lamb1=3e-3,\n",
    "    lr2=3e-2,\n",
    "    lamb2=1e-3,\n",
    "    lr3=3e-2,\n",
    "    lamb3=1e-3,\n",
    "    gamma=10,\n",
    "    batch_size=128,\n",
    "    num_w_epo = 3,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_akbdr_exp.predict(x_test)\n",
    "mse_mfakbdr = mse_func(y_test, test_pred)\n",
    "auc_mfakbdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_akbdr_exp, x_test, y_test)\n",
    "recall_res = recall_func(mf_akbdr_exp, x_test, y_test)\n",
    "precision_res = precision_func(mf_akbdr_exp, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-AKBDR-Exp]\" + \"***\"*5)\n",
    "print(\"[MF-AKBDR-Exp] test mse:\", mse_mfakbdr)\n",
    "print(\"[MF-AKBDR-Exp] test auc:\", auc_mfakbdr)\n",
    "print(\"[MF-AKBDR-Exp] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-AKBDR-Exp] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-AKBDR-Exp] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-AKBDR-Exp] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-AKBDR-Exp]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a65ad1c-d58d-4d50-af9c-941f5db98872",
   "metadata": {},
   "source": [
    "### WKBDR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8512f0e1-c68c-4b21-9b50-dec063bf129f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "mf_wkbdr_gau = MF_WKBDR_Gau(num_user, num_item)\n",
    "mf_wkbdr_gau.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "mf_wkbdr_gau.fit(x_train, y_train,  y_ips,\n",
    "    lr=3e-2,\n",
    "    lamb=3e-3,\n",
    "    lr2=1e-2,\n",
    "    lamb2=1e-5,\n",
    "    gamma=1,\n",
    "    batch_size=128,\n",
    "    J = 3,\n",
    "    C = 1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_wkbdr_gau.predict(x_test)\n",
    "mse_mfwkbdr = mse_func(y_test, test_pred)\n",
    "auc_mfwkbdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_wkbdr_gau, x_test, y_test)\n",
    "recall_res = recall_func(mf_wkbdr_gau, x_test, y_test)\n",
    "precision_res = precision_func(mf_wkbdr_gau, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-WKBDR-Gau]\" + \"***\"*5)\n",
    "print(\"[MF-WKBDR-Gau] test mse:\", mse_mfwkbdr)\n",
    "print(\"[MF-WKBDR-Gau] test auc:\", auc_mfwkbdr)\n",
    "print(\"[MF-WKBDR-Gau] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-WKBDR-Gau] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-WKBDR-Gau] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-WKBDR-Gau] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-WKBDR-Gau]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2045a985",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_wkbdr_exp = MF_WKBDR_Exp(num_user, num_item)\n",
    "mf_wkbdr_exp.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "mf_wkbdr_exp.fit(x_train, y_train,  y_ips,\n",
    "    lr=3e-2,\n",
    "    lamb=3e-3,\n",
    "    lr2=1e-2,\n",
    "    lamb2=1e-5,\n",
    "    gamma=1,\n",
    "    batch_size=128,\n",
    "    J = 3,\n",
    "    C = 1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_wkbdr_exp.predict(x_test)\n",
    "mse_mfwkbdr = mse_func(y_test, test_pred)\n",
    "auc_mfwkbdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_wkbdr_exp, x_test, y_test)\n",
    "recall_res = recall_func(mf_wkbdr_exp, x_test, y_test)\n",
    "precision_res = precision_func(mf_wkbdr_exp, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-WKBDR-Exp]\" + \"***\"*5)\n",
    "print(\"[MF-WKBDR-Exp] test mse:\", mse_mfwkbdr)\n",
    "print(\"[MF-WKBDR-Exp] test auc:\", auc_mfwkbdr)\n",
    "print(\"[MF-WKBDR-Exp] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-WKBDR-Exp] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-WKBDR-Exp] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-WKBDR-Exp] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-WKBDR-Exp]\" + \"***\"*5)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
