{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6b9587f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy.io as sio\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from utils_knn import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b76c028b",
   "metadata": {},
   "source": [
    "# Goodreads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "afd9ecc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.load(\"./data/X_goodread.npy\", allow_pickle=True)\n",
    "w = np.load(\"./data/w_goodread.npy\", allow_pickle=True)\n",
    "y_likability = np.loadtxt(\"./data/y_likability_goodread.csv\")\n",
    "y_genre = np.loadtxt(\"./data/y_genre_goodread.csv\")\n",
    "\n",
    "path = \"./results_goodreads/\"\n",
    "dataset = \"goodreads\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ba7c50c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_train_likability, idx_test_likability = [], []\n",
    "idx_train_genre, idx_test_genre = [], []\n",
    "\n",
    "np.random.seed(2023)\n",
    "\n",
    "seeds = np.random.randint(0, 1000, 5)\n",
    "\n",
    "for k in range(5):\n",
    "    y_train, y_test, id_train, id_test = train_test_split(y_likability, list(range(len(X))), \n",
    "                                                          shuffle=True, random_state=seeds[k],\n",
    "                                                         stratify=y_likability)\n",
    "    idx_train_likability.append(id_train)\n",
    "    idx_test_likability.append(id_test)\n",
    "    \n",
    "    \n",
    "    y_train, y_test, id_train, id_test = train_test_split(y_genre, list(range(len(X))), \n",
    "                                                          shuffle=True, random_state=seeds[k],\n",
    "                                                         stratify=y_genre)\n",
    "    idx_train_genre.append(id_train)\n",
    "    idx_test_genre.append(id_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd5099a3",
   "metadata": {},
   "source": [
    "### Results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d2c1314",
   "metadata": {},
   "source": [
    "#### OT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e45dc2b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7099601593625497\n",
      "0.5521912350597609\n"
     ]
    }
   ],
   "source": [
    "d_sw = np.loadtxt(path + \"d_w_goodreads_k0\")\n",
    "acc_w_likability = np.mean(get_acc_knn(d_sw, y_likability, idx_train_likability, idx_test_likability))\n",
    "acc_w_genre = np.mean(get_acc_knn(d_sw, y_genre, idx_train_genre, idx_test_genre))\n",
    "\n",
    "print(acc_w_likability)\n",
    "print(acc_w_genre)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e34ff4ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.5099601593625498,\n",
       " 0.5816733067729084,\n",
       " 0.549800796812749,\n",
       " 0.5697211155378487,\n",
       " 0.549800796812749]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_acc_knn(d_sw, y_genre, idx_train_genre, idx_test_genre)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "429ffdb1",
   "metadata": {},
   "source": [
    "#### SOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9709fdba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6560424966799469 0.001987600205457609\n",
      "0.5009296148738379 0.005053465495468909\n"
     ]
    }
   ],
   "source": [
    "L_sw_likability = np.zeros((3,))\n",
    "L_sw_genre = np.zeros((3,))\n",
    "\n",
    "for k in range(3):\n",
    "    d_sw = np.loadtxt(path + \"d_projs500_sw_goodreads_k\"+str(k))\n",
    "    acc_sw_likability = np.mean(get_acc_knn(d_sw, y_likability, idx_train_likability, idx_test_likability))\n",
    "    L_sw_likability[k] = acc_sw_likability\n",
    "    \n",
    "    acc_sw_genre = np.mean(get_acc_knn(d_sw, y_genre, idx_train_genre, idx_test_genre))\n",
    "    L_sw_genre[k] = acc_sw_genre\n",
    "    \n",
    "print(np.mean(L_sw_likability), np.std(L_sw_likability))\n",
    "print(np.mean(L_sw_genre), np.std(L_sw_genre))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d51633f",
   "metadata": {},
   "source": [
    "#### Sinkhorn UOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "68d9001d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.67808765]\n",
      "[0.53545817]\n"
     ]
    }
   ],
   "source": [
    "rhos = [0.1]\n",
    "\n",
    "L_mean_sinkhorn_likability = np.zeros((len(rhos),))\n",
    "L_mean_sinkhorn_genre = np.zeros((len(rhos),))\n",
    "\n",
    "for i, rho in enumerate(rhos):\n",
    "    d_sw = np.loadtxt(path + \"d_sinkhorn_goodreads_rho\"+str(rho)+\"_reg0.001_k0\")\n",
    "    L_acc = get_acc_knn(d_sw, y_likability, idx_train_likability, idx_test_likability)\n",
    "    L_mean_sinkhorn_likability[i] = np.mean(L_acc)\n",
    "    \n",
    "    L_acc = get_acc_knn(d_sw, y_genre, idx_train_genre, idx_test_genre)\n",
    "    L_mean_sinkhorn_genre[i] = np.mean(L_acc)\n",
    "\n",
    "print(L_mean_sinkhorn_likability)\n",
    "print(L_mean_sinkhorn_genre)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "552faf0b",
   "metadata": {},
   "source": [
    "#### SUOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6f08c131",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.47835325 0.50491368 0.55192563 0.43930943 0.65790173 0.66401062\n",
      " 0.66719788 0.66480744] [0.01203747 0.01166242 0.00716639 0.01439173 0.0039752  0.00187811\n",
      " 0.00375621 0.00270864]\n",
      "[0.17795485 0.16095618 0.17742364 0.19096946 0.50146082 0.50066401\n",
      " 0.49694555 0.49853918] [0.00369944 0.00945044 0.0039216  0.0019876  0.00037562 0.00560922\n",
      " 0.00776183 0.00641862]\n"
     ]
    }
   ],
   "source": [
    "rhos = [0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0]\n",
    "\n",
    "L_mean_suw_likability = np.zeros((len(rhos), 3))\n",
    "L_mean_suw_genre = np.zeros((len(rhos), 3))\n",
    "\n",
    "for i, rho in enumerate(rhos):\n",
    "    for k in range(3):\n",
    "        d_sw = np.loadtxt(path + \"d_projs500_suw_goodreads_rho1\"+str(rho)+\"_rho2\"+str(rho)+\"_k\"+str(k))\n",
    "        L_acc = get_acc_knn(d_sw, y_likability, idx_train_likability, idx_test_likability)\n",
    "        L_mean_suw_likability[i, k] = np.mean(L_acc)\n",
    "        \n",
    "        L_acc = get_acc_knn(d_sw, y_genre, idx_train_genre, idx_test_genre)\n",
    "        L_mean_suw_genre[i, k] = np.mean(L_acc)\n",
    "        \n",
    "print(np.mean(L_mean_suw_likability, axis=-1), np.std(L_mean_suw_likability, axis=-1))\n",
    "print(np.mean(L_mean_suw_genre, axis=-1), np.std(L_mean_suw_genre, axis=-1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8b33287",
   "metadata": {},
   "source": [
    "#### USOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ce1d8286",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.6063745  0.67330677 0.662417   0.67675963 0.67782205 0.67410359\n",
      " 0.67357238 0.67622842 0.6687915  0.65763612 0.65976096] [0.00640762 0.00703726 0.004969   0.00300497 0.0039216  0.0106507\n",
      " 0.0052587  0.00433187 0.00300497 0.00209137 0.00660282]\n",
      "[0.19017264 0.41752988 0.5189907  0.52669323 0.51633466 0.5187251\n",
      " 0.52217795 0.51606906 0.51527224 0.50491368 0.49296149] [0.00163729 0.00532535 0.00262935 0.00620628 0.00395741 0.00555869\n",
      " 0.00307459 0.00762428 0.0035832  0.00604505 0.00529879]\n"
     ]
    }
   ],
   "source": [
    "rhos = [0.0005, 0.001, 0.002, 0.003, 0.004, 0.005, 0.006, 0.007, 0.01, 0.1, 1.0]\n",
    "\n",
    "L_mean_rsw_likability = np.zeros((len(rhos), 3))\n",
    "L_mean_rsw_genre = np.zeros((len(rhos), 3))\n",
    "\n",
    "for i, rho in enumerate(rhos):\n",
    "    for k in range(3):\n",
    "        d_sw = np.loadtxt(path + \"d_projs500_rsw_goodreads_rho1\"+str(rho)+\"_rho2\"+str(rho)+\"_k\"+str(k))\n",
    "        L_acc = get_acc_knn(d_sw, y_likability, idx_train_likability, idx_test_likability)\n",
    "        L_mean_rsw_likability[i, k] = np.mean(L_acc)\n",
    "        \n",
    "        L_acc = get_acc_knn(d_sw, y_genre, idx_train_genre, idx_test_genre)\n",
    "        L_mean_rsw_genre[i, k] = np.mean(L_acc)\n",
    "        \n",
    "print(np.mean(L_mean_rsw_likability, axis=-1), np.std(L_mean_rsw_likability, axis=-1))\n",
    "print(np.mean(L_mean_rsw_genre, axis=-1), np.std(L_mean_rsw_genre, axis=-1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f338299d",
   "metadata": {},
   "source": [
    "#### SUSOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e5c838a4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.67330677] [0.00260238]\n",
      "[0.51925631] [0.0052587]\n"
     ]
    }
   ],
   "source": [
    "rhos = [0.005]\n",
    "\n",
    "L_mean_rsw_likability = np.zeros((len(rhos), 3))\n",
    "L_mean_rsw_genre = np.zeros((len(rhos), 3))\n",
    "\n",
    "for i, rho in enumerate(rhos):\n",
    "    for k in range(3):\n",
    "        d_sw = np.loadtxt(path + \"d_projs500_stochastic_rsw_goodreads_rho1\"+str(rho)+\"_rho2\"+str(rho)+\"_k\"+str(k))\n",
    "        L_acc = get_acc_knn(d_sw, y_likability, idx_train_likability, idx_test_likability)\n",
    "        L_mean_rsw_likability[i, k] = np.mean(L_acc)\n",
    "        \n",
    "        L_acc = get_acc_knn(d_sw, y_genre, idx_train_genre, idx_test_genre)\n",
    "        L_mean_rsw_genre[i, k] = np.mean(L_acc)\n",
    "        \n",
    "print(np.mean(L_mean_rsw_likability, axis=-1), np.std(L_mean_rsw_likability, axis=-1))\n",
    "print(np.mean(L_mean_rsw_genre, axis=-1), np.std(L_mean_rsw_genre, axis=-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "599a2407",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e45273d7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "70230612",
   "metadata": {},
   "source": [
    "#### USOT + Cross Val on $\\rho$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "49c27557",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.5147410358565737 [0.47410358565737054, 0.5338645418326693, 0.4860557768924303, 0.5219123505976095, 0.5577689243027888]\n",
      "1 0.5051792828685259 [0.46215139442231074, 0.5099601593625498, 0.50199203187251, 0.5219123505976095, 0.5298804780876494]\n",
      "2 0.5099601593625498 [0.4860557768924303, 0.5338645418326693, 0.46215139442231074, 0.5378486055776892, 0.5298804780876494]\n",
      "0.5099601593625498 0.0039035693112082666\n"
     ]
    }
   ],
   "source": [
    "rhos = [0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0]\n",
    "rhos = [0.0005, 0.001, 0.005, 0.01, 0.1, 1.0]\n",
    "\n",
    "L_mean_rsw = []\n",
    "\n",
    "for k in range(3):\n",
    "    L_acc = get_acc_knn_cv_rho(dataset, \"rsw\", k, rhos, y_genre, idx_train_genre, idx_test_genre)\n",
    "    print(k, np.mean(L_acc), L_acc)\n",
    "    L_mean_rsw.append(np.mean(L_acc))\n",
    "    \n",
    "print(np.mean(L_mean_rsw, axis=-1), np.std(L_mean_rsw, axis=-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "350fb241",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.6573705179282868 [0.6653386454183267, 0.6334661354581673, 0.6772908366533864, 0.6533864541832669, 0.6573705179282868]\n",
      "1 0.6661354581673307 [0.7290836653386454, 0.6454183266932271, 0.6334661354581673, 0.6533864541832669, 0.6693227091633466]\n",
      "2 0.6749003984063745 [0.6573705179282868, 0.6613545816733067, 0.6733067729083665, 0.701195219123506, 0.6812749003984063]\n",
      "0.6661354581673308 0.007156543737215141\n"
     ]
    }
   ],
   "source": [
    "L_mean_rsw = []\n",
    "\n",
    "for k in range(3):\n",
    "    L_acc = get_acc_knn_cv_rho(dataset, \"rsw\", k, rhos, y_likability, idx_train_likability, idx_test_likability)\n",
    "    print(k, np.mean(L_acc), L_acc)\n",
    "    L_mean_rsw.append(np.mean(L_acc))\n",
    "    \n",
    "print(np.mean(L_mean_rsw, axis=-1), np.std(L_mean_rsw, axis=-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28206068",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "23c324dc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.5298804780876494 [0.5059760956175299, 0.5219123505976095, 0.5338645418326693, 0.5338645418326693, 0.5537848605577689]\n",
      "1 0.5123505976095617 [0.46215139442231074, 0.5338645418326693, 0.5059760956175299, 0.5219123505976095, 0.5378486055776892]\n",
      "2 0.5195219123505976 [0.4860557768924303, 0.5378486055776892, 0.49800796812749004, 0.5378486055776892, 0.5378486055776892]\n",
      "0.5205843293492696 0.007195865702068569\n"
     ]
    }
   ],
   "source": [
    "rhos = [0.0005, 0.001, 0.002, 0.003, 0.004, 0.005, 0.01, 0.1, 1.0]\n",
    "\n",
    "L_mean_rsw = []\n",
    "\n",
    "for k in range(3):\n",
    "    L_acc = get_acc_knn_cv_rho(dataset, \"rsw\", k, rhos, y_genre, idx_train_genre, idx_test_genre)\n",
    "    print(k, np.mean(L_acc), L_acc)\n",
    "    L_mean_rsw.append(np.mean(L_acc))\n",
    "    \n",
    "print(np.mean(L_mean_rsw, axis=-1), np.std(L_mean_rsw, axis=-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e79e9ab2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.6629482071713146 [0.6653386454183267, 0.6613545816733067, 0.6772908366533864, 0.6533864541832669, 0.6573705179282868]\n",
      "1 0.6629482071713146 [0.6733067729083665, 0.6454183266932271, 0.6334661354581673, 0.6852589641434262, 0.6772908366533864]\n",
      "2 0.6709163346613546 [0.6573705179282868, 0.6653386454183267, 0.6733067729083665, 0.701195219123506, 0.6573705179282868]\n",
      "0.6656042496679947 0.003756211321044126\n"
     ]
    }
   ],
   "source": [
    "L_mean_rsw = []\n",
    "\n",
    "for k in range(3):\n",
    "    L_acc = get_acc_knn_cv_rho(dataset, \"rsw\", k, rhos, y_likability, idx_train_likability, idx_test_likability)\n",
    "    print(k, np.mean(L_acc), L_acc)\n",
    "    L_mean_rsw.append(np.mean(L_acc))\n",
    "    \n",
    "print(np.mean(L_mean_rsw, axis=-1), np.std(L_mean_rsw, axis=-1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9af7a274",
   "metadata": {},
   "source": [
    "#### SUOT + Cross val on $\\rho$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e8df210e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.4932270916334661 [0.4581673306772908, 0.5059760956175299, 0.47808764940239046, 0.5219123505976095, 0.50199203187251]\n",
      "1 0.5075697211155379 [0.47410358565737054, 0.5258964143426295, 0.4860557768924303, 0.5219123505976095, 0.5298804780876494]\n",
      "2 0.4892430278884462 [0.46613545816733065, 0.49800796812749004, 0.450199203187251, 0.50199203187251, 0.5298804780876494]\n",
      "0.49667994687915007 0.007870136727771102\n"
     ]
    }
   ],
   "source": [
    "rhos = [0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0]\n",
    "rhos = [0.0005, 0.001, 0.005, 0.01, 0.1, 1.0]\n",
    "\n",
    "L_mean_suw = []\n",
    "\n",
    "for k in range(3):\n",
    "    L_acc = get_acc_knn_cv_rho(dataset, \"suw\", k, rhos, y_genre, idx_train_genre, idx_test_genre)\n",
    "    print(k, np.mean(L_acc), L_acc)\n",
    "    L_mean_suw.append(np.mean(L_acc))\n",
    "    \n",
    "print(np.mean(L_mean_suw, axis=-1), np.std(L_mean_suw, axis=-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "51702857",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.6661354581673307 [0.6653386454183267, 0.6653386454183267, 0.6733067729083665, 0.6653386454183267, 0.6613545816733067]\n",
      "1 0.6637450199203186 [0.6693227091633466, 0.6573705179282868, 0.6653386454183267, 0.6693227091633466, 0.6573705179282868]\n",
      "2 0.6709163346613545 [0.6613545816733067, 0.6693227091633466, 0.6892430278884463, 0.6693227091633466, 0.6653386454183267]\n",
      "0.6669322709163344 0.002981400308186389\n"
     ]
    }
   ],
   "source": [
    "L_mean_suw = []\n",
    "\n",
    "for k in range(3):\n",
    "    L_acc = get_acc_knn_cv_rho(dataset, \"suw\", k, rhos, y_likability, idx_train_likability, idx_test_likability)\n",
    "    print(k, np.mean(L_acc), L_acc)\n",
    "    L_mean_suw.append(np.mean(L_acc))\n",
    "    \n",
    "print(np.mean(L_mean_suw, axis=-1), np.std(L_mean_suw, axis=-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ded87d5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
