{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "legend_dict={}\n",
    "legend_dict['original_RGAS']='Original (RGAS)'\n",
    "legend_dict['original_RAGAS']='Original (RAGAS)'\n",
    "legend_dict['StiefelSGD_ours']='Stiefel SGD (Ours)'\n",
    "legend_dict['StiefelAdam_ours']='Stiefel Adam (Ours)'\n",
    "legend_dict['MomentumlessStiefelSGD']='Momentumless Stiefel SGD'\n",
    "legend_dict['ProjectedStiefelSGD']='Projected Stiefel SGD'\n",
    "legend_dict['ProjectedStiefelAdam']='Projected Stiefel Adam'\n",
    "\n",
    "method_list = ['original_RGAS', 'original_RAGAS', 'StiefelSGD_ours', 'MomentumlessStiefelSGD', 'ProjectedStiefelSGD']\n",
    "from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes\n",
    "from mpl_toolkits.axes_grid1.inset_locator import mark_inset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# MNIST\n",
    "fig, ax = plt.subplots(figsize=(7,2.5))\n",
    "axins = zoomed_inset_axes(ax, 4, loc=7)\n",
    "color_cycle = ax._get_lines.prop_cycler\n",
    "\n",
    "with open('OT_val_memory_dict_mnist.pkl', 'rb') as handle:\n",
    "    OT_val_memory_dict_mnist = pickle.load(handle)\n",
    "for method in method_list:\n",
    "    color=next(color_cycle)['color']\n",
    "    if method not in OT_val_memory_dict_mnist.keys():\n",
    "        continue\n",
    "    OT_list=[]\n",
    "    for i in range(10):\n",
    "        for j in range(i + 1, 10):\n",
    "            OT_list+=[OT_val_memory_dict_mnist[method]['({}, {})'.format(i,j)]]\n",
    "            \n",
    "    OT_tensor=np.stack(OT_list, axis=1)\n",
    "    mean_OT=np.mean(OT_tensor, axis=1)\n",
    "    ax.plot(mean_OT/1000, label=legend_dict[method], color=color)\n",
    "    axins.plot(mean_OT/1000, label=legend_dict[method], color=color)\n",
    "ax.set_xlabel('iter', fontsize=13)\n",
    "ax.set_ylabel('PRW distance mean', fontsize=13)\n",
    "ax.set_ylim(0.0, 0.92)\n",
    "axins.set_xlim(44, 49)\n",
    "axins.set_xticks([44,45,46, 47, 48])\n",
    "axins.set_ylim(0.80,0.9)\n",
    "axins.set_yticks([0.8, 0.825,0.85,0.875, 0.9])\n",
    "# axins.set_yticks([0.05, 0.075,0.10,0.125, 0.150, 0.175])\n",
    "\n",
    "# plt.yscale('log')\n",
    "ax.tick_params(axis='x', labelsize=13)\n",
    "ax.tick_params(axis='y', labelsize=13)\n",
    "axins.tick_params(axis='x', labelsize=13)\n",
    "axins.tick_params(axis='y', labelsize=13)\n",
    "ax.set_title('PRW distance between MNIST digits', fontsize=15)\n",
    "# plt.legend()\n",
    "mark_inset(ax, axins, loc1=1, loc2=2, fc=\"none\", ec=\"0.5\")\n",
    "#  plt.text(0.5, 0.05,'projection robust Wasserstein \\ndistance between MNIST digits',\n",
    "#      horizontalalignment='left',\n",
    "#      verticalalignment='bottom',\n",
    "#      transform = ax.transAxes, \n",
    "#      size=13)\n",
    "plt.savefig('./PRW_mnist.pdf',  bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Shakespeare\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7,2.5))\n",
    "axins = zoomed_inset_axes(ax, 8, loc=7)\n",
    "color_cycle = ax._get_lines.prop_cycler\n",
    "\n",
    "scripts = ['H5.txt', 'Ham.txt', 'JC.txt', 'MV.txt', 'Oth.txt', 'Rom.txt']\n",
    "\n",
    "\n",
    "with open('OT_val_memory_dict_shakespeare.pkl', 'rb') as handle:\n",
    "    OT_val_memory_dict_shakespeare = pickle.load(handle)\n",
    "for method in method_list:\n",
    "    color=next(color_cycle)['color']\n",
    "    if method not in OT_val_memory_dict_shakespeare.keys():\n",
    "        continue\n",
    "    OT_list=[]\n",
    "    for art1 in scripts:\n",
    "        for art2 in scripts:\n",
    "            i = scripts.index(art1)\n",
    "            j = scripts.index(art2)\n",
    "            if i < j:\n",
    "                OT_list+=[OT_val_memory_dict_shakespeare[method]['({}, {})'.format(i,j)]]\n",
    "            \n",
    "    OT_tensor=np.stack(OT_list, axis=1)\n",
    "    mean_OT=np.mean(OT_tensor, axis=1)\n",
    "    ax.plot(mean_OT, label=legend_dict[method], color=color)\n",
    "    axins.plot(mean_OT, label=legend_dict[method], color=color)\n",
    "ax.set_xlabel('iter', fontsize=13)#, loc='right')\n",
    "# ax.set_ylabel('optimal transport value mean', fontsize=13)\n",
    "ax.set_ylabel('PRW distance mean', fontsize=13)\n",
    "ax.set_ylim(0.04, 0.195)\n",
    "\n",
    "axins.set_xlim(46, 49)\n",
    "axins.set_xticks([46, 47, 48])\n",
    "axins.set_yticks([0.184, 0.186,0.188,0.190])\n",
    "# ax.set_yticks([0.05, 0.075,0.10,0.125, 0.150, 0.175])\n",
    "\n",
    "axins.set_ylim(0.182,0.191)\n",
    "# plt.yscale('log')\n",
    "ax.set_title('PRW distance between Shakespeare plays', fontsize=15)\n",
    "# plt.legend()\n",
    "ax.tick_params(axis='x', labelsize=13)\n",
    "ax.tick_params(axis='y', labelsize=13)\n",
    "axins.tick_params(axis='x', labelsize=13)\n",
    "axins.tick_params(axis='y', labelsize=13)\n",
    "mark_inset(ax, axins, loc1=1, loc2=2, fc=\"none\", ec=\"0.5\")\n",
    "# plt.text(0.35, 0.02,'projection robust Wasserstein \\ndistance between Shakespeare plays',\n",
    "#      horizontalalignment='left',\n",
    "#      verticalalignment='bottom',\n",
    "#      transform = ax.transAxes, \n",
    "#      size=13)\n",
    "plt.savefig('./PRW_shakespeare.pdf',  bbox_inches='tight')# \n",
    "plt.show()\n",
    "\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pylab\n",
    "figlegend = pylab.figure(figsize=(3,2))\n",
    "figlegend.legend(ax.get_legend_handles_labels()[0], ax.get_legend_handles_labels()[1])\n",
    "\n",
    "figlegend.savefig('PRW_legend.pdf', bbox_inches='tight')\n"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "26de051ba29f2982a8de78e945f0abaf191376122a1563185a90213a26c5da77"
  },
  "kernelspec": {
   "display_name": "Python 3.10.4 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
