{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e6b7975",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import numpy as np\n",
    "import torch\n",
    "from sklearn.metrics import roc_auc_score\n",
    "np.random.seed(2020)\n",
    "torch.manual_seed(2020)\n",
    "import pdb\n",
    "\n",
    "from dataset import load_data\n",
    "from matrix_factorization import MF, MF_N_IPS, MF_N_DR_JL, MF_N_MRDR_JL\n",
    "\n",
    "from utils import gini_index, ndcg_func, get_user_wise_ctr, rating_mat_to_sample, binarize, shuffle, minU,recall_func, precision_func\n",
    "mse_func = lambda x,y: np.mean((x-y)**2)\n",
    "acc_func = lambda x,y: np.sum(x == y) / len(x)\n",
    "\n",
    "dataset_name = \"coat\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9418350",
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name == \"coat\":\n",
    "    train_mat, test_mat = load_data(\"coat\")        \n",
    "    x_train, y_train = rating_mat_to_sample(train_mat)\n",
    "    x_test, y_test = rating_mat_to_sample(test_mat)\n",
    "    num_user = train_mat.shape[0]\n",
    "    num_item = train_mat.shape[1]\n",
    "\n",
    "elif dataset_name == \"yahoo\":\n",
    "    x_train, y_train, x_test, y_test = load_data(\"yahoo\")\n",
    "    x_train, y_train = shuffle(x_train, y_train)\n",
    "    num_user = x_train[:,0].max() + 1\n",
    "    num_item = x_train[:,1].max() + 1\n",
    "\n",
    "print(\"# user: {}, # item: {}\".format(num_user, num_item))\n",
    "# binarize\n",
    "y_train = binarize(y_train)\n",
    "y_test = binarize(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "304cd8bb",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5222,
     "status": "ok",
     "timestamp": 1681303744609,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "304cd8bb",
    "outputId": "eab5110a-26f9-4ad6-9e95-c228b7b20a08"
   },
   "outputs": [],
   "source": [
    "\"MF naive\"\n",
    "mf = MF(num_user, num_item, batch_size=128)\n",
    "mf.cuda()\n",
    "mf.fit(x_train, y_train, \n",
    "    lr=0.05,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = mf.predict(x_test)\n",
    "mse_mf = mse_func(y_test, test_pred)\n",
    "auc_mf = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf, x_test, y_test)\n",
    "recall_res = recall_func(mf, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF]\" + \"***\"*5)\n",
    "print(\"[MF] test mse:\", mse_mf)\n",
    "print(\"[MF] test auc:\", auc_mf)\n",
    "print(\"[MF] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "PeyU9z1tbB14",
   "metadata": {
    "id": "PeyU9z1tbB14"
   },
   "outputs": [],
   "source": [
    "\"MF N IPS\"\n",
    "mf_interference_ips = MF_N_IPS(num_user, num_item, low = 0.05, up = 0.95, c = 1)\n",
    "mf_interference_ips.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_interference_ips.propensity_model.fit(x_train, thr = 1, lr = 0.01, lamb = 1e-3)\n",
    "\n",
    "mf_interference_ips.fit(x_train, y_train, y_ips, thr = 0.8, g_value = [0],\n",
    "    lr=0.01,\n",
    "    g = 50,\n",
    "    h = 45,\n",
    "    batch_size=128,\n",
    "    lamb1 = 5e-3,\n",
    "    lamb2 = 5e-3,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_interference_ips.predict(x_test)\n",
    "mse_mfips = mse_func(y_test, test_pred)\n",
    "auc_mfips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_interference_ips, x_test, y_test)\n",
    "recall_res = recall_func(mf_interference_ips, x_test, y_test)\n",
    "precision_res = precision_func(mf_interference_ips, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-Interference-IPS]\" + \"***\"*5)\n",
    "print(\"[MF-Interference-IPS] test mse:\", mse_func(y_test, test_pred))\n",
    "print(\"[MF-Interference-IPS] test auc:\", auc_mfips)\n",
    "print(\"[MF-Interference-IPS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-Interference-IPS] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-Interference-IPS] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))    \n",
    "print('f1@5', np.mean(recall_res[\"recall_5\"]) * np.mean(precision_res[\"precision_5\"])/\n",
    "     (np.mean(recall_res[\"recall_5\"]) + np.mean(precision_res[\"precision_5\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-Interference-IPS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cb337a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF N DR JL\"\n",
    "mf_interference_dr_jl = MF_N_DR_JL(num_user, num_item, low = 0.05, up = 0.95, c = 1)\n",
    "mf_interference_dr_jl.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_interference_dr_jl.propensity_model.fit(x_train, thr = 1, lr = 0.01, lamb = 1e-3)\n",
    "\n",
    "mf_interference_dr_jl.fit(x_train, y_train, y_ips, g_value = [0],\n",
    "    lr=0.01,\n",
    "    g = 50,\n",
    "    h = 50,\n",
    "    G = 4,\n",
    "    batch_size=128,\n",
    "    lamb1 = 1e-4,\n",
    "    lamb2 = 1e-4,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_interference_dr_jl.predict(x_test)\n",
    "mse_mfdrjl = mse_func(y_test, test_pred)\n",
    "auc_mfdrjl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_interference_dr_jl, x_test, y_test)\n",
    "recall_res = recall_func(mf_interference_dr_jl, x_test, y_test)\n",
    "precision_res = precision_func(mf_interference_dr_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-Interference-DR-JL]\" + \"***\"*5)\n",
    "print(\"[MF-Interference-DR-JL] test mse:\", mse_func(y_test, test_pred))\n",
    "print(\"[MF-Interference-DR-JL] test auc:\", auc_mfdrjl)\n",
    "print(\"[MF-Interference-DR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-Interference-DR-JL] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-Interference-DR-JL] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))    \n",
    "print('f1@5', np.mean(recall_res[\"recall_5\"]) * np.mean(precision_res[\"precision_5\"])/\n",
    "     (np.mean(recall_res[\"recall_5\"]) + np.mean(precision_res[\"precision_5\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-Interference-DR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1c8055b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF N MRDR JL\"\n",
    "mf_interference_mrdr_jl = MF_N_MRDR_JL(num_user, num_item, low = 0.05, up = 0.95)\n",
    "mf_interference_mrdr_jl.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_interference_mrdr_jl.propensity_model.fit(x_train, thr = 1, lamb = 1e-5)\n",
    "\n",
    "mf_interference_mrdr_jl.fit(x_train, y_train, y_ips, g_value = [0],\n",
    "    lr=0.01,\n",
    "    g = 50,\n",
    "    h = 40,\n",
    "    G = 1,\n",
    "    batch_size=128,\n",
    "    lamb1 = 1e-4,\n",
    "    lamb2 = 1e-4,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_interference_mrdr_jl.predict(x_test)\n",
    "mse_mfips = mse_func(y_test, test_pred)\n",
    "auc_mfips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_interference_mrdr_jl, x_test, y_test)\n",
    "recall_res = recall_func(mf_interference_mrdr_jl, x_test, y_test)\n",
    "precision_res = precision_func(mf_interference_mrdr_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-Interference-MRDR-JL]\" + \"***\"*5)\n",
    "print(\"[MF-Interference-MRDR-JL] test mse:\", mse_func(y_test, test_pred))\n",
    "print(\"[MF-Interference-MRDR-JL] test auc:\", auc_mfips)\n",
    "print(\"[MF-Interference-MRDR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "print(\"[MF-Interference-MRDR-JL] recall@5:{:.6f}, recall@10:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_5\"]), np.mean(recall_res[\"recall_10\"])))\n",
    "print(\"[MF-Interference-MRDR-JL] precision@5:{:.6f}, precision@10:{:.6f}\".format(\n",
    "        np.mean(precision_res[\"precision_5\"]), np.mean(precision_res[\"precision_10\"])))    \n",
    "print('f1@5', np.mean(recall_res[\"recall_5\"]) * np.mean(precision_res[\"precision_5\"])/\n",
    "     (np.mean(recall_res[\"recall_5\"]) + np.mean(precision_res[\"precision_5\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-Interference-MRDR-JL]\" + \"***\"*5)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python [conda env:pytorch-gpu]",
   "language": "python",
   "name": "conda-env-pytorch-gpu-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
