{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "Xo7qeXqAxVxp",
   "metadata": {
    "executionInfo": {
     "elapsed": 6326,
     "status": "ok",
     "timestamp": 1681716626478,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "Xo7qeXqAxVxp"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import pdb\n",
    "from sklearn.metrics import roc_auc_score\n",
    "np.random.seed(2020)\n",
    "torch.manual_seed(2020)\n",
    "import pdb\n",
    "\n",
    "from matrix_factorization import PMF, PMF_IPS, PMF_CVIB, PMF_SNIPS, PMF_ASIPS, PMF_DR, PMF_DR_JL, PMF_MRDR_JL, PMF_DIB, PMF_DR_BIAS, PMF_DR_MSE, PMF_ours_JL\n",
    "from TDR import PMF_TDR, PMF_TDR_JL\n",
    "from StableDR import PMF_Stable_DR\n",
    "\n",
    "from dataset import load_data\n",
    "from utils import gini_index, ndcg_func, get_user_wise_ctr, rating_mat_to_sample, binarize, shuffle, minU, recall_func, precision_func\n",
    "mse_func = lambda x,y: np.mean((x-y)**2)\n",
    "acc_func = lambda x,y: np.sum(x == y) / len(x)\n",
    "import pandas as pd\n",
    "dataset_name = \"kuai\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "902db9a6",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2311,
     "status": "ok",
     "timestamp": 1681716633111,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "902db9a6",
    "outputId": "bd68316b-67c5-4d61-ea3b-dd912943ab10"
   },
   "outputs": [],
   "source": [
    "if dataset_name == \"coat\":\n",
    "    train_mat, test_mat = load_data(\"coat\")        \n",
    "    x_train, y_train = rating_mat_to_sample(train_mat)\n",
    "    x_test, y_test = rating_mat_to_sample(test_mat)\n",
    "    num_user = train_mat.shape[0]\n",
    "    num_item = train_mat.shape[1]\n",
    "\n",
    "elif dataset_name == \"yahoo\":\n",
    "    x_train, y_train, x_test, y_test = load_data(\"yahoo\")\n",
    "    x_train, y_train = shuffle(x_train, y_train)\n",
    "    num_user = x_train[:,0].max() + 1\n",
    "    num_item = x_train[:,1].max() + 1\n",
    "\n",
    "if dataset_name == \"kuai\":\n",
    "    rdf_train = np.array(pd.read_table(\"./data/kuai/user.txt\", header = None, sep = ','))     \n",
    "    rdf_test = np.array(pd.read_table(\"./data/kuai/random.txt\", header = None, sep = ','))\n",
    "    rdf_train_new = np.c_[rdf_train, np.ones(rdf_train.shape[0])]\n",
    "    rdf_test_new = np.c_[rdf_test, np.zeros(rdf_test.shape[0])]\n",
    "    rdf = np.r_[rdf_train_new, rdf_test_new]\n",
    "    \n",
    "    rdf = rdf[np.argsort(rdf[:, 0])]\n",
    "    c = rdf.copy()\n",
    "    for i in range(rdf.shape[0]):\n",
    "        if i == 0:\n",
    "            c[:, 0][i] = i\n",
    "            temp = rdf[:, 0][0]\n",
    "        else:\n",
    "            if c[:, 0][i] == temp:\n",
    "                c[:, 0][i] = c[:, 0][i-1]\n",
    "            else:\n",
    "                c[:, 0][i] = c[:, 0][i-1] + 1\n",
    "            temp = rdf[:, 0][i]\n",
    "    \n",
    "    c = c[np.argsort(c[:, 1])]\n",
    "    d = c.copy()\n",
    "    for i in range(rdf.shape[0]):\n",
    "        if i == 0:\n",
    "            d[:, 1][i] = i\n",
    "            temp = c[:, 1][0]\n",
    "        else:\n",
    "            if d[:, 1][i] == temp:\n",
    "                d[:, 1][i] = d[:, 1][i-1]\n",
    "            else:\n",
    "                d[:, 1][i] = d[:, 1][i-1] + 1\n",
    "            temp = c[:, 1][i]\n",
    "\n",
    "    y_train = d[:, 2][d[:, 3] == 1]\n",
    "    y_test = d[:, 2][d[:, 3] == 0]\n",
    "    x_train = d[:, :2][d[:, 3] == 1]\n",
    "    x_test = d[:, :2][d[:, 3] == 0]\n",
    "    \n",
    "    num_user = x_train[:,0].max() + 1\n",
    "    num_item = x_train[:,1].max() + 1\n",
    "\n",
    "y_train = binarize(y_train, 2)\n",
    "y_test = binarize(y_test, 2)\n",
    "num_user = int(num_user)\n",
    "num_item = int(num_item)\n",
    "    \n",
    "print(\"# user: {}, # item: {}\".format(num_user, num_item))\n",
    "print(sum(y_train)/len(y_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b005339c-f453-4265-b344-949e2f14b078",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "304cd8bb",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 9196,
     "status": "ok",
     "timestamp": 1681654325865,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "304cd8bb",
    "outputId": "dd3be69c-feb6-4bad-9ef3-4ad15f573f83"
   },
   "outputs": [],
   "source": [
    "\"PMF naive\"\n",
    "PMF = PMF(num_user, num_item, batch_size=2048)\n",
    "PMF.cuda()\n",
    "PMF.fit(x_train, y_train, \n",
    "    lr=0.01,\n",
    "    lamb=5e-5,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF.predict(x_test)\n",
    "mse_PMF = mse_func(y_test, test_pred)\n",
    "auc_PMF = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF, x_test, y_test, top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF, x_test, y_test, top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF]\" + \"***\"*5)\n",
    "print(\"[PMF] test mse:\", mse_PMF)\n",
    "print(\"[PMF] test auc:\", auc_PMF)\n",
    "print(\"[PMF] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "C2vYZX-v1AmS",
   "metadata": {
    "id": "C2vYZX-v1AmS"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5842bc5b",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 34827,
     "status": "ok",
     "timestamp": 1681654501974,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "5842bc5b",
    "outputId": "19a218ac-72c3-4fc9-f8b7-bfd489f5a5fb"
   },
   "outputs": [],
   "source": [
    "\"PMF IPS\"\n",
    "PMF_ips = PMF_IPS(num_user, num_item,batch_size=2048)\n",
    "PMF_ips.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_ips.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_ips.predict(x_test)\n",
    "mse_PMFips = mse_func(y_test, test_pred)\n",
    "auc_PMFips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_ips, x_test, y_test, top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_ips, x_test, y_test, top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-IPS]\" + \"***\"*5)\n",
    "print(\"[PMF-IPS] test mse:\", mse_PMFips)\n",
    "print(\"[PMF-IPS] test auc:\", auc_PMFips)\n",
    "print(\"[PMF-IPS] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-IPS] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-IPS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ZWDPejaY1lWJ",
   "metadata": {
    "id": "ZWDPejaY1lWJ"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02c46142",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 48046,
     "status": "ok",
     "timestamp": 1681653055896,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "02c46142",
    "outputId": "339e8597-7c50-4c51-9a1a-63da12562e1e"
   },
   "outputs": [],
   "source": [
    "\"PMF ASIPS\"\n",
    "PMF_ips = PMF_ASIPS(num_user, num_item,batch_size = 2048)\n",
    "PMF_ips.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_ips.fit(x_train, y_train,  y_ips=y_ips, tao = 0.1,\n",
    "    batch_size = 2048,\n",
    "    stop = 5,\n",
    "    lr=0.01,\n",
    "    G = 1,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_ips.predict(x_test)\n",
    "mse_PMFips = mse_func(y_test, test_pred)\n",
    "auc_PMFips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_ips, x_test, y_test, top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_ips, x_test, y_test, top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-ASIPS]\" + \"***\"*5)\n",
    "print(\"[PMF-ASIPS] test mse:\", mse_PMFips)\n",
    "print(\"[PMF-ASIPS] test auc:\", auc_PMFips)\n",
    "print(\"[PMF-ASIPS] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-ASIPS] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-ASIPS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "O_GhaWY68NvM",
   "metadata": {
    "id": "O_GhaWY68NvM"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "509a2a98",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12004,
     "status": "ok",
     "timestamp": 1681652878268,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "509a2a98",
    "outputId": "a4c85db6-36e8-4d23-fb87-66576441643e"
   },
   "outputs": [],
   "source": [
    "\"PMF-SNIPS\"\n",
    "PMF_snips = PMF_SNIPS(num_user, num_item,batch_size=2048)\n",
    "PMF_snips.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_snips.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.01,\n",
    "    batch_size=2048,\n",
    "    lamb=1e-5,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_snips.predict(x_test)\n",
    "mse_PMFsnips = mse_func(y_test, test_pred)\n",
    "auc_PMFsnips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_snips, x_test, y_test,top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_snips, x_test, y_test,top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-SNIPS]\" + \"***\"*5)\n",
    "print(\"[PMF-SNIPS] test mse:\", mse_PMFsnips)\n",
    "print(\"[PMF-SNIPS] test auc:\", auc_PMFsnips)\n",
    "print(\"[PMF-SNIPS] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-SNIPS] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-SNIPS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "m1g5jxub9FWX",
   "metadata": {
    "id": "m1g5jxub9FWX"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fde7613",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 34570,
     "status": "ok",
     "timestamp": 1681652718916,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "7fde7613",
    "outputId": "dceeb86f-78bd-4b8e-e723-02313066406c"
   },
   "outputs": [],
   "source": [
    "\"PMF CVIB\"\n",
    "PMF_cvib = PMF_CVIB(num_user, num_item,batch_size=2048)\n",
    "PMF_cvib.cuda()\n",
    "PMF_cvib.fit(x_train, y_train, \n",
    "    lr=0.01,\n",
    "    batch_size=2048,\n",
    "    lamb=1e-5,\n",
    "    alpha=0.4,\n",
    "    gamma=1e-2,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "\n",
    "test_pred = PMF_cvib.predict(x_test)\n",
    "mse_PMFcvib = mse_func(y_test, test_pred)\n",
    "auc_PMFcvib = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_cvib, x_test, y_test, top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_cvib, x_test, y_test, top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-CVIB]\" + \"***\"*5)\n",
    "print(\"[PMF-CVIB] test mse:\", mse_PMFcvib)\n",
    "print(\"[PMF-CVIB] test auc:\", auc_PMFcvib)\n",
    "print(\"[PMF-CVIB] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-CVIB] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-CVIB]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "p6jRpCAz9_7Y",
   "metadata": {
    "id": "p6jRpCAz9_7Y"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6f67e1b",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 20560,
     "status": "ok",
     "timestamp": 1681652275875,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "d6f67e1b",
    "outputId": "ede9c715-c39a-442d-8dbf-c755cbea9731"
   },
   "outputs": [],
   "source": [
    "\"PMF DR\"\n",
    "PMF_dr = PMF_DR(num_user, num_item,batch_size=2048)\n",
    "PMF_dr.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.01,\n",
    "    G = 1,\n",
    "    batch_size=2048,\n",
    "    lamb=1e-6,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_dr.predict(x_test)\n",
    "mse_PMFdr = mse_func(y_test, test_pred)\n",
    "auc_PMFdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr, x_test, y_test, top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_dr, x_test, y_test, top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DR]\" + \"***\"*5)\n",
    "print(\"[PMF-DR] test mse:\", mse_PMFdr)\n",
    "print(\"[PMF-DR] test auc:\", auc_PMFdr)\n",
    "print(\"[PMF-DR] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-DR] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DR]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "Vm_mf75Q-qAC",
   "metadata": {
    "id": "Vm_mf75Q-qAC"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4e4f343",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 69388,
     "status": "ok",
     "timestamp": 1681610323594,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "c4e4f343",
    "outputId": "83d6f284-a94b-4d69-c0d8-ef47f7df16e2"
   },
   "outputs": [],
   "source": [
    "\"PMF DR JL\"\n",
    "PMF_dr_jl = PMF_DR_JL(num_user, num_item,batch_size=2048)\n",
    "PMF_dr_jl.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr_jl.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.01,\n",
    "    G = 2,\n",
    "    batch_size=2048,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_dr_jl.predict(x_test)\n",
    "mse_PMFdrjl = mse_func(y_test, test_pred)\n",
    "auc_PMFdrjl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_jl, x_test, y_test, top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_dr_jl, x_test, y_test, top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DR-JL]\" + \"***\"*5)\n",
    "print(\"[PMF-DR-JL] test mse:\", mse_PMFdrjl)\n",
    "print(\"[PMF-DR-JL] test auc:\", auc_PMFdrjl)\n",
    "print(\"[PMF-DR-JL] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-DR-JL] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baR9yEmg_kpw",
   "metadata": {
    "id": "baR9yEmg_kpw"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "082fa612",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 22560,
     "status": "ok",
     "timestamp": 1681651216295,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "082fa612",
    "outputId": "fe74b461-64b1-4b41-8a9b-b2eebcd2015a"
   },
   "outputs": [],
   "source": [
    "\"PMF MRDR JL\"\n",
    "PMF_mrdr_jl = PMF_MRDR_JL(num_user, num_item)\n",
    "PMF_mrdr_jl.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_mrdr_jl.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=2048,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_mrdr_jl.predict(x_test)\n",
    "mse_PMFmrdrjl = mse_func(y_test, test_pred)\n",
    "auc_PMFmrdrjl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_mrdr_jl, x_test, y_test, top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_mrdr_jl, x_test, y_test, top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-MRDR-JL]\" + \"***\"*5)\n",
    "print(\"[PMF-MRDR-JL] test mse:\", mse_PMFmrdrjl)\n",
    "print(\"[PMF-MRDR-JL] test auc:\", auc_PMFmrdrjl)\n",
    "print(\"[PMF-MRDR-JL] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-MRDR-JL] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-MRDR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7aDDZ-iNDA4O",
   "metadata": {
    "id": "7aDDZ-iNDA4O"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c93a8d44-e9fe-4f49-b86d-67d174f16fd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "PMF_dr_bias = PMF_DR_BIAS(num_user, num_item)\n",
    "PMF_dr_bias.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr_bias.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=2048,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_dr_bias.predict(x_test)\n",
    "mse_PMF_dr_bias = mse_func(y_test, test_pred)\n",
    "auc_PMF_dr_bias = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_bias, x_test, y_test, top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_dr_bias, x_test, y_test, top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DR_BIAS]\" + \"***\"*5)\n",
    "print(\"[PMF-DR_BIAS] test mse:\", mse_PMF_dr_bias)\n",
    "print(\"[PMF-DR_BIAS] test auc:\", auc_PMF_dr_bias)\n",
    "print(\"[PMF-DR_BIAS] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-DR_BIAS] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DR_BIAS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4af600bb-831c-4def-a09d-0c0c0517a287",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e5edad8-f551-4711-8d84-b9f73951fbc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "PMF_dr_mse = PMF_DR_MSE(num_user, num_item)\n",
    "PMF_dr_mse.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr_mse.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    gamma = 0.25,\n",
    "    batch_size=2048,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_dr_mse.predict(x_test)\n",
    "mse_PMF_dr_mse = mse_func(y_test, test_pred)\n",
    "auc_PMF_dr_mse = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_mse, x_test, y_test, top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_dr_mse, x_test, y_test, top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DR_MSE]\" + \"***\"*5)\n",
    "print(\"[PMF-DR_MSE] test mse:\", mse_PMF_dr_mse)\n",
    "print(\"[PMF-DR_MSE] test auc:\", auc_PMF_dr_mse)\n",
    "print(\"[PMF-DR_MSE] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-DR_MSE] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DR_MSE]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b54b42d-b80b-4422-9c08-38798753d035",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fl1yR4rmbFas",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 13580,
     "status": "ok",
     "timestamp": 1681654660581,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "fl1yR4rmbFas",
    "outputId": "28d6a038-e8f7-4170-b7e0-3e41c326b20c"
   },
   "outputs": [],
   "source": [
    "\"PMF-DIB\"\n",
    "PMF_dib = PMF_DIB(num_user, num_item,batch_size=2048)\n",
    "PMF_dib.cuda()\n",
    "PMF_dib.fit(x_train, y_train, \n",
    "    lr=0.01,\n",
    "    alpha=0.1, \n",
    "    gamma=0.9,\n",
    "    lamb=5e-5,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_dib.predict(x_test)\n",
    "mse_PMF_dib = mse_func(y_test, test_pred)\n",
    "auc_PMF_dib = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dib, x_test, y_test,top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_dib, x_test, y_test,top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DIB]\" + \"***\"*5)\n",
    "print(\"[PMF-DIB] test mse:\", mse_PMF_dib)\n",
    "print(\"[PMF-DIB] test auc:\", auc_PMF_dib)\n",
    "print(\"[PMF-DIB] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-DIB] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DIB]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "-lPyOALi0G7a",
   "metadata": {
    "id": "-lPyOALi0G7a"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "iydDcGuPoztc",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 62245,
     "status": "ok",
     "timestamp": 1681650710459,
     "user": {
      "displayName": "Jaqueline Noonan",
      "userId": "14082822236352942107"
     },
     "user_tz": -480
    },
    "id": "iydDcGuPoztc",
    "outputId": "926f84be-5ab3-46fc-9c04-1bbc997168be"
   },
   "outputs": [],
   "source": [
    "\"PMF DR JL\"\n",
    "PMF_dr_jl = PMF_DR_JL(num_user, num_item,batch_size=2048)\n",
    "PMF_dr_jl.cuda()\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr_jl.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.01,\n",
    "    G = 2,\n",
    "    batch_size=2048,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = PMF_dr_jl.predict(x_test)\n",
    "mse_PMFdrjl = mse_func(y_test, test_pred)\n",
    "auc_PMFdrjl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_jl, x_test, y_test, top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_dr_jl, x_test, y_test, top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-DR-JL]\" + \"***\"*5)\n",
    "print(\"[PMF-DR-JL] test mse:\", mse_PMFdrjl)\n",
    "print(\"[PMF-DR-JL] test auc:\", auc_PMFdrjl)\n",
    "print(\"[PMF-DR-JL] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-DR-JL] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-DR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b4ded59-332b-42ad-9c29-7aebc929ff90",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "065a62e6-4cd1-4d04-b9ea-8eb8ef312dba",
   "metadata": {},
   "outputs": [],
   "source": [
    "PMF_dr_tmle = PMF_TDR(num_user, num_item, batch_size=2048, embedding_k = 8)\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "prior_y = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_dr_tmle.fit(x_train, y_train, prior_y, gamma = 0.1,\n",
    "    lr=0.05,\n",
    "    G = 1,\n",
    "    lamb=5e-4,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF_dr_tmle.predict(x_test)\n",
    "mse_PMFdrtmle = mse_func(y_test, test_pred)\n",
    "auc_PMFdrtmle = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_tmle, x_test, y_test, top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_dr_tmle, x_test, y_test, top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-TDR]\" + \"***\"*5)\n",
    "print(\"[PMF-TDR] test mse:\", mse_PMFdrtmle)\n",
    "print(\"[PMF-TDR] test auc:\", auc_PMFdrtmle)\n",
    "print(\"[PMF-TDR] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-TDR] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-TDR]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6e8e344-7878-4aaa-88ca-95b1f65c0579",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b2e98ad-d02d-44af-8c4a-c6eb5b9bc913",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"PMF TDR JL\"\n",
    "PMF_dr_tmle_jl = PMF_TDR_JL(num_user, num_item, batch_size=2048, embedding_k = 8)\n",
    "PMF_dr_tmle_jl.fit(x_train, y_train, gamma = 0.1,\n",
    "    lr=0.05,\n",
    "    G = 3,\n",
    "    lamb= 1e-4,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = PMF_dr_tmle_jl.predict(x_test)\n",
    "mse_PMFdrtmlejl = mse_func(y_test, test_pred)\n",
    "auc_PMFdrtmlejl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_dr_tmle_jl, x_test, y_test, top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_dr_tmle_jl, x_test, y_test, top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-TDR-JL]\" + \"***\"*5)\n",
    "print(\"[PMF-TDR-JL] test mse:\", mse_PMFdrtmlejl)\n",
    "print(\"[PMF-TDR-JL] test auc:\", auc_PMFdrtmlejl)\n",
    "print(\"[PMF-TDR-JL] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-TDR-JL] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-TDR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ca77caa-e671-4539-98a6-936080bf6c6b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e5c02f6-fafc-4b44-8c95-8bf000cb9b98",
   "metadata": {},
   "outputs": [],
   "source": [
    "PMF_stable_dr = PMF_Stable_DR(num_user, num_item, embedding_k = 4)\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_stable_dr.fit(x_train, y_train, y_ips, stop = 1,\n",
    "    eta = 1000,\n",
    "    lr=0.05,\n",
    "    G = 5,\n",
    "    batch_size=8192,\n",
    "    lr1 = 150,\n",
    "    lamb=5e-5,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "\n",
    "test_pred = PMF_stable_dr.predict(x_test)\n",
    "mse_PMFsdr = mse_func(y_test, test_pred)\n",
    "auc_PMFsdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_stable_dr, x_test, y_test, top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_stable_dr, x_test, y_test, top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-Stable-DR]\" + \"***\"*5)\n",
    "print(\"[PMF-Stable-DR] test mse:\", mse_PMFsdr)\n",
    "print(\"[PMF-Stable-DR] test auc:\", auc_PMFsdr)\n",
    "print(\"[PMF-Stable-DR] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-Stable-DR] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-Stable-DR]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fd5ed33-3bb2-42e9-8893-2a6076c1b1f4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31c648d9-53f9-42c8-8e21-cf7d162a5d89",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"PMF-OURS\"\n",
    "PMF_ours = PMF_ours_JL(num_user, num_item, batch_size = 2048)\n",
    "PMF_ours.cuda()\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "PMF_ours.fit(x_train, y_train, y_ips=y_ips, \n",
    "        num_epoch=25,\n",
    "        stop = 1,\n",
    "        lr = 0.01,\n",
    "        lr2 = 1e-4,\n",
    "        lamb1 = 1e-5,\n",
    "        lamb2 = 1e-5,\n",
    "        lamb3 = 1e-4,\n",
    "        G = 5,\n",
    "        tol=5e-5)\n",
    "test_pred = PMF_ours.predict(x_test)\n",
    "mse_PMF_ours = mse_func(y_test, test_pred)\n",
    "auc_PMF_ours = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(PMF_ours, x_test, y_test,top_k_list = [20, 50])\n",
    "recall_res = recall_func(PMF_ours, x_test, y_test,top_k_list = [20, 50])\n",
    "\n",
    "print(\"***\"*5 + \"[PMF-OURS]\" + \"***\"*5)\n",
    "print(\"[PMF-OURS] test mse:\", mse_PMF_ours)\n",
    "print(\"[PMF-OURS] test auc:\", auc_PMF_ours)\n",
    "print(\"[PMF-OURS] ndcg@20:{:.6f}, ndcg@50:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_20\"]), np.mean(ndcg_res[\"ndcg_50\"])))\n",
    "print(\"[PMF-OURS] recall@20:{:.6f}, recall@50:{:.6f}\".format(\n",
    "        np.mean(recall_res[\"recall_20\"]), np.mean(recall_res[\"recall_50\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[PMF-OURS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b46d0985-c9f7-436e-af31-b298ef7d0e71",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
