{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3a363e24",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "from pymatgen.io.xyz import XYZ\n",
    "import pandas as pd\n",
    "from pymatgen.core.bonds import CovalentBond\n",
    "import numpy as np\n",
    "import math\n",
    "import heapq\n",
    "import time\n",
    "import networkx as nx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e1e78c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "## First Compute Graph structure \n",
    "def constructGraphMatrix(mole):\n",
    "    returnable = np.zeros([len(mole), len(mole)])\n",
    "    for i in range(len(mole)):\n",
    "        for j in range(len(mole)):\n",
    "            lol = CovalentBond.is_bonded(mole[i],mole[j])\n",
    "            if lol:\n",
    "                returnable[i][j] = 1\n",
    "            else:\n",
    "                returnable[i][j] = 0\n",
    "    return returnable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "181add87",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Load mols:\n",
    "graphs = []\n",
    "mols = []\n",
    "#path of dataset here:\n",
    "ss = \"/home/yury/Downloads/xyzanotherbranch/\"\n",
    "dirr = os.listdir(ss)\n",
    "xD = sorted(dirr)\n",
    "uuu = xD[2:]\n",
    "for i in range(2,len(xD)):\n",
    "    print(i)\n",
    "    u = os.path.join(ss, xD[i])\n",
    "    print(u)\n",
    "    xyzT = XYZ.from_file(u)\n",
    "    mole = xyzT.molecule\n",
    "    mols.append(mole)\n",
    "    distMat = constructGraphMatrix(mole) \n",
    "    graphs.append(distMat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "746bbf23",
   "metadata": {},
   "outputs": [],
   "source": [
    "def FindCentroid(molz):\n",
    "    tmp_coord = molz.cart_coords\n",
    "    summ = np.array([0,0,0])\n",
    "    for i in range(tmp_coord.shape[0]):\n",
    "        summ = summ + tmp_coord[i]\n",
    "    \n",
    "    return summ / tmp_coord.shape[0]\n",
    "\n",
    "def SuperMetricCalculator(mole):\n",
    "    ctr = FindCentroid(mole)\n",
    "    G = nx.Graph()\n",
    "    edges = []\n",
    "    for i in range(len(mole)):\n",
    "        distToCentroid = np.linalg.norm(ctr - mole.cart_coords[i])\n",
    "        edges.append((i, len(mole), distToCentroid))\n",
    "        for j in range(i):\n",
    "            lol = CovalentBond.is_bonded(mole[i],mole[j])\n",
    "            if lol:\n",
    "                locDist = mole.get_distance(i,j)\n",
    "                edges.append((i,j,locDist))\n",
    "    \n",
    "    G.add_weighted_edges_from(edges)\n",
    "    #print(G)\n",
    "    length = dict(nx.all_pairs_dijkstra_path_length(G))\n",
    "    abc = np.zeros([len(mole)+1,len(mole)+1])\n",
    "    #print(abc)\n",
    "    for i in range(len(mole)+1):\n",
    "        for j in range(len(mole)+1):\n",
    "            abc[i][j] = length[i][j]\n",
    "    return abc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "324fa1e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def toSortedVector(mole, distMat):\n",
    "    vec = []\n",
    "    for i in range(len(mole)):\n",
    "        for j in range(i):\n",
    "            locDist = distMat[i][j]\n",
    "            vec.append(locDist)\n",
    "    uu = sorted(vec)\n",
    "    return uu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2a9753c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ComputePDD(mole, distMat):\n",
    "    vec = []\n",
    "    for i in range(len(mole)):\n",
    "        tmpvec = []\n",
    "        for j in range(len(mole)):    \n",
    "            locDist = distMat[i][j]\n",
    "            tmpvec.append(locDist)\n",
    "        uu = sorted(tmpvec)\n",
    "        vec.append(uu)\n",
    "        \n",
    "    #print(vec)\n",
    "    return vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a407ce3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def lInftyMetric(vecA, vecB):\n",
    "    maxx = 0\n",
    "    for i in range(vecA.shape[0]):\n",
    "        maxx = max(maxx, abs(vecA[i]-vecB[i]))\n",
    "    return maxx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5723b3d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from amd import _emd\n",
    "from amd import network_simplex\n",
    "import os\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import wasserstein\n",
    "\n",
    "def EmdMetricPyPy(vecA, vecB,emd):\n",
    "    distMat = np.zeros([vecA.shape[0], vecB.shape[0]])\n",
    "   # print(\"Matrix initilized\")\n",
    "    for i in range(vecA.shape[0]):\n",
    "        for j in range(vecB.shape[0]):\n",
    "          #  print(\"I: \" + str(i) + \" | \" + \"J: \" + str(j))\n",
    "            dist = lInftyMetric(vecA[i], vecB[j])\n",
    "            distMat[i][j] = dist\n",
    "   # print(\"MAT FINNISHED PREPARING\")\n",
    "    simpleMat = np.ones([vecA.shape[0]])\n",
    "    simpleMatTwo = np.ones([vecB.shape[0]])\n",
    "    p = emd(simpleMat,simpleMatTwo,distMat)\n",
    "    return p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "289a593e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Crystal37Env",
   "language": "python",
   "name": "crystal37env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
