{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6e4194f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This file introduces a new algorithm for accelerating Algorithm 4 (MMJ distance by Calculation and Copy)\n",
    "# by parallel computing, the new algorithm is called Algorithm 13 (APPD accelerated by parallel computing).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfe7711f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "81c5661d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from sklearn.metrics import pairwise_distances\n",
    "import networkx as nx\n",
    "import sys\n",
    "from joblib import Parallel, delayed\n",
    "import random\n",
    "from numba import njit\n",
    "import threading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d16ead2a",
   "metadata": {},
   "outputs": [],
   "source": [
    " \n",
    "test_data_145 = pickle.load( open( \"./data/test_data_145.p\", \"rb\" ) ) \n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0b1d9482",
   "metadata": {},
   "outputs": [],
   "source": [
    "@njit(parallel=True, cache=True, fastmath=True)\n",
    "def primMST_numba(distance_matrix):\n",
    "    V = distance_matrix.shape[0]\n",
    "    key = np.full(V, np.inf)\n",
    "    parent = -np.ones(V, dtype=np.int64)\n",
    "    inMST = np.zeros(V, dtype=np.bool_)\n",
    "    key[0] = 0.0\n",
    "    \n",
    "    for _ in range(V):\n",
    "    \n",
    "        u = -1\n",
    "        min_val = np.inf\n",
    "        for v in range(V):\n",
    "            if (not inMST[v]) and (key[v] < min_val):\n",
    "                min_val = key[v]\n",
    "                u = v\n",
    "        if u == -1:\n",
    "            break\n",
    "        inMST[u] = True\n",
    "        for v in range(V):\n",
    "        \n",
    "            if (distance_matrix[u, v] > 0) and (not inMST[v]) and (distance_matrix[u, v] < key[v]):\n",
    "                key[v] = distance_matrix[u, v]\n",
    "                parent[v] = u\n",
    "    return parent\n",
    "\n",
    "def construct_MST_from_graph(distance_matrix):\n",
    "    V = distance_matrix.shape[0]\n",
    "\n",
    "    parent = primMST_numba(distance_matrix)\n",
    "    \n",
    "    MST = nx.Graph()\n",
    "    for i in range(V):\n",
    "        MST.add_node(i)\n",
    "\n",
    "    for i in range(1, V):\n",
    "        MST.add_edge(parent[i], i, weight=distance_matrix[i, parent[i]])\n",
    "    return MST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "56ae7678",
   "metadata": {},
   "outputs": [],
   "source": [
    "# n_processors is the number of processors.\n",
    "\n",
    "def cal_mmj_matrix_by_algo_4_Calculation_and_Copy_parallel_compu_algo_13(distance_matrix, n_processors):\n",
    "    global mmj_matrix  \n",
    "    lenX = len(distance_matrix)\n",
    "    \n",
    "    # distance_matrix = np.round(pairwise_distances(X), round_n)\n",
    "    \n",
    "    mmj_matrix = np.zeros((lenX, lenX)) \n",
    " \n",
    "    \n",
    "    MST = construct_MST_from_graph(distance_matrix)\n",
    " \n",
    "\n",
    "    MST_edge_list = list(MST.edges(data='weight'))\n",
    "\n",
    "    edge_node_list = [(edge[0], edge[1]) for edge in MST_edge_list]\n",
    "    edge_weight_list = [edge[2] for edge in MST_edge_list]\n",
    "    edge_large_to_small_arg = np.argsort(edge_weight_list)[::-1]\n",
    "    edge_weight_large_to_small = np.sort(edge_weight_list)[::-1]\n",
    "    edge_nodes_large_to_small = [edge_node_list[i] for i in edge_large_to_small_arg] \n",
    " \n",
    "    MST_list = []\n",
    "    MST_list.append(MST)\n",
    "    \n",
    "    for j in range(n_processors - 1):\n",
    "        MST_copy = MST.copy(as_view=False)\n",
    "        MST_list.append(MST_copy)\n",
    "\n",
    "    current_removed_edge_index = [-1]*n_processors\n",
    "    mst_locked = [False]*n_processors\n",
    " \n",
    "    N = lenX - 1\n",
    " \n",
    "    Parallel(n_jobs = n_processors, backend=\"threading\")(\n",
    "        delayed(for_parallel_compu)(i, MST_list, edge_nodes_large_to_small, edge_weight_large_to_small, current_removed_edge_index, mst_locked)\n",
    "        for i in range(N)\n",
    "    )\n",
    " \n",
    "    return mmj_matrix\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3cc11fc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def for_parallel_compu(i, MST_list, edge_nodes_large_to_small, edge_weight_large_to_small, current_removed_edge_index, mst_locked):\n",
    " \n",
    "    global mmj_matrix\n",
    " \n",
    "    for j in range(n_processors + 1):\n",
    "        assert j < n_processors, \"j < n_processors\"\n",
    "        if not mst_locked[j]:\n",
    "            if current_removed_edge_index[j] < i:\n",
    "                temppp = j\n",
    "                break\n",
    " \n",
    "    MST_temp = MST_list[temppp]\n",
    "    mst_locked[temppp] = True\n",
    " \n",
    "    P = current_removed_edge_index[temppp] + 1\n",
    " \n",
    "    for kk in range(P, i + 1):\n",
    "        edge_nodes = edge_nodes_large_to_small[kk]\n",
    "        MST_temp.remove_edge(*edge_nodes)\n",
    "        current_removed_edge_index[temppp] += 1\n",
    " \n",
    "\n",
    "    edge_weight = edge_weight_large_to_small[i]\n",
    "\n",
    "    tree1_nodes = list(nx.dfs_preorder_nodes(MST_temp, source=edge_nodes[0]))\n",
    "    tree2_nodes = list(nx.dfs_preorder_nodes(MST_temp, source=edge_nodes[1]))\n",
    " \n",
    "    mst_locked[temppp] = False\n",
    "  \n",
    "    idx1, idx2 = np.meshgrid(tree1_nodes, tree2_nodes, indexing=\"ij\")\n",
    "    mmj_matrix[idx1, idx2] = mmj_matrix[idx2, idx1] = edge_weight\n",
    "    # print(current_removed_edge_index, i, threading.get_ident())\n",
    "\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5dde8f38",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def generate_symmetric_matrix(n, low=0, high=1000, seed=222):\n",
    "    if seed is not None:\n",
    "        np.random.seed(seed)\n",
    "    A = np.random.randint(low, high, size=(n, n))\n",
    "    sym_A = (A + A.T) // 2  # Ensure symmetry and integer values\n",
    "    return sym_A.astype('float64')\n",
    "\n",
    " \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "15dc5963",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of points in data X: 10000\n",
      "Time used: 28.921s\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# n_processors is the number of processors.\n",
    "n_processors = 4\n",
    "\n",
    "NN = 10000\n",
    " \n",
    "distance_matrix = generate_symmetric_matrix(NN)\n",
    "\n",
    "print(f\"Number of points in data X: {NN}\" )\n",
    " \n",
    "\n",
    "mmj_matrix = None\n",
    " \n",
    "start = time.time()\n",
    "X_mmj_matrix_algo_4_parallel_compu = cal_mmj_matrix_by_algo_4_Calculation_and_Copy_parallel_compu_algo_13(distance_matrix, n_processors)\n",
    "end = time.time()\n",
    "time_used = end - start\n",
    "time_used = np.round(time_used, 3)\n",
    "\n",
    "print(f\"Time used: {time_used}s\" )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "738756ad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 7.,  7.,  8.,  9.,  8.,  9.,  7., 13.,  7., 10.,  7., 12.,  8.,\n",
       "        7.,  8.,  9.,  8.,  8.,  7.,  8., 10.,  7.,  9.,  7.,  7.,  7.,\n",
       "        8.,  7.,  7.,  8.])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_mmj_matrix_algo_4_parallel_compu[0, -30:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57d55db2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2ed3df0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
