{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "J8Z4eOQybIJr"
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "from typing import List, Dict, Tuple\n",
    "from scipy.sparse import csr_matrix\n",
    "from sklearn import metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Rmu7J-PFbMMr"
   },
   "outputs": [],
   "source": [
    "from FRLC import FRLC_iteration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3hI1olmbs6Bd"
   },
   "outputs": [],
   "source": [
    "# Create heat kernel matrix from precomputed eigenvalues/tangent_vectors\n",
    "def heat_kernel(lam,phi,t):\n",
    "    # Input: eigenvalues and eigenvectors for normalized Laplacian, time parameter t\n",
    "    # Output: heat kernel matrix\n",
    "\n",
    "    u = np.matmul(phi,np.matmul(np.diag(np.exp(-t*lam)),phi.T))\n",
    "\n",
    "    return u\n",
    "\n",
    "def directed_heat_kernel(G,t):\n",
    "    # Input: DiGraph G and time parameter t\n",
    "    # Output: heat kernel matrix\n",
    "    # Automatically computes directed laplacian matrix and then exponentiates\n",
    "\n",
    "    L = np.asarray(nx.directed_laplacian_matrix(G))\n",
    "    lam, phi = np.linalg.eigh(L)\n",
    "    return heat_kernel(lam,phi,t)\n",
    "\n",
    "def undirected_heat_kernel(G,t):\n",
    "    # Input: Graph G and time parameter t\n",
    "    # Output: heat kernel matrix\n",
    "    # Automatically computes directed laplacian matrix and then exponentiates\n",
    "\n",
    "    L = nx.laplacian_matrix(G).toarray()\n",
    "    lam, phi = np.linalg.eigh(L)\n",
    "    return heat_kernel(lam,phi,t)\n",
    "\n",
    "def undirected_normalized_heat_kernel(G,t):\n",
    "    # Input: Graph G and time parameter t\n",
    "    # Output: heat kernel matrix\n",
    "    # Automatically computes directed laplacian matrix and then exponentiates\n",
    "\n",
    "    L = nx.normalized_laplacian_matrix(G).toarray()\n",
    "    lam, phi = np.linalg.eigh(L)\n",
    "    return heat_kernel(lam,phi,t)\n",
    "\n",
    "def estimate_target_distribution(probs: Dict, dim_t: int = 2) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Estimate target distribution via the average of sorted source probabilities\n",
    "    Args:\n",
    "        probs: a dictionary of graphs {key: graph idx,\n",
    "                                       value: (n_s, 1) the distribution of source nodes}\n",
    "        dim_t: the dimension of target distribution\n",
    "    Returns:\n",
    "        p_t: (dim_t, 1) vector representing a distribution\n",
    "    \"\"\"\n",
    "    p_t = np.zeros((dim_t, ))\n",
    "    x_t = np.linspace(0, 1, p_t.shape[0])\n",
    "    for n in probs.keys():\n",
    "        p_s = probs[n]\n",
    "        p_s = np.sort(p_s)[::-1]\n",
    "        x_s = np.linspace(0, 1, p_s.shape[0])\n",
    "        p_t_n = np.interp(x_t, x_s, p_s)\n",
    "        p_t += p_t_n\n",
    "    p_t /= np.sum(p_t)\n",
    "    return p_t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vt61kHuybMO_"
   },
   "outputs": [],
   "source": [
    "def Ours(cost,database,num_partitions,rank):\n",
    "    cost = cost.astype(np.float64)\n",
    "    p_s = np.sum(cost, axis=1) ** 0.001\n",
    "    p_s /= np.sum(p_s)\n",
    "\n",
    "    p_t = estimate_target_distribution({0: p_s}, dim_t=num_partitions)\n",
    "    cost_t = np.diag(p_t)\n",
    "\n",
    "    A = torch.from_numpy(cost)\n",
    "    B = torch.from_numpy(cost_t)\n",
    "    # norm = torch.max(torch.tensor([torch.max(A), torch.max(B)]))\n",
    "    # A, B = A/norm, B/norm\n",
    "\n",
    "    mutual_info_list = []\n",
    "    for _ in range(10):\n",
    "        Pi, errs = FRLC_iteration(torch.zeros(cost.shape[0], cost_t.shape[0]), A=A, B=B, device=device, tau=0.01, gamma=10, r=rank, max_iter=1000, semiRelaxedRight=True, semiRelaxedLeft=False, Wasserstein=False\n",
    "                                )\n",
    "        Pi = Pi.cpu().numpy()\n",
    "\n",
    "        est_idx = np.argmax(Pi, axis=1)\n",
    "\n",
    "        mutual_info = metrics.adjusted_mutual_info_score(database['labels'], est_idx)\n",
    "        mutual_info_list.append(mutual_info)\n",
    "\n",
    "    return np.mean(mutual_info)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZosrkAQnDdHB"
   },
   "source": [
    "# Wikipedia"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "o06m4TOtGq5e"
   },
   "outputs": [],
   "source": [
    "# load data\n",
    "f = open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/wikicats.p', 'rb')\n",
    "database = pickle.load(f)\n",
    "f.close()\n",
    "dG = database['G']\n",
    "labels = database['labels']\n",
    "num_nodes = dG.number_of_nodes()\n",
    "num_partitions = len(np.unique(labels))\n",
    "\n",
    "\n",
    "G = dG.to_undirected()\n",
    "\n",
    "# Load precomputed noisy version\n",
    "save_name = \"/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/wiki_sym_noise.txt\"\n",
    "\n",
    "with open(save_name, \"rb\") as fp:\n",
    "    nG = pickle.load(fp)\n",
    "\n",
    "save_name = \"/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/wiki_asym_noise.txt\"\n",
    "\n",
    "with open(save_name, \"rb\") as fp:\n",
    "    ndG = pickle.load(fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CU9dlRcza00u"
   },
   "outputs": [],
   "source": [
    "rank = 15\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: Ours, symmetrized\n",
    "###########################################################\n",
    "print(\"Method: Ours, symmetrized\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(G).toarray()\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(nG).toarray()\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: Ours, asymmetric\n",
    "###########################################################\n",
    "print(\"Method: Ours, asymmetric\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(dG).toarray()\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(ndG).toarray()\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "soemA3z-a1oN"
   },
   "outputs": [],
   "source": [
    "t = 10\n",
    "rank = 15\n",
    "# ###########################################################\n",
    "# ###########################################################\n",
    "# # Method: SpecOurs, symmetrized\n",
    "# ###########################################################\n",
    "# print(\"Method: SpecOurs, symmetrized\")\n",
    "# # Raw\n",
    "# cost = undirected_normalized_heat_kernel(G, t)\n",
    "# mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "# print(mutual_info)\n",
    "\n",
    "# # Noisy\n",
    "# cost = undirected_normalized_heat_kernel(nG, t)\n",
    "# mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "# print(mutual_info)\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: SpecOurs, asymmetric\n",
    "###########################################################\n",
    "print(\"Method: SpecOurs, asymmetric\")\n",
    "# Raw\n",
    "cost = directed_heat_kernel(dG, t)\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = directed_heat_kernel(ndG, t)\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "p8EZBzgEIHQJ"
   },
   "source": [
    "# Amazon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xtM0QV5TIIC9"
   },
   "outputs": [],
   "source": [
    "# load data\n",
    "f = open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/amazon.p', 'rb')\n",
    "database = pickle.load(f)\n",
    "f.close()\n",
    "G = database['G']\n",
    "labels = database['labels']\n",
    "num_nodes = G.number_of_nodes()\n",
    "num_partitions = len(np.unique(labels))\n",
    "\n",
    "\n",
    "# Load precomputed noisy version\n",
    "save_name = \"/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/amazon_noise.txt\"\n",
    "\n",
    "with open(save_name, \"rb\") as fp:\n",
    "    nG = pickle.load(fp)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "vhWUpU6WIi8j"
   },
   "outputs": [],
   "source": [
    "rank = 12\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: Ours, symmetrized\n",
    "###########################################################\n",
    "print(\"Method: Ours, symmetrized\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(G).toarray()\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(nG).toarray()\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wvMsh4-erfZh"
   },
   "outputs": [],
   "source": [
    "t=10\n",
    "rank=12\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: SpecOurs, symmetrized\n",
    "###########################################################\n",
    "print(\"Method: SpecOurs, symmetrized\")\n",
    "# Raw\n",
    "cost = undirected_normalized_heat_kernel(G, t)\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = undirected_normalized_heat_kernel(G, t)\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "g8Tc6mv-Jok8"
   },
   "source": [
    "# Village"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uuh1vcLnJqkz"
   },
   "outputs": [],
   "source": [
    "# load data\n",
    "num_nodes = 1991\n",
    "num_partitions = 12\n",
    "\n",
    "with open('/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/India_database.p', 'rb') as f:\n",
    "    database = pickle.load(f)\n",
    "G = nx.Graph()\n",
    "nG = nx.Graph()\n",
    "for i in range(num_nodes):\n",
    "    G.add_node(i)\n",
    "    nG.add_node(i)\n",
    "for edge in database['edges']:\n",
    "    G.add_edge(edge[0], edge[1])\n",
    "    nG.add_edge(edge[0], edge[1])\n",
    "\n",
    "start_edges = nx.number_of_edges(G)\n",
    "\n",
    "\n",
    "# Load precomputed noisy version\n",
    "save_name = \"/content/drive/MyDrive/Research/SemiRelaxedLowRank/data/village_noise.txt\"\n",
    "\n",
    "with open(save_name, \"rb\") as fp:\n",
    "    nG = pickle.load(fp)\n",
    "\n",
    "database['labels'] = database['label']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "I8rUSMrxKKro"
   },
   "outputs": [],
   "source": [
    "rank = 12\n",
    "\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: Ours, symmetrized\n",
    "###########################################################\n",
    "print(\"Method: Ours, symmetrized\")\n",
    "# Raw\n",
    "cost = nx.adjacency_matrix(G).toarray()\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = nx.adjacency_matrix(nG).toarray()\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zBhdzK7wv81E"
   },
   "outputs": [],
   "source": [
    "t=10\n",
    "###########################################################\n",
    "###########################################################\n",
    "# Method: SpecOurs, symmetrized\n",
    "###########################################################\n",
    "print(\"Method: SpecOurs, symmetrized\")\n",
    "# Raw\n",
    "cost = undirected_normalized_heat_kernel(G, t)\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)\n",
    "\n",
    "# Noisy\n",
    "cost = undirected_normalized_heat_kernel(G, t)\n",
    "mutual_info = Ours(cost,database,num_partitions,rank=rank)\n",
    "print(mutual_info)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
