{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "Xo7qeXqAxVxp",
   "metadata": {
    "executionInfo": {
     "elapsed": 8551,
     "status": "ok",
     "timestamp": 1681721630998,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "Xo7qeXqAxVxp",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import numpy as np\n",
    "import torch\n",
    "import pdb\n",
    "from sklearn.metrics import roc_auc_score\n",
    "# np.random.seed(2024)\n",
    "# torch.manual_seed(2024)\n",
    "import pdb\n",
    "\n",
    "from dataset import load_data\n",
    "from matrix_factorization_music import MF_KBIPS_Gau,MF_AKBIPS_Gau,MF_WKBIPS_Gau,MF_KBDR_Gau,MF_AKBDR_Gau,MF_WKBDR_Gau\n",
    "from matrix_factorization_music import MF_KBIPS_Exp,MF_AKBIPS_Exp,MF_WKBIPS_Exp,MF_KBDR_Exp,MF_AKBDR_Exp,MF_WKBDR_Exp\n",
    "\n",
    "from utils import gini_index, ndcg_func, get_user_wise_ctr, rating_mat_to_sample, binarize, shuffle, minU,recall_func, precision_func\n",
    "mse_func = lambda x,y: np.mean((x-y)**2)\n",
    "acc_func = lambda x,y: np.sum(x == y) / len(x)\n",
    "\n",
    "dataset_name = \"yahoo\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "902db9a6",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1441,
     "status": "ok",
     "timestamp": 1681721635206,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "902db9a6",
    "outputId": "f5254160-9ad6-4c18-d5ac-92aa63d6700d",
    "tags": []
   },
   "outputs": [],
   "source": [
    "x_train, y_train, x_test, y_test = load_data(\"yahoo\")\n",
    "x_train, y_train = shuffle(x_train, y_train)\n",
    "num_user = x_train[:,0].max() + 1\n",
    "num_item = x_train[:,1].max() + 1\n",
    "\n",
    "print(\"# user: {}, # item: {}\".format(num_user, num_item))\n",
    "# binarize\n",
    "y_train = binarize(y_train)\n",
    "y_test = binarize(y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b005c0b-047d-4526-b359-48aa3244897a",
   "metadata": {},
   "source": [
    "### KBIPS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5184ba50",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_kbips_gau = MF_KBIPS_Gau(num_user, num_item)\n",
    "mf_kbips_gau.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_kbips_gau.fit(x_train, y_train,  y_ips,\n",
    "    lr1=5e-2,\n",
    "    lamb1=3e-5,\n",
    "    lr2=1e-2,\n",
    "    lamb2=1e-4,\n",
    "    gamma = 1,\n",
    "    batch_size=2048,\n",
    "    J = 3,\n",
    "    C = 1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_kbips_gau.predict(x_test)\n",
    "mse_mfkbips = mse_func(y_test, test_pred)\n",
    "auc_mfkbips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_kbips_gau, x_test, y_test)\n",
    "recall_res = recall_func(mf_kbips_gau, x_test, y_test)\n",
    "precision_res = precision_func(mf_kbips_gau, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-KBIPS-Gau]\" + \"***\"*5)\n",
    "print(\"[MF-KBIPS-Gau] test mse:\", mse_mfkbips)\n",
    "print(\"[MF-KBIPS-Gau] test auc:\", auc_mfkbips)\n",
    "print(\"[MF-KBIPS-Gau] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-KBIPS-Gau] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-KBIPS-Gau] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-KBIPS-Gau] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "\n",
    "print(\"***\"*5 + \"[MF-KBIPS-Gau]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d6816e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_kbips_exp = MF_KBIPS_Exp(num_user, num_item)\n",
    "mf_kbips_exp.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_kbips_exp.fit(x_train, y_train,  y_ips,\n",
    "    lr1=5e-2,\n",
    "    lamb1=3e-5,\n",
    "    lr2=1e-2,\n",
    "    lamb2=1e-4,\n",
    "    gamma = 1,\n",
    "    batch_size=2048,\n",
    "    J = 3,\n",
    "    C = 1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_kbips_exp.predict(x_test)\n",
    "mse_mfkbips = mse_func(y_test, test_pred)\n",
    "auc_mfkbips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_kbips_exp, x_test, y_test)\n",
    "recall_res = recall_func(mf_kbips_exp, x_test, y_test)\n",
    "precision_res = precision_func(mf_kbips_exp, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-KBIPS-Exp]\" + \"***\"*5)\n",
    "print(\"[MF-KBIPS-Exp] test mse:\", mse_mfkbips)\n",
    "print(\"[MF-KBIPS-Exp] test auc:\", auc_mfkbips)\n",
    "print(\"[MF-KBIPS-Exp] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-KBIPS-Exp] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-KBIPS-Exp] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-KBIPS-Exp] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-KBIPS-Exp]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22341dbb-b46d-4d24-80d8-e011f4619d89",
   "metadata": {},
   "source": [
    "### AKBIPS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59e387ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_akbips_Gau = MF_AKBIPS_Gau(num_user, num_item,embedding_k = 4)\n",
    "mf_akbips_Gau.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_akbips_Gau.fit(x_train, y_train,  y_ips,\n",
    "    lr1=3e-2,\n",
    "    lamb1=1e-4,\n",
    "    lr2=5e-2,\n",
    "    lamb2=1e-4,\n",
    "    lr3=5e-2,\n",
    "    lamb3=1e-4,\n",
    "    gamma = 20,\n",
    "    batch_size=2048,\n",
    "    num_w_epo = 3,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_akbips_Gau.predict(x_test)\n",
    "mse_mfakbips = mse_func(y_test, test_pred)\n",
    "auc_mfakbips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_akbips_Gau, x_test, y_test)\n",
    "recall_res = recall_func(mf_akbips_Gau, x_test, y_test)\n",
    "precision_res = precision_func(mf_akbips_Gau, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-AKBIPS-Gau]\" + \"***\"*5)\n",
    "print(\"[MF-AKBIPS-Gau] test mse:\", mse_mfakbips)\n",
    "print(\"[MF-AKBIPS-Gau] test auc:\", auc_mfakbips)\n",
    "print(\"[MF-AKBIPS-Gau] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-AKBIPS-Gau] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-AKBIPS-Gau] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-AKBIPS-Gau] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-AKBIPS-Gau]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83ad630b",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_akbips_exp = MF_AKBIPS_Exp(num_user, num_item,embedding_k = 4)\n",
    "mf_akbips_exp.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_akbips_exp.fit(x_train, y_train,  y_ips,\n",
    "    lr1=5e-2,\n",
    "    lamb1=1e-4,\n",
    "    lr2=5e-2,\n",
    "    lamb2=1e-4,\n",
    "    lr3=5e-2,\n",
    "    lamb3=1e-4,\n",
    "    gamma = 10,\n",
    "    batch_size=2048,\n",
    "    num_w_epo = 3,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_akbips_exp.predict(x_test)\n",
    "mse_mfakbips = mse_func(y_test, test_pred)\n",
    "auc_mfakbips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_akbips_exp, x_test, y_test)\n",
    "recall_res = recall_func(mf_akbips_exp, x_test, y_test)\n",
    "precision_res = precision_func(mf_akbips_exp, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-AKBIPS-Exp]\" + \"***\"*5)\n",
    "print(\"[MF-AKBIPS-Exp] test mse:\", mse_mfakbips)\n",
    "print(\"[MF-AKBIPS-Exp] test auc:\", auc_mfakbips)\n",
    "print(\"[MF-AKBIPS-Exp] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-AKBIPS-Exp] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-AKBIPS-Exp] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-AKBIPS-Exp] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-AKBIPS-Exp]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05f6a82a-e3f0-4004-b899-ed3c6fa0ce2f",
   "metadata": {},
   "source": [
    "### WKBIPS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d03a91d",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_wkbips_gau = MF_WKBIPS_Gau(num_user, num_item)\n",
    "mf_wkbips_gau.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_wkbips_gau.fit(x_train, y_train,  y_ips,\n",
    "    lr1=3e-2,\n",
    "    lamb1=1e-4,\n",
    "    lr2=1e-2,\n",
    "    lamb2=1e-4,\n",
    "    lr3=5e-2,\n",
    "    lamb3=1e-3,\n",
    "    gamma = 10,\n",
    "    batch_size=2048,\n",
    "    num_w_epo = 3,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_wkbips_gau.predict(x_test)\n",
    "mse_mfwkbips = mse_func(y_test, test_pred)\n",
    "auc_mfwkbips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_wkbips_gau, x_test, y_test)\n",
    "recall_res = recall_func(mf_wkbips_gau, x_test, y_test)\n",
    "precision_res = precision_func(mf_wkbips_gau, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-WKBIPS-Gau]\" + \"***\"*5)\n",
    "print(\"[MF-WKBIPS-Gau] test mse:\", mse_mfwkbips)\n",
    "print(\"[MF-WKBIPS-Gau] test auc:\", auc_mfwkbips)\n",
    "print(\"[MF-WKBIPS-Gau] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-WKBIPS-Gau] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-WKBIPS-Gau] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-WKBIPS-Gau] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-WKBIPS-Gau]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4399bff",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_wkbips_exp = MF_WKBIPS_Exp(num_user, num_item)\n",
    "mf_wkbips_exp.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_wkbips_exp.fit(x_train, y_train,  y_ips,\n",
    "    lr1=3e-2,\n",
    "    lamb1=1e-4,\n",
    "    lr2=1e-2,\n",
    "    lamb2=1e-4,\n",
    "    lr3=5e-2,\n",
    "    lamb3=1e-3,\n",
    "    gamma = 10,\n",
    "    batch_size=2048,\n",
    "    num_w_epo = 3,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_wkbips_exp.predict(x_test)\n",
    "mse_mfwkbips = mse_func(y_test, test_pred)\n",
    "auc_mfwkbips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_wkbips_exp, x_test, y_test)\n",
    "recall_res = recall_func(mf_wkbips_exp, x_test, y_test)\n",
    "precision_res = precision_func(mf_wkbips_exp, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-WKBIPS-Exp]\" + \"***\"*5)\n",
    "print(\"[MF-WKBIPS-Exp] test mse:\", mse_mfwkbips)\n",
    "print(\"[MF-WKBIPS-Exp] test auc:\", auc_mfwkbips)\n",
    "print(\"[MF-WKBIPS-Exp] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-WKBIPS-Exp] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-WKBIPS-Exp] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-WKBIPS-Exp] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-WKBIPS-Exp]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b95ef6a2-35bd-4880-92e8-91fa16a72cf8",
   "metadata": {},
   "source": [
    "### KBDR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e426176a",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_kbdr_gau = MF_KBDR_Gau(num_user, num_item)\n",
    "mf_kbdr_gau.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_kbdr_gau.fit(x_train, y_train, y_ips, \n",
    "    lr=3e-2,\n",
    "    lamb=1e-5,\n",
    "    lr2=5e-2,\n",
    "    lamb2=1e-5,\n",
    "    gamma = 100,\n",
    "    batch_size=2048,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_kbdr_gau.predict(x_test)\n",
    "mse_mfkbdr = mse_func(y_test, test_pred)\n",
    "auc_mfkbdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_kbdr_gau, x_test, y_test)\n",
    "recall_res = recall_func(mf_kbdr_gau, x_test, y_test)\n",
    "precision_res = precision_func(mf_kbdr_gau, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-KBDR-Gau]\" + \"***\"*5)\n",
    "print(\"[MF-KBDR-Gau] test mse:\", mse_mfkbdr)\n",
    "print(\"[MF-KBDR-Gau] test auc:\", auc_mfkbdr)\n",
    "print(\"[MF-KBDR-Gau] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-KBDR-Gau] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-KBDR-Gau] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-KBDR-Gau] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-KBDR-Gau]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c43f70c",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_kbdr_exp = MF_KBDR_Exp(num_user, num_item)\n",
    "mf_kbdr_exp.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_kbdr_exp.fit(x_train, y_train, y_ips, \n",
    "    lr=3e-2,\n",
    "    lamb=5e-5,\n",
    "    lr2=5e-2,\n",
    "    lamb2=1e-5,\n",
    "    gamma = 100,\n",
    "    batch_size=2048,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_kbdr_exp.predict(x_test)\n",
    "mse_mfkbdr = mse_func(y_test, test_pred)\n",
    "auc_mfkbdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_kbdr_exp, x_test, y_test)\n",
    "recall_res = recall_func(mf_kbdr_exp, x_test, y_test)\n",
    "precision_res = precision_func(mf_kbdr_exp, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-KBDR-Exp]\" + \"***\"*5)\n",
    "print(\"[MF-KBDR-Exp] test mse:\", mse_mfkbdr)\n",
    "print(\"[MF-KBDR-Exp] test auc:\", auc_mfkbdr)\n",
    "print(\"[MF-KBDR-Exp] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-KBDR-Exp] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-KBDR-Exp] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-KBDR-Exp] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-KBDR-Exp]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0dc13a5-74f1-4d54-8cc9-3c9ae078a819",
   "metadata": {},
   "source": [
    "### AKBDR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a995691",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_akbdr_gau = MF_AKBDR_Gau(num_user, num_item)\n",
    "mf_akbdr_gau.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_akbdr_gau.fit(x_train, y_train, y_ips, \n",
    "    lr=0.01,\n",
    "    lamb=1e-4,\n",
    "    lr1=1e-2,\n",
    "    lamb1=1e-5,\n",
    "    lr2=1e-2,\n",
    "    lamb2=1e-5,\n",
    "    lr3=1e-2,\n",
    "    lamb3=1e-6,\n",
    "    gamma=10,\n",
    "    batch_size=2048,\n",
    "    num_w_epo = 3,\n",
    "    J = 5,\n",
    "    C = 1e-1,\n",
    "    tol=1e-5)\n",
    "\n",
    "test_pred = mf_akbdr_gau.predict(x_test)\n",
    "mse_mfakbdr = mse_func(y_test, test_pred)\n",
    "auc_mfakbdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_akbdr_gau, x_test, y_test)\n",
    "recall_res = recall_func(mf_akbdr_gau, x_test, y_test)\n",
    "precision_res = precision_func(mf_akbdr_gau, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-AKBDR-Gau]\" + \"***\"*5)\n",
    "print(\"[MF-AKBDR-Gau] test mse:\", mse_mfakbdr)\n",
    "print(\"[MF-AKBDR-Gau] test auc:\", auc_mfakbdr)\n",
    "print(\"[MF-AKBDR-Gau] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-AKBDR-Gau] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-AKBDR-Gau] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-AKBDR-Gau] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-AKBDR-Gau]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4270f870",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_akbdr_gau = MF_AKBDR_Gau(num_user, num_item)\n",
    "mf_akbdr_gau.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_akbdr_gau.fit(x_train, y_train, y_ips, \n",
    "    lr=0.01,\n",
    "    lamb=1e-4,\n",
    "    lr1=1e-2,\n",
    "    lamb1=1e-5,\n",
    "    lr2=1e-2,\n",
    "    lamb2=1e-5,\n",
    "    lr3=1e-2,\n",
    "    lamb3=1e-6,\n",
    "    gamma=10,\n",
    "    batch_size=2048,\n",
    "    num_w_epo = 3,\n",
    "    J = 3,\n",
    "    C = 5e-1,\n",
    "    tol=1e-5)\n",
    "\n",
    "test_pred = mf_akbdr_gau.predict(x_test)\n",
    "mse_mfakbdr = mse_func(y_test, test_pred)\n",
    "auc_mfakbdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_akbdr_gau, x_test, y_test)\n",
    "recall_res = recall_func(mf_akbdr_gau, x_test, y_test)\n",
    "precision_res = precision_func(mf_akbdr_gau, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-AKBDR-Gau]\" + \"***\"*5)\n",
    "print(\"[MF-AKBDR-Gau] test mse:\", mse_mfakbdr)\n",
    "print(\"[MF-AKBDR-Gau] test auc:\", auc_mfakbdr)\n",
    "print(\"[MF-AKBDR-Gau] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-AKBDR-Gau] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-AKBDR-Gau] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-AKBDR-Gau] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-AKBDR-Gau]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02aef59f",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_akbdr_exp = MF_AKBDR_Exp(num_user, num_item)\n",
    "mf_akbdr_exp.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_akbdr_exp.fit(x_train, y_train, y_ips, \n",
    "    lr=1e-2,\n",
    "    lamb=1e-4,\n",
    "    lr1=1e-2,\n",
    "    lamb1=1e-5,\n",
    "    lr2=1e-2,\n",
    "    lamb2=1e-5,\n",
    "    lr3=1e-2,\n",
    "    lamb3=1e-6,\n",
    "    gamma=10,\n",
    "    batch_size=2048,\n",
    "    num_w_epo = 3,\n",
    "    J = 5,\n",
    "    C = 1e-6,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_akbdr_exp.predict(x_test)\n",
    "mse_mfakbdr = mse_func(y_test, test_pred)\n",
    "auc_mfakbdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_akbdr_exp, x_test, y_test)\n",
    "recall_res = recall_func(mf_akbdr_exp, x_test, y_test)\n",
    "precision_res = precision_func(mf_akbdr_exp, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-AKBDR-Exp]\" + \"***\"*5)\n",
    "print(\"[MF-AKBDR-Exp] test mse:\", mse_mfakbdr)\n",
    "print(\"[MF-AKBDR-Exp] test auc:\", auc_mfakbdr)\n",
    "print(\"[MF-AKBDR-Exp] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-AKBDR-Exp] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-AKBDR-Exp] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-AKBDR-Exp] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-AKBDR-Exp]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a65ad1c-d58d-4d50-af9c-941f5db98872",
   "metadata": {},
   "source": [
    "### WKBDR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5ab00b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_wkbdr_gau = MF_WKBDR_Gau(num_user, num_item)\n",
    "mf_wkbdr_gau.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "mf_wkbdr_gau.fit(x_train, y_train,  y_ips,\n",
    "    lr=1e-2,\n",
    "    lamb=1e-4,\n",
    "    lr2=5e-2,\n",
    "    lamb2=1e-5,\n",
    "    gamma = 100,\n",
    "    batch_size=2048,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_wkbdr_gau.predict(x_test)\n",
    "mse_mfwkbdr = mse_func(y_test, test_pred)\n",
    "auc_mfwkbdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_wkbdr_gau, x_test, y_test)\n",
    "recall_res = recall_func(mf_wkbdr_gau, x_test, y_test)\n",
    "precision_res = precision_func(mf_wkbdr_gau, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-WKBDR-Gau]\" + \"***\"*5)\n",
    "print(\"[MF-WKBDR-Gau] test mse:\", mse_mfwkbdr)\n",
    "print(\"[MF-WKBDR-Gau] test auc:\", auc_mfwkbdr)\n",
    "print(\"[MF-WKBDR-Gau] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-WKBDR-Gau] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-WKBDR-Gau] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-WKBDR-Gau] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-WKBDR-Gau]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0d1ed2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "mf_wkbdr_exp = MF_WKBDR_Exp(num_user, num_item)\n",
    "mf_wkbdr_exp.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "mf_wkbdr_exp.fit(x_train, y_train,  y_ips,\n",
    "    lr=3e-2,\n",
    "    lamb=1e-4,\n",
    "    lr2=5e-2,\n",
    "    lamb2=1e-5,\n",
    "    gamma = 100,\n",
    "    batch_size=2048,\n",
    "    J = 3,\n",
    "    C = 1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_wkbdr_exp.predict(x_test)\n",
    "mse_mfwkbdr = mse_func(y_test, test_pred)\n",
    "auc_mfwkbdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_wkbdr_exp, x_test, y_test)\n",
    "recall_res = recall_func(mf_wkbdr_exp, x_test, y_test)\n",
    "precision_res = precision_func(mf_wkbdr_exp, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-WKBDR-Exp]\" + \"***\"*5)\n",
    "print(\"[MF-WKBDR-Exp] test mse:\", mse_mfwkbdr)\n",
    "print(\"[MF-WKBDR-Exp] test auc:\", auc_mfwkbdr)\n",
    "print(\"[MF-WKBDR-Exp] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-WKBDR-Exp] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-WKBDR-Exp] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))\n",
    "print(\"[MF-WKBDR-Exp] f1@5:{:.6f}, f1@10:{:.6f}\".format(\n",
    "        2 * (np.mean(precision_res[\"precision_5\"]) * np.mean(recall_res[\"recall_5\"])) / (np.mean(precision_res[\"precision_5\"]) + np.mean(recall_res[\"recall_5\"])), \n",
    "        2 * (np.mean(precision_res[\"precision_10\"]) * np.mean(recall_res[\"recall_10\"])) / (np.mean(precision_res[\"precision_10\"]) + np.mean(recall_res[\"recall_10\"]))))\n",
    "print(\"***\"*5 + \"[MF-WKBDR-Exp]\" + \"***\"*5)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
