{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f31bb1ac-5eaa-4759-a3aa-51aaae6227b8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import ot\n",
    "\n",
    "\n",
    "os.chdir(\".\")\n",
    "\n",
    "from lib.opt import *\n",
    "from lib.gromov import *   \n",
    "\n",
    "import numpy as np \n",
    "import numba as nb\n",
    "import warnings\n",
    "import time\n",
    "from ot.backend import get_backend, NumpyBackend\n",
    "from ot.lp import emd\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4d977d9a-58cd-42ba-b9b9-0abc23487470",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "@nb.njit(cache=True)\n",
    "def tensor_dot_param(C1,C2,Lambda=0,loss='square_loss'):\n",
    "    if loss=='square_loss':\n",
    "        def f1(r1):\n",
    "            return r1**2-2*Lambda\n",
    "        def f2(r2):\n",
    "            return r2**2\n",
    "        def h1(r1):\n",
    "            return r1\n",
    "        def h2(r2):\n",
    "            return 2*r2\n",
    "    # else:\n",
    "    #     warnings.warn(\"loss function error\")\n",
    "\n",
    "    fC1=f1(C1)\n",
    "    fC2=f2(C2)\n",
    "    hC1=h1(C1)\n",
    "    hC2=h2(C2)\n",
    "    \n",
    "    return fC1,fC2,hC1,hC2\n",
    "\n",
    "#@nb.njit(cache=True)\n",
    "def tensor_dot_func(fC1,fC2,hC1,hC2,Gamma):\n",
    "    #Gamma=np.ascontiguousarray(Gamma)\n",
    "    n,m=Gamma.shape\n",
    "    #Gamma_1=\n",
    "    #Gamma_2=)\n",
    "    C1=fC1.dot(Gamma.sum(1).reshape((-1,1))) #.dot(np.ones((1,m)))\n",
    "    C2=Gamma.sum(0).dot(fC2.T)\n",
    "    tensor_dot=(C1+C2)-hC1.dot(Gamma).dot(hC2.T) \n",
    "    return tensor_dot\n",
    "\n",
    "#@nb.njit(cache=True)\n",
    "def gwgrad_partial1(C1, C2, T,loss='square'):\n",
    "    \"\"\"Compute the GW gradient. Note: we can not use the trick in :ref:`[12] <references-gwgrad-partial>`\n",
    "    as the marginals may not sum to 1.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    C1: array of shape (n_p,n_p)\n",
    "        intra-source (P) cost matrix\n",
    "\n",
    "    C2: array of shape (n_u,n_u)\n",
    "        intra-target (U) cost matrix\n",
    "\n",
    "    T : array of shape(n_p+nb_dummies, n_u) (default: None)\n",
    "        Transport matrix\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    numpy.array of shape (n_p+nb_dummies, n_u)\n",
    "        gradient\n",
    "\n",
    "\n",
    "    .. _references-gwgrad-partial:\n",
    "    References\n",
    "    ----------\n",
    "    .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,\n",
    "        \"Gromov-Wasserstein averaging of kernel and distance matrices.\"\n",
    "        International Conference on Machine Learning (ICML). 2016.\n",
    "    \"\"\"\n",
    "    #T=np.ascontiguousarray(T)\n",
    "    if loss=='square':\n",
    "        cC1 = np.dot(C1 ** 2 , np.dot(T, np.ones(C2.shape[0]).reshape(-1, 1)))\n",
    "        cC2 = np.dot(np.dot(np.ones(C1.shape[0]).reshape(1, -1), T), C2 ** 2 )\n",
    "        constC = cC1 + cC2\n",
    "        A = -2*np.dot(C1, T).dot(C2.T)\n",
    "        tens = constC + A\n",
    "    elif loss=='dot':\n",
    "        constC=0\n",
    "        A = -2*np.dot(C1, T).dot(C2.T)\n",
    "        tens = constC + A\n",
    "    return tens \n",
    "\n",
    "@nb.njit(cache=True,fastmath=True)\n",
    "def tensor_dot_orig(C1,C2,gamma,Lambda=0.0):\n",
    "    n,m=C1.shape[0],C2.shape[0]\n",
    "    tens=np.zeros((n,m))\n",
    "    for i in range(n):\n",
    "        for j in range(m):\n",
    "            for i1 in range(n):\n",
    "                for j1 in range(m):\n",
    "                    M_iji1ji=(C1[i,i1]-C2[j,j1])**2-2*Lambda\n",
    "                    tens[i,j]+=M_iji1ji*gamma[i1,j1]\n",
    "    return tens\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9d8538a1-87ce-472f-8ab5-f1bb0c945d8d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 33.3 s, sys: 19.5 s, total: 52.8 s\n",
      "Wall time: 848 ms\n",
      "CPU times: user 30.7 s, sys: 7.29 s, total: 38 s\n",
      "Wall time: 593 ms\n",
      "0.0017386495550175327\n"
     ]
    }
   ],
   "source": [
    "n=5000\n",
    "X1=np.random.rand(n,2)\n",
    "X2=np.random.rand(n,2)\n",
    "C1=cost_matrix_d(X1,X1)\n",
    "C2=cost_matrix_d(X2,X2)\n",
    "gamma=np.random.rand(n,n)\n",
    "\n",
    "deltaG=np.random.rand(n,n)\n",
    "\n",
    "Lambda=10.0\n",
    "fC1,fC2,hC1,hC2=tensor_dot_param(C1,C2,Lambda=Lambda)\n",
    "\n",
    "\n",
    "%time M_circ_G1=gwgrad_partial1(C1, C2, gamma)-2*Lambda*gamma.sum()\n",
    "%time M_circ_G2=tensor_dot_func(fC1,fC2,hC1,hC2,gamma)\n",
    "#%time M_circ_G3=tensor_dot_orig(C1,C2,gamma,Lambda=Lambda)\n",
    "print(np.linalg.norm(M_circ_G1-M_circ_G2))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "cf52603e-efc5-4c94-b4f1-66d121aa6695",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n is 5\n",
      "time v2 is 0.003100872039794922\n",
      "time prim is 0.003022432327270508\n",
      "time v1 is 0.0019388198852539062\n",
      "n is 10n\n",
      "time v2 is 0.0011029243469238281\n",
      "time prim is 0.0015728473663330078\n",
      "time v1 is 0.0006184577941894531\n",
      "n is 100\n",
      "time v2 is 0.01756596565246582\n",
      "time prim is 0.014112472534179688\n",
      "time v1 is 0.01593160629272461\n",
      "n is 500\n",
      "time v2 is 0.8414745330810547\n",
      "time prim is 0.43082427978515625\n",
      "time v1 is 0.2852790355682373\n",
      "n is 1000\n",
      "time v2 is 1.8139047622680664\n",
      "time prim is 1.8896005153656006\n",
      "time v1 is 1.5916125774383545\n",
      "sinkhorn\r"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[14], line 71\u001b[0m\n\u001b[1;32m     69\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msinkhorn\u001b[39m\u001b[38;5;124m'\u001b[39m,end\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\r\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m     70\u001b[0m time1\u001b[38;5;241m=\u001b[39mtime\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m---> 71\u001b[0m ot\u001b[38;5;241m.\u001b[39mpartial\u001b[38;5;241m.\u001b[39mentropic_partial_gromov_wasserstein(C1, C2, p, q, reg\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.1\u001b[39m, m\u001b[38;5;241m=\u001b[39mmass, G0\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, numItermax\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1000\u001b[39m, tol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-05\u001b[39m, log\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m     72\u001b[0m time2\u001b[38;5;241m=\u001b[39mtime\u001b[38;5;241m.\u001b[39mtime()            \n\u001b[1;32m     73\u001b[0m pgw_sinkhorn[n_idx,Lambda_idx,repeat]\u001b[38;5;241m=\u001b[39mtime2\u001b[38;5;241m-\u001b[39mtime1\n",
      "File \u001b[0;32m~/miniconda3/envs/opt/lib/python3.11/site-packages/ot/partial.py:1055\u001b[0m, in \u001b[0;36mentropic_partial_gromov_wasserstein\u001b[0;34m(C1, C2, p, q, reg, m, G0, numItermax, tol, log, verbose)\u001b[0m\n\u001b[1;32m   1053\u001b[0m Gprev \u001b[38;5;241m=\u001b[39m G0\n\u001b[1;32m   1054\u001b[0m M_entr \u001b[38;5;241m=\u001b[39m gwgrad_partial(C1, C2, G0)\n\u001b[0;32m-> 1055\u001b[0m G0 \u001b[38;5;241m=\u001b[39m entropic_partial_wasserstein(p, q, M_entr, reg, m)\n\u001b[1;32m   1056\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cpt \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m10\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:  \u001b[38;5;66;03m# to speed up the computations\u001b[39;00m\n\u001b[1;32m   1057\u001b[0m     err \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mlinalg\u001b[38;5;241m.\u001b[39mnorm(G0 \u001b[38;5;241m-\u001b[39m Gprev)\n",
      "File \u001b[0;32m~/miniconda3/envs/opt/lib/python3.11/site-packages/ot/partial.py:902\u001b[0m, in \u001b[0;36mentropic_partial_wasserstein\u001b[0;34m(a, b, M, reg, m, numItermax, stopThr, verbose, log)\u001b[0m\n\u001b[1;32m    900\u001b[0m K2prev \u001b[38;5;241m=\u001b[39m K2\n\u001b[1;32m    901\u001b[0m K2 \u001b[38;5;241m=\u001b[39m K2 \u001b[38;5;241m*\u001b[39m q3\n\u001b[0;32m--> 902\u001b[0m K \u001b[38;5;241m=\u001b[39m K2 \u001b[38;5;241m*\u001b[39m (m \u001b[38;5;241m/\u001b[39m nx\u001b[38;5;241m.\u001b[39msum(K2))\n\u001b[1;32m    903\u001b[0m q3 \u001b[38;5;241m=\u001b[39m q3 \u001b[38;5;241m*\u001b[39m K2prev \u001b[38;5;241m/\u001b[39m K\n\u001b[1;32m    905\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m nx\u001b[38;5;241m.\u001b[39many(nx\u001b[38;5;241m.\u001b[39misnan(K)) \u001b[38;5;129;01mor\u001b[39;00m nx\u001b[38;5;241m.\u001b[39many(nx\u001b[38;5;241m.\u001b[39misinf(K)):\n",
      "File \u001b[0;32m~/miniconda3/envs/opt/lib/python3.11/site-packages/ot/backend.py:1012\u001b[0m, in \u001b[0;36mNumpyBackend.sum\u001b[0;34m(self, a, axis, keepdims)\u001b[0m\n\u001b[1;32m   1009\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1010\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39meye(N, M, dtype\u001b[38;5;241m=\u001b[39mtype_as\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[0;32m-> 1012\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msum\u001b[39m(\u001b[38;5;28mself\u001b[39m, a, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, keepdims\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m   1013\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39msum(a, axis, keepdims\u001b[38;5;241m=\u001b[39mkeepdims)\n\u001b[1;32m   1015\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcumsum\u001b[39m(\u001b[38;5;28mself\u001b[39m, a, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# run time \n",
    "import time\n",
    "start=0\n",
    "n_list=np.array([5,10,100,500,1000,2000,5000,10000])\n",
    "\n",
    "n_list=np.array(n_list,dtype=np.int64)\n",
    "repeat_num=5\n",
    "Lambda_list=np.array([0.2,1.0,10.0])\n",
    "pgw_v1=np.zeros((len(n_list),len(Lambda_list),repeat_num))\n",
    "pgw_v2=np.zeros((len(n_list),len(Lambda_list),repeat_num))\n",
    "mpgw=np.zeros((len(n_list),len(Lambda_list),repeat_num))\n",
    "pgw_sinkhorn=np.zeros((len(n_list),len(Lambda_list),repeat_num))\n",
    "\n",
    "for (n_idx,n) in enumerate(n_list[start:]):\n",
    "    n_idx+=start\n",
    "    print('n is',n)\n",
    "    for (Lambda_idx,Lambda) in enumerate(Lambda_list):\n",
    "        for repeat in range(repeat_num):\n",
    "            print('repeat',repeat,end='\\r')\n",
    "            #print('repeat',repeat)\n",
    "            np.random.seed(1)\n",
    "            m=n+10\n",
    "            X=np.random.rand(n,1)*2\n",
    "            Y=np.random.rand(m,1)*2\n",
    "            C1=cost_matrix_d(X,X)\n",
    "            C2=cost_matrix_d(Y,Y)\n",
    "            \n",
    "            p=np.ones(n)\n",
    "            p=p/m\n",
    "            q=np.ones(m)\n",
    "            q=q/m\n",
    "            \n",
    "            print('v2',end='\\r')\n",
    "            time1=time.time()\n",
    "            Gamma2=partial_gromov_ver2(C1, C2, p, q, Lambda=Lambda, nb_dummies=1, G0=None,thres=1, numItermax=100*n, tol=1e-5,log=False, verbose=False)\n",
    "            time2=time.time()\n",
    "            pgw_v2[n_idx,Lambda_idx,repeat]=time2-time1\n",
    "            if repeat==0 and Lambda_idx ==0:\n",
    "                print('time v2 is', time2-time1)\n",
    "\n",
    "\n",
    "           \n",
    "            \n",
    "            mass=np.sum(Gamma2)\n",
    "            mass=np.min((mass,p.sum(),q.sum()))\n",
    "\n",
    "            print('mpgw',end='\\r')\n",
    "            time1=time.time()\n",
    "            Gamma3=partial_gromov_wasserstein(C1, C2, p, q, m=mass, nb_dummies=1, G0=None,thres=1, numItermax=100*n, tol=1e-5,log=False, verbose=False)\n",
    "            time2=time.time()\n",
    "            mpgw[n_idx,Lambda_idx,repeat]=time2-time1\n",
    "            if repeat==0 and Lambda_idx==0:\n",
    "                print('time prim is', time2-time1)\n",
    "                #print('iter_num is',iter_num)\n",
    "            \n",
    "            print('v1',end='\\r')\n",
    "            time1=time.time()\n",
    "            Gamma1=partial_gromov_ver1(C1, C2, p, q, Lambda=Lambda, nb_dummies=1, G0=None,thres=1, numItermax=100*n, tol=1e-5,log=False, verbose=False)\n",
    "            time2=time.time()\n",
    "            pgw_v1[n_idx,Lambda_idx,repeat]=time2-time1\n",
    "            if repeat==0 and Lambda_idx ==0:\n",
    "                print('time v1 is', time2-time1)\n",
    "                #print('iter number ',iter_num)\n",
    "\n",
    "\n",
    "\n",
    "            if n<2000:\n",
    "                print('sinkhorn',end='\\r')\n",
    "                time1=time.time()\n",
    "                ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, reg=0.1, m=mass, G0=None, numItermax=1000, tol=1e-05, log=False, verbose=False)\n",
    "                time2=time.time()            \n",
    "                pgw_sinkhorn[n_idx,Lambda_idx,repeat]=time2-time1\n",
    "            \n",
    "            np.savez(\"run_time/time_list.npz\", pgw_v1=pgw_v1, pgw_v2=pgw_v2, mpgw=mpgw, pgw_sinkhorn=pgw_sinkhorn)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ab1fd58-45c6-4a15-ac13-00aa3db0dc63",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# wall clock time test \n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "n_list=np.array([5,10,100,500,1000,2000,5000,10000])\n",
    "#n_list=[5e3,1e4]\n",
    "n_list=np.array(n_list,dtype=np.int64)\n",
    "repeat_num=3\n",
    "Lambda_list=np.array([0.2,1.0,10.0])\n",
    "\n",
    "dists = np.load(\"bone_star_dist_res.npz\")\n",
    "pgw_v1, pgw_v2, mpgw, pgw_sinkhorn = dists.values()\n",
    "    \n",
    "\n",
    "pgw_v1_mean=np.mean(pgw_v1,2)\n",
    "pgw_v2_mean=np.mean(pgw_v2,2)\n",
    "mpgw_mean=np.mean(mpgw,2) \n",
    "pgw_sinkhorn_mean=np.mean(pgw_sinkhorn,2) \n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(12, 6))\n",
    "\n",
    "# Plotting both error lists\n",
    "for (method,color,time_mean) in zip(['v1','v2','m','s'],['orange','green','blue','pink'],[pgw_v1_mean,pgw_v2_mean,mpgw_mean,pgw_sinkhorn_mean]):\n",
    "    for (Lambda_idx,Lambda,style) in zip(np.array([0,2]),Lambda_list[[0,2]],['-',':']):\n",
    "        if method !='s':\n",
    "            ax.semilogy(n_list[:], time_mean[:,Lambda_idx],color=color,linestyle=style,alpha=0.5,label=method+f',$\\lambda={Lambda}$')\n",
    "        else:\n",
    "            ax.semilogy(n_list[0:5], time_mean[0:5,Lambda_idx],color=color,alpha=0.5,linestyle=style,label=method+f',$\\lambda={Lambda}$')\n",
    "\n",
    "\n",
    "ax.set_xlabel('n: size of $p$',fontsize=20)\n",
    "ax.set_ylabel('wall-clock time',fontsize=20)\n",
    "#ax.set_title('Plot of Relative Error vs Size of $\\mu$ for Different $\\lambda$ Values')\n",
    "ax.legend(bbox_to_anchor=(.75, 0.65),fontsize=15, loc=\"upper left\")\n",
    "plt.savefig('run_time/time.png',dpi=200)\n",
    "plt.savefig('run_time/time.jpg',dpi=200)\n",
    "plt.savefig('rin_time/time.pdf',dpi=200)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad38878a-8f52-4625-8b7c-caffc0833c61",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
