{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import logging\n",
    "import os\n",
    "import pickle5\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "import itertools\n",
    "from utils.loader import load_data, load_ckpt\n",
    "from collections import defaultdict\n",
    "from parsers.config import get_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = ['community_small', 'ego_small', 'grid', 'enzymes']\n",
    "mol_datasets = ['qm9', 'zinc250k']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "constraints = ['nedges', 'nedgesl2', 'nconn', 'cheeger', 'specradius']\n",
    "mol_constraints = ['valency', 'molWeight']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "constraints = ['nedges', 'ntriangles', 'maxDegree']\n",
    "mol_constraints = ['valency', 'molWeight']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.5       ,  1.88888889,  3.27777778,  4.66666667,  6.05555556,\n",
       "        7.44444444,  8.83333333, 10.22222222, 11.61111111, 13.        ])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.linspace (min(Es)/2, max(Es), num=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([1, 4, 5], {0, 2, 3, 6, 7})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[x for x in graphs[0].neighbors(3)], set(list(graphs[0].nodes)).difference((x for x in graphs[0].neighbors(3)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cheeger_const (G):\n",
    "    node_set = set(list(G.nodes))\n",
    "    minh = np.inf\n",
    "    for A in itertools.combinations(G.nodes, G.number_of_nodes()//2):\n",
    "        V_A = node_set.difference (A)\n",
    "        dA = 0\n",
    "        for n in A:\n",
    "            dA += len([x for x in G.neighbors(n) if x in V_A])\n",
    "        hA = dA/len(A)\n",
    "        minh = hA if (hA < minh) else minh\n",
    "    return minh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cheeger_grid (G):\n",
    "    # cheeger ~ 2/max{m, n}\n",
    "    # m . n = v, m*(n-1) + n*(m-1) = e\n",
    "    v = G.number_of_nodes() \n",
    "    e = G.number_of_edges() \n",
    "    return 4 / (2*v - e + ((2*v - e)**2 - 4*v)**0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "atomic_weights = {\n",
    "    'C': (12.0096 + 12.0116)/2,\n",
    "    'N': (14.00643 + 14.00728)/2,\n",
    "    'O': (15.99903 + 15.99977)/2,\n",
    "    'F': 18.998403163,\n",
    "    'P': 30.973761998,\n",
    "    'S': (32.059 + 32.076)/2,\n",
    "    'Cl': (35.446 + 35.457)/2,\n",
    "    'Br': (79.901 + 79.907)/2,\n",
    "    'I': 126.90447\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_constraint_params (dataset, constraint, nsamples=10):\n",
    "    config = get_config (f'sample_{dataset}', seed=42)\n",
    "    ckpt_dict = load_ckpt(config, 'cuda:0')\n",
    "    configt = ckpt_dict['config']\n",
    "    if dataset in ['qm9', 'zinc250k']:\n",
    "        with open(f'data/{configt.data.data.lower()}_test_nx.pkl', 'rb') as f:\n",
    "            graphs = pickle5.load(f)\n",
    "    else:\n",
    "        _, graphs = load_data(configt, get_graph_list=True)\n",
    "    if constraint == 'nedges':\n",
    "        Es = [graph.number_of_edges() for graph in graphs]\n",
    "        return np.linspace (min(Es), max(Es), num=nsamples)\n",
    "    elif constraint == 'nedgesl2':\n",
    "        Es = [graph.number_of_edges() for graph in graphs]\n",
    "        Vmax = max([graph.number_of_nodes() for graph in graphs])\n",
    "        return np.linspace (min(Es), max(Es), num=nsamples)/((Vmax*(Vmax -1)/2)**0.5)\n",
    "    elif constraint == 'nconn':\n",
    "        Vs = [graph.number_of_nodes() for graph in graphs]\n",
    "        return list(range (1, min(Vs), max(1, int((min(Vs) - 1)/nsamples))))\n",
    "        # return np.linspace (1, min(Vs)/2, num=nsamples)\n",
    "    elif constraint == 'cheeger':\n",
    "        if dataset == 'grid':\n",
    "            print (\"inside grid\")\n",
    "            chis = [cheeger_grid(graph) for graph in graphs]\n",
    "        else:\n",
    "            chis = [cheeger_const(graph) for graph in graphs]\n",
    "        return np.linspace (min(chis)/2, 2*max(chis), num=nsamples)\n",
    "    elif constraint == 'specradius':\n",
    "        d_avs = [np.mean(list(dict(graph.degree).values())) for graph in graphs]\n",
    "        d_maxs = [np.max(list(dict(graph.degree).values())) for graph in graphs]\n",
    "        return np.linspace (min(d_avs), max(d_maxs), num=nsamples)\n",
    "    elif constraint == 'ntriangles':\n",
    "        ntriangles = [sum([nx.triangles(graph, node) for node in graph.nodes()]) for graph in graphs]\n",
    "        return np.linspace (min(ntriangles), max(ntriangles), num=nsamples)\n",
    "    elif constraint == 'maxDegree':\n",
    "        d_maxs = [np.max(list(dict(graph.degree).values())) for graph in graphs]\n",
    "        return np.linspace (min(d_maxs), max(d_maxs), num=nsamples)\n",
    "    elif constraint == 'diameter':\n",
    "        diameters = [nx.diameter(graph) for graph in graphs]\n",
    "        return list (range (min(diameters), max(diameters)+1,  max(1, (max(diameters) - min(diameters)//2) // nsamples)))\n",
    "    elif constraint == 'valency':\n",
    "        if dataset == 'qm9':\n",
    "            # [6, 7, 8, 9]\n",
    "            return [[4, 3, 2, 1], [4, 5, 2, 1]]\n",
    "        elif dataset == 'zinc250k':\n",
    "            # [6, 7, 8, 9, 15, 16, 17, 35, 53]\n",
    "            val_phos = [3, 5]\n",
    "            val_sulf = [2, 4, 6]\n",
    "            return [[4, 3, 2, 1] + [x] + [y] + [1, 1, 1] for x, y in itertools.product(val_phos, val_sulf)]\n",
    "    elif constraint == 'atomCount':\n",
    "        def get_counts (mol):\n",
    "            counts = []\n",
    "            for atom in acount:\n",
    "                n, mol = mol.split(atom, 1)\n",
    "                try:\n",
    "                    counts.append (int(n))\n",
    "                except:\n",
    "                    pass\n",
    "            return counts[:4] if dataset == 'qm9' else counts \n",
    "        mols_counts = defaultdict (lambda: 0)\n",
    "        for graph in graphs:\n",
    "            acount = {'C': 0, 'N': 0, 'O': 0, 'F': 0, 'P': 0, 'S': 0, 'Cl': 0, 'Br': 0, 'I': 0}\n",
    "            nodes = graph.nodes()\n",
    "            for n in nodes:\n",
    "                acount[nodes[n]['label']] += 1\n",
    "            mols_counts[''.join([f'{k}{v}' for k, v in acount.items()])] += 1\n",
    "        l = sorted (mols_counts, key=lambda x: mols_counts[x], reverse=True)[:nsamples]\n",
    "        return [get_counts(x) for x in sorted (mols_counts, key=lambda x: mols_counts[x], reverse=True)[:nsamples]]\n",
    "    elif constraint == 'molWeight':\n",
    "        mol_wts = []\n",
    "        for G in graphs:\n",
    "            mol_wt = 0\n",
    "            for n in G.nodes:\n",
    "                mol_wt += atomic_weights[G.nodes[n]['label']]\n",
    "            mol_wts.append(mol_wt)\n",
    "        return np.linspace(min(mol_wts), max(mol_wts), num=nsamples)\n",
    "    elif constraint.startswith('prop'):\n",
    "        prop_name = constraint.split (\"-\")[1]\n",
    "        mol_file = pd.read_csv(f\"data/{dataset}.csv\")\n",
    "        all_props = mol_file[prop_name]\n",
    "        return np.linspace (min(all_props), max(all_props), num=nsamples)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint_name = {\n",
    "    \"nedges\": 'Num-Edges',\n",
    "    \"ntriangles\": \"Num-Triangles\", \n",
    "    \"maxDegree\": \"Max-Degree\",\n",
    "    \"valency\": 'Valency',\n",
    "    \"atomCount\": 'Atom-Count',\n",
    "    \"molWeight\": 'Mol-Weight',\n",
    "}\n",
    "    # \"regression\": 'Property-MLP',\n",
    "    # \"diameter\": \"Diameter\",\n",
    "    # \"nedgesl2\": 'L2-adj',\n",
    "    # \"nconn\": 'Nconn_atleast',\n",
    "    # \"cheeger\": 'Cheeger-bound',\n",
    "    # \"specradius\": 'Spectral-radius',"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./checkpoints/community_small/gdss_community_small.pth loaded\n",
      "./checkpoints/ego_small/gdss_ego_small.pth loaded\n",
      "./checkpoints/grid/gdss_grid.pth loaded\n",
      "./checkpoints/ENZYMES/gdss_enzymes.pth loaded\n",
      "{'community_small': [21.0, 25.67, 30.33, 35.0, 39.67, 44.33, 49.0, 53.67, 58.33, 63.0], 'ego_small': [3.0, 8.78, 14.56, 20.33, 26.11, 31.89, 37.67, 43.44, 49.22, 55.0], 'grid': [218.0, 269.78, 321.56, 373.33, 425.11, 476.89, 528.67, 580.44, 632.22, 684.0], 'enzymes': [16.0, 29.0, 42.0, 55.0, 68.0, 81.0, 94.0, 107.0, 120.0, 133.0]}\n",
      "\n",
      "./checkpoints/community_small/gdss_community_small.pth loaded\n",
      "./checkpoints/ego_small/gdss_ego_small.pth loaded\n",
      "./checkpoints/grid/gdss_grid.pth loaded\n",
      "./checkpoints/ENZYMES/gdss_enzymes.pth loaded\n",
      "{'community_small': [30.0, 47.67, 65.33, 83.0, 100.67, 118.33, 136.0, 153.67, 171.33, 189.0], 'ego_small': [0.0, 22.0, 44.0, 66.0, 88.0, 110.0, 132.0, 154.0, 176.0, 198.0], 'grid': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'enzymes': [0.0, 18.67, 37.33, 56.0, 74.67, 93.33, 112.0, 130.67, 149.33, 168.0]}\n",
      "\n",
      "./checkpoints/community_small/gdss_community_small.pth loaded\n",
      "./checkpoints/ego_small/gdss_ego_small.pth loaded\n",
      "./checkpoints/grid/gdss_grid.pth loaded\n",
      "./checkpoints/ENZYMES/gdss_enzymes.pth loaded\n",
      "{'community_small': [5.0, 5.56, 6.11, 6.67, 7.22, 7.78, 8.33, 8.89, 9.44, 10.0], 'ego_small': [3.0, 4.44, 5.89, 7.33, 8.78, 10.22, 11.67, 13.11, 14.56, 16.0], 'grid': [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0], 'enzymes': [2.0, 2.78, 3.56, 4.33, 5.11, 5.89, 6.67, 7.44, 8.22, 9.0]}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import yaml\n",
    "\n",
    "for constraint in constraints:\n",
    "    master_constraint = {}\n",
    "    # print (constraint)\n",
    "    for dataset in datasets:\n",
    "        params = get_constraint_params (dataset, constraint, nsamples=10)\n",
    "        master_constraint[dataset] = params.tolist() if type(params) is not list else params\n",
    "        master_constraint[dataset] = [round(x, 2) for x in master_constraint[dataset]]\n",
    "        # print (dataset)\n",
    "    print (master_constraint)\n",
    "    yaml.dump(master_constraint, open(f\"config/constraints/master_{constraint}.yaml\", \"w\"))\n",
    "    constraint_config = {\n",
    "        \"constraint\": constraint_name[constraint],\n",
    "        \"method\": {\"op\": \"proj\", \"gamma\": 1.0, \"solve_order\": \"cpj\"},\n",
    "        \"burnin\": 0, \"rounding\": \"none\", \"max_samples\": 10000, \n",
    "        \"add_diff_step\": 0, \"schedule\": {\"gamma\": \"poly\", \"params\": [0, 1]},\n",
    "        \"params\": [\"zeros\", master_constraint[datasets[0]][0]] if constraint.startswith('nedges') else [master_constraint[datasets[0]][0]]\n",
    "    }\n",
    "    yaml.dump(constraint_config, open(f\"config/constraints/{constraint}.yaml\", \"w\"))\n",
    "    print ()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "mol_constraints = ['atomCount']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./checkpoints/QM9/gdss_qm9.pth loaded\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[[7, 1, 1, 0],\n",
       " [7, 0, 2, 0],\n",
       " [6, 1, 2, 0],\n",
       " [8, 0, 1, 0],\n",
       " [6, 2, 1, 0],\n",
       " [6, 0, 3, 0],\n",
       " [5, 2, 2, 0],\n",
       " [8, 1, 0, 0],\n",
       " [7, 2, 0, 0],\n",
       " [5, 3, 1, 0]]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = 'qm9'\n",
    "constraint = 'atomCount'\n",
    "get_constraint_params (dataset, constraint, nsamples=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./checkpoints/QM9/gdss_qm9.pth loaded\n",
      "./checkpoints/ZINC250k/gdss_zinc250k.pth loaded\n",
      "{'qm9': [[4, 3, 2, 1], [4, 5, 2, 1]], 'zinc250k': [[4, 3, 2, 1, 3, 2, 1, 1, 1], [4, 3, 2, 1, 3, 4, 1, 1, 1], [4, 3, 2, 1, 3, 6, 1, 1, 1], [4, 3, 2, 1, 5, 2, 1, 1, 1], [4, 3, 2, 1, 5, 4, 1, 1, 1], [4, 3, 2, 1, 5, 6, 1, 1, 1]]}\n",
      "\n",
      "./checkpoints/QM9/gdss_qm9.pth loaded\n",
      "./checkpoints/ZINC250k/gdss_zinc250k.pth loaded\n",
      "{'qm9': [24.02, 36.58, 49.14, 61.69, 74.25, 86.81, 99.37, 111.92, 124.48, 137.04], 'zinc250k': [134.11, 172.97, 211.83, 250.69, 289.55, 328.41, 367.26, 406.12, 444.98, 483.84]}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import yaml\n",
    "\n",
    "for constraint in mol_constraints:\n",
    "    master_constraint = {}\n",
    "    for dataset in mol_datasets:\n",
    "        params = get_constraint_params (dataset, constraint, nsamples=10)\n",
    "        master_constraint[dataset] = params.tolist() if type(params) is not list else params\n",
    "        try:\n",
    "            master_constraint[dataset] = [round(x, 2) for x in master_constraint[dataset]]\n",
    "        except:\n",
    "            pass\n",
    "    print (master_constraint)\n",
    "    yaml.dump(master_constraint, open(f\"config/constraints/master_{constraint}.yaml\", \"w\"))\n",
    "    constraint_config = {\n",
    "        \"constraint\": constraint_name[constraint],\n",
    "        \"method\": {\"op\": \"proj\", \"gamma\": 1.0, \"solve_order\": \"cpj\"},\n",
    "        \"burnin\": 0, \"rounding\": \"none\", \"max_samples\": 10000, \n",
    "        \"add_diff_step\": 0, \"schedule\": {\"gamma\": \"poly\", \"params\": [0, 1]},\n",
    "        \"params\": [list(atomic_weights.values()), master_constraint[dataset][0]] if constraint.startswith('molWeight') else [master_constraint[dataset][0]]\n",
    "    }\n",
    "    yaml.dump(constraint_config, open(f\"config/constraints/{constraint}_{dataset}.yaml\", \"w\"))\n",
    "    print ()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./checkpoints/QM9/gdss_qm9.pth loaded\n",
      "./checkpoints/QM9/gdss_qm9.pth loaded\n",
      "./checkpoints/QM9/gdss_qm9.pth loaded\n",
      "./checkpoints/ZINC250k/gdss_zinc250k.pth loaded\n",
      "./checkpoints/QM9/gdss_qm9.pth loaded\n",
      "./checkpoints/QM9/gdss_qm9.pth loaded\n",
      "./checkpoints/QM9/gdss_qm9.pth loaded\n",
      "./checkpoints/ZINC250k/gdss_zinc250k.pth loaded\n",
      "./checkpoints/ZINC250k/gdss_zinc250k.pth loaded\n"
     ]
    }
   ],
   "source": [
    "import yaml\n",
    "import pandas as pd\n",
    "\n",
    "saved_models = os.listdir (\"config/constraints/regmodels\")\n",
    "\n",
    "for saved_model in saved_models:\n",
    "    _, dataset, prop_name = saved_model[:-3].split(\"_\")\n",
    "    constraint = f'prop-{prop_name}'\n",
    "    params = get_constraint_params (dataset, constraint, nsamples=10)\n",
    "    master_constraint = {}\n",
    "    master_constraint[dataset] = params.tolist() if type(params) is not list else params\n",
    "    yaml.dump(master_constraint, open(f\"config/constraints/master_{constraint}.yaml\", \"w\"))\n",
    "    constraint_config = {\n",
    "        \"constraint\": \"Property-SGC\",\n",
    "        \"method\": {\"op\": \"proj\", \"gamma\": 1.0, \"solve_order\": \"cpj\"},\n",
    "        \"burnin\": 0, \"rounding\": \"none\", \"max_samples\": 10000, \n",
    "        \"add_diff_step\": 0, \"schedule\": {\"gamma\": \"poly\", \"params\": [0, 1]},\n",
    "        \"params\": [saved_model, master_constraint[dataset][0]]\n",
    "    }\n",
    "    yaml.dump(constraint_config, open(f\"config/constraints/{constraint}_{dataset}.yaml\", \"w\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "\n",
    "prop_name = \"homo\"\n",
    "dataset = \"qm9\"\n",
    "\n",
    "mol_file = pd.read_csv(f\"data/{dataset}.csv\")\n",
    "from find_metrics import find_metric\n",
    "self.ys = torch.tensor(df[prop_name]) if prop_name in df else torch.tensor (find_metric(self.smiles, prop_name))\n",
    "\n",
    "all_props = mol_file[prop_name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.2605"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_props[:100].median()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "44dc670dcdd6ffb1ba23034ae072504999a2c20bd6cc686fd82920ca8c3f3b47"
  },
  "kernelspec": {
   "display_name": "Python 3.7.15 64-bit ('moltemp': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.15"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
