{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a4abb30",
   "metadata": {
    "executionInfo": {
     "elapsed": 6246,
     "status": "ok",
     "timestamp": 1681717210355,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "0a4abb30",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import numpy as np\n",
    "import torch\n",
    "import pdb\n",
    "from sklearn.metrics import roc_auc_score\n",
    "np.random.seed(2020)\n",
    "torch.manual_seed(2020)\n",
    "import pdb\n",
    "\n",
    "from dataset import load_data\n",
    "from matrix_factorization import PMF, PMF_IPS, PMF_CVIB, PMF_SNIPS, PMF_ASIPS, PMF_DR, PMF_DR_JL, PMF_MRDR_JL, PMF_DIB, PMF_DR_BIAS, PMF_DR_MSE, PMF_ours_JL\n",
    "from TDR import PMF_TDR, PMF_TDR_JL\n",
    "from StableDR import PMF_Stable_DR\n",
    "from utils import gini_index, ndcg_func, get_user_wise_ctr, rating_mat_to_sample, binarize, shuffle, minU,recall_func, precision_func\n",
    "mse_func = lambda x,y: np.mean((x-y)**2)\n",
    "acc_func = lambda x,y: np.sum(x == y) / len(x)\n",
    "\n",
    "dataset_name = \"yahoo\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "902db9a6",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2411,
     "status": "ok",
     "timestamp": 1681717215335,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "902db9a6",
    "outputId": "91ea78f9-f263-4389-98b4-cff9721c3c63",
    "tags": []
   },
   "outputs": [],
   "source": [
    "if dataset_name == \"coat\":\n",
    "    train_mat, test_mat = load_data(\"coat\")        \n",
    "    x_train, y_train = rating_mat_to_sample(train_mat)\n",
    "    x_test, y_test = rating_mat_to_sample(test_mat)\n",
    "    num_user = train_mat.shape[0]\n",
    "    num_item = train_mat.shape[1]\n",
    "\n",
    "elif dataset_name == \"yahoo\":\n",
    "    x_train, y_train, x_test, y_test = load_data(\"yahoo\")\n",
    "    x_train, y_train = shuffle(x_train, y_train)\n",
    "    num_user = x_train[:,0].max() + 1\n",
    "    num_item = x_train[:,1].max() + 1\n",
    "\n",
    "print(\"# user: {}, # item: {}\".format(num_user, num_item))\n",
    "# binarize\n",
    "y_train = binarize(y_train)\n",
    "y_test = binarize(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bb872c5-d591-4288-98d3-c1c5961b0b25",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "304cd8bb",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5222,
     "status": "ok",
     "timestamp": 1681303744609,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "304cd8bb",
    "outputId": "eab5110a-26f9-4ad6-9e95-c228b7b20a08"
   },
   "outputs": [],
   "source": [
    "\"PMF naive\"\n",
    "PMF = PMF(num_user, num_item, batch_size=2048)\n",
    "PMF.cuda()\n",
    "PMF.fit(x_train, y_train, \n",
    "    lr=0.05,\n",
    "    lamb=5e-5,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF.predict(x_test)\n",
    "mse_PMF = mse_func(y_test, test_pred)\n",
    "auc_PMF = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF, x_test, y_test)\n",
    "recall_res = recall_func(PMF, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF]\" + \"***\"*5)\n",
    "print(\"[PMF] test mse:\", mse_PMF)\n",
    "print(\"[PMF] test auc:\", auc_PMF)\n",
    "print(\"[PMF] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3635663c-148c-4150-94ba-c9b6144c4289",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5842bc5b",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2515,
     "status": "ok",
     "timestamp": 1681303749161,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "5842bc5b",
    "outputId": "68c877f4-781e-4c1e-c2c8-d1fa097c577b"
   },
   "outputs": [],
   "source": [
    "\"PMF IPS\"\n",
    "PMF_ips = PMF_IPS(num_user, num_item,batch_size=2048)\n",
    "PMF_ips.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_ips.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.01,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF_ips.predict(x_test)\n",
    "mse_PMFips = mse_func(y_test, test_pred)\n",
    "auc_PMFips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_ips, x_test, y_test)\n",
    "recall_res = recall_func(PMF_ips, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-IPS]\" + \"***\"*5)\n",
    "print(\"[PMF-IPS] test mse:\", mse_PMFips)\n",
    "print(\"[PMF-IPS] test auc:\", auc_PMFips)\n",
    "print(\"[PMF-IPS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-IPS] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-IPS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebfeb807-1988-4400-bfbe-e1be246280ba",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02c46142",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 11033,
     "status": "ok",
     "timestamp": 1681303762345,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "02c46142",
    "outputId": "1e530923-e591-40b2-fb3d-cff0505f2235"
   },
   "outputs": [],
   "source": [
    "\"PMF ASIPS\"\n",
    "PMF_ips = PMF_ASIPS(num_user, num_item,batch_size = 2048)\n",
    "PMF_ips.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_ips.fit(x_train, y_train,  y_ips=y_ips, tao = 0.1,\n",
    "    batch_size = 2048,\n",
    "    stop = 20,\n",
    "    lr=0.05,\n",
    "    G = 5,\n",
    "    lamb=5e-6,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_ips.predict(x_test)\n",
    "mse_PMFips = mse_func(y_test, test_pred)\n",
    "auc_PMFips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_ips, x_test, y_test)\n",
    "recall_res = recall_func(PMF_ips, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-ASIPS]\" + \"***\"*5)\n",
    "print(\"[PMF-ASIPS] test mse:\", mse_PMFips)\n",
    "print(\"[PMF-ASIPS] test auc:\", auc_PMFips)\n",
    "print(\"[PMF-ASIPS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-ASIPS] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-ASIPS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c1c88ae-a1e8-4690-8f8d-40a440af4e6d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "509a2a98",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 4097,
     "status": "ok",
     "timestamp": 1681303773183,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "509a2a98",
    "outputId": "3e3855bc-47e8-488e-e575-658ba85f24ea"
   },
   "outputs": [],
   "source": [
    "\"PMF-SNIPS\"\n",
    "PMF_snips = PMF_SNIPS(num_user, num_item,batch_size=8192)\n",
    "PMF_snips.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_snips.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.01,\n",
    "    lamb=1e-5,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF_snips.predict(x_test)\n",
    "mse_PMFsnips = mse_func(y_test, test_pred)\n",
    "auc_PMFsnips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_snips, x_test, y_test)\n",
    "recall_res = recall_func(PMF_snips, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-SNIPS]\" + \"***\"*5)\n",
    "print(\"[PMF-SNIPS] test mse:\", mse_PMFsnips)\n",
    "print(\"[PMF-SNIPS] test auc:\", auc_PMFsnips)\n",
    "print(\"[PMF-SNIPS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-SNIPS] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-SNIPS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4e09783-c157-47b3-aa42-f9c53a026a74",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fde7613",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 4165,
     "status": "ok",
     "timestamp": 1681303781377,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "7fde7613",
    "outputId": "008200bc-c761-4333-8791-989ef37ad541"
   },
   "outputs": [],
   "source": [
    "\"PMF CVIB\"\n",
    "PMF_cvib = PMF_CVIB(num_user, num_item,batch_size=2048)\n",
    "PMF_cvib.cuda()\n",
    "PMF_cvib.fit(x_train, y_train, \n",
    "    lr=0.01,\n",
    "    lamb=5e-5,\n",
    "    alpha=1,\n",
    "    gamma=1e-3,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "\n",
    "test_pred = PMF_cvib.predict(x_test)\n",
    "mse_PMFcvib = mse_func(y_test, test_pred)\n",
    "auc_PMFcvib = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_cvib, x_test, y_test)\n",
    "recall_res = recall_func(PMF_cvib, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-CVIB]\" + \"***\"*5)\n",
    "print(\"[PMF-CVIB] test mse:\", mse_PMFcvib)\n",
    "print(\"[PMF-CVIB] test auc:\", auc_PMFcvib)\n",
    "print(\"[PMF-CVIB] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-CVIB] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-CVIB]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3648dc5e-c055-4ec9-aafd-5e65d4a0fdc8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6f67e1b",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1656,
     "status": "ok",
     "timestamp": 1681303790112,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "d6f67e1b",
    "outputId": "1ef702c8-7ed0-4ee5-d0b0-3b9c376c94fb"
   },
   "outputs": [],
   "source": [
    "\"PMF DR\"\n",
    "PMF_dr = PMF_DR(num_user, num_item,batch_size=2048)\n",
    "PMF_dr.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.001,\n",
    "    batch_size=2048,\n",
    "    lamb=1e-6,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF_dr.predict(x_test)\n",
    "mse_PMFdr = mse_func(y_test, test_pred)\n",
    "auc_PMFdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dr, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DR]\" + \"***\"*5)\n",
    "print(\"[PMF-DR] test mse:\", mse_PMFdr)\n",
    "print(\"[PMF-DR] test auc:\", auc_PMFdr)\n",
    "print(\"[PMF-DR] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-DR] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DR]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8502acab-7218-479c-8c8d-563de37f82ec",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4e4f343",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12103,
     "status": "ok",
     "timestamp": 1681303805661,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "c4e4f343",
    "outputId": "ca9cbdcd-0c55-491b-8228-450500019c47"
   },
   "outputs": [],
   "source": [
    "\"PMF DR JL\"\n",
    "PMF_dr_jl = PMF_DR_JL(num_user, num_item,batch_size=2048)\n",
    "PMF_dr_jl.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr_jl.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=2048,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF_dr_jl.predict(x_test)\n",
    "mse_PMFdrjl = mse_func(y_test, test_pred)\n",
    "auc_PMFdrjl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_jl, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dr_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DR-JL]\" + \"***\"*5)\n",
    "print(\"[PMF-DR-JL] test mse:\", mse_PMFdrjl)\n",
    "print(\"[PMF-DR-JL] test auc:\", auc_PMFdrjl)\n",
    "print(\"[PMF-DR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-DR-JL] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a5dc8ce-246c-4743-b737-870be2571b22",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "082fa612",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 8969,
     "status": "ok",
     "timestamp": 1681303818908,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "082fa612",
    "outputId": "0a6e3270-022a-41db-c44a-2f8c07033508"
   },
   "outputs": [],
   "source": [
    "\"PMF MRDR JL\"\n",
    "PMF_mrdr_jl = PMF_MRDR_JL(num_user, num_item)\n",
    "PMF_mrdr_jl.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_mrdr_jl.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=2048,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF_mrdr_jl.predict(x_test)\n",
    "mse_PMFmrdrjl = mse_func(y_test, test_pred)\n",
    "auc_PMFmrdrjl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_mrdr_jl, x_test, y_test)\n",
    "recall_res = recall_func(PMF_mrdr_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-MRDR-JL]\" + \"***\"*5)\n",
    "print(\"[PMF-MRDR-JL] test mse:\", mse_PMFmrdrjl)\n",
    "print(\"[PMF-MRDR-JL] test auc:\", auc_PMFmrdrjl)\n",
    "print(\"[PMF-MRDR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-MRDR-JL] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-MRDR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "917fa337-a82c-47ba-b626-d882c0514b69",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a256ec2-6418-4b2c-87a9-8f48cc2c4ab7",
   "metadata": {},
   "outputs": [],
   "source": [
    "PMF_dr_bias = PMF_DR_BIAS(num_user, num_item)\n",
    "PMF_dr_bias.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr_bias.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=2048,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF_dr_bias.predict(x_test)\n",
    "mse_PMF_dr_bias = mse_func(y_test, test_pred)\n",
    "auc_PMF_dr_bias = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_bias, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dr_bias, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DR_BIAS]\" + \"***\"*5)\n",
    "print(\"[PMF-DR_BIAS] test mse:\", mse_PMF_dr_bias)\n",
    "print(\"[PMF-DR_BIAS] test auc:\", auc_PMF_dr_bias)\n",
    "print(\"[PMF-DR_BIAS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-DR_BIAS] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DR_BIAS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1204e86-b3d8-48a9-964b-e9d677dae5e7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3ebb5f0-9818-4ede-a3f2-475ffd50c200",
   "metadata": {},
   "outputs": [],
   "source": [
    "PMF_dr_mse = PMF_DR_MSE(num_user, num_item)\n",
    "PMF_dr_mse.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr_mse.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.03,\n",
    "    gamma = 0.5,\n",
    "    batch_size=2048,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF_dr_mse.predict(x_test)\n",
    "mse_PMF_dr_mse = mse_func(y_test, test_pred)\n",
    "auc_PMF_dr_mse = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_mse, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dr_mse, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DR_MSE]\" + \"***\"*5)\n",
    "print(\"[PMF-DR_MSE] test mse:\", mse_PMF_dr_mse)\n",
    "print(\"[PMF-DR_MSE] test auc:\", auc_PMF_dr_mse)\n",
    "print(\"[PMF-DR_MSE] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-DR_MSE] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DR_MSE]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01b4053b-402e-4e1d-a590-39072d5f3f89",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fl1yR4rmbFas",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 6413,
     "status": "ok",
     "timestamp": 1681303829039,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "fl1yR4rmbFas",
    "outputId": "99a4a84a-6cce-42b2-86f0-de664cedd96b"
   },
   "outputs": [],
   "source": [
    "\"PMF-DIB\"\n",
    "PMF_dib = PMF_DIB(num_user, num_item,batch_size=2048)\n",
    "PMF_dib.cuda()\n",
    "PMF_dib.fit(x_train, y_train, \n",
    "    lr=0.01,\n",
    "    alpha=0.5, \n",
    "    gamma=0.5,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_dib.predict(x_test)\n",
    "mse_PMF_dib = mse_func(y_test, test_pred)\n",
    "auc_PMF_dib = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dib, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dib, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DIB]\" + \"***\"*5)\n",
    "print(\"[PMF-DIB] test mse:\", mse_PMF_dib)\n",
    "print(\"[PMF-DIB] test auc:\", auc_PMF_dib)\n",
    "print(\"[PMF-DIB] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-DIB] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DIB]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34aa7be3-27a1-4eca-ab58-93603cc942a8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47d86ad0-131f-4179-86db-ee82a1d11295",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\"PMF TDR\"\n",
    "PMF_dr_tmle = PMF_TDR(num_user, num_item, batch_size=2048, embedding_k = 8)\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "prior_y = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr_tmle.fit(x_train, y_train, prior_y, gamma = 0.15,\n",
    "    lr=0.05,\n",
    "    G = 3,\n",
    "    lamb=5e-4,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF_dr_tmle.predict(x_test)\n",
    "mse_PMFdrtmle = mse_func(y_test, test_pred)\n",
    "auc_PMFdrtmle = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_tmle, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dr_tmle, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-TDR]\" + \"***\"*5)\n",
    "print(\"[PMF-TDR] test mse:\", mse_PMFdrtmle)\n",
    "print(\"[PMF-TDR] test auc:\", auc_PMFdrtmle)\n",
    "print(\"[PMF-TDR] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-TDR] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-TDR]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52baeebf-2a28-4601-af0d-ce4309c9f141",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48c97d88-023e-4adc-81ea-821bea2fa31e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\"PMF TDR JL\"\n",
    "PMF_dr_tmle_jl = PMF_TDR_JL(num_user, num_item, batch_size=2048, embedding_k = 8)\n",
    "\n",
    "PMF_dr_tmle_jl.fit(x_train, y_train, gamma = 0.1,\n",
    "    lr=0.05,\n",
    "    G = 3,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF_dr_tmle_jl.predict(x_test)\n",
    "mse_PMFdrtmlejl = mse_func(y_test, test_pred)\n",
    "auc_PMFdrtmlejl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_tmle_jl, x_test, y_test)\n",
    "recall_res = recall_func(PMF_dr_tmle_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-TDR-JL]\" + \"***\"*5)\n",
    "print(\"[PMF-TDR-JL] test mse:\", mse_PMFdrtmlejl)\n",
    "print(\"[PMF-TDR-JL] test auc:\", auc_PMFdrtmlejl)\n",
    "print(\"[PMF-TDR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-TDR-JL] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-TDR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a201274-a736-439c-8965-48ac92c9e7d8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4821c7c-1c8c-4103-9a3b-35ff0ca284f9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "PMF_stable_dr = PMF_Stable_DR(num_user, num_item, embedding_k = 8)\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_stable_dr.fit(x_train, y_train, y_ips, stop = 1,\n",
    "    eta = 1000,\n",
    "    lr=0.05,\n",
    "    G = 5,\n",
    "    batch_size=8192,\n",
    "    lr1 = 150,\n",
    "    lamb=1e-5,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "\n",
    "test_pred = PMF_stable_dr.predict(x_test)\n",
    "mse_PMFsdr = mse_func(y_test, test_pred)\n",
    "auc_PMFsdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_stable_dr, x_test, y_test)\n",
    "recall_res = recall_func(PMF_stable_dr, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-Stable-DR]\" + \"***\"*5)\n",
    "print(\"[PMF-Stable-DR] test mse:\", mse_PMFsdr)\n",
    "print(\"[PMF-Stable-DR] test auc:\", auc_PMFsdr)\n",
    "print(\"[PMF-Stable-DR] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-Stable-DR] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-Stable-DR]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e016ac9-1956-436a-932d-90294d4a0d5e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ede22b68",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"PMF-OURS\"\n",
    "PMF_ours = PMF_ours_JL(num_user, num_item, batch_size = 2048)\n",
    "PMF_ours.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_ours.fit(x_train, y_train, y_ips=y_ips, \n",
    "        lr = 0.01,\n",
    "        lr2 = 1e-5,\n",
    "        lamb1 = 1e-4,\n",
    "        lamb2 = 1e-4,\n",
    "        lamb3 = 1e-4,\n",
    "        tol=1e-5)\n",
    "test_pred = PMF_ours.predict(x_test)\n",
    "mse_PMF_ours = mse_func(y_test, test_pred)\n",
    "auc_PMF_ours = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_ours, x_test, y_test)\n",
    "recall_res = recall_func(PMF_ours, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-OURS]\" + \"***\"*5)\n",
    "print(\"[PMF-OURS] test mse:\", mse_PMF_ours)\n",
    "print(\"[PMF-OURS] test auc:\", auc_PMF_ours)\n",
    "print(\"[PMF-OURS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[PMF-OURS] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-OURS]\" + \"***\"*5)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
