{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "15dc5963",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This file explores how Algorithm 4 (MMJ distance by Calculation and Copy) can be accelerated\n",
    "# by parallel computing, which is discussed in Section 6.2 (Using parallel programming).\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2f5578c2",
   "metadata": {},
   "outputs": [],
   "source": [
    " \n",
    "import time\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from sklearn.metrics import pairwise_distances\n",
    "import networkx as nx\n",
    "import sys\n",
    "from joblib import Parallel, delayed\n",
    "import random\n",
    "from numba import njit\n",
    "import threading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6111765b",
   "metadata": {},
   "outputs": [],
   "source": [
    " \n",
    "test_data_145 = pickle.load(  open( \"./data/test_data_145.p\", \"rb\" ) ) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "087e1fd2",
   "metadata": {},
   "outputs": [],
   "source": [
    " \n",
    "\n",
    "@njit(parallel=True, cache=True, fastmath=True)\n",
    "def primMST_numba(distance_matrix):\n",
    "    V = distance_matrix.shape[0]\n",
    "    key = np.full(V, np.inf)\n",
    "    parent = -np.ones(V, dtype=np.int64)\n",
    "    inMST = np.zeros(V, dtype=np.bool_)\n",
    "    key[0] = 0.0\n",
    "    \n",
    "    for _ in range(V):\n",
    "    \n",
    "        u = -1\n",
    "        min_val = np.inf\n",
    "        for v in range(V):\n",
    "            if (not inMST[v]) and (key[v] < min_val):\n",
    "                min_val = key[v]\n",
    "                u = v\n",
    "        if u == -1:\n",
    "            break\n",
    "        inMST[u] = True\n",
    "        for v in range(V):\n",
    "        \n",
    "            if (distance_matrix[u, v] > 0) and (not inMST[v]) and (distance_matrix[u, v] < key[v]):\n",
    "                key[v] = distance_matrix[u, v]\n",
    "                parent[v] = u\n",
    "    return parent\n",
    "\n",
    "def construct_MST_from_graph(distance_matrix):\n",
    "    V = distance_matrix.shape[0]\n",
    "\n",
    "    parent = primMST_numba(distance_matrix)\n",
    "    \n",
    "    MST = nx.Graph()\n",
    "    for i in range(V):\n",
    "        MST.add_node(i)\n",
    "\n",
    "    for i in range(1, V):\n",
    "        MST.add_edge(parent[i], i, weight=distance_matrix[i, parent[i]])\n",
    "    return MST\n",
    "\n",
    "\n",
    "\n",
    "def cal_mmj_matrix_by_algo_4_Calculation_and_Copy_parallel_compu(X, round_n=15, n_jobs=-1):\n",
    "    global mmj_matrix  \n",
    "    lenX = len(X)\n",
    "    \n",
    "    distance_matrix = np.round(pairwise_distances(X), round_n)\n",
    "    \n",
    "    mmj_matrix = np.zeros((lenX, lenX)) \n",
    " \n",
    "   \n",
    "\n",
    "    MST = construct_MST_from_graph(distance_matrix)\n",
    "    MST_edge_list = list(MST.edges(data='weight'))\n",
    "\n",
    "    edge_node_list = [(edge[0], edge[1]) for edge in MST_edge_list]\n",
    "    edge_weight_list = [edge[2] for edge in MST_edge_list]\n",
    "    edge_large_to_small_arg = np.argsort(edge_weight_list)[::-1]\n",
    "    edge_weight_large_to_small = np.sort(edge_weight_list)[::-1]\n",
    "    edge_nodes_large_to_small = [edge_node_list[i] for i in edge_large_to_small_arg]\n",
    "    \n",
    "    MST_list = []\n",
    "    MST_list.append(MST)\n",
    "    \n",
    "    for j in range(num_MST):\n",
    "        MST_copy = MST.copy(as_view=False)\n",
    "        MST_list.append(MST_copy)\n",
    "\n",
    "  \n",
    "    lock = threading.Lock() \n",
    "\n",
    "    N = lenX - 1\n",
    "    W = min(num_W, N)\n",
    "\n",
    "    Parallel(n_jobs=n_jobs, backend=\"threading\")(\n",
    "        delayed(for_parallel_compu)(i, MST_list, edge_nodes_large_to_small, edge_weight_large_to_small, lock)\n",
    "        for i in range(W)\n",
    "    )\n",
    "    \n",
    "    #This part can be allocated to another thread, with a copy of MST.\n",
    "    for kk in range(W):\n",
    "        edge_nodes = edge_nodes_large_to_small[kk]\n",
    "        MST.remove_edge(*edge_nodes)       \n",
    "    for i in range(W, N):\n",
    "        not_for_parallel_compu(i, MST, edge_nodes_large_to_small, edge_weight_large_to_small)\n",
    "    \n",
    " \n",
    " \n",
    "    return mmj_matrix\n",
    " \n",
    "\n",
    "\n",
    "def not_for_parallel_compu(i, MST, edge_nodes_large_to_small, edge_weight_large_to_small):\n",
    "    global mmj_matrix  \n",
    "    edge_nodes = edge_nodes_large_to_small[i]\n",
    "    MST.remove_edge(*edge_nodes)\n",
    "    \n",
    "    edge_weight = edge_weight_large_to_small[i]\n",
    "\n",
    "    tree1_nodes = list(nx.dfs_preorder_nodes(MST, source=edge_nodes[0]))\n",
    "    tree2_nodes = list(nx.dfs_preorder_nodes(MST, source=edge_nodes[1]))\n",
    " \n",
    "    idx1, idx2 = np.meshgrid(tree1_nodes, tree2_nodes, indexing=\"ij\")\n",
    "    mmj_matrix[idx1, idx2] = mmj_matrix[idx2, idx1] = edge_weight\n",
    "\n",
    " \n",
    "\n",
    "def for_parallel_compu(i, MST_list, edge_nodes_large_to_small, edge_weight_large_to_small, lock):\n",
    "    global mmj_matrix\n",
    "    \n",
    "    rand_int = random.randint(0, num_MST)\n",
    "    \n",
    "\n",
    "    \n",
    "    MST_temp = MST_list[rand_int]\n",
    " \n",
    "\n",
    "    with lock:\n",
    "        removed_edges = []\n",
    "        for kk in range(i + 1):\n",
    "            edge_nodes = edge_nodes_large_to_small[kk]\n",
    "            MST_temp.remove_edge(*edge_nodes)\n",
    "            removed_edges.append(edge_nodes)\n",
    "        \n",
    "        edge_weight = edge_weight_large_to_small[i]\n",
    "\n",
    "        tree1_nodes = list(nx.dfs_preorder_nodes(MST_temp, source=edge_nodes[0]))\n",
    "        tree2_nodes = list(nx.dfs_preorder_nodes(MST_temp, source=edge_nodes[1]))\n",
    "\n",
    "    \n",
    "        for u, v in removed_edges:\n",
    "            MST_temp.add_edge(u, v, weight=1.0)\n",
    "\n",
    "\n",
    "    idx1, idx2 = np.meshgrid(tree1_nodes, tree2_nodes, indexing=\"ij\")\n",
    "    mmj_matrix[idx1, idx2] = mmj_matrix[idx2, idx1] = edge_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "78f8e900",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of points in data X: 10000\n"
     ]
    }
   ],
   "source": [
    " \n",
    "data_id = 136\n",
    "X = test_data_145[data_id]\n",
    "\n",
    "print(f\"Number of points in data X: {len(X)}\" )\n",
    "\n",
    "\n",
    "num_MST = 1\n",
    "\n",
    "num_W =  250\n",
    "\n",
    "mmj_matrix = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8c286952",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time used: 4.072s\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "X_mmj_matrix_algo_4_parallel_compu = cal_mmj_matrix_by_algo_4_Calculation_and_Copy_parallel_compu(X)\n",
    "end = time.time()\n",
    "time_used = end - start\n",
    "time_used = np.round(time_used, 3)\n",
    "\n",
    "print(f\"Time used: {time_used}s\" )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa578591",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
