{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "Xo7qeXqAxVxp",
   "metadata": {
    "executionInfo": {
     "elapsed": 8551,
     "status": "ok",
     "timestamp": 1681721630998,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "Xo7qeXqAxVxp",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import numpy as np\n",
    "import torch\n",
    "import pdb\n",
    "from sklearn.metrics import roc_auc_score\n",
    "np.random.seed(2020)\n",
    "torch.manual_seed(2020)\n",
    "import pdb\n",
    "\n",
    "from dataset import load_data\n",
    "from matrix_factorization import PMF, PMF_IPS, PMF_CVIB, PMF_SNIPS, PMF_ASIPS, PMF_DR, PMF_DR_JL, PMF_MRDR_JL, PMF_DIB, PMF_DR_BIAS, PMF_DR_MSE, PMF_ours_JL\n",
    "from TDR import PMF_TDR, PMF_TDR_JL\n",
    "from StableDR import PMF_Stable_DR\n",
    "\n",
    "from utils import gini_index, ndcg_func, get_user_wise_ctr, rating_mat_to_sample, binarize, shuffle, minU,recall_func, precision_func\n",
    "mse_func = lambda x,y: np.mean((x-y)**2)\n",
    "acc_func = lambda x,y: np.sum(x == y) / len(x)\n",
    "\n",
    "dataset_name = \"coat\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "902db9a6",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1441,
     "status": "ok",
     "timestamp": 1681721635206,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "902db9a6",
    "outputId": "f5254160-9ad6-4c18-d5ac-92aa63d6700d",
    "tags": []
   },
   "outputs": [],
   "source": [
    "if dataset_name == \"coat\":\n",
    "    train_mat, test_mat = load_data(\"coat\")        \n",
    "    x_train, y_train = rating_mat_to_sample(train_mat)\n",
    "    x_test, y_test = rating_mat_to_sample(test_mat)\n",
    "    num_user = train_mat.shape[0]\n",
    "    num_item = train_mat.shape[1]\n",
    "\n",
    "elif dataset_name == \"yahoo\":\n",
    "    x_train, y_train, x_test, y_test = load_data(\"yahoo\")\n",
    "    x_train, y_train = shuffle(x_train, y_train)\n",
    "    num_user = x_train[:,0].max() + 1\n",
    "    num_item = x_train[:,1].max() + 1\n",
    "\n",
    "print(\"# user: {}, # item: {}\".format(num_user, num_item))\n",
    "# binarize\n",
    "y_train = binarize(y_train)\n",
    "y_test = binarize(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4016c946",
   "metadata": {
    "id": "4016c946"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f0469f0",
   "metadata": {
    "id": "0f0469f0"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "304cd8bb",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5222,
     "status": "ok",
     "timestamp": 1681303744609,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "304cd8bb",
    "outputId": "eab5110a-26f9-4ad6-9e95-c228b7b20a08"
   },
   "outputs": [],
   "source": [
    "\"PMF naive\"\n",
    "PMF = PMF(num_user, num_item, batch_size=128)\n",
    "PMF.cuda()\n",
    "PMF.fit(x_train, y_train, \n",
    "    lr=0.05,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF.predict(x_test)\n",
    "mse_PMF = mse_func(y_test, test_pred)\n",
    "auc_PMF = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF, x_test, y_test)\n",
    "recall_res = recall_func(PMF, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF]\" + \"***\"*5)\n",
    "print(\"[PMF] test mse:\", mse_PMF)\n",
    "print(\"[PMF] test auc:\", auc_PMF)\n",
    "print(\"[PMF] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "PeyU9z1tbB14",
   "metadata": {
    "id": "PeyU9z1tbB14"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5842bc5b",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2391,
     "status": "ok",
     "timestamp": 1681721970288,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "5842bc5b",
    "outputId": "accaf3b7-b6fb-40b3-af50-608f531efb64"
   },
   "outputs": [],
   "source": [
    "\"PMF IPS\"\n",
    "PMF_ips = PMF_IPS(num_user, num_item,batch_size=128)\n",
    "PMF_ips.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_ips.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    lamb=5e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_ips.predict(x_test)\n",
    "mse_PMFips = mse_func(y_test, test_pred)\n",
    "auc_PMFips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_ips, x_test, y_test)\n",
    "recall_res = recall_func(PMF_ips, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-IPS]\" + \"***\"*5)\n",
    "print(\"[PMF-IPS] test mse:\", mse_PMFips)\n",
    "print(\"[PMF-IPS] test auc:\", auc_PMFips)\n",
    "print(\"[PMF-IPS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-IPS] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-IPS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "Fo3Xuj98bVE6",
   "metadata": {
    "id": "Fo3Xuj98bVE6"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02c46142",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 10834,
     "status": "ok",
     "timestamp": 1681722619745,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "02c46142",
    "outputId": "649907c2-57b0-4c7f-af1c-4a8041b9d3af"
   },
   "outputs": [],
   "source": [
    "\"PMF ASIPS\"\n",
    "PMF_ips = PMF_ASIPS(num_user, num_item,batch_size = 128)\n",
    "PMF_ips.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_ips.fit(x_train, y_train,  y_ips=y_ips, tao = 0.1,\n",
    "    batch_size = 128,\n",
    "    lr=0.05,\n",
    "    G = 1,\n",
    "    lamb=5e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_ips.predict(x_test)\n",
    "mse_PMFips = mse_func(y_test, test_pred)\n",
    "auc_PMFips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_ips, x_test, y_test)\n",
    "recall_res = recall_func(PMF_ips, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-ASIPS]\" + \"***\"*5)\n",
    "print(\"[PMF-ASIPS] test mse:\", mse_PMFips)\n",
    "print(\"[PMF-ASIPS] test auc:\", auc_PMFips)\n",
    "print(\"[PMF-ASIPS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-ASIPS] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-ASIPS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ePO7jlWfbse3",
   "metadata": {
    "id": "ePO7jlWfbse3"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "509a2a98",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3645,
     "status": "ok",
     "timestamp": 1681722870704,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "509a2a98",
    "outputId": "ffd5725e-5255-46f8-c5c8-ac4e9f058f24"
   },
   "outputs": [],
   "source": [
    "\"PMF-SNIPS\"\n",
    "PMF_snips = PMF_SNIPS(num_user, num_item,batch_size=128)\n",
    "PMF_snips.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_snips.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=128,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_snips.predict(x_test)\n",
    "mse_PMFsnips = mse_func(y_test, test_pred)\n",
    "auc_PMFsnips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_snips, x_test, y_test)\n",
    "recall_res = recall_func(PMF_snips, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-SNIPS]\" + \"***\"*5)\n",
    "print(\"[PMF-SNIPS] test mse:\", mse_PMFsnips)\n",
    "print(\"[PMF-SNIPS] test auc:\", auc_PMFsnips)\n",
    "print(\"[PMF-SNIPS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-SNIPS] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-SNIPS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "NbbAQQ7Ee4cu",
   "metadata": {
    "id": "NbbAQQ7Ee4cu"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fde7613",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 4165,
     "status": "ok",
     "timestamp": 1681303781377,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "7fde7613",
    "outputId": "008200bc-c761-4333-8791-989ef37ad541"
   },
   "outputs": [],
   "source": [
    "\"PMF CVIB\"\n",
    "PMF_cvib = PMF_CVIB(num_user, num_item,batch_size=128)\n",
    "PMF_cvib.cuda()\n",
    "PMF_cvib.fit(x_train, y_train, \n",
    "    lr=0.05,\n",
    "    batch_size=128,\n",
    "    lamb=5e-4,\n",
    "    alpha=1.0,\n",
    "    gamma=1e-3,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "\n",
    "test_pred = PMF_cvib.predict(x_test)\n",
    "mse_PMFcvib = mse_func(y_test, test_pred)\n",
    "auc_PMFcvib = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_cvib, x_test, y_test)\n",
    "recall_res = recall_func(PMF_cvib, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-CVIB]\" + \"***\"*5)\n",
    "print(\"[PMF-CVIB] test mse:\", mse_PMFcvib)\n",
    "print(\"[PMF-CVIB] test auc:\", auc_PMFcvib)\n",
    "print(\"[PMF-CVIB] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-CVIB] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-CVIB]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "XvIp-YZRgELf",
   "metadata": {
    "id": "XvIp-YZRgELf"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6f67e1b",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2529,
     "status": "ok",
     "timestamp": 1681723316055,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "d6f67e1b",
    "outputId": "bfe640bc-1bab-4c41-c983-cf0f2e6ed5c7"
   },
   "outputs": [],
   "source": [
    "\"PMF DR\"\n",
    "PMF_dr = PMF_DR(num_user, num_item,batch_size=128)\n",
    "PMF_dr.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.03,\n",
    "    G = 1,\n",
    "    batch_size=128,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_dr.predict(x_test)\n",
    "mse_PMFdr = mse_func(y_test, test_pred)\n",
    "auc_PMFdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dr, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DR]\" + \"***\"*5)\n",
    "print(\"[PMF-DR] test mse:\", mse_PMFdr)\n",
    "print(\"[PMF-DR] test auc:\", auc_PMFdr)\n",
    "print(\"[PMF-DR] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-DR] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DR]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ANDSxdagsQI",
   "metadata": {
    "id": "2ANDSxdagsQI"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4e4f343",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 8032,
     "status": "ok",
     "timestamp": 1681723615717,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "c4e4f343",
    "outputId": "74518481-e0e6-4e8b-cfba-ed54b1043bbd"
   },
   "outputs": [],
   "source": [
    "\"PMF DR JL\"\n",
    "PMF_dr_jl = PMF_DR_JL(num_user, num_item,batch_size=128)\n",
    "PMF_dr_jl.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr_jl.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=128,\n",
    "    lamb=5e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_dr_jl.predict(x_test)\n",
    "mse_PMFdrjl = mse_func(y_test, test_pred)\n",
    "auc_PMFdrjl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_jl, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dr_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DR-JL]\" + \"***\"*5)\n",
    "print(\"[PMF-DR-JL] test mse:\", mse_PMFdrjl)\n",
    "print(\"[PMF-DR-JL] test auc:\", auc_PMFdrjl)\n",
    "print(\"[PMF-DR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-DR-JL] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "D0uTw4GyhnUR",
   "metadata": {
    "id": "D0uTw4GyhnUR"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "082fa612",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 9371,
     "status": "ok",
     "timestamp": 1681723922726,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "082fa612",
    "outputId": "b5512ee8-2399-464c-bc82-d5f71158b35c"
   },
   "outputs": [],
   "source": [
    "\"PMF MRDR JL\"\n",
    "PMF_mrdr_jl = PMF_MRDR_JL(num_user, num_item)\n",
    "PMF_mrdr_jl.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_mrdr_jl.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=128,\n",
    "    lamb=5e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_mrdr_jl.predict(x_test)\n",
    "mse_PMFmrdrjl = mse_func(y_test, test_pred)\n",
    "auc_PMFmrdrjl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_mrdr_jl, x_test, y_test)\n",
    "recall_res = recall_func(PMF_mrdr_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-MRDR-JL]\" + \"***\"*5)\n",
    "print(\"[PMF-MRDR-JL] test mse:\", mse_PMFmrdrjl)\n",
    "print(\"[PMF-MRDR-JL] test auc:\", auc_PMFmrdrjl)\n",
    "print(\"[PMF-MRDR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-MRDR-JL] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-MRDR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "M07Puf8ci7Vi",
   "metadata": {
    "id": "M07Puf8ci7Vi"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4da57a5e-cca5-4057-8c61-c3ed247dde1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "PMF_dr_bias = PMF_DR_BIAS(num_user, num_item)\n",
    "PMF_dr_bias.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr_bias.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=128,\n",
    "    lamb=2e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_dr_bias.predict(x_test)\n",
    "mse_PMF_dr_bias = mse_func(y_test, test_pred)\n",
    "auc_PMF_dr_bias = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_bias, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dr_bias, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DR_BIAS]\" + \"***\"*5)\n",
    "print(\"[PMF-DR_BIAS] test mse:\", mse_PMF_dr_bias)\n",
    "print(\"[PMF-DR_BIAS] test auc:\", auc_PMF_dr_bias)\n",
    "print(\"[PMF-DR_BIAS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-DR_BIAS] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DR_BIAS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dae9e95c-ce51-46df-8019-272f59b36ebe",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "590c8426-c04d-464f-ba10-d79c9cf8d455",
   "metadata": {},
   "outputs": [],
   "source": [
    "PMF_dr_mse = PMF_DR_MSE(num_user, num_item)\n",
    "PMF_dr_mse.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr_mse.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    gamma = 0.5,\n",
    "    batch_size=128,\n",
    "    lamb=5e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_dr_mse.predict(x_test)\n",
    "mse_PMF_dr_mse = mse_func(y_test, test_pred)\n",
    "auc_PMF_dr_mse = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_mse, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dr_mse, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DR_MSE]\" + \"***\"*5)\n",
    "print(\"[PMF-DR_MSE] test mse:\", mse_PMF_dr_mse)\n",
    "print(\"[PMF-DR_MSE] test auc:\", auc_PMF_dr_mse)\n",
    "print(\"[PMF-DR_MSE] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-DR_MSE] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DR_MSE]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d25f0c15-7507-4c07-b870-be4b770294a5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fl1yR4rmbFas",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 8523,
     "status": "ok",
     "timestamp": 1681724292101,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "fl1yR4rmbFas",
    "outputId": "f98fb1cb-71e3-41ff-fe05-82226820538a"
   },
   "outputs": [],
   "source": [
    "\"PMF-DIB\"\n",
    "PMF_dib = PMF_DIB(num_user, num_item,batch_size=128)\n",
    "PMF_dib.cuda()\n",
    "PMF_dib.fit(x_train, y_train, \n",
    "    lr=0.01,\n",
    "    alpha=0.9, \n",
    "    gamma=0.9,\n",
    "    lamb=1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_dib.predict(x_test)\n",
    "mse_PMF_dib = mse_func(y_test, test_pred)\n",
    "auc_PMF_dib = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dib, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dib, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DIB]\" + \"***\"*5)\n",
    "print(\"[PMF-DIB] test mse:\", mse_PMF_dib)\n",
    "print(\"[PMF-DIB] test auc:\", auc_PMF_dib)\n",
    "print(\"[PMF-DIB] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-DIB] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DIB]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "lPV686bsIIRO",
   "metadata": {
    "id": "lPV686bsIIRO"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0db91e4b-8292-4449-b0c7-cef817b080b8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "PMF_dr_tmle = PMF_TDR(num_user, num_item, batch_size = 128)\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "prior_y = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr_tmle.fit(x_train, y_train, prior_y, gamma = 0.1,\n",
    "    lr=0.05,\n",
    "    G = 3,\n",
    "    lamb=5e-5,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF_dr_tmle.predict(x_test)\n",
    "mse_PMFdrtmle = mse_func(y_test, test_pred)\n",
    "auc_PMFdrtmle = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_tmle, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dr_tmle, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-TDR]\" + \"***\"*5)\n",
    "print(\"[PMF-TDR] test mse:\", mse_PMFdrtmle)\n",
    "print(\"[PMF-TDR] test auc:\", auc_PMFdrtmle)\n",
    "print(\"[PMF-TDR] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-TDR] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-TDR]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9760e8a8-744f-4bec-aaf5-a3c3ca48113d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad311fbf-cef0-4cf6-af45-4abbe6e8e9a1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\"PMF TDR JL\"\n",
    "\n",
    "PMF_dr_tmle_jl = PMF_TDR_JL(num_user, num_item, batch_size = 128)\n",
    "\n",
    "PMF_dr_tmle_jl.fit(x_train, y_train,\n",
    "    lr=0.05,\n",
    "    G = 4,\n",
    "    lamb=1e-3,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF_dr_tmle_jl.predict(x_test)\n",
    "mse_PMFdrtmlejl = mse_func(y_test, test_pred)\n",
    "auc_PMFdrtmlejl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_tmle_jl, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dr_tmle_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-TDR-JL]\" + \"***\"*5)\n",
    "print(\"[PMF-TDR-JL] test mse:\", mse_PMFdrtmlejl)\n",
    "print(\"[PMF-TDR-JL] test auc:\", auc_PMFdrtmlejl)\n",
    "print(\"[PMF-TDR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-TDR-JL] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-TDR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46e7469e-8375-4b93-9373-79ea29b5d3dd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f1fdc25-4947-44c8-b1be-416dbf1c1fd3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\"PMF Stable DR\"\n",
    "PMF_stable_dr = PMF_Stable_DR(num_user, num_item)\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_stable_dr.fit(x_train, y_train, y_ips,\n",
    "    eta = 100,\n",
    "    lr=0.01,\n",
    "    G = 3,\n",
    "    batch_size=128,\n",
    "    lr1 = 10,\n",
    "    lamb=5e-4,\n",
    "    tol=1e-5)\n",
    "\n",
    "test_pred = PMF_stable_dr.predict(x_test)\n",
    "mse_PMFsdr = mse_func(y_test, test_pred)\n",
    "auc_PMFsdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_stable_dr, x_test, y_test)\n",
    "recall_res = recall_func(PMF_stable_dr, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-Stable-DR]\" + \"***\"*5)\n",
    "print(\"[PMF-Stable-DR] test mse:\", mse_PMFsdr)\n",
    "print(\"[PMF-Stable-DR] test auc:\", auc_PMFsdr)\n",
    "print(\"[PMF-Stable-DR] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-Stable-DR] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-Stable-DR]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3758160a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c97041c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"PMF-OURS\"\n",
    "PMF_ours = PMF_ours_JL(num_user, num_item, batch_size = 128, embedding_k = 8)\n",
    "PMF_ours.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_ours.fit(x_train, y_train, y_ips=y_ips, \n",
    "        lr = 0.05,\n",
    "        lr2 = 0.1,\n",
    "        lamb1 = 0.005,\n",
    "        lamb2 = 0.001,\n",
    "        lamb3 = 0.001,\n",
    "        tol=1e-5)\n",
    "test_pred = PMF_ours.predict(x_test)\n",
    "mse_PMF_ours = mse_func(y_test, test_pred)\n",
    "auc_PMF_ours = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_ours, x_test, y_test)\n",
    "recall_res = recall_func(PMF_ours, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-OURS]\" + \"***\"*5)\n",
    "print(\"[PMF-OURS] test mse:\", mse_PMF_ours)\n",
    "print(\"[PMF-OURS] test auc:\", auc_PMF_ours)\n",
    "print(\"[PMF-OURS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-OURS] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-OURS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85670cf1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
