{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"A100","machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","source":["!pip install ott-jax"],"metadata":{"id":"VukzksGycdb2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#import scanpy as sc\n","from ott.geometry import pointcloud, geometry\n","from ott.problems.linear import linear_problem\n","from ott.solvers.linear import sinkhorn, sinkhorn_lr\n","from ott.problems.quadratic import quadratic_problem\n","from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr\n","import jax\n","from scipy.stats import pearsonr, spearmanr\n","from scipy.spatial import distance\n","from sklearn.preprocessing import LabelBinarizer\n","from sklearn.metrics.cluster import adjusted_rand_score, adjusted_mutual_info_score\n","import numpy as np\n","# import anndata as ad\n","import json"],"metadata":{"id":"J8Z4eOQybIJr"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"Rmu7J-PFbMMr"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Low rank W"],"metadata":{"id":"-QYX8pIB9Qq7"}},{"cell_type":"code","source":["def pred_expression(q, g, r, f, m):\n","  temp = q.T @ f\n","  temp = np.diag(1/g) @ temp\n","  temp = r @ temp\n","  return m * temp\n","\n","\n","def pred_type(q, g, r, matrix, m):\n","  temp = q.T @ matrix\n","  temp = np.diag(1/g) @ temp\n","  temp = r @ temp\n","  return m * temp\n","\n","def scale_matrix_rows(matrix):\n","    # Calculate the L2 norm for each row\n","    norms = np.linalg.norm(matrix, axis=1)\n","\n","    # Find the maximum norm\n","    max_norm = np.max(norms)\n","\n","    # Avoid division by zero\n","    if max_norm == 0:\n","        return matrix\n","\n","    # Scale each row\n","    matrix_scaled = matrix / max_norm\n","\n","    return matrix_scaled\n","\n","def LRW(epsilon_list, rank_list, tau_list, validation_gene_list, test_gene_list):\n","  with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice1_marker_expr.json', 'r') as file:\n","    slice1_marker_expr = json.load(file)\n","  with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice2_marker_expr.json', 'r') as file:\n","    slice2_marker_expr = json.load(file)\n","  with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice1_types.json', 'r') as file:\n","    slice1_types = json.load(file)\n","  with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice2_types.json', 'r') as file:\n","    slice2_types = json.load(file)\n","  data_t1 = np.load('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice1_feature.npy')\n","  data_t2 = np.load('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice2_feature.npy')\n","  n, m = data_t1.shape[0], data_t2.shape[0]\n","  data_t1 = scale_matrix_rows(data_t1)\n","  data_t2 = scale_matrix_rows(data_t2)\n","  geom = pointcloud.PointCloud(x=data_t1, y=data_t2)\n","  # Validation\n","  validation_param = {}\n","  for rank in rank_list:\n","    for epsilon in epsilon_list:\n","      for tau in tau_list:\n","        ot_prob = linear_problem.LinearProblem(geom, tau_a=tau, tau_b=tau)\n","        solver = jax.jit(sinkhorn_lr.LRSinkhorn(rank=rank, epsilon=epsilon))\n","        ot_lr = solver(ot_prob)\n","        q, g, r, = ot_lr.q, ot_lr.g, ot_lr.r\n","        validation_corr_list = []\n","        for gene in validation_gene_list:\n","          expression_t1 = np.array(slice1_marker_expr[gene])\n","          expression_t2 = np.array(slice2_marker_expr[gene])\n","          pred_expression_t2 = pred_expression(q, g, r, expression_t1, m)\n","          correlation, _ = spearmanr(expression_t2, pred_expression_t2)\n","          validation_corr_list.append(correlation)\n","        validation_param[(rank, epsilon, tau)] = np.mean(np.array(validation_corr_list))\n","        del ot_prob\n","        del solver\n","        del ot_lr\n","        del q, g, r\n","  rank, epsilon, tau = max(validation_param, key=validation_param.get)\n","  print(\"The best parameter combination is: \", (rank, epsilon, tau))\n","  print(\"The best validation spearman correlation is: \", validation_param[(rank, epsilon, tau)])\n","\n","  # Test\n","  ot_prob = linear_problem.LinearProblem(geom, tau_a=tau, tau_b=tau)\n","  solver = jax.jit(sinkhorn_lr.LRSinkhorn(rank=rank, epsilon=epsilon))\n","  ot_lr = solver(ot_prob)\n","  q, g, r, = ot_lr.q, ot_lr.g, ot_lr.r\n","\n","  # Pearson coorelation\n","  test_corr_list = []\n","  for gene in test_gene_list:\n","    expression_t1 = np.array(slice1_marker_expr[gene])\n","    expression_t2 = np.array(slice2_marker_expr[gene])\n","    # pred_expression_t2 = m * r @ np.diag(1/g) @ q.T @ expression_t1\n","    pred_expression_t2 = pred_expression(q, g, r, expression_t1, m)\n","    correlation, _ = spearmanr(expression_t2, pred_expression_t2)\n","    test_corr_list.append(correlation)\n","  print(\"The test spearman correlation is: \", np.mean(np.array(test_corr_list)))\n","\n","  # Clustering prediction\n","  # Instantiate the LabelBinarizer\n","  lb = LabelBinarizer()\n","  # Perform one-hot encoding\n","  slice1_label_onehot = lb.fit_transform(slice1_types) # one_hot_encoded_matrix, lb.classes_\n","  # pred_slice2_label_onehot = m * r @ np.diag(1/g) @ q.T @ slice1_label_onehot\n","  pred_slice2_label_onehot = pred_type(q, g, r, slice1_label_onehot, m)\n","  # Finding the index of the max value in each row\n","  pred_slice2_label_index = np.argmax(pred_slice2_label_onehot, axis=1)\n","  pred_slice2_label = [lb.classes_[index] for index in pred_slice2_label_index]\n","  ari = adjusted_rand_score(slice2_types, pred_slice2_label)\n","  ami = adjusted_mutual_info_score(slice2_types, pred_slice2_label)\n","  print(\"The ARI is: \", ari)\n","  print(\"The AMI is: \", ami)\n","\n","\n","test_gene_list = ['Tubb2b', 'Pantr1', 'Actc1', 'Tnni1', 'Afp', 'Hbb-bh1', 'Fez1', 'Crabp1', 'Crabp2', 'Col3a1']\n","validation_gene_list = ['Ckb', 'Fabp7', 'Myl4', 'Tnnt2', 'Apoa2', 'Hba-x', 'Tubb3', 'Epha7', 'Ldha', 'Col1a2']\n","epsilon_list = [0.001, 0.01, 0.1]\n","rank_list = [50, 100, 200]\n","tau_list = [0.99, 0.9, 0.7]\n","\n","LRW(epsilon_list, rank_list, tau_list, validation_gene_list, test_gene_list)"],"metadata":{"id":"dTNju5i6Ke_H"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Low rank GW"],"metadata":{"id":"a-HAiwWE9Wct"}},{"cell_type":"code","source":["def pred_expression(q, g, r, f, m):\n","  temp = q.T @ f\n","  temp = np.diag(1/g) @ temp\n","  temp = r @ temp\n","  return m * temp\n","\n","\n","def pred_type(q, g, r, matrix, m):\n","  temp = q.T @ matrix\n","  temp = np.diag(1/g) @ temp\n","  temp = r @ temp\n","  return m * temp\n","\n","def scale_matrix_rows(matrix):\n","    # Calculate the L2 norm for each row\n","    norms = np.linalg.norm(matrix, axis=1)\n","\n","    # Find the maximum norm\n","    max_norm = np.max(norms)\n","\n","    # Avoid division by zero\n","    if max_norm == 0:\n","        return matrix\n","\n","    # Scale each row\n","    matrix_scaled = matrix / max_norm\n","\n","    return matrix_scaled\n","\n","def LRGW(epsilon_list, rank_list, tau_list, validation_gene_list, test_gene_list):\n","  with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice1_marker_expr.json', 'r') as file:\n","    slice1_marker_expr = json.load(file)\n","  with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice2_marker_expr.json', 'r') as file:\n","    slice2_marker_expr = json.load(file)\n","  with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice1_types.json', 'r') as file:\n","    slice1_types = json.load(file)\n","  with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice2_types.json', 'r') as file:\n","    slice2_types = json.load(file)\n","  data_t1 = np.load('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice1_coordinates.npy')\n","  data_t2 = np.load('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice2_coordinates.npy')\n","  n, m = data_t1.shape[0], data_t2.shape[0]\n","  data_t1 = scale_matrix_rows(data_t1)\n","  data_t2 = scale_matrix_rows(data_t2)\n","  geom_xx = pointcloud.PointCloud(x=data_t1, y=data_t1)\n","  geom_yy = pointcloud.PointCloud(x=data_t2, y=data_t2)\n","  # Validation\n","  validation_param = {}\n","  cnt = 1\n","  for rank in rank_list:\n","    for epsilon in epsilon_list:\n","      for tau in tau_list:\n","        print(cnt)\n","        cnt += 1\n","        ot_prob = quadratic_problem.QuadraticProblem(geom_xx=geom_xx,geom_yy=geom_yy,tau_a=tau,tau_b=tau)\n","        solver = jax.jit(gromov_wasserstein_lr.LRGromovWasserstein(rank=rank, epsilon=epsilon))\n","        ot_lr = solver(ot_prob)\n","        q, g, r, = ot_lr.q, ot_lr.g, ot_lr.r\n","        validation_corr_list = []\n","        for gene in validation_gene_list:\n","          expression_t1 = np.array(slice1_marker_expr[gene])\n","          expression_t2 = np.array(slice2_marker_expr[gene])\n","          pred_expression_t2 = pred_expression(q, g, r, expression_t1, m)\n","          correlation, _ = spearmanr(expression_t2, pred_expression_t2)\n","          validation_corr_list.append(correlation)\n","        validation_param[(rank, epsilon, tau)] = np.mean(np.array(validation_corr_list))\n","        del ot_prob\n","        del solver\n","        del ot_lr\n","        del q\n","        del g\n","        del r\n","  rank, epsilon, tau = max(validation_param, key=validation_param.get)\n","  print(\"The best parameter combination is: \", (rank, epsilon, tau))\n","  print(\"The best validation spearman correlation is: \", validation_param[(rank, epsilon, tau)])\n","\n","  # Test\n","  ot_prob = quadratic_problem.QuadraticProblem(geom_xx=geom_xx,geom_yy=geom_yy,tau_a=tau,tau_b=tau)\n","  solver = jax.jit(gromov_wasserstein_lr.LRGromovWasserstein(rank=rank, epsilon=epsilon))\n","  ot_lr = solver(ot_prob)\n","  q, g, r, = ot_lr.q, ot_lr.g, ot_lr.r\n","\n","  # Pearson coorelation\n","  test_corr_list = []\n","  for gene in test_gene_list:\n","    expression_t1 = np.array(slice1_marker_expr[gene])\n","    expression_t2 = np.array(slice2_marker_expr[gene])\n","    pred_expression_t2 = pred_expression(q, g, r, expression_t1, m)\n","    correlation, _ = spearmanr(expression_t2, pred_expression_t2)\n","    test_corr_list.append(correlation)\n","  print(\"The test spearman correlation is: \", np.mean(np.array(test_corr_list)))\n","\n","  # Clustering prediction\n","  # Instantiate the LabelBinarizer\n","  lb = LabelBinarizer()\n","  # Perform one-hot encoding\n","  slice1_label_onehot = lb.fit_transform(slice1_types) # one_hot_encoded_matrix, lb.classes_\n","  print(slice1_label_onehot)\n","  print(lb.classes_)\n","  # pred_slice2_label_onehot = m * r @ np.diag(1/g) @ q.T @ slice1_label_onehot\n","  pred_slice2_label_onehot = pred_type(q, g, r, slice1_label_onehot, m)\n","  # Finding the index of the max value in each row\n","  pred_slice2_label_index = np.argmax(pred_slice2_label_onehot, axis=1)\n","  pred_slice2_label = [lb.classes_[index] for index in pred_slice2_label_index]\n","  ari = adjusted_rand_score(slice2_types, pred_slice2_label)\n","  ami = adjusted_mutual_info_score(slice2_types, pred_slice2_label)\n","  print(\"The ARI is: \", ari)\n","  print(\"The AMI is: \", ami)\n","\n","\n","test_gene_list = ['Tubb2b', 'Pantr1', 'Actc1', 'Tnni1', 'Afp', 'Hbb-bh1', 'Fez1', 'Crabp1', 'Crabp2', 'Col3a1']\n","validation_gene_list = ['Ckb', 'Fabp7', 'Myl4', 'Tnnt2', 'Apoa2', 'Hba-x', 'Tubb3', 'Epha7', 'Ldha', 'Col1a2']\n","epsilon_list = [0.001, 0.01, 0.1]\n","rank_list = [50, 100, 200]\n","tau_list = [0.99, 0.9, 0.7]\n","\n","LRGW(epsilon_list, rank_list, tau_list, validation_gene_list, test_gene_list)"],"metadata":{"id":"g45z_PteMMMl"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Low rank FGW"],"metadata":{"id":"2gCU5A5iBPW5"}},{"cell_type":"code","source":["def pred_expression(q, g, r, f, m):\n","  temp = q.T @ f\n","  temp = np.diag(1/g) @ temp\n","  temp = r @ temp\n","  return m * temp\n","\n","\n","def pred_type(q, g, r, matrix, m):\n","  temp = q.T @ matrix\n","  temp = np.diag(1/g) @ temp\n","  temp = r @ temp\n","  return m * temp\n","\n","def scale_matrix_rows(matrix):\n","    # Calculate the L2 norm for each row\n","    norms = np.linalg.norm(matrix, axis=1)\n","\n","    # Find the maximum norm\n","    max_norm = np.max(norms)\n","\n","    # Avoid division by zero\n","    if max_norm == 0:\n","        return matrix\n","\n","    # Scale each row\n","    matrix_scaled = matrix / max_norm\n","\n","    return matrix_scaled\n","\n","# alpha=0.1\n","# fused_penalty = (1 - alpha) / alpha\n","def LRFGW(epsilon_list, rank_list, tau_list, validation_gene_list, test_gene_list):\n","  with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice1_marker_expr.json', 'r') as file:\n","    slice1_marker_expr = json.load(file)\n","  with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice2_marker_expr.json', 'r') as file:\n","    slice2_marker_expr = json.load(file)\n","  with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice1_types.json', 'r') as file:\n","    slice1_types = json.load(file)\n","  with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice2_types.json', 'r') as file:\n","    slice2_types = json.load(file)\n","  data_t1 = np.load('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice1_feature.npy')\n","  data_t2 = np.load('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice2_feature.npy')\n","  coordinate_t1 = np.load('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice1_coordinates.npy')\n","  coordinate_t2 = np.load('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/slice2_coordinates.npy')\n","  n, m = data_t1.shape[0], data_t2.shape[0]\n","  data_t1 = scale_matrix_rows(data_t1)\n","  data_t2 = scale_matrix_rows(data_t2)\n","  coordinate_t1 = scale_matrix_rows(coordinate_t1)\n","  coordinate_t2 = scale_matrix_rows(coordinate_t2)\n","  geom_xx = pointcloud.PointCloud(x=coordinate_t1, y=coordinate_t1)\n","  geom_yy = pointcloud.PointCloud(x=coordinate_t2, y=coordinate_t2)\n","  geom_xy = pointcloud.PointCloud(x=data_t1, y=data_t2)\n","  # Validation\n","  validation_param = {}\n","  for rank in rank_list:\n","    for epsilon in epsilon_list:\n","      for tau in tau_list:\n","        ot_prob = quadratic_problem.QuadraticProblem(geom_xx=geom_xx,geom_yy=geom_yy,geom_xy=geom_xy, tau_a=tau, tau_b=tau)\n","        solver = jax.jit(gromov_wasserstein_lr.LRGromovWasserstein(rank=rank, epsilon=epsilon))\n","        ot_lr = solver(ot_prob)\n","        q, g, r, = ot_lr.q, ot_lr.g, ot_lr.r\n","        validation_corr_list = []\n","        for gene in validation_gene_list:\n","          expression_t1 = np.array(slice1_marker_expr[gene])\n","          expression_t2 = np.array(slice2_marker_expr[gene])\n","          pred_expression_t2 = pred_expression(q, g, r, expression_t1, m)\n","          correlation, _ = spearmanr(expression_t2, pred_expression_t2)\n","          validation_corr_list.append(correlation)\n","        validation_param[(rank, epsilon, tau)] = np.mean(np.array(validation_corr_list))\n","        del ot_prob\n","        del solver\n","        del ot_lr\n","        del q\n","        del g\n","        del r\n","  rank, epsilon, tau = max(validation_param, key=validation_param.get)\n","  print(\"The best parameter combination is: \", (rank, epsilon, tau))\n","  print(\"The best validation spearman correlation is: \", validation_param[(rank, epsilon, tau)])\n","\n","  # Test\n","  ot_prob = quadratic_problem.QuadraticProblem(geom_xx=geom_xx,geom_yy=geom_yy,geom_xy=geom_xy, tau_a=tau, tau_b=tau)\n","  solver = jax.jit(gromov_wasserstein_lr.LRGromovWasserstein(rank=rank, epsilon=epsilon))\n","  ot_lr = solver(ot_prob)\n","  q, g, r, = ot_lr.q, ot_lr.g, ot_lr.r\n","\n","  # Pearson coorelation\n","  test_corr_list = []\n","  for gene in test_gene_list:\n","    expression_t1 = np.array(slice1_marker_expr[gene])\n","    expression_t2 = np.array(slice2_marker_expr[gene])\n","    # pred_expression_t2 = m * r @ np.diag(1/g) @ q.T @ expression_t1\n","    pred_expression_t2 = pred_expression(q, g, r, expression_t1, m)\n","    correlation, _ = spearmanr(expression_t2, pred_expression_t2)\n","    test_corr_list.append(correlation)\n","  print(\"The test spearman correlation is: \", np.mean(np.array(test_corr_list)))\n","\n","  # Clustering prediction\n","  # Instantiate the LabelBinarizer\n","  lb = LabelBinarizer()\n","  # Perform one-hot encoding\n","  slice1_label_onehot = lb.fit_transform(slice1_types) # one_hot_encoded_matrix, lb.classes_\n","  print(slice1_label_onehot)\n","  print(lb.classes_)\n","  # pred_slice2_label_onehot = m * r @ np.diag(1/g) @ q.T @ slice1_label_onehot\n","  pred_slice2_label_onehot = pred_type(q, g, r, slice1_label_onehot, m)\n","  # Finding the index of the max value in each row\n","  pred_slice2_label_index = np.argmax(pred_slice2_label_onehot, axis=1)\n","  pred_slice2_label = [lb.classes_[index] for index in pred_slice2_label_index]\n","  ari = adjusted_rand_score(slice2_types, pred_slice2_label)\n","  ami = adjusted_mutual_info_score(slice2_types, pred_slice2_label)\n","  print(\"The ARI is: \", ari)\n","  print(\"The AMI is: \", ami)\n","\n","\n","test_gene_list = ['Tubb2b', 'Pantr1', 'Actc1', 'Tnni1', 'Afp', 'Hbb-bh1', 'Fez1', 'Crabp1', 'Crabp2', 'Col3a1']\n","validation_gene_list = ['Ckb', 'Fabp7', 'Myl4', 'Tnnt2', 'Apoa2', 'Hba-x', 'Tubb3', 'Epha7', 'Ldha', 'Col1a2']\n","epsilon_list = [0.001, 0.01, 0.1]\n","rank_list = [50, 100, 200]\n","tau_list = [0.99, 0.9, 0.7]\n","\n","LRFGW(epsilon_list, rank_list, tau_list, validation_gene_list, test_gene_list)"],"metadata":{"id":"pPBCmVhoBQpy"},"execution_count":null,"outputs":[]}]}