{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "298276c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import numpy as np\n",
    "import torch\n",
    "import pdb\n",
    "from sklearn.metrics import roc_auc_score\n",
    "np.random.seed(2020)\n",
    "torch.manual_seed(2020)\n",
    "import pdb\n",
    "\n",
    "from dataset import load_data\n",
    "from MF_NB_Single_Model__MF_IPS_IPSAT_SNIPS_DR_CVIB import MF, MF_IPS, MF_SNIPS, MF_DR, MF_CVIB, MF_IPS_AT\n",
    "from MF_NB_Joint_Learning__DRJL_MRDRJL import MF_DR_JL, MF_MRDR_JL\n",
    "from MF_NB_Cycle_Learning__SDR_SMRDR import MF_Stable_DR, MF_Stable_MRDR\n",
    "from utils import gini_index, ndcg_func, get_user_wise_ctr, rating_mat_to_sample, binarize, shuffle, minU\n",
    "mse_func = lambda x,y: np.mean((x-y)**2)\n",
    "acc_func = lambda x,y: np.sum(x == y) / len(x)\n",
    "\n",
    "dataset_name = \"coat\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "902db9a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name == \"coat\":\n",
    "    train_mat, test_mat = load_data(\"coat\")        \n",
    "    x_train, y_train = rating_mat_to_sample(train_mat)\n",
    "    x_test, y_test = rating_mat_to_sample(test_mat)\n",
    "    num_user = train_mat.shape[0]\n",
    "    num_item = train_mat.shape[1]\n",
    "\n",
    "elif dataset_name == \"yahoo\":\n",
    "    x_train, y_train, x_test, y_test = load_data(\"yahoo\")\n",
    "    x_train, y_train = shuffle(x_train, y_train)\n",
    "    num_user = x_train[:,0].max() + 1\n",
    "    num_item = x_train[:,1].max() + 1\n",
    "\n",
    "print(\"# user: {}, # item: {}\".format(num_user, num_item))\n",
    "# binarize\n",
    "y_train = binarize(y_train)\n",
    "y_test = binarize(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "304cd8bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF naive\"\n",
    "mf = MF(num_user, num_item)\n",
    "\n",
    "mf.fit(x_train, y_train, \n",
    "    lr=0.05,\n",
    "    batch_size=128,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = mf.predict(x_test)\n",
    "mse_mf = mse_func(y_test, test_pred)\n",
    "auc_mf = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF]\" + \"***\"*5)\n",
    "print(\"[MF] test mse:\", mse_mf)\n",
    "print(\"[MF] test auc:\", auc_mf)\n",
    "print(\"[MF] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5842bc5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF IPS\"\n",
    "mf_ips = MF_IPS(num_user, num_item)\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_ips.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=128,\n",
    "    lamb=1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_ips.predict(x_test)\n",
    "mse_mfips = mse_func(y_test, test_pred)\n",
    "auc_mfips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_ips, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-IPS]\" + \"***\"*5)\n",
    "print(\"[MF-IPS] test mse:\", mse_mfips)\n",
    "print(\"[MF-IPS] test auc:\", auc_mfips)\n",
    "print(\"[MF-IPS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-IPS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02c46142",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF IPS_AT\"\n",
    "mf_ips = MF_IPS_AT(num_user, num_item)\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_ips.fit(x_train, y_train,  y_ips=y_ips, batch_size = 128, tao = 0.1,\n",
    "    lr=0.01,\n",
    "    G = 5,\n",
    "    lamb=5*1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_ips.predict(x_test)\n",
    "mse_mfips = mse_func(y_test, test_pred)\n",
    "auc_mfips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_ips, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-IPS_AT]\" + \"***\"*5)\n",
    "print(\"[MF-IPS_AT] test mse:\", mse_mfips)\n",
    "print(\"[MF-IPS_AT] test auc:\", auc_mfips)\n",
    "print(\"[MF-IPS_AT] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-IPS_AT]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "509a2a98",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF-SNIPS\"\n",
    "mf_snips = MF_SNIPS(num_user, num_item)\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_snips.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=128,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_snips.predict(x_test)\n",
    "mse_mfsnips = mse_func(y_test, test_pred)\n",
    "auc_mfsnips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_snips, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-SNIPS]\" + \"***\"*5)\n",
    "print(\"[MF-SNIPS] test mse:\", mse_mfsnips)\n",
    "print(\"[MF-SNIPS] test auc:\", auc_mfsnips)\n",
    "print(\"[MF-SNIPS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-SNIPS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fde7613",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF CVIB\"\n",
    "mf_cvib = MF_CVIB(num_user, num_item)\n",
    "mf_cvib.fit(x_train, y_train, \n",
    "    lr=0.05,\n",
    "    batch_size=128,\n",
    "    lamb=3*1e-4,\n",
    "    alpha=1.0,\n",
    "    gamma=1e-3,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "\n",
    "test_pred = mf_cvib.predict(x_test)\n",
    "mse_mfcvib = mse_func(y_test, test_pred)\n",
    "auc_mfcvib = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_cvib, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-CVIB]\" + \"***\"*5)\n",
    "print(\"[MF-CVIB] test mse:\", mse_mfcvib)\n",
    "print(\"[MF-CVIB] test auc:\", auc_mfcvib)\n",
    "print(\"[MF-CVIB] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-CVIB]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6f67e1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF DR\"\n",
    "mf_dr = MF_DR(num_user, num_item)\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_dr.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=128,\n",
    "    lamb=1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_dr.predict(x_test)\n",
    "mse_mfdr = mse_func(y_test, test_pred)\n",
    "auc_mfdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_dr, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-DR]\" + \"***\"*5)\n",
    "print(\"[MF-DR] test mse:\", mse_mfdr)\n",
    "print(\"[MF-DR] test auc:\", auc_mfdr)\n",
    "print(\"[MF-DR] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-DR]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4e4f343",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF DR JL\"\n",
    "mf_dr_jl = MF_DR_JL(num_user, num_item)\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_dr_jl.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=128,\n",
    "    lamb=1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_dr_jl.predict(x_test)\n",
    "mse_mfdrjl = mse_func(y_test, test_pred)\n",
    "auc_mfdrjl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_dr_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-DR-JL]\" + \"***\"*5)\n",
    "print(\"[MF-DR-JL] test mse:\", mse_mfdrjl)\n",
    "print(\"[MF-DR-JL] test auc:\", auc_mfdrjl)\n",
    "print(\"[MF-DR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-DR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2be80ea5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF Stable DR\"\n",
    "mf_stable_dr = MF_Stable_DR(num_user, num_item)\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_stable_dr.fit(x_train, y_train, y_ips,\n",
    "    eta = 100,\n",
    "    lr=0.01,\n",
    "    G = 4,\n",
    "    batch_size=128,\n",
    "    lr1 = 10,\n",
    "    lamb=5*1e-4,\n",
    "    tol=1e-5)\n",
    "\n",
    "test_pred = mf_stable_dr.predict(x_test)\n",
    "mse_mfsdr = mse_func(y_test, test_pred)\n",
    "auc_mfsdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_stable_dr, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-Stable-DR]\" + \"***\"*5)\n",
    "print(\"[MF-Stable-DR] test mse:\", mse_mfsdr)\n",
    "print(\"[MF-Stable-DR] test auc:\", auc_mfsdr)\n",
    "print(\"[MF-Stable-DR] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-Stable-DR]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "082fa612",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF MRDR JL\"\n",
    "mf_mrdr_jl = MF_MRDR_JL(num_user, num_item)\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_mrdr_jl.fit(x_train, y_train,  y_ips=y_ips,\n",
    "    lr=0.05,\n",
    "    batch_size=128,\n",
    "    lamb=1e-3,\n",
    "    tol=1e-5)\n",
    "test_pred = mf_mrdr_jl.predict(x_test)\n",
    "mse_mfmrdrjl = mse_func(y_test, test_pred)\n",
    "auc_mfmrdrjl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_mrdr_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-MRDR-JL]\" + \"***\"*5)\n",
    "print(\"[MF-MRDR-JL] test mse:\", mse_mfmrdrjl)\n",
    "print(\"[MF-MRDR-JL] test auc:\", auc_mfmrdrjl)\n",
    "print(\"[MF-MRDR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-MRDR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49dc8137",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF Stable MRDR\"\n",
    "mf_stable_mrdr = MF_Stable_MRDR(num_user, num_item)\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "y_ips = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_stable_mrdr.fit(x_train, y_train, y_ips,\n",
    "    eta = 100,\n",
    "    lr=0.01,\n",
    "    G = 4,\n",
    "    batch_size=128,\n",
    "    lr1 = 10,\n",
    "    lamb=5*1e-4,\n",
    "    tol=1e-5)\n",
    "\n",
    "test_pred = mf_stable_mrdr.predict(x_test)\n",
    "mse_mfsmrdr = mse_func(y_test, test_pred)\n",
    "auc_mfsmrdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_stable_mrdr, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-Stable-MRDR]\" + \"***\"*5)\n",
    "print(\"[MF-Stable-MRDR] test mse:\", mse_mfsmrdr)\n",
    "print(\"[MF-Stable-MRDR] test auc:\", auc_mfsmrdr)\n",
    "print(\"[MF-Stable-MRDR] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-Stable-MRDR]\" + \"***\"*5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:sdr_test]",
   "language": "python",
   "name": "conda-env-sdr_test-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
