{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "J8Z4eOQybIJr"
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "from typing import List, Dict, Tuple\n",
    "from scipy.sparse import csr_matrix\n",
    "import ott\n",
    "import jax\n",
    "from ott.geometry import pointcloud\n",
    "from ott.problems.quadratic import quadratic_problem\n",
    "from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr\n",
    "from sklearn import metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JZd3Sz3TFdjP"
   },
   "source": [
    "# Method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3hI1olmbs6Bd"
   },
   "outputs": [],
   "source": [
    "# Create heat kernel matrix from precomputed eigenvalues/tangent_vectors\n",
    "def heat_kernel(lam,phi,t):\n",
    "    # Input: eigenvalues and eigenvectors for normalized Laplacian, time parameter t\n",
    "    # Output: heat kernel matrix\n",
    "\n",
    "    u = np.matmul(phi,np.matmul(np.diag(np.exp(-t*lam)),phi.T))\n",
    "\n",
    "    return u\n",
    "\n",
    "def directed_heat_kernel(G,t):\n",
    "    # Input: DiGraph G and time parameter t\n",
    "    # Output: heat kernel matrix\n",
    "    # Automatically computes directed laplacian matrix and then exponentiates\n",
    "\n",
    "    L = np.asarray(nx.directed_laplacian_matrix(G))\n",
    "    lam, phi = np.linalg.eigh(L)\n",
    "    return heat_kernel(lam,phi,t)\n",
    "\n",
    "def undirected_heat_kernel(G,t):\n",
    "    # Input: Graph G and time parameter t\n",
    "    # Output: heat kernel matrix\n",
    "    # Automatically computes directed laplacian matrix and then exponentiates\n",
    "\n",
    "    L = nx.laplacian_matrix(G).toarray()\n",
    "    lam, phi = np.linalg.eigh(L)\n",
    "    return heat_kernel(lam,phi,t)\n",
    "\n",
    "def undirected_normalized_heat_kernel(G,t):\n",
    "    # Input: Graph G and time parameter t\n",
    "    # Output: heat kernel matrix\n",
    "    # Automatically computes directed laplacian matrix and then exponentiates\n",
    "\n",
    "    L = nx.normalized_laplacian_matrix(G).toarray()\n",
    "    lam, phi = np.linalg.eigh(L)\n",
    "    return heat_kernel(lam,phi,t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vt61kHuybMO_"
   },
   "outputs": [],
   "source": [
    "def estimate_target_distribution(probs: Dict, dim_t: int = 2) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Estimate target distribution via the average of sorted source probabilities\n",
    "    Args:\n",
    "        probs: a dictionary of graphs {key: graph idx,\n",
    "                                       value: (n_s, 1) the distribution of source nodes}\n",
    "        dim_t: the dimension of target distribution\n",
    "    Returns:\n",
    "        p_t: (dim_t, 1) vector representing a distribution\n",
    "    \"\"\"\n",
    "    p_t = np.zeros((dim_t, ))\n",
    "    x_t = np.linspace(0, 1, p_t.shape[0])\n",
    "    for n in probs.keys():\n",
    "        p_s = probs[n]\n",
    "        p_s = np.sort(p_s)[::-1]\n",
    "        x_s = np.linspace(0, 1, p_s.shape[0])\n",
    "        p_t_n = np.interp(x_t, x_s, p_s)\n",
    "        p_t += p_t_n\n",
    "    p_t /= np.sum(p_t)\n",
    "    return p_t\n",
    "\n",
    "\n",
    "def GWL(cost,database,num_partitions,epsilon):\n",
    "    cost = cost.astype(np.float64)\n",
    "    p_s = np.sum(cost, axis=1) ** 0.001\n",
    "    p_s /= np.sum(p_s)\n",
    "\n",
    "    p_t = estimate_target_distribution({0: p_s}, dim_t=num_partitions)\n",
    "    cost_t = np.diag(p_t)\n",
    "\n",
    "    geom_xx = ott.geometry.geometry.Geometry(cost_matrix=cost)\n",
    "    geom_yy = ott.geometry.geometry.Geometry(cost_matrix=cost_t)\n",
    "\n",
    "    prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy, a=p_s, b=p_t)\n",
    "    solver = jax.jit(gromov_wasserstein.GromovWasserstein(epsilon=epsilon))\n",
    "\n",
    "    out = solver(prob)\n",
    "    T = np.array(out.matrix)\n",
    "\n",
    "    # has_converged = bool(out.linear_convergence[out.n_iters - 1])\n",
    "    # print(f\"{out.n_iters} outer iterations were needed.\")\n",
    "    # print(f\"The last Sinkhorn iteration has converged: {has_converged}\")\n",
    "    # print(f\"The outer loop of Gromov Wasserstein has converged: {out.converged}\")\n",
    "    # print(f\"The final regularized GW cost is: {out.reg_gw_cost:.3f}\")\n",
    "\n",
    "    est_idx = np.argmax(T, axis=1)\n",
    "\n",
    "    mutual_info = metrics.adjusted_mutual_info_score(database['labels'], est_idx)\n",
    "\n",
    "    return mutual_info\n",
    "\n",
    "\n",
    "def LRGWL(cost,database,num_partitions,rank):\n",
    "    cost = cost.astype(np.float64)\n",
    "    p_s = np.sum(cost, axis=1) ** 0.001\n",
    "    p_s /= np.sum(p_s)\n",
    "\n",
    "    p_t = estimate_target_distribution({0: p_s}, dim_t=num_partitions)\n",
    "    cost_t = np.diag(p_t)\n",
    "\n",
    "    geom_xx = ott.geometry.geometry.Geometry(cost_matrix=cost)\n",
    "    geom_yy = ott.geometry.geometry.Geometry(cost_matrix=cost_t)\n",
    "\n",
    "    prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy, a=p_s, b=p_t)\n",
    "    solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=rank)\n",
    "\n",
    "    out = solver(prob)\n",
    "    T = np.array(out.matrix)\n",
    "\n",
    "    est_idx = np.argmax(T, axis=1)\n",
    "    mutual_info = metrics.adjusted_mutual_info_score(database['labels'], est_idx)\n",
    "\n",
    "    return mutual_info\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZosrkAQnDdHB"
   },
   "source": [
    "# Wikipedia"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "o06m4TOtGq5e"
   },
   "outputs": [],
   "source": [
    "# load data\n",
    "f = open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/wikicats.p', 'rb')\n",
    "database = pickle.load(f)\n",
    "f.close()\n",
    "dG = database['G']\n",
    "labels = database['labels']\n",
    "num_nodes = dG.number_of_nodes()\n",
    "num_partitions = len(np.unique(labels))\n",
    "\n",
    "\n",
    "G = dG.to_undirected()\n",
    "\n",
    "# Load precomputed noisy version\n",
    "save_name = \"/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/wiki_sym_noise.txt\"\n",
    "\n",
    "with open(save_name, \"rb\") as fp:\n",
    "    nG = pickle.load(fp)\n",
    "\n",
    "save_name = \"/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/wiki_asym_noise.txt\"\n",
    "\n",
    "with open(save_name, \"rb\") as fp:\n",
    "    ndG = pickle.load(fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WB1RKmvzDap8"
   },
   "outputs": [],
   "source": [
    "epsilon=1e-6\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: GWL, symmetrized\n",
    "###########################################################\n",
    "print(\"Method: GWL, symmetrized\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(G).toarray()\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(nG).toarray()\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: GWL, asymmetric\n",
    "###########################################################\n",
    "print(\"Method: GWL, asymmetric\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(dG).toarray()\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(ndG).toarray()\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "print('================')\n",
    "\n",
    "rank = 15\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: LRGWL, symmetrized\n",
    "###########################################################\n",
    "print(\"Method: LRGWL, symmetrized\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(G).toarray()\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(nG).toarray()\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: LRGWL, asymmetric\n",
    "###########################################################\n",
    "print(\"Method: LRGWL, asymmetric\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(dG).toarray()\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(ndG).toarray()\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "print('================')\n",
    "t = 10\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: SpecGWL, symmetrized\n",
    "###########################################################\n",
    "print(\"Method: SpecGWL, symmetrized\")\n",
    "# Raw\n",
    "cost = undirected_normalized_heat_kernel(G, t)\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = undirected_normalized_heat_kernel(nG, t)\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: SpecGWL, asymmetric\n",
    "###########################################################\n",
    "print(\"Method: SpecGWL, asymmetric\")\n",
    "# Raw\n",
    "cost = directed_heat_kernel(dG, t)\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = directed_heat_kernel(ndG, t)\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "print('================')\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: LRSpecGWL, symmetrized\n",
    "###########################################################\n",
    "print(\"Method: LRSpecGWL, symmetrized\")\n",
    "# Raw\n",
    "cost = undirected_normalized_heat_kernel(G, t)\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = undirected_normalized_heat_kernel(nG, t)\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: LRSpecGWL, asymmetric\n",
    "###########################################################\n",
    "print(\"Method: LRSpecGWL, asymmetric\")\n",
    "# Raw\n",
    "cost = directed_heat_kernel(dG, t)\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = directed_heat_kernel(ndG, t)\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hh0gJr19DiCV"
   },
   "source": [
    "# EU-email"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XKk0kKNjDgU9"
   },
   "outputs": [],
   "source": [
    "# load data\n",
    "f = open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/eu-email.p', 'rb')\n",
    "database = pickle.load(f)\n",
    "f.close()\n",
    "dG = database['G']\n",
    "labels = database['labels']\n",
    "num_nodes = dG.number_of_nodes()\n",
    "num_partitions = len(np.unique(labels))\n",
    "\n",
    "\n",
    "G = dG.to_undirected()\n",
    "\n",
    "# Load precomputed noisy version\n",
    "save_name = \"/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/eu_sym_noise.txt\"\n",
    "\n",
    "with open(save_name, \"rb\") as fp:\n",
    "    nG = pickle.load(fp)\n",
    "\n",
    "save_name = \"/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/eu_asym_noise.txt\"\n",
    "\n",
    "with open(save_name, \"rb\") as fp:\n",
    "    ndG = pickle.load(fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mhZ3XlmbDbbO"
   },
   "outputs": [],
   "source": [
    "epsilon=1e-6\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: GWL, symmetrized\n",
    "###########################################################\n",
    "print(\"Method: GWL, symmetrized\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(G).toarray()\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(nG).toarray()\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: GWL, asymmetric\n",
    "###########################################################\n",
    "print(\"Method: GWL, asymmetric\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(dG).toarray()\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(ndG).toarray()\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "print('================')\n",
    "\n",
    "rank = 42\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: LRGWL, symmetrized\n",
    "###########################################################\n",
    "print(\"Method: LRGWL, symmetrized\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(G).toarray()\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(nG).toarray()\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: LRGWL, asymmetric\n",
    "###########################################################\n",
    "print(\"Method: LRGWL, asymmetric\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(dG).toarray()\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(ndG).toarray()\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "print('================')\n",
    "t = 10\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: SpecGWL, symmetrized\n",
    "###########################################################\n",
    "print(\"Method: SpecGWL, symmetrized\")\n",
    "# Raw\n",
    "cost = undirected_normalized_heat_kernel(G, t)\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = undirected_normalized_heat_kernel(nG, t)\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: SpecGWL, asymmetric\n",
    "###########################################################\n",
    "print(\"Method: SpecGWL, asymmetric\")\n",
    "# Raw\n",
    "cost = directed_heat_kernel(dG, t)\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = directed_heat_kernel(ndG, t)\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "print('================')\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: LRSpecGWL, symmetrized\n",
    "###########################################################\n",
    "print(\"Method: LRSpecGWL, symmetrized\")\n",
    "# Raw\n",
    "cost = undirected_normalized_heat_kernel(G, t)\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = undirected_normalized_heat_kernel(nG, t)\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: LRSpecGWL, asymmetric\n",
    "###########################################################\n",
    "print(\"Method: LRSpecGWL, asymmetric\")\n",
    "# Raw\n",
    "cost = directed_heat_kernel(dG, t)\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = directed_heat_kernel(ndG, t)\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "p8EZBzgEIHQJ"
   },
   "source": [
    "# Amazon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xtM0QV5TIIC9"
   },
   "outputs": [],
   "source": [
    "# load data\n",
    "f = open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/amazon.p', 'rb')\n",
    "database = pickle.load(f)\n",
    "f.close()\n",
    "G = database['G']\n",
    "labels = database['labels']\n",
    "num_nodes = G.number_of_nodes()\n",
    "num_partitions = len(np.unique(labels))\n",
    "\n",
    "\n",
    "# Load precomputed noisy version\n",
    "save_name = \"/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/amazon_noise.txt\"\n",
    "\n",
    "with open(save_name, \"rb\") as fp:\n",
    "    nG = pickle.load(fp)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vhWUpU6WIi8j"
   },
   "outputs": [],
   "source": [
    "epsilon=1e-6\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: GWL\n",
    "###########################################################\n",
    "print(\"Method: GWL\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(G).toarray()\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(nG).toarray()\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "print('================')\n",
    "\n",
    "rank = 10\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: LRGWL\n",
    "###########################################################\n",
    "print(\"Method: LRGWL\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(G).toarray()\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(nG).toarray()\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "print('================')\n",
    "t = 10\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: SpecGWL\n",
    "###########################################################\n",
    "print(\"Method: SpecGWL\")\n",
    "# Raw\n",
    "cost = undirected_normalized_heat_kernel(G, t)\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = undirected_normalized_heat_kernel(nG, t)\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "print('================')\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: LRSpecGWL\n",
    "###########################################################\n",
    "print(\"Method: LRSpecGWL\")\n",
    "# Raw\n",
    "cost = undirected_normalized_heat_kernel(G, t)\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = undirected_normalized_heat_kernel(nG, t)\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "g8Tc6mv-Jok8"
   },
   "source": [
    "# Village"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uuh1vcLnJqkz"
   },
   "outputs": [],
   "source": [
    "# load data\n",
    "num_nodes = 1991\n",
    "num_partitions = 12\n",
    "\n",
    "with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/India_database.p', 'rb') as f:\n",
    "    database = pickle.load(f)\n",
    "G = nx.Graph()\n",
    "nG = nx.Graph()\n",
    "for i in range(num_nodes):\n",
    "    G.add_node(i)\n",
    "    nG.add_node(i)\n",
    "for edge in database['edges']:\n",
    "    G.add_edge(edge[0], edge[1])\n",
    "    nG.add_edge(edge[0], edge[1])\n",
    "\n",
    "start_edges = nx.number_of_edges(G)\n",
    "\n",
    "\n",
    "# Load precomputed noisy version\n",
    "save_name = \"/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/village_noise.txt\"\n",
    "\n",
    "with open(save_name, \"rb\") as fp:\n",
    "    nG = pickle.load(fp)\n",
    "\n",
    "database['labels'] = database['label']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "I8rUSMrxKKro"
   },
   "outputs": [],
   "source": [
    "epsilon=1e-6\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: GWL\n",
    "###########################################################\n",
    "print(\"Method: GWL\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(G).toarray()\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(nG).toarray()\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "print('================')\n",
    "\n",
    "rank = 12\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: LRGWL\n",
    "###########################################################\n",
    "print(\"Method: LRGWL\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(G).toarray()\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(nG).toarray()\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "print('================')\n",
    "t = 10\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: SpecGWL\n",
    "###########################################################\n",
    "print(\"Method: SpecGWL\")\n",
    "# Raw\n",
    "cost = undirected_normalized_heat_kernel(G, t)\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = undirected_normalized_heat_kernel(nG, t)\n",
    "mutual_info = GWL(cost,database,num_partitions,epsilon=epsilon)\n",
    "print(mutual_info)\n",
    "\n",
    "print('================')\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: LRSpecGWL\n",
    "###########################################################\n",
    "print(\"Method: LRSpecGWL\")\n",
    "# Raw\n",
    "cost = undirected_normalized_heat_kernel(G, t)\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = undirected_normalized_heat_kernel(nG, t)\n",
    "mutual_info = LRGWL(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
