{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "55760f17",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.cluster.hierarchy import dendrogram, linkage\n",
    "import scipy.spatial.distance as ssd\n",
    "import random\n",
    "import pickle\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e8d755bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils\n",
    "import treerep\n",
    "import hccfit\n",
    "import rootedtreefit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b0587f75",
   "metadata": {},
   "outputs": [],
   "source": [
    "def job(G, D_original, l_steps = 10, n_seed = -1, n_edges = 500, u_delta = 0.1, method = 'hcc'):\n",
    "    tree = G.copy()\n",
    "    start_time = time.time()\n",
    "    N = D_original.shape[0]\n",
    "#     If n_seed = -1, then we won't fix the seed. If not, \n",
    "    if(n_seed >= 0):\n",
    "        np.random.seed(n_seed)\n",
    "    edge_list = []\n",
    "    while(len(edge_list) < n_edges):\n",
    "        pi = np.random.permutation(N)\n",
    "        i, j = pi[0], pi[1]\n",
    "        if ((i,j) not in edge_list) and ((j,i) not in edge_list) and (D_original[i,j] > 2.0):\n",
    "            edge_list.append((i,j))\n",
    "    error_tape = []\n",
    "    for idx in range(n_edges):\n",
    "        i = edge_list[idx][0]\n",
    "        j = edge_list[idx][1]\n",
    "        tree.add_edge(i, j, weight = D_original[i,j] - 2 * u_delta)\n",
    "        if(idx%l_steps == l_steps - 1):\n",
    "            d = np.zeros((N,N))\n",
    "            p = dict(nx.shortest_path_length(tree, method = 'bellman-ford', weight = 'weight'))\n",
    "#             end = time.time()\n",
    "            #     print(\"Dijkstra takes \", end - start, \" secs\")\n",
    "            for x in range(N):\n",
    "                for y in range(N):\n",
    "                    d[x][y] = p[x][y]\n",
    "#             print(\"d computed\")\n",
    "            error = []\n",
    "            if(method == 'TR'):\n",
    "                for tr_seed in range(50):\n",
    "                    np.random.seed(tr_seed)\n",
    "                    T = treerep.TreeRep(d)\n",
    "                    T.learn_tree()\n",
    "                    for e in T.G.edges():\n",
    "                        if(T.G[e[0]][e[1]]['weight'] < 0):\n",
    "                            # print(e[0], e[1])\n",
    "                            T.G[e[0]][e[1]]['weight'] = 0\n",
    "                        # T.G[e[0]][e[1]]['weight'] = self.W[e[0],e[1]]\n",
    "                    d_T = np.zeros((N,N))\n",
    "                    start = time.time()\n",
    "                    p = dict(nx.shortest_path_length(T.G, method = 'bellman-ford', weight = 'weight'))\n",
    "                    end = time.time()\n",
    "                    #     print(\"Dijkstra takes \", end - start, \" secs\")\n",
    "                    for x in range(N):\n",
    "                        for y in range(N):\n",
    "                            d_T[x][y] = p[x][y]\n",
    "                    error.append(np.sum(abs(d_T - d)) / (N*(N-1)))\n",
    "            elif(method == 'hcc'):\n",
    "                RT = rootedtreefit.RootedTreeFit(d)\n",
    "                RT.fit_treeM(pivot_idx = 0, method = 'hcc')\n",
    "                d_T = RT.d_T\n",
    "                error.append(np.sum(abs(d_T - d)) / (N*(N-1)))\n",
    "            elif(method == 'gromov'):\n",
    "                RT = rootedtreefit.RootedTreeFit(d)\n",
    "                RT.fit_treeM(pivot_idx = 0, method = 'gromov')\n",
    "                d_T = RT.d_T\n",
    "                error.append(np.sum(abs(d_T - d)) / (N*(N-1)))\n",
    "            elif(method == 'average'):\n",
    "                RT = rootedtreefit.RootedTreeFit(d)\n",
    "                RT.fit_treeM(pivot_idx = 0, method = 'average')\n",
    "                d_T = RT.d_T\n",
    "                error.append(np.sum(abs(d_T - d)) / (N*(N-1)))\n",
    "            elif(method == 'complete'):\n",
    "                RT = rootedtreefit.RootedTreeFit(d)\n",
    "                RT.fit_treeM(pivot_idx = 0, method = 'complete')\n",
    "                d_T = RT.d_T\n",
    "                error.append(np.sum(abs(d_T - d)) / (N*(N-1)))\n",
    "            elif(method == 'NJ'):\n",
    "                NJtree = utils.NeighborJoin(d)\n",
    "                for e in NJtree.edges():\n",
    "                    if(NJtree[e[0]][e[1]]['weight'] < 0):\n",
    "                        # print(e[0], e[1])\n",
    "                        NJtree[e[0]][e[1]]['weight'] = 0\n",
    "#                         print('negative edge has been found')\n",
    "#                 print(\"Time takes \", time.time() - start)\n",
    "                length = dict(nx.all_pairs_dijkstra_path_length(NJtree))\n",
    "                n_tree = len(NJtree)\n",
    "                d_T = np.zeros((N,N))\n",
    "                for i in range(N):\n",
    "                    for j in range(N):\n",
    "                        d_T[i][j] = length[i][j]\n",
    "                error.append(np.sum(abs(d_T - d)) / (N*(N-1)))\n",
    "            error_tape.append(error)\n",
    "    return error_tape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "da326920",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inputs we have used\n",
    "b_tree_8_3 = nx.balanced_tree(r=8, h=3)\n",
    "b_tree_5_4 = nx.balanced_tree(r=5, h=4)\n",
    "b_tree_3_5 = nx.balanced_tree(r=3, h=5)\n",
    "b_tree_2_8 = nx.balanced_tree(r=2, h=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "54c94069",
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "# For disease, use follows. (We have removed header line)\n",
    "# opening the CSV file\n",
    "with open('dataset/disease_lp/disease_lp.edges.csv', mode ='r') as file:\n",
    "    csvFile = csv.reader(file)\n",
    "    disease_lp_edges = []\n",
    "    n_lines = 0\n",
    "    # displaying the contents of the CSV file\n",
    "    for lines in csvFile:\n",
    "#         print(lines)\n",
    "        disease_lp_edges.append(lines)\n",
    "        n_lines += 1\n",
    "disease_lp_G = nx.Graph()\n",
    "for e in disease_lp_edges:\n",
    "    if e[0] not in disease_lp_G.nodes:\n",
    "        disease_lp_G.add_node(e[0])\n",
    "    if e[1] not in disease_lp_G.nodes:\n",
    "        disease_lp_G.add_node(e[0])\n",
    "    disease_lp_G.add_edge(e[0], e[1])\n",
    "disease_lp_G = nx.convert_node_labels_to_integers(disease_lp_G)\n",
    "start = time.time()\n",
    "D_disease_lp = nx.floyd_warshall_numpy(disease_lp_G)\n",
    "end = time.time()\n",
    "# print(\"Time: \", end - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cfaa8403",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.0010537407797681778], [0.0023767708699215565], [0.0029680365296803676], [0.004025289778714439], [0.00543964406978106]]\n"
     ]
    }
   ],
   "source": [
    "D = nx.floyd_warshall_numpy(b_tree_8_3)\n",
    "output = job(b_tree_8_3, D, l_steps = 100, n_seed = -1, n_edges = 500, u_delta = 0.1, method = 'hcc')\n",
    "print(output)\n",
    "# output = job(disease_lp_G, D_disease_lp, l_steps = 100, n_seed = -1, n_edges = 500, u_delta = 0.1, method = 'hcc')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ds_nx_torch_notebook",
   "language": "python",
   "name": "ds_nx_torch_notebook"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
