{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "327b1057",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import ot \n",
    "import scipy as sp\n",
    "from risgw import risgw_gpu\n",
    "from sgw_numpy import sgw_cpu\n",
    "import torch\n",
    "from SEINT.SEINT_numpy import SEINT\n",
    "\n",
    "      \n",
    "\n",
    "def RISGW(X, Y):\n",
    "    n, p = np.shape(X)\n",
    "    nproj1 =  round(np.log10(n)*10)\n",
    "    device='cpu'\n",
    "    dist = risgw_gpu(torch.from_numpy(X).to(torch.float32).to(device), torch.from_numpy(Y).to(torch.float32).to(device), device='cpu', nproj=nproj1)\n",
    "    return dist\n",
    "\n",
    "def GW(X, Y):\n",
    "    C3 = sp.spatial.distance.cdist(X, X)\n",
    "    C4 = sp.spatial.distance.cdist(Y, Y)\n",
    "    C3 /= C3.max()\n",
    "    C4 /= C4.max()\n",
    "    p = ot.unif(len(X))\n",
    "    q = ot.unif(len(Y))\n",
    "    dist = ot.gromov.gromov_wasserstein(\n",
    "        C3, C4, p, q, \"square_loss\", verbose=False, log=True)[1][\"gw_dist\"]\n",
    "    return dist\n",
    "\n",
    "def SEINT_Func(X, Y):\n",
    "    n, p = np.shape(X)\n",
    "    nproj1 =  round(np.log10(n)*10)\n",
    "    dist = SEINT(X,Y,  maxed=True,set_seed = True, acc = True, rep = nproj1)\n",
    "    return dist\n",
    "\n",
    "def ISEINT(X, Y):\n",
    "    n, p = np.shape(X)\n",
    "    nproj1 =  round(np.log10(n)*10)\n",
    "    dist = SEINT(X,Y, maxed=False, set_seed = True, acc = True, rep = nproj1)\n",
    "    return dist\n",
    "def sinkhorn(X, Y, reg=1e-5, p=2):\n",
    "    a = ot.unif(len(X))         \n",
    "    b = ot.unif(len(Y))\n",
    "    M = ot.dist(X, Y, metric='euclidean') ** p  \n",
    "    return ot.sinkhorn2(a, b, M, reg) ** (1.0 / p)\n",
    "\n",
    "\n",
    "def emd(X, Y, p=2):\n",
    "    a = ot.unif(len(X))\n",
    "    b = ot.unif(len(Y))\n",
    "    M = ot.dist(X, Y, metric='euclidean') ** p\n",
    "    T = ot.emd(a, b, M)          \n",
    "    dist = np.sum(T * M) ** (1.0 / p) \n",
    "    return dist\n",
    "\n",
    "\n",
    "def SGW(X,Y):\n",
    "    n, p = np.shape(X)\n",
    "    nproj1 =  round(np.log10(n)*10)\n",
    "    dist  = sgw_cpu(X,Y,nproj = nproj1)\n",
    "    return dist \n",
    "\n",
    "\n",
    "METHODS = {\n",
    "    \"GW\":     GW,\n",
    "    \"RISGW\":      RISGW,\n",
    "    \"SGW\":        SGW,\n",
    "    \"SINKHORN\":        sinkhorn,\n",
    "    \"EMD\": emd,\n",
    "    \"SEINT\":         SEINT_Func,\n",
    "    \"ISEINT\":  ISEINT,\n",
    "}\n",
    "\n",
    "\n",
    "def make_cov_matrices(theta: float, d: int):\n",
    "    \"\"\"\n",
    "    Σ_X  = diag(3 I₂ , I_{d-2})\n",
    "    Σ_Y  = diag(3 I₂ + 3 θ B₂ , I_{d-2}),  B₂ = [[0,1],[1,0]]\n",
    "    \"\"\"\n",
    "\n",
    "    Sigma_X = np.eye(d)\n",
    "    Sigma_X[:2, :2] = 3 * np.eye(2)\n",
    "\n",
    "    B2 = np.array([[0., 1.],\n",
    "                   [1., 0.]])\n",
    "    Sigma_block = 3 * np.eye(2) + 3 * theta * B2\n",
    "    Sigma_Y = np.eye(d)\n",
    "    Sigma_Y[:2, :2] = Sigma_block\n",
    "    return Sigma_X, Sigma_Y\n",
    "\n",
    "def sample_gaussian(cov: np.ndarray, n: int,seed):\n",
    "    np.random.seed(seed)\n",
    "    return np.random.multivariate_normal(mean=np.zeros(cov.shape[0]), cov=cov, size=n)\n",
    "\n",
    "\n",
    "def experiment_n_time(n_list, theta=0.1, d=4, n_trials=3):\n",
    "    \"\"\"\n",
    "    Return:\n",
    "        times_mean[method], times_std[method]  shape=(len(n_list),)\n",
    "    \"\"\"\n",
    "    times_mean = {m: [] for m in METHODS}\n",
    "    times_std  = {m: [] for m in METHODS}\n",
    "\n",
    "    Σ_X_full, Σ_Y_full = make_cov_matrices(theta, d)  \n",
    "\n",
    "    for n in n_list:\n",
    "        trial_times = {m: [] for m in METHODS}\n",
    "        # print(n)\n",
    "\n",
    "        for _ in range(n_trials):\n",
    "            # print(_)\n",
    "            X = sample_gaussian(Σ_X_full, n,_*10)\n",
    "            Y = sample_gaussian(Σ_Y_full, n,_*10)\n",
    "\n",
    "            for name, func in METHODS.items():\n",
    "                t0 = time.perf_counter()\n",
    "                _  = func(X, Y)\n",
    "                t1 = time.perf_counter()\n",
    "                trial_times[name].append(t1 - t0)\n",
    "\n",
    "        for name in METHODS:\n",
    "            trial_times_arr = np.array(trial_times[name])\n",
    "            times_mean[name].append(trial_times_arr.mean())\n",
    "            times_std[name].append(trial_times_arr.std())\n",
    "\n",
    "    return times_mean, times_std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db69994e",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_vals = (10 ** np.arange(1.0, 5.1, 0.5)).astype(int)\n",
    "time_mean, time_std = experiment_n_time(n_vals, theta=0.5)\n",
    "\n",
    "plt.figure(figsize=(6, 4))\n",
    "log_n = np.log10(n_vals)\n",
    "for name in METHODS:\n",
    "    m  = np.array(time_mean[name])\n",
    "    sd = np.array(time_std[name])\n",
    "    plt.plot(log_n, np.log10(m), marker='o', label=name)          \n",
    "    plt.fill_between(log_n, np.log10(m - sd), np.log10(m + sd), alpha=0.2)\n",
    "\n",
    "plt.xlabel(r'$\\log_{10}(n)$')\n",
    "plt.ylabel(r'$\\log_{10}(\\text{time / sec})$')\n",
    "plt.title('CPU Time versus $n$ (d=4, $\\\\theta=0$)')\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "HW",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
