{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b64a836b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "from funcs import *\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "dtype_prec = \"float128\"\n",
    "element_max = np.vectorize(max)\n",
    "name_dist1 = \"Gaussian\"\n",
    "name_dist2 = \"Uniform\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "33cea2d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameters for the Gaussian distribution\n",
    "std_dev = 1  # Standard deviation of the distribution\n",
    "dim = 10\n",
    "vmin = 0\n",
    "vmax = 1\n",
    "source_size = target_size = 1000\n",
    "stopThr = 1e-3\n",
    "reg = 1e-3\n",
    "gap_list = np.array(range(0, 6))*0.5\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0130f2b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Testing Sinkhorn/ROT/EMD:   0%|                                                                                                                                                                                  | 0/12 [05:51<?, ?it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 32\u001b[0m\n\u001b[1;32m     28\u001b[0m \u001b[38;5;66;03m# sk\u001b[39;00m\n\u001b[1;32m     29\u001b[0m \n\u001b[1;32m     30\u001b[0m \u001b[38;5;66;03m# sk_ROT\u001b[39;00m\n\u001b[1;32m     31\u001b[0m tic \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m---> 32\u001b[0m W2_value \u001b[38;5;241m=\u001b[39m \u001b[43mrw2_sk\u001b[49m\u001b[43m(\u001b[49m\u001b[43msource_supports\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_supports\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msource_masses\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_masses\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstopThr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     33\u001b[0m toc \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m     34\u001b[0m time_sk_ROT[repeat_idx, gap_idx] \u001b[38;5;241m=\u001b[39m (toc \u001b[38;5;241m-\u001b[39m tic)\n",
      "File \u001b[0;32m~/PycharmProjects/RW_project/Robust_Wasserstein_Metric/Exp1/funcs.py:23\u001b[0m, in \u001b[0;36mrw2_sk\u001b[0;34m(src_supp, tgt_supp, src_mass, tgt_mass, reg, stopThr, numItermax)\u001b[0m\n\u001b[1;32m     21\u001b[0m tgt_mean \u001b[38;5;241m=\u001b[39m tgt_supp\u001b[38;5;241m.\u001b[39mT \u001b[38;5;241m@\u001b[39m tgt_mass\n\u001b[1;32m     22\u001b[0m C \u001b[38;5;241m=\u001b[39m ot\u001b[38;5;241m.\u001b[39mdist(src_supp \u001b[38;5;241m-\u001b[39m src_mean, tgt_supp \u001b[38;5;241m-\u001b[39m tgt_mean, metric\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msqeuclidean\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 23\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[43mot\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msinkhorn2\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     24\u001b[0m \u001b[43m    \u001b[49m\u001b[43msrc_mass\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtgt_mass\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mC\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     25\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreg\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreg\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     26\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43msinkhorn_log\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m     27\u001b[0m \u001b[43m    \u001b[49m\u001b[43mstopThr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstopThr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     28\u001b[0m \u001b[43m    \u001b[49m\u001b[43mnumItermax\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnumItermax\u001b[49m\n\u001b[1;32m     29\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     30\u001b[0m value \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mlinalg\u001b[38;5;241m.\u001b[39mnorm(src_mean \u001b[38;5;241m-\u001b[39m tgt_mean) \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[1;32m     31\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value\n",
      "File \u001b[0;32m~/PycharmProjects/RW_project/lib/python3.10/site-packages/ot/bregman/_sinkhorn.py:329\u001b[0m, in \u001b[0;36msinkhorn2\u001b[0;34m(a, b, M, reg, method, numItermax, stopThr, verbose, log, warn, warmstart, **kwargs)\u001b[0m\n\u001b[1;32m    324\u001b[0m     res \u001b[38;5;241m=\u001b[39m sinkhorn_knopp(a, b, M, reg, numItermax\u001b[38;5;241m=\u001b[39mnumItermax,\n\u001b[1;32m    325\u001b[0m                          stopThr\u001b[38;5;241m=\u001b[39mstopThr, verbose\u001b[38;5;241m=\u001b[39mverbose,\n\u001b[1;32m    326\u001b[0m                          log\u001b[38;5;241m=\u001b[39mlog, warn\u001b[38;5;241m=\u001b[39mwarn, warmstart\u001b[38;5;241m=\u001b[39mwarmstart,\n\u001b[1;32m    327\u001b[0m                          \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    328\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m method\u001b[38;5;241m.\u001b[39mlower() \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msinkhorn_log\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m--> 329\u001b[0m     res \u001b[38;5;241m=\u001b[39m \u001b[43msinkhorn_log\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mM\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnumItermax\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnumItermax\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    330\u001b[0m \u001b[43m                       \u001b[49m\u001b[43mstopThr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstopThr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    331\u001b[0m \u001b[43m                       \u001b[49m\u001b[43mlog\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwarn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwarn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwarmstart\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwarmstart\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    332\u001b[0m \u001b[43m                       \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    333\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m method\u001b[38;5;241m.\u001b[39mlower() \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msinkhorn_stabilized\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m    334\u001b[0m     res \u001b[38;5;241m=\u001b[39m sinkhorn_stabilized(a, b, M, reg, numItermax\u001b[38;5;241m=\u001b[39mnumItermax,\n\u001b[1;32m    335\u001b[0m                               stopThr\u001b[38;5;241m=\u001b[39mstopThr, warmstart\u001b[38;5;241m=\u001b[39mwarmstart,\n\u001b[1;32m    336\u001b[0m                               verbose\u001b[38;5;241m=\u001b[39mverbose, log\u001b[38;5;241m=\u001b[39mlog, warn\u001b[38;5;241m=\u001b[39mwarn,\n\u001b[1;32m    337\u001b[0m                               \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[0;32m~/PycharmProjects/RW_project/lib/python3.10/site-packages/ot/bregman/_sinkhorn.py:726\u001b[0m, in \u001b[0;36msinkhorn_log\u001b[0;34m(a, b, M, reg, numItermax, stopThr, verbose, log, warn, warmstart, **kwargs)\u001b[0m\n\u001b[1;32m    723\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m ii \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(numItermax):\n\u001b[1;32m    725\u001b[0m     v \u001b[38;5;241m=\u001b[39m logb \u001b[38;5;241m-\u001b[39m nx\u001b[38;5;241m.\u001b[39mlogsumexp(Mr \u001b[38;5;241m+\u001b[39m u[:, \u001b[38;5;28;01mNone\u001b[39;00m], \u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m--> 726\u001b[0m     u \u001b[38;5;241m=\u001b[39m loga \u001b[38;5;241m-\u001b[39m \u001b[43mnx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlogsumexp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mMr\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    728\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m ii \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m10\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m    729\u001b[0m         \u001b[38;5;66;03m# we can speed up the process by checking for the error only all\u001b[39;00m\n\u001b[1;32m    730\u001b[0m         \u001b[38;5;66;03m# the 10th iterations\u001b[39;00m\n\u001b[1;32m    731\u001b[0m \n\u001b[1;32m    732\u001b[0m         \u001b[38;5;66;03m# compute right marginal tmp2= (diag(u)Kdiag(v))^T1\u001b[39;00m\n\u001b[1;32m    733\u001b[0m         tmp2 \u001b[38;5;241m=\u001b[39m nx\u001b[38;5;241m.\u001b[39msum(nx\u001b[38;5;241m.\u001b[39mexp(get_logT(u, v)), \u001b[38;5;241m0\u001b[39m)\n",
      "File \u001b[0;32m~/PycharmProjects/RW_project/lib/python3.10/site-packages/ot/backend.py:1235\u001b[0m, in \u001b[0;36mNumpyBackend.logsumexp\u001b[0;34m(self, a, axis)\u001b[0m\n\u001b[1;32m   1234\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlogsumexp\u001b[39m(\u001b[38;5;28mself\u001b[39m, a, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m-> 1235\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mspecial\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlogsumexp\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maxis\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/PycharmProjects/RW_project/lib/python3.10/site-packages/scipy/special/_logsumexp.py:111\u001b[0m, in \u001b[0;36mlogsumexp\u001b[0;34m(a, axis, b, keepdims, return_sign)\u001b[0m\n\u001b[1;32m    109\u001b[0m     tmp \u001b[38;5;241m=\u001b[39m b \u001b[38;5;241m*\u001b[39m np\u001b[38;5;241m.\u001b[39mexp(a \u001b[38;5;241m-\u001b[39m a_max)\n\u001b[1;32m    110\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 111\u001b[0m     tmp \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexp\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43ma_max\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    113\u001b[0m \u001b[38;5;66;03m# suppress warnings about log of zero\u001b[39;00m\n\u001b[1;32m    114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m np\u001b[38;5;241m.\u001b[39merrstate(divide\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mignore\u001b[39m\u001b[38;5;124m'\u001b[39m):\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "import numpy as np\n",
    "\n",
    "repeated_times = 2\n",
    "time_sk = np.zeros((repeated_times, len(gap_list)))\n",
    "time_sk_ROT = np.zeros((repeated_times, len(gap_list)))\n",
    "time_emd = np.zeros((repeated_times, len(gap_list)))\n",
    "\n",
    "w2_sk = np.zeros((repeated_times, len(gap_list)))\n",
    "w2_sk_ROT = np.zeros((repeated_times, len(gap_list)))\n",
    "w2_emd = np.zeros((repeated_times, len(gap_list)))\n",
    "\n",
    "total_iters = repeated_times * len(gap_list)\n",
    "\n",
    "with tqdm(total=total_iters, desc=\"Testing Sinkhorn/ROT/EMD\") as pbar:\n",
    "    for repeat_idx in range(repeated_times):\n",
    "        for gap_idx, gap in enumerate(gap_list):\n",
    "            # generate source and target distributions\n",
    "            source_supports = generate_dist(\"Gaussian\", dim, source_size)\n",
    "            vec = np.zeros((1, dim))\n",
    "            vec[-1] = gap\n",
    "            target_supports = generate_dist(\"Uniform\", dim, target_size)\n",
    "            target_supports = target_supports + vec  # translate\n",
    "            source_masses = np.ones(source_size, dtype=dtype_prec) * 1 / source_size\n",
    "            target_masses = np.ones(target_size, dtype=dtype_prec) * 1 / target_size\n",
    "\n",
    "            # sk\n",
    "\n",
    "            # sk_ROT\n",
    "            tic = time.time()\n",
    "            W2_value = rw2_sk(source_supports, target_supports, source_masses, target_masses, reg, stopThr)\n",
    "            toc = time.time()\n",
    "            time_sk_ROT[repeat_idx, gap_idx] = (toc - tic)\n",
    "            w2_sk_ROT[repeat_idx, gap_idx] = W2_value\n",
    "            \n",
    "            tic = time.time()\n",
    "            W2_value = sk(source_supports, target_supports, source_masses, target_masses, reg, stopThr)\n",
    "            toc = time.time()\n",
    "            time_sk[repeat_idx, gap_idx] = (toc - tic)\n",
    "            w2_sk[repeat_idx, gap_idx] = W2_value\n",
    "\n",
    "\n",
    "            # exact_emd\n",
    "            tic = time.time()\n",
    "            W2_value = exact_emd(source_supports, target_supports, source_masses, target_masses)\n",
    "            toc = time.time()\n",
    "            time_emd[repeat_idx, gap_idx] = (toc - tic)\n",
    "            w2_emd[repeat_idx, gap_idx] = W2_value\n",
    "\n",
    "            # update progress bar\n",
    "            pbar.update(1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "272aad3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pickle\n",
    "# file_name = f\"{name_dist1}_vs_{name_dist2}_size:{source_size}_dim:{dim}_lambda:{lamda}\"\n",
    "# # data = [\n",
    "# #     time_sk, \n",
    "# #     time_sk_ROT,\n",
    "# #     time_emd,\n",
    "# #     w2_sk, \n",
    "# #     w2_sk_ROT,\n",
    "# #     w2_emd,\n",
    "# # ]\n",
    "# # with open(file_name+\".pkl\", 'wb') as file:  # 'wb' denotes write binary mode\n",
    "# #     pickle.dump(data, file)\n",
    "# with open(file_name + \".pkl\", \"rb\") as file:\n",
    "#     data = pickle.load(file)\n",
    "# [\n",
    "#     time_sk, \n",
    "#     time_sk_ROT,\n",
    "#     time_emd,\n",
    "#     w2_sk, \n",
    "#     w2_sk_ROT,\n",
    "#     w2_emd,\n",
    "# ] = data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "850a58ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def compute_mean_std_by_col(arr):\n",
    "    \"\"\"\n",
    "    Compute mean and std along columns of a 2D array.\n",
    "    \n",
    "    Parameters:\n",
    "        arr (ndarray): 2D numpy array of shape (n_samples, n_features)\n",
    "    \n",
    "    Returns:\n",
    "        mean (ndarray): 1D array of column means\n",
    "        std (ndarray): 1D array of column stds\n",
    "    \"\"\"\n",
    "    mean = np.mean(arr, axis=0)\n",
    "    std = np.std(arr, axis=0)\n",
    "    return mean, std\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d2f3efd",
   "metadata": {},
   "outputs": [],
   "source": [
    "running_time_OT_mean, running_time_OT_std = compute_mean_std_by_col(time_sk)\n",
    "running_time_ROT_mean, running_time_ROT_std = compute_mean_std_by_col(time_sk_ROT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5178c68",
   "metadata": {},
   "outputs": [],
   "source": [
    "def element_wise_add(arr1, arr2):\n",
    "    print(arr1)\n",
    "    print(arr2)\n",
    "    return arr1 + arr2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85cd2495",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(8,8))\n",
    "plt.rcParams['font.size'] = 25\n",
    "fontsize = 25\n",
    "x = gap_list\n",
    "\n",
    "yerr0 = element_wise_add(running_time_OT_mean, running_time_OT_std)\n",
    "yerr1 = element_wise_add(running_time_OT_mean, -running_time_OT_std)\n",
    "ax.plot(x, running_time_OT_mean, color='r', label = \"Sinkhorn\", marker = \"^\")\n",
    "plt.fill_between(x, yerr0, yerr1, color='r', alpha=0.5)\n",
    "\n",
    "yerr0 = element_wise_add(running_time_ROT_mean, running_time_ROT_std)\n",
    "yerr1 = element_wise_add(running_time_ROT_mean, -running_time_ROT_std)\n",
    "ax.plot(x, running_time_ROT_mean, color='b', label = \"RW2 Sinkhorn\", marker = \"D\")\n",
    "plt.fill_between(x, yerr0, yerr1, color='b', alpha=0.5)\n",
    "plt.legend(loc = \"lower left\")\n",
    "\n",
    "ax.set_xlabel(\"Length of translation s\", fontsize=fontsize)\n",
    "\n",
    "ax.set_ylabel(\"Running time(s)\", fontsize=fontsize)\n",
    "\n",
    "plt.savefig(f\"{file_name}_running_time.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51e21eb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "error_sk = np.abs(w2_sk - w2_emd)\n",
    "error_sk_ROT = np.abs(w2_sk_ROT - w2_emd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "016a19cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "error_OT_mean, error_OT_std = compute_mean_std_by_col(error_sk)\n",
    "error_ROT_mean, error_ROT_std = compute_mean_std_by_col(error_sk_ROT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e4c29d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(8,8))\n",
    "plt.rcParams['font.size'] = 25\n",
    "fontsize = 25\n",
    "x = gap_list\n",
    "\n",
    "yerr0 = element_wise_add(error_OT_mean, error_OT_std)\n",
    "yerr1 = element_wise_add(error_OT_mean, -error_OT_std)\n",
    "ax.plot(x, error_OT_mean, color='r', label = \"Sinkhorn\", marker = \"^\")\n",
    "plt.fill_between(x, yerr0, yerr1, color='r', alpha=0.5)\n",
    "\n",
    "yerr0 = element_wise_add(error_ROT_mean, error_ROT_std)\n",
    "yerr1 = element_wise_add(error_ROT_mean, -error_ROT_std)\n",
    "ax.plot(x, error_ROT_mean, color='b', label = \"RW2 Sinkhorn\", marker = \"D\")\n",
    "plt.fill_between(x, yerr0, yerr1, color='b', alpha=0.5)\n",
    "plt.legend(loc = \"upper left\")\n",
    "\n",
    "ax.set_xlabel(\"Length of translation s\", fontsize=fontsize)\n",
    "\n",
    "ax.set_ylabel(\"Computation errors\", fontsize=fontsize)\n",
    "plt.savefig(f\"{file_name}_errors.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c53e33b9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "563d87a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import norm\n",
    "\n",
    "# Set up the range for x values\n",
    "x = np.linspace(-4, 4, 1000)\n",
    "\n",
    "# Calculate the probability density function (PDF)\n",
    "pdf = norm.pdf(x, 0, 1)  # mean=0, std=1\n",
    "\n",
    "# Create the plot\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(x, pdf, label='Standard Gaussian Distribution', color='blue', linewidth=5)\n",
    "\n",
    "\n",
    "\n",
    "# Show the plot\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba01806c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import norm\n",
    "\n",
    "# Set up the range for x values\n",
    "x = np.linspace(-4, 4, 1000)\n",
    "\n",
    "# Calculate the probability density function (PDF)\n",
    "pdf = norm.pdf(x, 0, 1)  # mean=0, std=1\n",
    "\n",
    "# Create the plot\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(x, pdf, label='Standard Gaussian Distribution', color='red', linewidth=5)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Show the plot\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7c3584f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a82a237-f733-48b6-8c87-f833193f2659",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
