{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import numpy as np\n",
    "import torch\n",
    "import pdb\n",
    "from sklearn.metrics import roc_auc_score\n",
    "np.random.seed(2020)\n",
    "torch.manual_seed(2020)\n",
    "import pdb\n",
    "\n",
    "from TDR_Based import MF, MF_CVIB, MF_IPS, MF_SNIPS, MF_DR, MF_DR_JL, MF_TDR, MF_TDR_JL, MF_TDR_CL, MF_DR_CL, MF_IPS_AT\n",
    "from TMRDR_Based import MF_MRDR_JL, MF_TMRDR_JL, MF_MRDR_CL, MF_TMRDR_CL \n",
    "from dataset import load_data\n",
    "from utils import gini_index, ndcg_func, get_user_wise_ctr, rating_mat_to_sample, binarize, shuffle, minU\n",
    "mse_func = lambda x,y: np.mean((x-y)**2)\n",
    "acc_func = lambda x,y: np.sum(x == y) / len(x)\n",
    "\n",
    "dataset_name = \"coat\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name == \"coat\":\n",
    "    train_mat, test_mat = load_data(\"coat\")        \n",
    "    x_train, y_train = rating_mat_to_sample(train_mat)\n",
    "    x_test, y_test = rating_mat_to_sample(test_mat)\n",
    "    num_user = train_mat.shape[0]\n",
    "    num_item = train_mat.shape[1]\n",
    "\n",
    "elif dataset_name == \"yahoo\":\n",
    "    x_train, y_train, x_test, y_test = load_data(\"yahoo\")\n",
    "    x_train, y_train = shuffle(x_train, y_train)\n",
    "    num_user = x_train[:,0].max() + 1\n",
    "    num_item = x_train[:,1].max() + 1\n",
    "\n",
    "print(\"# user: {}, # item: {}\".format(num_user, num_item))\n",
    "# binarize\n",
    "y_train = binarize(y_train)\n",
    "y_test = binarize(y_test)\n",
    "print(x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF naive\"\n",
    "mf = MF(num_user, num_item, batch_size = 128)\n",
    "\n",
    "mf.fit(x_train, y_train,\n",
    "    lr=0.01,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf.predict(x_test)\n",
    "mse_mf = mse_func(y_test, test_pred)\n",
    "auc_mf = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF]\" + \"***\"*5)\n",
    "print(\"[MF] test mse:\", mse_mf)\n",
    "print(\"[MF] test auc:\", auc_mf)\n",
    "print(\"[MF] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF]\" + \"***\"*5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF IPS\"\n",
    "mf_ips = MF_IPS(num_user, num_item, batch_size = 128)\n",
    "\n",
    "mf_ips.fit(x_train, y_train, gamma = 0.1,\n",
    "    lr=0.01,\n",
    "    lamb=(1e-5),\n",
    "    tol=1e-4,\n",
    "    verbose=False)\n",
    "test_pred = mf_ips.predict(x_test)\n",
    "mse_mfips = mse_func(y_test, test_pred)\n",
    "auc_mfips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_ips, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-IPS]\" + \"***\"*5)\n",
    "print(\"[MF-IPS] test mse:\", mse_mfips)\n",
    "print(\"[MF-IPS] test auc:\", auc_mfips)\n",
    "print(\"[MF-IPS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-IPS]\" + \"***\"*5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF IPS_AT\"\n",
    "mf_ips = MF_IPS_AT(num_user, num_item, batch_size=128)\n",
    "\n",
    "mf_ips.fit(x_train, y_train, tao = 0.05,\n",
    "    lr=0.01,\n",
    "    gamma = 0.05,\n",
    "    G = 4,\n",
    "    lamb = 4*1e-5,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_ips.predict(x_test)\n",
    "mse_mfips = mse_func(y_test, test_pred)\n",
    "auc_mfips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_ips, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-IPS-AT]\" + \"***\"*5)\n",
    "print(\"[MF-IPS-AT] test mse:\", mse_func(y_test, test_pred))\n",
    "print(\"[MF-IPS-AT] test auc:\", auc_mfips)\n",
    "print(\"[MF-IPS-AT] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-IPS-AT]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF-SNIPS\"\n",
    "\n",
    "mf_snips = MF_SNIPS(num_user, num_item, batch_size = 128)\n",
    "\n",
    "mf_snips.fit(x_train, y_train, gamma = 0.1,\n",
    "    lr=0.01,\n",
    "    lamb=1e-4,\n",
    "    tol=1e-3,\n",
    "    verbose=False)\n",
    "test_pred = mf_snips.predict(x_test)\n",
    "mse_mfsnips = mse_func(y_test, test_pred)\n",
    "auc_mfsnips = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_snips, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-SNIPS]\" + \"***\"*5)\n",
    "print(\"[MF-SNIPS] test mse:\", mse_mfsnips)\n",
    "print(\"[MF-SNIPS] test auc:\", auc_mfsnips)\n",
    "print(\"[MF-SNIPS] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-SNIPS]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF CVIB\"\n",
    "\n",
    "mf_cvib = MF_CVIB(num_user, num_item)\n",
    "mf_cvib.fit(x_train, y_train, \n",
    "    lr=0.05,\n",
    "    batch_size=128,\n",
    "    lamb=3*1e-4,\n",
    "    alpha=1.0,\n",
    "    gamma=1e-3,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "\n",
    "test_pred = mf_cvib.predict(x_test)\n",
    "mse_mfcvib = mse_func(y_test, test_pred)\n",
    "auc_mfcvib = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_cvib, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-CVIB]\" + \"***\"*5)\n",
    "print(\"[MF-CVIB] test mse:\", mse_mfcvib)\n",
    "print(\"[MF-CVIB] test auc:\", auc_mfcvib)\n",
    "print(\"[MF-CVIB] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-CVIB]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF DR\"\n",
    "mf_dr = MF_DR(num_user, num_item, batch_size = 128)\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "prior_y = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_dr.fit(x_train, y_train, prior_y, gamma = 0.1,\n",
    "    lr=0.05,\n",
    "    G = 1,\n",
    "    lamb=1e-3,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_dr.predict(x_test)\n",
    "mse_mfdr = mse_func(y_test, test_pred)\n",
    "auc_mfdr = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_dr, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-DR]\" + \"***\"*5)\n",
    "print(\"[MF-DR] test mse:\", mse_mfdr)\n",
    "print(\"[MF-DR] test auc:\", auc_mfdr)\n",
    "print(\"[MF-DR] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-DR]\" + \"***\"*5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF DR JL\"\n",
    "\n",
    "mf_dr_jl = MF_DR_JL(num_user, num_item, batch_size = 128)\n",
    "\n",
    "mf_dr_jl.fit(x_train, y_train, x_test, y_test,\n",
    "    lr=0.05,\n",
    "    G = 6,\n",
    "    lamb=1e-3,\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_dr_jl.predict(x_test)\n",
    "mse_mfdrjl = mse_func(y_test, test_pred)\n",
    "auc_mfdrjl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_dr_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-DR-JL]\" + \"***\"*5)\n",
    "print(\"[MF-DR-JL] test mse:\", mse_mfdrjl)\n",
    "print(\"[MF-DR-JL] test auc:\", auc_mfdrjl)\n",
    "print(\"[MF-DR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-DR-JL]\" + \"***\"*5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF DR CL\"\n",
    "\n",
    "mf_dr_tl = MF_DR_CL(num_user, num_item, batch_size = 128)\n",
    "\n",
    "mf_dr_tl.fit(x_train, y_train, x_test, y_test,\n",
    "    lr=0.05,\n",
    "    G = 6,\n",
    "    lamb=(1e-3),\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_dr_tl.predict(x_test)\n",
    "mse_mfdrtl = mse_func(y_test, test_pred)\n",
    "auc_mfdrtl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_dr_tl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-DR-CL]\" + \"***\"*5)\n",
    "print(\"[MF-DR-CL] test mse:\", mse_mfdrtl)\n",
    "print(\"[MF-DR-CL] test auc:\", auc_mfdrtl)\n",
    "print(\"[MF-DR-CL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-DR-CL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF TDR\"\n",
    "\n",
    "mf_dr_tmle = MF_TDR(num_user, num_item, batch_size = 128)\n",
    "\n",
    "ips_idxs = np.arange(len(y_test))\n",
    "np.random.shuffle(ips_idxs)\n",
    "prior_y = y_test[ips_idxs[:int(0.05 * len(ips_idxs))]]\n",
    "\n",
    "mf_dr_tmle.fit(x_train, y_train, prior_y, gamma = 0.1,\n",
    "    lr=0.05,\n",
    "    G = 3,\n",
    "    lamb=5*(1e-5),\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_dr_tmle.predict(x_test)\n",
    "mse_mfdrtmle = mse_func(y_test, test_pred)\n",
    "auc_mfdrtmle = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_dr_tmle, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-TDR]\" + \"***\"*5)\n",
    "print(\"[MF-TDR] test mse:\", mse_mfdrtmle)\n",
    "print(\"[MF-TDR] test auc:\", auc_mfdrtmle)\n",
    "print(\"[MF-TDR] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-TDR]\" + \"***\"*5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF TDR JL\"\n",
    "\n",
    "mf_dr_tmle_jl = MF_TDR_JL(num_user, num_item, batch_size = 128)\n",
    "\n",
    "mf_dr_tmle_jl.fit(x_train, y_train,\n",
    "    lr=0.05,\n",
    "    G = 4,\n",
    "    lamb=(1e-3),\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_dr_tmle_jl.predict(x_test)\n",
    "mse_mfdrtmlejl = mse_func(y_test, test_pred)\n",
    "auc_mfdrtmlejl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_dr_tmle_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-TDR-JL]\" + \"***\"*5)\n",
    "print(\"[MF-TDR-JL] test mse:\", mse_mfdrtmlejl)\n",
    "print(\"[MF-TDR-JL] test auc:\", auc_mfdrtmlejl)\n",
    "print(\"[MF-TDR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-TDR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF TDR CL\"\n",
    "\n",
    "mf_dr_tmle_tl = MF_TDR_CL(num_user, num_item, batch_size = 128)\n",
    "\n",
    "mf_dr_tmle_tl.fit(x_train, y_train, x_test, y_test,\n",
    "    lr=0.05,\n",
    "    G = 1,\n",
    "    lamb=3*(1e-3),\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_dr_tmle_tl.predict(x_test)\n",
    "mse_mfdrtmletl = mse_func(y_test, test_pred)\n",
    "auc_mfdrtmletl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_dr_tmle_tl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-DR-TMLE-TL]\" + \"***\"*5)\n",
    "print(\"[MF-DR-TMLE-TL] test mse:\", mse_mfdrtmletl)\n",
    "print(\"[MF-DR-TMLE-TL] test auc:\", auc_mfdrtmletl)\n",
    "print(\"[MF-DR-TMLE-TL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-DR-TMLE-TL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF MRDR JL\"\n",
    "\n",
    "mf_mrdr_jl = MF_MRDR_JL(num_user, num_item, batch_size = 128)\n",
    "\n",
    "mf_mrdr_jl.fit(x_train, y_train, x_test, y_test, gamma = 0.1,\n",
    "    lr=0.05,\n",
    "    G = 2,\n",
    "    lamb=(1e-3),\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_mrdr_jl.predict(x_test)\n",
    "mse_mfmrdrjl = mse_func(y_test, test_pred)\n",
    "auc_mfmrdrjl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_mrdr_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-MRDR-JL]\" + \"***\"*5)\n",
    "print(\"[MF-MRDR-JL] test mse:\", mse_mfmrdrjl)\n",
    "print(\"[MF-MRDR-JL] test auc:\", auc_mfmrdrjl)\n",
    "print(\"[MF-MRDR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-MRDR-JL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF MRDR CL\"\n",
    "\n",
    "mf_mrdr_tl = MF_MRDR_CL(num_user, num_item, batch_size = 128)\n",
    "\n",
    "mf_mrdr_tl.fit(x_train, y_train, x_test, y_test,\n",
    "    lr=0.05,\n",
    "    G = 1,\n",
    "    lamb=2*(1e-3),\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_mrdr_tl.predict(x_test)\n",
    "mse_mfmrdrtl = mse_func(y_test, test_pred)\n",
    "auc_mfmrdrtl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_mrdr_tl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-MRDR-CL]\" + \"***\"*5)\n",
    "print(\"[MF-MRDR-CL] test mse:\", mse_mfmrdrtl)\n",
    "print(\"[MF-MRDR-CL] test auc:\", auc_mfmrdrtl)\n",
    "print(\"[MF-MRDR-CL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-MRDR-CL]\" + \"***\"*5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF TMRDR JL\"\n",
    "\n",
    "mf_mrdr_tmle_jl = MF_TMRDR_JL(num_user, num_item, batch_size = 128)\n",
    "\n",
    "mf_mrdr_tmle_jl.fit(x_train, y_train,\n",
    "    lr=0.05,\n",
    "    G = 1,\n",
    "    lamb=3*(1e-3),\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_mrdr_tmle_jl.predict(x_test)\n",
    "mse_mfmrdrtmlejl = mse_func(y_test, test_pred)\n",
    "auc_mfmrdrtmlejl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_mrdr_tmle_jl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-TMRDR-JL]\" + \"***\"*5)\n",
    "print(\"[MF-TMRDR-JL] test mse:\", mse_mfmrdrtmlejl)\n",
    "print(\"[MF-TMRDR-JL] test auc:\", auc_mfmrdrtmlejl)\n",
    "print(\"[MF-TMRDR-JL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-TMRDR-JL]\" + \"***\"*5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"MF TMRDR CL\"\n",
    "\n",
    "mf_mrdr_tmle_tl = MF_TMRDR_CL(num_user, num_item, batch_size = 128)\n",
    "\n",
    "mf_mrdr_tmle_tl.fit(x_train, y_train, x_test, y_test,\n",
    "    lr=0.05,\n",
    "    G = 2,\n",
    "    lamb=3*(1e-3),\n",
    "    tol=1e-5,\n",
    "    verbose=False)\n",
    "test_pred = mf_mrdr_tmle_tl.predict(x_test)\n",
    "mse_mfmrdrtmletl = mse_func(y_test, test_pred)\n",
    "auc_mfmrdrtmletl = roc_auc_score(y_test, test_pred)\n",
    "ndcg_res = ndcg_func(mf_mrdr_tmle_tl, x_test, y_test)\n",
    "\n",
    "print(\"***\"*5 + \"[MF-MRDR-TMLE-TL]\" + \"***\"*5)\n",
    "print(\"[MF-MRDR-TMLE-TL] test mse:\", mse_mfmrdrtmletl)\n",
    "print(\"[MF-MRDR-TMLE-TL] test auc:\", auc_mfmrdrtmletl)\n",
    "print(\"[MF-MRDR-TMLE-TL] ndcg@5:{:.6f}, ndcg@10:{:.6f}\".format(\n",
    "        np.mean(ndcg_res[\"ndcg_5\"]), np.mean(ndcg_res[\"ndcg_10\"])))\n",
    "user_wise_ctr = get_user_wise_ctr(x_test,y_test,test_pred)\n",
    "gi,gu = gini_index(user_wise_ctr)\n",
    "print(\"***\"*5 + \"[MF-MRDR-TMLE-TL]\" + \"***\"*5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:sdr_test]",
   "language": "python",
   "name": "conda-env-sdr_test-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
