{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 3512,
     "status": "ok",
     "timestamp": 1706238155589,
     "user": {
      "displayName": "",
      "userId": "08293807456834214233"
     },
     "user_tz": 300
    },
    "id": "J8Z4eOQybIJr"
   },
   "outputs": [],
   "source": [
    "#import scanpy as sc\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "from scipy.spatial import distance\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "from sklearn.metrics.cluster import adjusted_rand_score, adjusted_mutual_info_score\n",
    "import numpy as np\n",
    "# import anndata as ad\n",
    "import json\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 231,
     "status": "ok",
     "timestamp": 1706238156635,
     "user": {
      "displayName": "",
      "userId": "08293807456834214233"
     },
     "user_tz": 300
    },
    "id": "Rmu7J-PFbMMr"
   },
   "outputs": [],
   "source": [
    "device='cpu'\n",
    "dtype=torch.float64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 199,
     "status": "ok",
     "timestamp": 1706238320242,
     "user": {
      "displayName": "",
      "userId": "08293807456834214233"
     },
     "user_tz": 300
    },
    "id": "9BdgviyppIIM"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "\n",
    "def random_simplex_sample(N, device='cpu', dtype=torch.float64):\n",
    "    d = torch.exp(torch.randn(N, device=device, dtype=dtype))\n",
    "    return d / torch.sum(d)\n",
    "\n",
    "def compute_grad_A(C, Q, R, Lambda, gamma, semiRelaxedLeft, semiRelaxedRight, device, Wasserstein=True, FGW=False, A=None, B=None, alpha=0.5):\n",
    "\n",
    "    r = Lambda.shape[0]\n",
    "    one_r = torch.ones((r), device=device, dtype=dtype)\n",
    "    One_rr = torch.outer(one_r, one_r).to(device)\n",
    "\n",
    "    if Wasserstein:\n",
    "        gradQ = C @ R @ Lambda.T\n",
    "        gradR = C.T @ Q @ Lambda\n",
    "\n",
    "    elif A is not None and B is not None:\n",
    "        if not semiRelaxedLeft and not semiRelaxedRight:\n",
    "            # Balanced gradient (Q1_r = a AND R1_r = b)\n",
    "            gradQ = - 4 * (A@Q)@Lambda@(R.T@B@R)@Lambda.T\n",
    "            gradR = - 4 * (B@R@Lambda.T)@(Q.T@A@Q)@Lambda\n",
    "        elif semiRelaxedRight:\n",
    "            # Semi-relaxed right marginal gradient (Q1_r = a)\n",
    "            gradQ = - 4 * (A@Q)@Lambda@(R.T@B@R)@Lambda.T\n",
    "            gradR = 2*B**2 @ R @ One_rr - 4*(B@R@Lambda.T)@(Q.T@A@Q)@Lambda\n",
    "        elif semiRelaxedLeft:\n",
    "            # Semi-relaxed left marginal gradient (R1_r = b)\n",
    "            gradQ = 2*A**2 @ Q @ One_rr - 4 * (A@Q)@Lambda@(R.T@B@R)@Lambda.T\n",
    "            gradR = - 4 * (B@R@Lambda.T)@(Q.T@A@Q)@Lambda\n",
    "        else:\n",
    "            # Fully unbalanced with no marginal constraints\n",
    "            gradQ = 2*A**2 @ Q @ One_rr - 4 * (A@Q)@Lambda@(R.T@B@R)@Lambda.T\n",
    "            gradR = 2*B**2 @ R @ One_rr - 4 * (B@R@Lambda.T)@(Q.T@A@Q)@Lambda\n",
    "        # Readjust cost for FGW problem\n",
    "        if FGW:\n",
    "            gradQ = (1-alpha)*(C @ R @ Lambda.T) + alpha*gradQ\n",
    "            gradR = (1-alpha)*(C.T @ Q @ Lambda) + alpha*gradR\n",
    "    else:\n",
    "        print(\"---Input either Wasserstein=True or provide distance matrices A and B for GW problem---\")\n",
    "\n",
    "    normalizer = torch.max(torch.tensor([torch.max(torch.abs(gradQ)) , torch.max(torch.abs(gradR))]))\n",
    "    gamma_k = gamma / normalizer\n",
    "\n",
    "    return gradQ, gradR, gamma_k\n",
    "\n",
    "def compute_grad_B(C, Q, R, Lambda, gQ, gR, gamma, device, Wasserstein=True, FGW=False, A=None, B=None, alpha=0.5):\n",
    "    if Wasserstein:\n",
    "        gradLambda = Q.T @ C @ R\n",
    "    else:\n",
    "        gradLambda = -4 * Q.T @ A @ Q @ Lambda @ R.T @ B @ R\n",
    "        if FGW:\n",
    "            gradLambda = (1-alpha)*(Q.T @ C @ R) + alpha*gradLambda\n",
    "    gradT = torch.diag(1/gQ) @ gradLambda @ torch.diag(1/gR) # (mass-reweighted form)\n",
    "    gamma_T = gamma / torch.max(torch.abs(gradT))\n",
    "    return gradT, gamma_T\n",
    "\n",
    "def semi_project_Left(xi1, g1, g, N1, r, gamma_k, tau, max_iter = 100, delta = 1e-9):\n",
    "\n",
    "    u = torch.ones((N1), device=device, dtype=dtype)\n",
    "    v = torch.ones((r), device=device, dtype=dtype)\n",
    "    u_tild = u\n",
    "    v_tild = v\n",
    "    i = 0\n",
    "    while i == 0 or (i < max_iter and\n",
    "                     gamma_k**-1 * torch.max(torch.tensor([torch.max(torch.log(u/u_tild)),torch.max(torch.log(v/v_tild))])) > delta ):\n",
    "        u_tild = u\n",
    "        v_tild = v\n",
    "        u = (g1 / (xi1 @ v))**(tau/(tau + gamma_k**-1 ))\n",
    "        v = (g / (xi1.T @ u))\n",
    "        i+=1\n",
    "\n",
    "    return u, v\n",
    "\n",
    "def semi_project_Right(xi2, g2, g, N2, r, gamma_k, tau, max_iter = 100, delta = 1e-9):\n",
    "\n",
    "    u = torch.ones((N2), device=device, dtype=dtype)\n",
    "    v = torch.ones((r), device=device, dtype=dtype)\n",
    "    u_tild = u\n",
    "    v_tild = v\n",
    "    i = 0\n",
    "    while i == 0 or (i < max_iter and\n",
    "                     gamma_k**-1 * torch.max(torch.tensor([torch.max(torch.log(u/u_tild)),torch.max(torch.log(v/v_tild))])) > delta ):\n",
    "        u_tild = u\n",
    "        v_tild = v\n",
    "        u = (g2 / (xi2 @ v))**(tau/(tau + gamma_k**-1 ))\n",
    "        v = (g / (xi2.T @ u))\n",
    "        i+=1\n",
    "\n",
    "    return u, v\n",
    "\n",
    "def semi_project_Balanced(xi1, g1, g, N1, r, gamma_k, tau, max_iter = 100, delta = 1e-9):\n",
    "    # Lax-inner marginal\n",
    "    u = torch.ones((N1), device=device, dtype=dtype)\n",
    "    v = torch.ones((r), device=device, dtype=dtype)\n",
    "    u_tild = u\n",
    "    v_tild = v\n",
    "    i = 0\n",
    "    while i == 0 or (i < max_iter and\n",
    "                     gamma_k**-1 * torch.max(torch.tensor([torch.max(torch.log(u/u_tild)),torch.max(torch.log(v/v_tild))])) > delta ):\n",
    "        u_tild = u\n",
    "        v_tild = v\n",
    "        v = (g / (xi1.T @ u))**(tau/(tau + gamma_k**-1 ))\n",
    "        u = (g1 / (xi1 @ v))\n",
    "        i+=1\n",
    "\n",
    "    return u, v\n",
    "\n",
    "def project_Unbalanced(xi1, g1, g, N1, r, gamma_k, tau, max_iter = 100, delta = 1e-9):\n",
    "    # Lax-inner marginal\n",
    "    u = torch.ones((N1), device=device, dtype=dtype)\n",
    "    v = torch.ones((r), device=device, dtype=dtype)\n",
    "    u_tild = u\n",
    "    v_tild = v\n",
    "    i = 0\n",
    "    while i == 0 or (i < max_iter and\n",
    "                     gamma_k**-1 * torch.max(torch.tensor([torch.max(torch.log(u/u_tild)),torch.max(torch.log(v/v_tild))])) > delta ):\n",
    "        u_tild = u\n",
    "        v_tild = v\n",
    "        v = (g / (xi1.T @ u))**(tau/(tau + gamma_k**-1 ))\n",
    "        u = (g1 / (xi1 @ v))**(tau/(tau + gamma_k**-1 ))\n",
    "        i+=1\n",
    "\n",
    "    return u, v\n",
    "\n",
    "def Sinkhorn(xi, a, b, N1, r, gamma_k, tau, max_iter = 300, delta = 1e-9):\n",
    "    u = torch.ones((N1), device=device, dtype=dtype)\n",
    "    v = torch.ones((r), device=device, dtype=dtype)\n",
    "    u_tild = u\n",
    "    v_tild = v\n",
    "    i = 0\n",
    "    '''\n",
    "    while i == 0 or (i < max_iter and\n",
    "                     gamma_k**-1 * torch.max(torch.tensor([torch.max(torch.log(u/u_tild)),torch.max(torch.log(v/v_tild))])) > delta ):\n",
    "    '''\n",
    "    while i < max_iter:\n",
    "        u_tild = u\n",
    "        v_tild = v\n",
    "        u = (a / (xi @ v))\n",
    "        v = (b / (xi.T @ u))\n",
    "        i+=1\n",
    "    return u, v\n",
    "\n",
    "def Cost(f, g, Grad, epsilon, device='cpu'):\n",
    "    '''\n",
    "    A matrix which is using for the broadcasted log-domain log-sum-exp trick-based updates.\n",
    "    ------Parameters------\n",
    "    f: torch.tensor (N1)\n",
    "        First dual variable of semi-unbalanced Sinkhorn\n",
    "    g: torch.tensor (N2)\n",
    "        Second dual variable of semi-unbalanced Sinkhorn\n",
    "    Grad: torch.tensor (N1 x N2)\n",
    "        A collection of terms in our gradient for the update\n",
    "    epsilon: float\n",
    "        Entropic regularization for Sinkhorn\n",
    "    device: 'str'\n",
    "        Device tensors placed on\n",
    "    '''\n",
    "    return -( Grad - torch.outer(f, torch.ones(Grad.size(dim=1), device=device)) - torch.outer(torch.ones(Grad.size(dim=0), device=device), g) ) / epsilon\n",
    "\n",
    "def logSinkhorn(grad, a, b, N1, r, gamma_k, tau, device='cpu', max_iter = 100, delta = 1e-9, epsilon=1):\n",
    "\n",
    "    f_k = torch.zeros((N1), device=device)\n",
    "    g_k = torch.zeros((r), device=device)\n",
    "\n",
    "    log_a = torch.log(a)\n",
    "    log_b = torch.log(b)\n",
    "\n",
    "    i = 0\n",
    "    '''\n",
    "    while i == 0 or (i < max_iter and\n",
    "                     gamma_k**-1 * torch.max(torch.tensor([torch.max(torch.log(u/u_tild)),torch.max(torch.log(v/v_tild))])) > delta ):\n",
    "    '''\n",
    "    while i < max_iter:\n",
    "        f_k = f_k + epsilon*(log_a - torch.logsumexp(Cost(f_k, g_k, grad, epsilon, device=device), axis=1))\n",
    "        g_k = g_k + epsilon*(log_b - torch.logsumexp(Cost(f_k, g_k, grad, epsilon, device=device), axis=0))\n",
    "        i+=1\n",
    "\n",
    "    return torch.exp(Cost(f_k, g_k, grad, epsilon, device=device))\n",
    "\n",
    "def LOT_iteration(C, A=None, B=None, tau=10, gamma=5, r = 10, max_iter=100, device='cpu', dtype=torch.float64, \\\n",
    "                  semiRelaxedLeft=True, semiRelaxedRight=False, changeOfMass=False, Wasserstein=True,\n",
    "                 printCost=True, swap=True, initA=False, FGW=False, alpha=0.5):\n",
    "    '''\n",
    "    REWRITE DESCRIPTION BELOW!\n",
    "\n",
    "    ------Parameters------\n",
    "    C: torch.tensor (N1 x N2)\n",
    "        A matrix of pairwise feature distances between transcript vectors in slice 1 and slice 2 (interslice).\n",
    "    A: torch.tensor (N1 x N1)\n",
    "        A matrix of pairwise distances between points in slice 1.\n",
    "    B: torch.tensor (N2 x N2)\n",
    "        A matrix of pairwise distances between points in slice 2.\n",
    "    Pi_0: torch.tensor (N1 x N2)\n",
    "        An initialization for the alignment matrix Pi. Should respect marginals of semi-relaxed or balanced.\n",
    "    alpha: float\n",
    "        A balance parameter between the interslice feature term of the objective and the merged feature-spatial term.\n",
    "    beta: float\n",
    "        A balance parameter between the GW (quartet) term and the triplet term of the merged feature-spatial term.\n",
    "    gamma: float\n",
    "        A hyperparameter controlling the strength of the KL-divergence term in the objective.\n",
    "    epsilon: float\n",
    "        A hyperparameter controlling the strength of the entropic regularization in the Sinkhorn algorithm.\n",
    "    max_iter: int\n",
    "        The maximal number of iterations DeST-OT is run.\n",
    "    balanced: bool\n",
    "        Boolean for whether to default to a balanced OT or to use semi-relaxed routine. Default set to False.\n",
    "    device: str\n",
    "        Device that torch tensors are placed on. Using GPU/'cuda' is much faster.\n",
    "    dtype: torch.type\n",
    "        The default datatype that the alignment and other tensors are in. Ideally torch.float64 or torch.float32\n",
    "    override_EDM: bool\n",
    "        Whether to override the merged feature-spatial matrix with a standard Euclidean distance matrix.\n",
    "    override_FDM:\n",
    "        Whether to override the merged feature-spatial matrix with an intraslice feature distance matrix.\n",
    "    '''\n",
    "    if C is not None:\n",
    "        N1, N2 = C.size(dim=0), C.size(dim=1)\n",
    "    else:\n",
    "        N1, N2 = A.size(dim=0), B.size(dim=0)\n",
    "\n",
    "    k = 0\n",
    "    stationarity_gap = torch.inf\n",
    "\n",
    "    one_N1 = torch.ones((N1), device=device, dtype=dtype)\n",
    "    one_N2 = torch.ones((N2), device=device, dtype=dtype)\n",
    "    one_r = torch.ones((r), device=device, dtype=dtype)\n",
    "\n",
    "    g1 = one_N1 / N1\n",
    "\n",
    "    if not changeOfMass:\n",
    "        g2 = one_N2 / N2\n",
    "    else:\n",
    "        g2 = one_N2 / N1\n",
    "\n",
    "    g = (1/r)*one_r\n",
    "    lambd = torch.min(torch.tensor([torch.min(g1), torch.min(g2), torch.min(g)])) / 2\n",
    "\n",
    "    a1 = random_simplex_sample(N1, device=device, dtype=dtype)\n",
    "    a2 = (g1 - lambd*a1)/(1 - lambd)\n",
    "    b1 = random_simplex_sample(N2, device=device, dtype=dtype)\n",
    "    b2 = (g2 - lambd*b1)/(1 - lambd)\n",
    "    g_1 = random_simplex_sample(r, device=device, dtype=dtype)\n",
    "    g_2 = (g - lambd*g_1)/(1 - lambd)\n",
    "\n",
    "    #Q = lambd*torch.outer(a1, g_1).to(device) + (1 - lambd)*torch.outer(a2, g_2).to(device)\n",
    "    #R = lambd*torch.outer(b1, g_1).to(device) + (1 - lambd)*torch.outer(b2, g_2).to(device)\n",
    "\n",
    "    if initA:\n",
    "        Q = lambd*torch.outer(a1, g_1).to(device) + (1 - lambd)*torch.outer(a2, g_2).to(device)\n",
    "        R = lambd*torch.outer(b1, g_1).to(device) + (1 - lambd)*torch.outer(b2, g_2).to(device)\n",
    "\n",
    "        gR, gQ = R.T @ one_N2, Q.T @ one_N1\n",
    "\n",
    "        T = (1-lambd)*torch.diag(g) + lambd*torch.outer(gR, gQ).to(device)\n",
    "        Lambda = torch.linalg.inv(T) # torch.diag(1/gQ) @ T @ torch.diag(1/gR)\n",
    "\n",
    "    else:\n",
    "        C_random = torch.rand((N1,r), dtype=torch.float64)\n",
    "        xi_random = torch.exp( -C_random )\n",
    "        u, v = Sinkhorn(xi_random, g1, g, r, r, gamma, tau)\n",
    "        Q = torch.diag(u) @ xi_random @ torch.diag(v)\n",
    "\n",
    "        C_random = torch.rand((N2,r), dtype=torch.float64)\n",
    "        xi_random = torch.exp( -C_random )\n",
    "        u, v = Sinkhorn(xi_random, g2, g, r, r, gamma, tau)\n",
    "        R = torch.diag(u) @ xi_random @ torch.diag(v)\n",
    "\n",
    "        gR, gQ = R.T @ one_N2, Q.T @ one_N1\n",
    "\n",
    "        # Generate a random (full-rank) matrix\n",
    "        C_random = torch.rand((r,r), dtype=torch.float64)\n",
    "        # Generate a random Kernel\n",
    "        xi_random = torch.exp( -C_random )\n",
    "        # Generate a random coupling between gQ and gR\n",
    "        u, v = Sinkhorn(xi_random, gQ, gR, r, r, gamma, tau)\n",
    "        T = torch.diag(u) @ xi_random @ torch.diag(v)\n",
    "        # Use this to form the inner inverse coupling\n",
    "        Lambda = torch.linalg.inv(T) #torch.diag(1/gQ) @ T @ torch.diag(1/gR)\n",
    "\n",
    "    #log_g1 = torch.log(g1)\n",
    "    #log_g2 = torch.log(g2)\n",
    "    errs = []\n",
    "\n",
    "    grad = torch.inf\n",
    "    gamma_k = gamma\n",
    "\n",
    "    while k < max_iter:\n",
    "        print(f'Iteration: {k}')\n",
    "\n",
    "#         if k % 50 == 0:\n",
    "#             print(f'Iteration: {k}')\n",
    "\n",
    "        gradQ, gradR, gamma_k = compute_grad_A(C, Q, R, Lambda, gamma, semiRelaxedLeft, semiRelaxedRight, device, Wasserstein=Wasserstein, A=A, B=B, FGW=FGW, alpha=alpha)\n",
    "        if semiRelaxedLeft:\n",
    "            xi1 = Q * torch.exp( -gamma_k * gradQ )\n",
    "            xi2 = R * torch.exp( -gamma_k * gradR )\n",
    "\n",
    "            _u, _v = semi_project_Balanced(xi2, g2, gR, N2, r, gamma_k, tau)\n",
    "            R = torch.diag(_u) @ xi2 @ torch.diag(_v)\n",
    "\n",
    "            u, v = project_Unbalanced(xi1, g1, gQ, N1, r, gamma_k, tau)\n",
    "            Q = torch.diag(u) @ xi1 @ torch.diag(v)\n",
    "\n",
    "            gQ, gR = Q.T @ one_N1, R.T @ one_N2\n",
    "            gradT, gamma_T = compute_grad_B(C, Q, R, Lambda, gQ, gR, gamma, device, Wasserstein=Wasserstein, A=A, B=B, FGW=FGW, alpha=alpha)\n",
    "            xi3 = T*torch.exp(- gamma_T * gradT )\n",
    "\n",
    "            if not swap:\n",
    "                # Lambda = diag(gQ)^-1 T diag(gR)^-1 form\n",
    "                u, v = Sinkhorn(xi3, gQ, gR, r, r, gamma_k, tau)\n",
    "                v = gR / (xi3.T @ u) # Last tightening of marginal\n",
    "            else:\n",
    "                # Swap marginals; Lambda = T^-1 form\n",
    "                u, v = Sinkhorn(xi3, gR, gQ, r, r, gamma_k, tau)\n",
    "                v = gR / (xi3.T @ u)\n",
    "        elif semiRelaxedRight:\n",
    "            xi1 = Q * torch.exp(- gamma_k * gradQ )\n",
    "            xi2 = R * torch.exp(- gamma_k * gradR )\n",
    "\n",
    "            u, v = semi_project_Balanced(xi1, g1, gQ, N1, r, gamma_k, tau)\n",
    "            Q = torch.diag(u) @ xi1 @ torch.diag(v)\n",
    "\n",
    "            _u, _v = project_Unbalanced(xi2, g2, gR, N2, r, gamma_k, tau)\n",
    "            R = torch.diag(_u) @ xi2 @ torch.diag(_v)\n",
    "\n",
    "            gQ, gR = Q.T @ one_N1, R.T @ one_N2\n",
    "            gradT, gamma_T = compute_grad_B(C, Q, R, Lambda, gQ, gR, gamma, device, Wasserstein=Wasserstein, A=A, B=B, FGW=FGW, alpha=alpha)\n",
    "            xi3 = T*torch.exp(- gamma_T * gradT )\n",
    "\n",
    "            if not swap:\n",
    "                 # Lambda = diag(gQ)^-1 T diag(gR)^-1 form\n",
    "                u, v = Sinkhorn(xi3, gQ, gR, r, r, gamma_k, tau)\n",
    "                u = gQ / (xi3 @ v) # Last tightening of marginal\n",
    "            else:\n",
    "                # Swap marginals; Lambda = T^-1 form\n",
    "                u, v = Sinkhorn(xi3, gR, gQ, r, r, gamma_k, tau)\n",
    "                u = gQ / (xi3 @ v)\n",
    "        else:\n",
    "            xi1 = Q * torch.exp(- gamma_k * gradQ )\n",
    "            xi2 = R * torch.exp(- gamma_k * gradR )\n",
    "\n",
    "            u, v = semi_project_Balanced(xi1, g1, gQ, N1, r, gamma_k, tau)\n",
    "            Q = torch.diag(u) @ xi1 @ torch.diag(v)\n",
    "\n",
    "            _u, _v = semi_project_Balanced(xi2, g2, gR, N2, r, gamma_k, tau)\n",
    "            R = torch.diag(_u) @ xi2 @ torch.diag(_v)\n",
    "\n",
    "            gQ, gR = Q.T @ one_N1, R.T @ one_N2\n",
    "            gradT, gamma_T = compute_grad_B(C, Q, R, Lambda, gQ, gR, gamma, device, Wasserstein=Wasserstein, A=A, B=B, FGW=FGW, alpha=alpha)\n",
    "            xi3 = T * torch.exp(- gamma_T * gradT )\n",
    "\n",
    "            if not swap:\n",
    "                # Lambda = diag(gQ)^-1 T diag(gR)^-1 form\n",
    "                u, v = Sinkhorn(xi3, gQ, gR, r, r, gamma_k, tau)\n",
    "            else:\n",
    "                # Swap marginals; Lambda = T^-1 form\n",
    "                u, v = Sinkhorn(xi3, gR, gQ, r, r, gamma_k, tau)\n",
    "        # Construct latent transition matrix\n",
    "        T = torch.diag(u) @ xi3 @ torch.diag(v)\n",
    "        # Inner latent transition-inverse matrix\n",
    "        Lambda = torch.diag(1/gQ) @ T @ torch.diag(1/gR)\n",
    "#         if Wasserstein:\n",
    "#             P = Q @ Lambda @ R.T\n",
    "#             cost = torch.sum(C * P)\n",
    "#         else:\n",
    "#             P = Q @ Lambda @ R.T\n",
    "#             M1 = Q.T @ A**2 @ Q\n",
    "#             M2 = R.T @ B**2 @ R\n",
    "#             cost = one_r.T @ M1 @ one_r + one_r.T @ M2 @ one_r -2*torch.trace((A @ P @ B).T @ P)\n",
    "#             #cost = one_N2.T @ P.T @ A**2 @ P @ one_N2 + one_N1.T @ P @ B**2 @ P.T @ one_N1 - 2*torch.trace((A @ P @ B).T @ P)\n",
    "#             if FGW:\n",
    "#                 cost = (1-alpha)*torch.sum(C * P) + alpha*cost\n",
    "#         errs.append(cost)\n",
    "        #print(f'OT cost (T): {cost},diff from g1:{torch.linalg.norm(P @ one_N2 - g1)},diff from g2:{torch.linalg.norm(P.T @ one_N1 - g2)}')\n",
    "        k+=1\n",
    "        #print(\"Iteration finished\")\n",
    "\n",
    "    # Plotting OT objective value across iterations\n",
    "#     plt.plot(range(len(errs)), errs)\n",
    "#     plt.show()\n",
    "\n",
    "    # plt.imshow(T)\n",
    "    # plt.show()\n",
    "\n",
    "    # if semiRelaxedLeft or semiRelaxedRight:\n",
    "    #     P = Q @ Lambda @ R.T\n",
    "    #     # return P, errs\n",
    "    # else:\n",
    "    #     P = Q @ Lambda @ R.T\n",
    "    #     # return P, errs\n",
    "    return Q, Lambda, R"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 436
    },
    "executionInfo": {
     "elapsed": 158381,
     "status": "error",
     "timestamp": 1706238500274,
     "user": {
      "displayName": "",
      "userId": "08293807456834214233"
     },
     "user_tz": 300
    },
    "id": "dTNju5i6Ke_H",
    "outputId": "1e0309b5-1b4c-42e9-dceb-75efb0c1134b",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def pred_expression(Q, Lambda, R, f, m):\n",
    "    temp = Q.T @ f\n",
    "    temp = Lambda.T @ temp\n",
    "    temp = R @ temp\n",
    "    return m * temp\n",
    "\n",
    "\n",
    "def pred_type(Q, Lambda, R, matrix, m):\n",
    "    temp = Q.T @ matrix\n",
    "    temp = Lambda.T @ temp\n",
    "    temp = R @ temp\n",
    "    return m * temp\n",
    "\n",
    "def scale_matrix_rows(matrix):\n",
    "    # Calculate the L2 norm for each row\n",
    "    norms = np.linalg.norm(matrix, axis=1)\n",
    "\n",
    "    # Find the maximum norm\n",
    "    max_norm = np.max(norms)\n",
    "\n",
    "    # Avoid division by zero\n",
    "    if max_norm == 0:\n",
    "        return matrix\n",
    "\n",
    "    # Scale each row\n",
    "    matrix_scaled = matrix / max_norm\n",
    "\n",
    "    return matrix_scaled\n",
    "\n",
    "gamma = 50\n",
    "max_iter = 50\n",
    "def LRW(tau_list, rank_list, validation_gene_list, test_gene_list):\n",
    "    with open('/PATH/TO/mouse_embryo/slice1_marker_expr.json', 'r') as file:\n",
    "        slice1_marker_expr = json.load(file)\n",
    "    with open('/PATH/TO/mouse_embryo/slice2_marker_expr.json', 'r') as file:\n",
    "        slice2_marker_expr = json.load(file)\n",
    "    with open('/PATH/TO/mouse_embryo/slice1_types.json', 'r') as file:\n",
    "        slice1_types = json.load(file)\n",
    "    with open('/PATH/TO/mouse_embryo/slice2_types.json', 'r') as file:\n",
    "        slice2_types = json.load(file)\n",
    "    data_t1 = np.load('/PATH/TO/mouse_embryo/slice1_feature.npy')\n",
    "    data_t2 = np.load('/PATH/TO/mouse_embryo/slice2_feature.npy')\n",
    "    n, m = data_t1.shape[0], data_t2.shape[0]\n",
    "    C = distance.cdist(data_t1, data_t2)\n",
    "    C /= C.max()\n",
    "    C = torch.from_numpy(C).to(device)\n",
    "    # Validation\n",
    "    validation_param = {}\n",
    "    for rank in rank_list:\n",
    "        for tau in tau_list:\n",
    "            Q, Lambda, R = LOT_iteration(C, A=None, B=None, tau=tau, gamma=gamma, r=rank, max_iter=max_iter,\n",
    "                  semiRelaxedLeft=True, semiRelaxedRight=False, Wasserstein=True,\n",
    "                 swap=False, initA=False, FGW=False, alpha=0.5)\n",
    "            Q, Lambda, R = Q.numpy(),Lambda.numpy(),R.numpy()\n",
    "            validation_corr_list = []\n",
    "            for gene in validation_gene_list:\n",
    "                expression_t1 = np.array(slice1_marker_expr[gene])\n",
    "                expression_t2 = np.array(slice2_marker_expr[gene])\n",
    "                pred_expression_t2 = pred_expression(Q, Lambda, R, expression_t1, m)\n",
    "                correlation, _ = spearmanr(expression_t2, pred_expression_t2)\n",
    "                validation_corr_list.append(correlation)\n",
    "            validation_param[(rank, tau)] = np.mean(np.array(validation_corr_list))\n",
    "    rank, tau = max(validation_param, key=validation_param.get)\n",
    "    print(\"The best parameter combination is: \", (rank, tau))\n",
    "    print(\"The best validation spearman correlation is: \", validation_param[(rank, tau)])\n",
    "\n",
    "    # Test\n",
    "    Q, Lambda, R = LOT_iteration(C, A=None, B=None, tau=tau, gamma=gamma, r=rank, max_iter=max_iter,\n",
    "                  semiRelaxedLeft=True, semiRelaxedRight=False, Wasserstein=True,\n",
    "                 swap=False, initA=False, FGW=False, alpha=0.5)\n",
    "    Q, Lambda, R = Q.numpy(),Lambda.numpy(),R.numpy()\n",
    "\n",
    "    # Pearson coorelation\n",
    "    test_corr_list = []\n",
    "    for gene in test_gene_list:\n",
    "        expression_t1 = np.array(slice1_marker_expr[gene])\n",
    "        expression_t2 = np.array(slice2_marker_expr[gene])\n",
    "        pred_expression_t2 = pred_expression(Q, Lambda, R, expression_t1, m)\n",
    "        correlation, _ = spearmanr(expression_t2, pred_expression_t2)\n",
    "        test_corr_list.append(correlation)\n",
    "    print(\"The test spearman correlation is: \", np.mean(np.array(test_corr_list)))\n",
    "\n",
    "    # Clustering prediction\n",
    "    # Instantiate the LabelBinarizer\n",
    "    lb = LabelBinarizer()\n",
    "    # Perform one-hot encoding\n",
    "    slice1_label_onehot = lb.fit_transform(slice1_types) # one_hot_encoded_matrix, lb.classes_\n",
    "    print(slice1_label_onehot)\n",
    "    print(lb.classes_)\n",
    "    # pred_slice2_label_onehot = m * r @ np.diag(1/g) @ q.T @ slice1_label_onehot\n",
    "    pred_slice2_label_onehot = pred_type(Q, Lambda, R, slice1_label_onehot, m)\n",
    "    # Finding the index of the max value in each row\n",
    "    pred_slice2_label_index = np.argmax(pred_slice2_label_onehot, axis=1)\n",
    "    pred_slice2_label = [lb.classes_[index] for index in pred_slice2_label_index]\n",
    "    ari = adjusted_rand_score(slice2_types, pred_slice2_label)\n",
    "    ami = adjusted_mutual_info_score(slice2_types, pred_slice2_label)\n",
    "    print(\"The ARI is: \", ari)\n",
    "    print(\"The AMI is: \", ami)\n",
    "\n",
    "\n",
    "test_gene_list = ['Tubb2b', 'Pantr1', 'Actc1', 'Tnni1', 'Afp', 'Hbb-bh1', 'Fez1', 'Crabp1', 'Crabp2', 'Col3a1']\n",
    "validation_gene_list = ['Ckb', 'Fabp7', 'Myl4', 'Tnnt2', 'Apoa2', 'Hba-x', 'Tubb3', 'Epha7', 'Ldha', 'Col1a2']\n",
    "tau_list = [30, 50, 100]\n",
    "rank_list = [50, 100, 200]\n",
    "\n",
    "LRW(tau_list, rank_list, validation_gene_list, test_gene_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pred_expression(Q, Lambda, R, f, m):\n",
    "    temp = Q.T @ f\n",
    "    temp = Lambda.T @ temp\n",
    "    temp = R @ temp\n",
    "    return m * temp\n",
    "\n",
    "\n",
    "def pred_type(Q, Lambda, R, matrix, m):\n",
    "    temp = Q.T @ matrix\n",
    "    temp = Lambda.T @ temp\n",
    "    temp = R @ temp\n",
    "    return m * temp\n",
    "\n",
    "gamma = 50\n",
    "max_iter = 50\n",
    "def LRGW(tau_list, rank_list, validation_gene_list, test_gene_list):\n",
    "    with open('/PATH/TO/mouse_embryo/slice1_marker_expr.json', 'r') as file:\n",
    "        slice1_marker_expr = json.load(file)\n",
    "    with open('/PATH/TO/mouse_embryo/slice2_marker_expr.json', 'r') as file:\n",
    "        slice2_marker_expr = json.load(file)\n",
    "    with open('/PATH/TO/mouse_embryo/slice1_types.json', 'r') as file:\n",
    "        slice1_types = json.load(file)\n",
    "    with open('/PATH/TO/mouse_embryo/slice2_types.json', 'r') as file:\n",
    "        slice2_types = json.load(file)\n",
    "    data_t1 = np.load('/PATH/TO/mouse_embryo/slice1_coordinates.npy')\n",
    "    data_t2 = np.load('/PATH/TO/mouse_embryo/slice2_coordinates.npy')\n",
    "    n, m = data_t1.shape[0], data_t2.shape[0]\n",
    "    A = distance.cdist(data_t1, data_t1)\n",
    "    B = distance.cdist(data_t2, data_t2)\n",
    "    A /= A.max()\n",
    "    B /= B.max()\n",
    "    A = torch.from_numpy(A).to(device)\n",
    "    B = torch.from_numpy(B).to(device)\n",
    "    # Validation\n",
    "    validation_param = {}\n",
    "    for rank in rank_list:\n",
    "        for tau in tau_list:\n",
    "            Q, Lambda, R = LOT_iteration(C=None, A=A, B=B, tau=tau, gamma=gamma, r=rank, max_iter=max_iter,\n",
    "                  semiRelaxedLeft=True, semiRelaxedRight=False, Wasserstein=False,\n",
    "                 swap=False, initA=False, FGW=False, alpha=0.5)\n",
    "            Q, Lambda, R = Q.numpy(),Lambda.numpy(),R.numpy()\n",
    "            validation_corr_list = []\n",
    "            for gene in validation_gene_list:\n",
    "                expression_t1 = np.array(slice1_marker_expr[gene])\n",
    "                expression_t2 = np.array(slice2_marker_expr[gene])\n",
    "                pred_expression_t2 = pred_expression(Q, Lambda, R, expression_t1, m)\n",
    "                correlation, _ = spearmanr(expression_t2, pred_expression_t2)\n",
    "                validation_corr_list.append(correlation)\n",
    "            validation_param[(rank, tau)] = np.mean(np.array(validation_corr_list))\n",
    "    rank, tau = max(validation_param, key=validation_param.get)\n",
    "    print(\"The best parameter combination is: \", (rank, tau))\n",
    "    print(\"The best validation spearman correlation is: \", validation_param[(rank, tau)])\n",
    "\n",
    "    # Test\n",
    "    Q, Lambda, R = LOT_iteration(C=None, A=A, B=B, tau=tau, gamma=gamma, r=rank, max_iter=max_iter,\n",
    "                  semiRelaxedLeft=True, semiRelaxedRight=False, Wasserstein=False,\n",
    "                 swap=False, initA=False, FGW=False, alpha=0.5)\n",
    "    Q, Lambda, R = Q.numpy(),Lambda.numpy(),R.numpy()\n",
    "\n",
    "    # Pearson coorelation\n",
    "    test_corr_list = []\n",
    "    for gene in test_gene_list:\n",
    "        expression_t1 = np.array(slice1_marker_expr[gene])\n",
    "        expression_t2 = np.array(slice2_marker_expr[gene])\n",
    "        pred_expression_t2 = pred_expression(Q, Lambda, R, expression_t1, m)\n",
    "        correlation, _ = spearmanr(expression_t2, pred_expression_t2)\n",
    "        test_corr_list.append(correlation)\n",
    "    print(\"The test spearman correlation is: \", np.mean(np.array(test_corr_list)))\n",
    "\n",
    "    # Clustering prediction\n",
    "    # Instantiate the LabelBinarizer\n",
    "    lb = LabelBinarizer()\n",
    "    # Perform one-hot encoding\n",
    "    slice1_label_onehot = lb.fit_transform(slice1_types) # one_hot_encoded_matrix, lb.classes_\n",
    "    print(slice1_label_onehot)\n",
    "    print(lb.classes_)\n",
    "    # pred_slice2_label_onehot = m * r @ np.diag(1/g) @ q.T @ slice1_label_onehot\n",
    "    pred_slice2_label_onehot = pred_type(Q, Lambda, R, slice1_label_onehot, m)\n",
    "    # Finding the index of the max value in each row\n",
    "    pred_slice2_label_index = np.argmax(pred_slice2_label_onehot, axis=1)\n",
    "    pred_slice2_label = [lb.classes_[index] for index in pred_slice2_label_index]\n",
    "    ari = adjusted_rand_score(slice2_types, pred_slice2_label)\n",
    "    ami = adjusted_mutual_info_score(slice2_types, pred_slice2_label)\n",
    "    print(\"The ARI is: \", ari)\n",
    "    print(\"The AMI is: \", ami)\n",
    "\n",
    "\n",
    "test_gene_list = ['Tubb2b', 'Pantr1', 'Actc1', 'Tnni1', 'Afp', 'Hbb-bh1', 'Fez1', 'Crabp1', 'Crabp2', 'Col3a1']\n",
    "validation_gene_list = ['Ckb', 'Fabp7', 'Myl4', 'Tnnt2', 'Apoa2', 'Hba-x', 'Tubb3', 'Epha7', 'Ldha', 'Col1a2']\n",
    "tau_list = [30, 50, 100]\n",
    "rank_list = [50, 100, 200]\n",
    "\n",
    "LRGW(tau_list, rank_list, validation_gene_list, test_gene_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FGW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def pred_expression(Q, Lambda, R, f, m):\n",
    "    temp = Q.T @ f\n",
    "    temp = Lambda.T @ temp\n",
    "    temp = R @ temp\n",
    "    return m * temp\n",
    "\n",
    "\n",
    "def pred_type(Q, Lambda, R, matrix, m):\n",
    "    temp = Q.T @ matrix\n",
    "    temp = Lambda.T @ temp\n",
    "    temp = R @ temp\n",
    "    return m * temp\n",
    "\n",
    "gamma = 50\n",
    "max_iter = 50\n",
    "def LRFGW(tau_list, rank_list, validation_gene_list, test_gene_list):\n",
    "    with open('/PATH/TO/mouse_embryo/slice1_marker_expr.json', 'r') as file:\n",
    "        slice1_marker_expr = json.load(file)\n",
    "    with open('/PATH/TO/mouse_embryo/slice2_marker_expr.json', 'r') as file:\n",
    "        slice2_marker_expr = json.load(file)\n",
    "    with open('/PATH/TO/mouse_embryo/slice1_types.json', 'r') as file:\n",
    "        slice1_types = json.load(file)\n",
    "    with open('/PATH/TO/mouse_embryo/slice2_types.json', 'r') as file:\n",
    "        slice2_types = json.load(file)\n",
    "    data_t1 = np.load('/PATH/TO/mouse_embryo/slice1_feature.npy')\n",
    "    data_t2 = np.load('/PATH/TO/mouse_embryo/slice2_feature.npy')\n",
    "    coordinate_t1 = np.load('/PATH/TO/mouse_embryo/slice1_coordinates.npy')\n",
    "    coordinate_t2 = np.load('/PATH/TO/mouse_embryo/slice2_coordinates.npy')\n",
    "    n, m = data_t1.shape[0], data_t2.shape[0]\n",
    "    A = distance.cdist(coordinate_t1, coordinate_t1)\n",
    "    B = distance.cdist(coordinate_t2, coordinate_t2)\n",
    "    C = distance.cdist(data_t1, data_t2)\n",
    "    C /= C.max()\n",
    "    A /= A.max()\n",
    "    B /= B.max()\n",
    "    A = torch.from_numpy(A).to(device)\n",
    "    B = torch.from_numpy(B).to(device)\n",
    "    C = torch.from_numpy(C).to(device)\n",
    "    print(\"data loaded\")\n",
    "    # Validation\n",
    "    validation_param = {}\n",
    "    for rank in rank_list:\n",
    "        for tau in tau_list:\n",
    "            Q, Lambda, R = LOT_iteration(C=C, A=A, B=B, tau=tau, gamma=gamma, r=rank, max_iter=max_iter,\n",
    "                  semiRelaxedLeft=True, semiRelaxedRight=False, Wasserstein=False,\n",
    "                 swap=False, initA=False, FGW=True, alpha=0.1)\n",
    "            Q, Lambda, R = Q.numpy(),Lambda.numpy(),R.numpy()\n",
    "            validation_corr_list = []\n",
    "            for gene in validation_gene_list:\n",
    "                expression_t1 = np.array(slice1_marker_expr[gene])\n",
    "                expression_t2 = np.array(slice2_marker_expr[gene])\n",
    "                pred_expression_t2 = pred_expression(Q, Lambda, R, expression_t1, m)\n",
    "                correlation, _ = spearmanr(expression_t2, pred_expression_t2)\n",
    "                validation_corr_list.append(correlation)\n",
    "            validation_param[(rank, tau)] = np.mean(np.array(validation_corr_list))\n",
    "    rank, tau = max(validation_param, key=validation_param.get)\n",
    "    print(\"The best parameter combination is: \", (rank, tau))\n",
    "    print(\"The best validation spearman correlation is: \", validation_param[(rank, tau)])\n",
    "\n",
    "    # Test\n",
    "    Q, Lambda, R = LOT_iteration(C=C, A=A, B=B, tau=tau, gamma=gamma, r=rank, max_iter=max_iter,\n",
    "                  semiRelaxedLeft=True, semiRelaxedRight=False, Wasserstein=False,\n",
    "                 swap=False, initA=False, FGW=True, alpha=0.1)\n",
    "    Q, Lambda, R = Q.numpy(),Lambda.numpy(),R.numpy()\n",
    "\n",
    "    # Pearson coorelation\n",
    "    test_corr_list = []\n",
    "    for gene in test_gene_list:\n",
    "        expression_t1 = np.array(slice1_marker_expr[gene])\n",
    "        expression_t2 = np.array(slice2_marker_expr[gene])\n",
    "        pred_expression_t2 = pred_expression(Q, Lambda, R, expression_t1, m)\n",
    "        correlation, _ = spearmanr(expression_t2, pred_expression_t2)\n",
    "        test_corr_list.append(correlation)\n",
    "    print(\"The test spearman correlation is: \", np.mean(np.array(test_corr_list)))\n",
    "\n",
    "    # Clustering prediction\n",
    "    # Instantiate the LabelBinarizer\n",
    "    lb = LabelBinarizer()\n",
    "    # Perform one-hot encoding\n",
    "    slice1_label_onehot = lb.fit_transform(slice1_types) # one_hot_encoded_matrix, lb.classes_\n",
    "    print(slice1_label_onehot)\n",
    "    print(lb.classes_)\n",
    "    # pred_slice2_label_onehot = m * r @ np.diag(1/g) @ q.T @ slice1_label_onehot\n",
    "    pred_slice2_label_onehot = pred_type(Q, Lambda, R, slice1_label_onehot, m)\n",
    "    # Finding the index of the max value in each row\n",
    "    pred_slice2_label_index = np.argmax(pred_slice2_label_onehot, axis=1)\n",
    "    pred_slice2_label = [lb.classes_[index] for index in pred_slice2_label_index]\n",
    "    ari = adjusted_rand_score(slice2_types, pred_slice2_label)\n",
    "    ami = adjusted_mutual_info_score(slice2_types, pred_slice2_label)\n",
    "    print(\"The ARI is: \", ari)\n",
    "    print(\"The AMI is: \", ami)\n",
    "\n",
    "\n",
    "test_gene_list = ['Tubb2b', 'Pantr1', 'Actc1', 'Tnni1', 'Afp', 'Hbb-bh1', 'Fez1', 'Crabp1', 'Crabp2', 'Col3a1']\n",
    "validation_gene_list = ['Ckb', 'Fabp7', 'Myl4', 'Tnnt2', 'Apoa2', 'Hba-x', 'Tubb3', 'Epha7', 'Ldha', 'Col1a2']\n",
    "tau_list = [30, 50, 100]\n",
    "rank_list = [50, 100, 200]\n",
    "\n",
    "LRFGW(tau_list, rank_list, validation_gene_list, test_gene_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
