{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f1758d89",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.cluster.hierarchy import dendrogram, linkage\n",
    "import scipy.spatial.distance as ssd\n",
    "import random\n",
    "import pickle\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ced020f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils\n",
    "import treerep\n",
    "import hccfit\n",
    "import rootedtreefit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bae3629a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import distance matrix for Celegan and CSphd\n",
    "D_celegans = pickle.load(open('dataset/D_celegan.pkl', 'rb'))\n",
    "D_csphd = pickle.load(open('dataset/D_csphd.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "76887232",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute distance matrix for CORA\n",
    "cora = pickle.load(open('dataset/cora/ind.cora.graph', 'rb'))\n",
    "# For CORA, we need to do this\n",
    "G_cora = nx.from_dict_of_lists(cora)\n",
    "li = [G_cora.subgraph(c) for c in nx.connected_components(G_cora)]\n",
    "connected_G = nx.Graph(li[0])\n",
    "connected_G.remove_edges_from(nx.selfloop_edges(connected_G))\n",
    "D_cora = nx.floyd_warshall_numpy(connected_G)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d1f88892",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute distance matrix for Airport\n",
    "airport = pickle.load(open('dataset/airport/airport.p', 'rb'))\n",
    "# For Airport, we need to do this\n",
    "li = [airport.subgraph(c) for c in nx.connected_components(airport)]\n",
    "connected_G = nx.Graph(li[0])\n",
    "connected_G.remove_edges_from(nx.selfloop_edges(connected_G))\n",
    "D_airport = nx.floyd_warshall_numpy(connected_G)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e0a8a6a0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------\n",
      "0.893 0.200 0.594 1.235\n",
      "----------------\n",
      "6.125 0.839 5.000 8.000\n",
      "----------------\n",
      "0.078 0.007 0.066 0.092\n"
     ]
    }
   ],
   "source": [
    "D = D_celegans\n",
    "N = D.shape[0]\n",
    "# TreeRep\n",
    "# D = D.astype('float64')\n",
    "n_trials = 10\n",
    "d_max = D.max()\n",
    "TR_error = []\n",
    "TR_ellinf_error = []\n",
    "TR_time = []\n",
    "for n_seed in range(n_trials):\n",
    "#     np.random.seed(n_seed) # If we need to fix the seed\n",
    "    start = time.time()\n",
    "    T = treerep.TreeRep(D)\n",
    "    T.learn_tree()\n",
    "    for e in T.G.edges():\n",
    "        if(T.G[e[0]][e[1]]['weight'] < 0):\n",
    "            T.G[e[0]][e[1]]['weight'] = 0\n",
    "    end = time.time()\n",
    "    D_T = np.zeros((N,N))\n",
    "    p = dict(nx.shortest_path_length(T.G, method = 'dijkstra', weight = 'weight'))\n",
    "    for i in range(N):\n",
    "        for j in range(N):\n",
    "            D_T[i][j] = p[i][j]\n",
    "    TR_error.append(np.sum(abs(D_T - D)) / (N*(N-1)))\n",
    "    TR_ellinf_error.append(np.max(abs(D_T - D)))\n",
    "    TR_time.append(end - start)\n",
    "\n",
    "print('----------------')\n",
    "print(\"{:.3f}\".format(np.mean(TR_error)), \"{:.3f}\".format(np.std(TR_error)), \"{:.3f}\".format(np.min(TR_error)), \"{:.3f}\".format(np.max(TR_error)))\n",
    "print('----------------')\n",
    "print(\"{:.3f}\".format(np.mean(TR_ellinf_error)),\"{:.3f}\".format(np.std(TR_ellinf_error)), \"{:.3f}\".format(np.min(TR_ellinf_error)), \"{:.3f}\".format(np.max(TR_ellinf_error)))\n",
    "print('----------------')\n",
    "print(\"{:.3f}\".format(np.mean(TR_time)),\"{:.3f}\".format(np.std(TR_time)), \"{:.3f}\".format(np.min(TR_time)), \"{:.3f}\".format(np.max(TR_time)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cf40ef02",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------\n",
      "1.158 0.023 1.111 1.211\n",
      "----------------\n",
      "3.400 0.490 3.000 4.000\n",
      "----------------\n",
      "0.066 0.016 0.058 0.114\n"
     ]
    }
   ],
   "source": [
    "D = D_celegans\n",
    "N = D.shape[0]\n",
    "# Gromov\n",
    "# D = D.astype('float64')\n",
    "n_trials = 10\n",
    "GR_error = []\n",
    "GR_ellinf_error = []\n",
    "GR_time = []\n",
    "for n_seed in range(n_trials):\n",
    "#     np.random.seed(n_seed) # If we need to fix the seed\n",
    "    start = time.time()\n",
    "    pivot_idx = np.random.randint(N)\n",
    "    RT = rootedtreefit.RootedTreeFit(D)\n",
    "    RT.fit_treeM(pivot_idx = pivot_idx, method = 'gromov')\n",
    "    D_T = RT.d_T\n",
    "    end = time.time()\n",
    "    GR_error.append(np.sum(abs(D_T - D)) / (N*(N-1)))\n",
    "    GR_ellinf_error.append(np.max(abs(D_T - D)))\n",
    "    GR_time.append(end - start)\n",
    "print('----------------')\n",
    "print(\"{:.3f}\".format(np.mean(GR_error)), \"{:.3f}\".format(np.std(GR_error)), \"{:.3f}\".format(np.min(GR_error)), \"{:.3f}\".format(np.max(GR_error)))\n",
    "print('----------------')\n",
    "print(\"{:.3f}\".format(np.mean(GR_ellinf_error)),\"{:.3f}\".format(np.std(GR_ellinf_error)), \"{:.3f}\".format(np.min(GR_ellinf_error)), \"{:.3f}\".format(np.max(GR_ellinf_error)))\n",
    "print('----------------')\n",
    "print(\"{:.3f}\".format(np.mean(GR_time)),\"{:.3f}\".format(np.std(GR_time)), \"{:.3f}\".format(np.min(GR_time)), \"{:.3f}\".format(np.max(GR_time)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3d9bec2d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------\n",
      "0.853 0.238 0.339 1.225\n",
      "----------------\n",
      "4.200 0.400 4.000 5.000\n",
      "----------------\n",
      "0.690 0.018 0.660 0.709\n"
     ]
    }
   ],
   "source": [
    "D = D_celegans\n",
    "N = D.shape[0]\n",
    "# HCC\n",
    "# D = D.astype('float64')\n",
    "n_trials = 10\n",
    "HCC_error = []\n",
    "HCC_ellinf_error = []\n",
    "HCC_time = []\n",
    "for n_seed in range(n_trials):\n",
    "#     np.random.seed(n_seed)\n",
    "    start = time.time()\n",
    "    pivot_idx = np.random.randint(N)\n",
    "    RT = rootedtreefit.RootedTreeFit(D)\n",
    "    RT.fit_treeM(pivot_idx = pivot_idx, method = 'hcc')\n",
    "    D_T = RT.d_T\n",
    "    end = time.time()\n",
    "    HCC_error.append(np.sum(abs(D_T - D)) / (N*(N-1)))\n",
    "    HCC_ellinf_error.append(np.max(abs(D_T - D)))\n",
    "    HCC_time.append(end - start)\n",
    "print('----------------')\n",
    "print(\"{:.3f}\".format(np.mean(HCC_error)), \"{:.3f}\".format(np.std(HCC_error)), \"{:.3f}\".format(np.min(HCC_error)), \"{:.3f}\".format(np.max(HCC_error)))\n",
    "print('----------------')\n",
    "print(\"{:.3f}\".format(np.mean(HCC_ellinf_error)),\"{:.3f}\".format(np.std(HCC_ellinf_error)), \"{:.3f}\".format(np.min(HCC_ellinf_error)), \"{:.3f}\".format(np.max(HCC_ellinf_error)))\n",
    "print('----------------')\n",
    "print(\"{:.3f}\".format(np.mean(HCC_time)),\"{:.3f}\".format(np.std(HCC_time)), \"{:.3f}\".format(np.min(HCC_time)), \"{:.3f}\".format(np.max(HCC_time)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f1790fbe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time takes  1.0201072692871094\n",
      "ell_1 norm =  0.298\n",
      "ell_infty norm =  2.974\n"
     ]
    }
   ],
   "source": [
    "D = D_celegans\n",
    "tree = utils.NeighborJoin(D)\n",
    "for e in tree.edges():\n",
    "    if(tree[e[0]][e[1]]['weight'] < 0):\n",
    "        # print(e[0], e[1])\n",
    "        tree[e[0]][e[1]]['weight'] = 0\n",
    "        print('negative edge has been found')\n",
    "print(\"Time takes \", time.time() - start)\n",
    "length = dict(nx.all_pairs_dijkstra_path_length(tree))\n",
    "N = len(D)\n",
    "n_tree = len(tree)\n",
    "D_tree = np.zeros((N,N))\n",
    "for i in range(N):\n",
    "    for j in range(N):\n",
    "        D_tree[i][j] = length[i][j]\n",
    "        \n",
    "# print(D_tree)\n",
    "print(\"ell_1 norm = \", \"{:.3f}\".format(np.sum(abs(D_tree - D)) / (N*(N-1))))\n",
    "print(\"ell_infty norm = \", \"{:.3f}\".format(np.max(abs(D_tree - D))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ds_nx_torch_notebook",
   "language": "python",
   "name": "ds_nx_torch_notebook"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
