{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from projop.project_bisection import project_sublevel, project\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DotDict(dict):\n",
    "    \"\"\"\n",
    "    a dictionary that supports dot notation \n",
    "    as well as dictionary access notation \n",
    "    usage: d = DotDict() or d = DotDict({'val1':'first'})\n",
    "    set attributes: d.val2 = 'second' or d['val2'] = 'second'\n",
    "    get attributes: d.val2 or d['val2']\n",
    "    \"\"\"\n",
    "    __getattr__ = dict.__getitem__\n",
    "    __setattr__ = dict.__setitem__\n",
    "    __delattr__ = dict.__delitem__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "N, n, F = 20, 20, 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjs = torch.zeros(20, 5, 5)\n",
    "rows, cols = torch.triu_indices (adjs.shape[1], adjs.shape[2], offset=1)\n",
    "adjs[:, rows, cols] = torch.randn (adjs.shape[0], rows.shape[0])\n",
    "adjs = adjs + adjs.transpose (1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = torch.randn (20, 5, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from reg_models import SGCReg, LinearGC\n",
    "model1 = SGCReg (4, 2).float()\n",
    "model2 = LinearGC (4, 2).float()\n",
    "\n",
    "model2.theta = model1.conv.lin.weight\n",
    "model2.bias = model1.conv.lin.bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils.sparse import dense_to_sparse\n",
    "\n",
    "edge_ind, edge_attr = dense_to_sparse (adjs[0])\n",
    "x = xs[0]\n",
    "batch = torch.zeros_like (x[:, 0]).long()\n",
    "z = model1(x, edge_ind, edge_attr, batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 5, 1])\n"
     ]
    }
   ],
   "source": [
    "z2 = model2(xs[[0]], adjs[[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 4.0531e-06,  7.1526e-07, -5.2452e-06, -1.9073e-06, -3.2187e-06,\n",
      "        -8.3447e-06, -2.8610e-06,  4.5300e-06, -6.7949e-06,  3.0994e-06,\n",
      "         7.1526e-07,  8.2254e-06, -2.8610e-06, -3.2783e-06, -1.0729e-06,\n",
      "        -6.3181e-06, -2.5034e-06,  7.8678e-06, -1.4305e-06,  9.5367e-06])\n"
     ]
    }
   ],
   "source": [
    "B = 1\n",
    "proj_xs, proj_adjs = project (xs, adjs, DotDict({'constraint': 'Num-Edges', 'params': ['zeros', B]}))\n",
    "print (proj_adjs.sum(dim=-1).sum(dim=-1)/2 - B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 tensor([0.4722]) tensor([0.4722])\n",
      "1 tensor([0.4722]) tensor([-0.1283])\n",
      "tensor([-9.9868e-01, -5.0694e-01, -1.0000e+00, -9.2422e-01, -3.8147e-06,\n",
      "        -1.0000e+00, -3.2072e-01, -6.0267e-01, -5.8653e-01, -1.0000e+00,\n",
      "        -9.8967e-01, -1.0000e+00, -7.6515e-01, -1.0000e+00, -3.6230e-01,\n",
      "        -7.3909e-01, -1.0000e+00, -9.7939e-01, -1.0000e+00, -1.0000e+00])\n"
     ]
    }
   ],
   "source": [
    "nT = 1\n",
    "_, proj_adjs = project (xs, adjs, DotDict({'constraint': 'Num-Triangles', 'params': [nT]}))\n",
    "print (1/6 * torch.diagonal(torch.matrix_power(proj_adjs, 3), dim1=1, dim2=2).sum(dim=1) - nT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) tensor([1.7389, 1.4342, 2.0206, 2.0206, 1.5664, 1.5508, 1.5664, 2.9302, 1.8055,\n",
      "        2.6887, 1.3006, 1.3006, 2.1728, 1.6217, 2.1728, 0.9458, 0.8881, 1.5470,\n",
      "        1.7200, 1.2898]) tensor([0.8006, 0.6408, 0.3350, 0.4947, 0.7660, 0.7492, 0.4492, 1.0000, 1.0972,\n",
      "        1.0000, 0.9281, 1.2484, 0.4974, 0.4110, 0.4647, 1.0263, 1.1594, 1.2265,\n",
      "        0.4530, 1.3551])\n",
      "1 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) tensor([0.8695, 0.7171, 1.0103, 1.0103, 0.7832, 0.7754, 0.7832, 1.4651, 0.9028,\n",
      "        1.3443, 0.6503, 0.6503, 1.0864, 0.8109, 1.0864, 0.4729, 0.4441, 0.7735,\n",
      "        0.8600, 0.6449]) tensor([0.2353, 0.0470, 0.3744, 0.0482, 0.1576, 0.2508, 0.0000, 0.9835, 0.2300,\n",
      "        0.4809, 0.0473, 0.2749, 0.3100, 0.3462, 0.4309, 0.3170, 0.4933, 0.6564,\n",
      "        0.0000, 0.7831])\n",
      "2 tensor([0.0000, 0.3585, 0.5051, 0.5051, 0.0000, 0.3877, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.3252, 0.0000, 0.5432, 0.4054, 0.5432, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000]) tensor([0.4347, 0.7171, 1.0103, 1.0103, 0.3916, 0.7754, 0.7832, 0.7325, 0.4514,\n",
      "        0.6722, 0.6503, 0.3252, 1.0864, 0.8109, 1.0864, 0.2364, 0.2220, 0.3867,\n",
      "        0.8600, 0.3225]) tensor([0.0881, 0.2822, 0.0394, 0.2422, 0.0084, 0.1676, 0.0000, 0.4722, 0.2214,\n",
      "        0.1448, 0.4404, 0.0748, 0.2258, 0.1243, 0.0338, 0.0377, 0.1401, 0.2696,\n",
      "        0.0000, 0.3043])\n",
      "3 tensor([0.2174, 0.3585, 0.7577, 0.5051, 0.1958, 0.3877, 0.0000, 0.0000, 0.2257,\n",
      "        0.0000, 0.3252, 0.1626, 0.5432, 0.6081, 0.5432, 0.1182, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000]) tensor([0.4347, 0.5378, 1.0103, 0.7577, 0.3916, 0.5816, 0.7832, 0.3663, 0.4514,\n",
      "        0.3361, 0.4877, 0.3252, 0.8148, 0.8109, 0.8148, 0.2364, 0.1110, 0.1934,\n",
      "        0.8600, 0.1612]) tensor([0.1266, 0.1029, 0.2087, 0.1159, 0.0597, 0.0570, 0.0000, 0.1060, 0.0043,\n",
      "        0.0000, 0.1965, 0.0878, 0.0209, 0.1662, 0.2378, 0.1396, 0.0820, 0.0762,\n",
      "        0.0000, 0.0625])\n",
      "4 tensor([0.2174, 0.3585, 0.7577, 0.5051, 0.1958, 0.4846, 0.0000, 0.0000, 0.2257,\n",
      "        0.0000, 0.3252, 0.1626, 0.5432, 0.6081, 0.6790, 0.1182, 0.0555, 0.0000,\n",
      "        0.0000, 0.0000]) tensor([0.3261, 0.4482, 0.8840, 0.6314, 0.2937, 0.5816, 0.7832, 0.1831, 0.3385,\n",
      "        0.3361, 0.4064, 0.2439, 0.6790, 0.7095, 0.8148, 0.1773, 0.1110, 0.0967,\n",
      "        0.8600, 0.0806]) tensor([0.0206, 0.0420, 0.0868, 0.0527, 0.0107, 0.0399, 0.0000, 0.0772, 0.1086,\n",
      "        0.0000, 0.0746, 0.0065, 0.1521, 0.0141, 0.1020, 0.0510, 0.0290, 0.0205,\n",
      "        0.0000, 0.0585])\n",
      "5 tensor([0.2174, 0.3585, 0.7577, 0.5051, 0.1958, 0.4846, 0.0000, 0.0916, 0.2821,\n",
      "        0.0000, 0.3252, 0.1626, 0.6111, 0.6081, 0.7469, 0.1182, 0.0555, 0.0483,\n",
      "        0.0000, 0.0403]) tensor([0.2717, 0.4034, 0.8209, 0.5683, 0.2447, 0.5331, 0.7832, 0.1831, 0.3385,\n",
      "        0.3361, 0.3658, 0.2032, 0.6790, 0.6588, 0.8148, 0.1478, 0.0833, 0.0967,\n",
      "        0.8600, 0.0806]) tensor([0.0337, 0.0028, 0.0237, 0.0045, 0.0000, 0.0085, 0.0000, 0.0144, 0.0522,\n",
      "        0.0000, 0.0136, 0.0342, 0.0503, 0.0619, 0.0341, 0.0066, 0.0265, 0.0279,\n",
      "        0.0000, 0.0020])\n",
      "6 tensor([0.2445, 0.3810, 0.7577, 0.5051, 0.1958, 0.5089, 0.0000, 0.0916, 0.3103,\n",
      "        0.0000, 0.3252, 0.1829, 0.6451, 0.6335, 0.7809, 0.1182, 0.0694, 0.0483,\n",
      "        0.0000, 0.0403]) tensor([0.2717, 0.4034, 0.7893, 0.5367, 0.2447, 0.5331, 0.7832, 0.1374, 0.3385,\n",
      "        0.3361, 0.3455, 0.2032, 0.6790, 0.6588, 0.8148, 0.1330, 0.0833, 0.0725,\n",
      "        0.8600, 0.0605]) tensor([0.0066, 0.0196, 0.0079, 0.0271, 0.0000, 0.0157, 0.0000, 0.0314, 0.0239,\n",
      "        0.0000, 0.0168, 0.0138, 0.0130, 0.0239, 0.0001, 0.0155, 0.0013, 0.0037,\n",
      "        0.0000, 0.0282])\n",
      "7 tensor([0.2581, 0.3810, 0.7735, 0.5209, 0.1958, 0.5089, 0.0000, 0.1145, 0.3244,\n",
      "        0.0000, 0.3353, 0.1931, 0.6620, 0.6462, 0.7978, 0.1256, 0.0694, 0.0483,\n",
      "        0.0000, 0.0504]) tensor([0.2717, 0.3922, 0.7893, 0.5367, 0.2447, 0.5210, 0.7832, 0.1374, 0.3385,\n",
      "        0.3361, 0.3455, 0.2032, 0.6790, 0.6588, 0.8148, 0.1330, 0.0763, 0.0604,\n",
      "        0.8600, 0.0605]) tensor([0.0070, 0.0084, 0.0079, 0.0113, 0.0000, 0.0036, 0.0000, 0.0085, 0.0098,\n",
      "        0.0000, 0.0016, 0.0037, 0.0039, 0.0049, 0.0168, 0.0045, 0.0126, 0.0084,\n",
      "        0.0000, 0.0131])\n",
      "8 tensor([0.2581, 0.3810, 0.7735, 0.5288, 0.1958, 0.5089, 0.0000, 0.1259, 0.3315,\n",
      "        0.0000, 0.3404, 0.1981, 0.6620, 0.6525, 0.7978, 0.1293, 0.0729, 0.0544,\n",
      "        0.0000, 0.0554]) tensor([0.2649, 0.3866, 0.7814, 0.5367, 0.2447, 0.5149, 0.7832, 0.1374, 0.3385,\n",
      "        0.3361, 0.3455, 0.2032, 0.6705, 0.6588, 0.8063, 0.1330, 0.0763, 0.0604,\n",
      "        0.8600, 0.0605]) tensor([2.3890e-04, 2.7581e-03, 2.9325e-05, 3.3946e-03, 0.0000e+00, 2.4698e-03,\n",
      "        0.0000e+00, 2.9573e-03, 2.7840e-03, 0.0000e+00, 6.0151e-03, 1.4050e-03,\n",
      "        4.5450e-03, 4.6198e-03, 8.3462e-03, 1.0915e-03, 5.6572e-03, 2.3239e-03,\n",
      "        0.0000e+00, 5.5652e-03])\n",
      "9 tensor([0.2581, 0.3810, 0.7735, 0.5328, 0.1958, 0.5119, 0.0000, 0.1259, 0.3350,\n",
      "        0.0000, 0.3404, 0.1981, 0.6663, 0.6525, 0.7978, 0.1293, 0.0746, 0.0574,\n",
      "        0.0000, 0.0579]) tensor([0.2615, 0.3838, 0.7774, 0.5367, 0.2447, 0.5149, 0.7832, 0.1316, 0.3385,\n",
      "        0.3361, 0.3429, 0.2007, 0.6705, 0.6557, 0.8021, 0.1312, 0.0763, 0.0604,\n",
      "        0.8600, 0.0605]) tensor([3.1574e-03, 4.2915e-05, 3.9170e-03, 5.5194e-04, 0.0000e+00, 5.5909e-04,\n",
      "        0.0000e+00, 2.7657e-03, 7.4232e-04, 0.0000e+00, 2.2049e-03, 1.1351e-03,\n",
      "        3.0112e-04, 1.3113e-04, 4.1023e-03, 1.6794e-03, 2.1877e-03, 6.9761e-04,\n",
      "        0.0000e+00, 1.7865e-03])\n",
      "10 tensor([0.2598, 0.3824, 0.7755, 0.5328, 0.1958, 0.5119, 0.0000, 0.1288, 0.3350,\n",
      "        0.0000, 0.3404, 0.1994, 0.6684, 0.6541, 0.7978, 0.1302, 0.0755, 0.0574,\n",
      "        0.0000, 0.0592]) tensor([0.2615, 0.3838, 0.7774, 0.5347, 0.2447, 0.5134, 0.7832, 0.1316, 0.3368,\n",
      "        0.3361, 0.3417, 0.2007, 0.6705, 0.6557, 0.8000, 0.1312, 0.0763, 0.0589,\n",
      "        0.8600, 0.0605]) tensor([1.4591e-03, 1.3576e-03, 1.9438e-03, 1.4215e-03, 0.0000e+00, 9.5534e-04,\n",
      "        0.0000e+00, 9.5725e-05, 1.0209e-03, 0.0000e+00, 2.9945e-04, 1.3494e-04,\n",
      "        1.8208e-03, 2.2442e-03, 1.9804e-03, 2.9397e-04, 4.5347e-04, 8.1301e-04,\n",
      "        0.0000e+00, 1.0300e-04])\n",
      "11 tensor([0.2607, 0.3824, 0.7765, 0.5338, 0.1958, 0.5127, 0.0000, 0.1288, 0.3359,\n",
      "        0.0000, 0.3404, 0.1994, 0.6684, 0.6541, 0.7978, 0.1307, 0.0759, 0.0582,\n",
      "        0.0000, 0.0592]) tensor([0.2615, 0.3831, 0.7774, 0.5347, 0.2447, 0.5134, 0.7832, 0.1302, 0.3368,\n",
      "        0.3361, 0.3410, 0.2000, 0.6695, 0.6549, 0.7989, 0.1312, 0.0763, 0.0589,\n",
      "        0.8600, 0.0598]) tensor([6.1011e-04, 6.5732e-04, 9.5725e-04, 4.3488e-04, 0.0000e+00, 1.9813e-04,\n",
      "        0.0000e+00, 1.3351e-03, 1.3924e-04, 0.0000e+00, 6.5303e-04, 5.0020e-04,\n",
      "        7.5984e-04, 1.0564e-03, 9.1946e-04, 3.9864e-04, 4.1389e-04, 5.7697e-05,\n",
      "        0.0000e+00, 8.4186e-04])\n",
      "12 tensor([0.2611, 0.3824, 0.7769, 0.5342, 0.1958, 0.5130, 0.0000, 0.1295, 0.3363,\n",
      "        0.0000, 0.3407, 0.1997, 0.6684, 0.6541, 0.7978, 0.1307, 0.0759, 0.0585,\n",
      "        0.0000, 0.0595]) tensor([0.2615, 0.3827, 0.7774, 0.5347, 0.2447, 0.5134, 0.7832, 0.1302, 0.3368,\n",
      "        0.3361, 0.3410, 0.2000, 0.6689, 0.6545, 0.7984, 0.1309, 0.0761, 0.0589,\n",
      "        0.8600, 0.0598]) tensor([1.8549e-04, 3.0720e-04, 4.6396e-04, 5.8532e-05, 0.0000e+00, 1.8048e-04,\n",
      "        0.0000e+00, 6.1965e-04, 3.0160e-04, 0.0000e+00, 1.7667e-04, 1.8263e-04,\n",
      "        2.2936e-04, 4.6253e-04, 3.8886e-04, 5.2452e-05, 1.9550e-05, 3.1984e-04,\n",
      "        0.0000e+00, 3.6955e-04])\n",
      "13 tensor([0.2613, 0.3824, 0.7772, 0.5342, 0.1958, 0.5130, 0.0000, 0.1298, 0.3363,\n",
      "        0.0000, 0.3409, 0.1999, 0.6684, 0.6541, 0.7978, 0.1307, 0.0760, 0.0585,\n",
      "        0.0000, 0.0597]) tensor([0.2615, 0.3825, 0.7774, 0.5345, 0.2447, 0.5132, 0.7832, 0.1302, 0.3366,\n",
      "        0.3361, 0.3410, 0.2000, 0.6687, 0.6543, 0.7981, 0.1308, 0.0761, 0.0587,\n",
      "        0.8600, 0.0598]) tensor([2.6703e-05, 1.3208e-04, 2.1720e-04, 1.8811e-04, 0.0000e+00, 8.8215e-06,\n",
      "        0.0000e+00, 2.6202e-04, 8.1182e-05, 0.0000e+00, 6.1274e-05, 2.3842e-05,\n",
      "        3.5763e-05, 1.6558e-04, 1.2374e-04, 1.2064e-04, 1.9741e-04, 1.3101e-04,\n",
      "        0.0000e+00, 1.3328e-04])\n",
      "14 tensor([0.2613, 0.3824, 0.7773, 0.5344, 0.1958, 0.5130, 0.0000, 0.1300, 0.3363,\n",
      "        0.0000, 0.3409, 0.2000, 0.6685, 0.6541, 0.7978, 0.1307, 0.0760, 0.0585,\n",
      "        0.0000, 0.0598]) tensor([0.2614, 0.3824, 0.7774, 0.5345, 0.2447, 0.5132, 0.7832, 0.1302, 0.3364,\n",
      "        0.3361, 0.3409, 0.2000, 0.6687, 0.6542, 0.7980, 0.1308, 0.0761, 0.0586,\n",
      "        0.8600, 0.0598]) tensor([7.9393e-05, 4.4584e-05, 9.3937e-05, 6.4850e-05, 0.0000e+00, 8.8215e-06,\n",
      "        0.0000e+00, 8.3208e-05, 2.9087e-05, 0.0000e+00, 5.7697e-05, 5.5552e-05,\n",
      "        9.6798e-05, 1.7166e-05, 9.0599e-06, 3.4094e-05, 8.8692e-05, 3.6597e-05,\n",
      "        0.0000e+00, 1.5259e-05])\n",
      "15 tensor([0.2614, 0.3824, 0.7774, 0.5344, 0.1958, 0.5130, 0.0000, 0.1301, 0.3364,\n",
      "        0.0000, 0.3409, 0.2000, 0.6685, 0.6541, 0.7978, 0.1308, 0.0760, 0.0585,\n",
      "        0.0000, 0.0598]) tensor([0.2614, 0.3824, 0.7774, 0.5345, 0.2447, 0.5132, 0.7832, 0.1302, 0.3364,\n",
      "        0.3361, 0.3409, 0.2000, 0.6686, 0.6541, 0.7980, 0.1308, 0.0760, 0.0586,\n",
      "        0.8600, 0.0598]) tensor([2.6226e-05, 8.3447e-07, 3.2187e-05, 3.0994e-06, 0.0000e+00, 8.8215e-06,\n",
      "        0.0000e+00, 6.3181e-06, 2.6107e-05, 0.0000e+00, 1.9073e-06, 1.5736e-05,\n",
      "        3.0518e-05, 5.7220e-05, 9.0599e-06, 9.1791e-06, 3.4809e-05, 1.0490e-05,\n",
      "        0.0000e+00, 4.3869e-05])\n",
      "16 tensor([0.2614, 0.3824, 0.7774, 0.5344, 0.1958, 0.5130, 0.0000, 0.1301, 0.3364,\n",
      "        0.0000, 0.3409, 0.2000, 0.6685, 0.6541, 0.7978, 0.1308, 0.0760, 0.0586,\n",
      "        0.0000, 0.0598]) tensor([0.2614, 0.3824, 0.7774, 0.5345, 0.2447, 0.5132, 0.7832, 0.1302, 0.3364,\n",
      "        0.3361, 0.3409, 0.2000, 0.6686, 0.6541, 0.7980, 0.1308, 0.0760, 0.0586,\n",
      "        0.8600, 0.0598]) tensor([2.3842e-07, 8.3447e-07, 1.4305e-06, 3.0994e-06, 0.0000e+00, 8.8215e-06,\n",
      "        0.0000e+00, 6.3181e-06, 1.4305e-06, 0.0000e+00, 1.9073e-06, 4.0531e-06,\n",
      "        2.6226e-06, 2.0027e-05, 9.0599e-06, 9.1791e-06, 7.6294e-06, 1.2994e-05,\n",
      "        0.0000e+00, 1.4305e-05])\n",
      "17 tensor([0.2614, 0.3824, 0.7774, 0.5344, 0.1958, 0.5130, 0.0000, 0.1301, 0.3364,\n",
      "        0.0000, 0.3409, 0.2000, 0.6685, 0.6541, 0.7978, 0.1308, 0.0760, 0.0586,\n",
      "        0.0000, 0.0598]) tensor([0.2614, 0.3824, 0.7774, 0.5345, 0.2447, 0.5132, 0.7832, 0.1302, 0.3364,\n",
      "        0.3361, 0.3409, 0.2000, 0.6686, 0.6541, 0.7980, 0.1308, 0.0760, 0.0586,\n",
      "        0.8600, 0.0598]) tensor([2.3842e-07, 8.3447e-07, 1.4305e-06, 3.0994e-06, 0.0000e+00, 8.8215e-06,\n",
      "        0.0000e+00, 6.3181e-06, 1.4305e-06, 0.0000e+00, 1.9073e-06, 4.0531e-06,\n",
      "        2.6226e-06, 1.4305e-06, 9.0599e-06, 9.1791e-06, 7.6294e-06, 1.1921e-06,\n",
      "        0.0000e+00, 4.7684e-07])\n",
      "tensor([[-7.1398e-01, -2.3842e-07, -7.8693e-01, -6.4920e-01, -8.4606e-01],\n",
      "        [-1.5452e+00, -1.0231e+00, -5.8585e-01, -1.9825e+00, -8.3447e-07],\n",
      "        [-2.0000e+00, -9.1675e-01, -2.9133e-01, -1.0832e+00, -2.9133e-01],\n",
      "        [-2.0000e+00, -5.1800e-01, -1.0000e+00, -1.5180e+00,  0.0000e+00],\n",
      "        [-5.4486e-01, -2.0000e+00, -5.7415e-01, -3.9939e-01, -1.7209e+00],\n",
      "        [-6.3181e-06,  0.0000e+00, -9.8713e-01, -7.4937e-01, -1.0000e+00],\n",
      "        [-1.0000e+00, -1.0000e+00, -4.8019e-01, -4.8019e-01,  0.0000e+00],\n",
      "        [-1.7447e+00, -1.0250e+00, -1.0000e+00, -1.0818e+00, -6.8784e-01],\n",
      "        [-1.4445e+00, -1.0000e+00, -1.6079e+00, -2.8039e-01, -1.7720e+00],\n",
      "        [ 1.4305e-06,  0.0000e+00, -1.5096e+00,  0.0000e+00, -4.9037e-01],\n",
      "        [-2.0000e+00, -1.2552e+00, -6.2368e-01, -1.3685e+00, -2.0000e+00],\n",
      "        [-1.2403e+00, -1.2403e+00, -1.0272e+00, -1.5639e+00, -1.1239e+00],\n",
      "        [-1.0000e+00, -1.3583e+00, -1.0053e+00, -4.9905e-01, -2.0000e+00],\n",
      "        [-5.8401e-01, -1.5171e+00, -7.6015e-02, -1.3080e-01, -1.4684e+00],\n",
      "        [-1.1921e-06, -1.2202e+00, -6.1748e-01, -1.8377e+00,  0.0000e+00],\n",
      "        [-1.6561e+00, -1.9183e+00, -2.8837e-03, -1.0000e+00, -1.4285e+00],\n",
      "        [-1.8256e+00, -1.3567e-01, -2.6718e-01, -6.1277e-01, -1.3444e+00],\n",
      "        [-1.1758e+00, -1.7026e+00, -4.4272e-03, -1.3768e+00, -3.2235e-01],\n",
      "        [-1.1740e+00, -1.7377e+00,  4.7684e-07, -6.8613e-01, -1.7745e+00],\n",
      "        [-1.7793e+00, -1.4228e+00, -6.6061e-01, -5.9785e-02, -6.0128e-01]])\n"
     ]
    }
   ],
   "source": [
    "D = 2\n",
    "proj_xs, proj_adjs = project (xs, adjs, DotDict({'constraint': 'Max-Degree', 'params': [D]}))\n",
    "print (proj_adjs.sum(dim=2) - D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0.]) tensor([1.1994, 1.3592, 1.5156, 2.0206, 1.0944, 1.2151, 1.5664, 1.5508, 2.9302,\n",
      "        1.1302, 1.1733, 0.2280, 1.8055, 2.6887, 2.6887, 1.3006, 2.1728, 2.1728,\n",
      "        1.2467, 0.9458, 0.5584, 0.8881, 1.7200, 1.7200, 0.8643, 1.4539, 0.7884,\n",
      "        1.0974, 1.2898, 0.2853, 3.0469]) tensor([0.8701, 0.5467, 0.8154, 1.7805, 0.3713, 1.0564, 0.4041, 0.7013, 0.5584,\n",
      "        1.0033, 0.1048, 0.0858, 0.9028, 1.6294, 1.5845, 1.3026, 1.5890, 2.9699,\n",
      "        0.8213, 0.6409, 0.3869, 0.0646, 1.2529, 0.1732, 0.5187, 0.0776, 0.2430,\n",
      "        0.7140, 0.0288, 0.1022, 2.4071])\n",
      "1 tensor([0.0000, 0.6796, 0.7578, 0.0000, 0.0000, 0.0000, 0.7832, 0.7754, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.9028, 1.3443, 0.0000, 0.0000, 1.0864, 0.0000,\n",
      "        0.0000, 0.4729, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.5487, 0.0000, 0.0000, 0.0000]) tensor([0.5997, 1.3592, 1.5156, 1.0103, 0.5472, 0.6076, 1.5664, 1.5508, 1.4651,\n",
      "        0.5651, 0.5866, 0.1140, 1.8055, 2.6887, 1.3443, 0.6503, 2.1728, 1.0864,\n",
      "        0.6234, 0.9458, 0.2792, 0.4441, 0.8600, 0.8600, 0.4322, 0.7269, 0.3942,\n",
      "        1.0974, 0.6449, 0.1427, 1.5234]) tensor([0.5222, 0.2069, 0.3148, 0.3958, 0.1107, 0.4489, 0.3791, 0.0910, 0.1906,\n",
      "        0.4835, 0.3215, 0.0288, 0.4514, 0.4961, 0.3933, 0.3291, 0.5432, 1.4525,\n",
      "        0.3943, 0.2003, 0.2473, 0.6015, 0.8118, 0.6868, 0.0846, 0.3096, 0.0459,\n",
      "        0.1999, 0.5432, 0.0309, 1.4669])\n",
      "2 tensor([0.0000, 1.0194, 1.1367, 0.0000, 0.2736, 0.0000, 0.7832, 0.7754, 0.7325,\n",
      "        0.0000, 0.2933, 0.0000, 1.3541, 2.0165, 0.0000, 0.0000, 1.6296, 0.0000,\n",
      "        0.0000, 0.7093, 0.0000, 0.2220, 0.0000, 0.4300, 0.2161, 0.3635, 0.0000,\n",
      "        0.8230, 0.3225, 0.0000, 0.0000]) tensor([0.2999, 1.3592, 1.5156, 0.5051, 0.5472, 0.3038, 1.1748, 1.1631, 1.4651,\n",
      "        0.2826, 0.5866, 0.0570, 1.8055, 2.6887, 0.6722, 0.3252, 2.1728, 0.5432,\n",
      "        0.3117, 0.9458, 0.1396, 0.4441, 0.4300, 0.8600, 0.4322, 0.7269, 0.1971,\n",
      "        1.0974, 0.6449, 0.0713, 0.7617]) tensor([2.2230e-01, 3.6964e-02, 1.2531e-01, 3.7001e-01, 1.6292e-01, 1.4510e-01,\n",
      "        1.2477e-02, 2.9675e-01, 1.9216e-01, 2.0093e-01, 4.1847e-02, 2.6543e-04,\n",
      "        2.2569e-01, 1.6002e-01, 2.7889e-01, 1.5860e-01, 2.7160e-01, 6.3773e-01,\n",
      "        8.2657e-02, 8.2095e-02, 1.1603e-01, 2.6847e-01, 3.8183e-01, 2.5676e-01,\n",
      "        2.3955e-01, 1.0413e-01, 5.2689e-02, 6.2771e-02, 2.2073e-01, 4.8099e-03,\n",
      "        7.0518e-01])\n",
      "3 tensor([0.0000, 1.1893, 1.3261, 0.2526, 0.2736, 0.0000, 0.9790, 0.9693, 0.7325,\n",
      "        0.0000, 0.4400, 0.0000, 1.5798, 2.3526, 0.3361, 0.1626, 1.9012, 0.0000,\n",
      "        0.0000, 0.8276, 0.0000, 0.3331, 0.0000, 0.6450, 0.2161, 0.5452, 0.0986,\n",
      "        0.9602, 0.4837, 0.0357, 0.0000]) tensor([0.1499, 1.3592, 1.5156, 0.5051, 0.4104, 0.1519, 1.1748, 1.1631, 1.0988,\n",
      "        0.1413, 0.5866, 0.0285, 1.8055, 2.6887, 0.6722, 0.3252, 2.1728, 0.2716,\n",
      "        0.1558, 0.9458, 0.0698, 0.4441, 0.2150, 0.8600, 0.3241, 0.7269, 0.1971,\n",
      "        1.0974, 0.6449, 0.0713, 0.3809]) tensor([0.0724, 0.0480, 0.0306, 0.0169, 0.0261, 0.0068, 0.1833, 0.1029, 0.0090,\n",
      "        0.0596, 0.0315, 0.0140, 0.1128, 0.0080, 0.0572, 0.0853, 0.1358, 0.2303,\n",
      "        0.0732, 0.0230, 0.0462, 0.1019, 0.1668, 0.0418, 0.0775, 0.0133, 0.0034,\n",
      "        0.0058, 0.0595, 0.0130, 0.3243])\n",
      "4 tensor([0.0000, 1.1893, 1.4208, 0.2526, 0.2736, 0.0759, 0.9790, 1.0662, 0.7325,\n",
      "        0.0000, 0.4400, 0.0143, 1.6927, 2.3526, 0.3361, 0.1626, 2.0370, 0.0000,\n",
      "        0.0779, 0.8867, 0.0000, 0.3886, 0.0000, 0.7525, 0.2161, 0.6361, 0.1478,\n",
      "        0.9602, 0.5643, 0.0357, 0.0000]) tensor([0.0750, 1.2743, 1.5156, 0.3789, 0.3420, 0.1519, 1.0769, 1.1631, 0.9157,\n",
      "        0.0706, 0.5133, 0.0285, 1.8055, 2.5206, 0.5041, 0.2439, 2.1728, 0.1358,\n",
      "        0.1558, 0.9458, 0.0349, 0.4441, 0.1075, 0.8600, 0.2701, 0.7269, 0.1971,\n",
      "        1.0288, 0.6449, 0.0535, 0.1904]) tensor([0.0026, 0.0055, 0.0168, 0.1725, 0.0423, 0.0692, 0.0854, 0.0060, 0.0825,\n",
      "        0.0110, 0.0052, 0.0069, 0.0564, 0.0760, 0.1109, 0.0367, 0.0679, 0.0266,\n",
      "        0.0047, 0.0066, 0.0113, 0.0187, 0.0593, 0.0657, 0.0035, 0.0322, 0.0212,\n",
      "        0.0285, 0.0115, 0.0041, 0.1339])\n",
      "5 tensor([0.0375, 1.1893, 1.4208, 0.3157, 0.3078, 0.0759, 0.9790, 1.1147, 0.8241,\n",
      "        0.0353, 0.4766, 0.0214, 1.7491, 2.4366, 0.4201, 0.2032, 2.1049, 0.0000,\n",
      "        0.0779, 0.8867, 0.0000, 0.4163, 0.0000, 0.7525, 0.2431, 0.6361, 0.1478,\n",
      "        0.9945, 0.6046, 0.0357, 0.0000]) tensor([0.0750, 1.2318, 1.4682, 0.3789, 0.3420, 0.1139, 1.0279, 1.1631, 0.9157,\n",
      "        0.0706, 0.5133, 0.0285, 1.8055, 2.5206, 0.5041, 0.2439, 2.1728, 0.0679,\n",
      "        0.1169, 0.9162, 0.0175, 0.4441, 0.0537, 0.8062, 0.2701, 0.6815, 0.1725,\n",
      "        1.0288, 0.6449, 0.0446, 0.0952]) tensor([0.0349, 0.0157, 0.0069, 0.0778, 0.0081, 0.0312, 0.0365, 0.0425, 0.0368,\n",
      "        0.0243, 0.0131, 0.0033, 0.0282, 0.0340, 0.0268, 0.0243, 0.0340, 0.0752,\n",
      "        0.0342, 0.0082, 0.0061, 0.0230, 0.0056, 0.0120, 0.0370, 0.0095, 0.0089,\n",
      "        0.0113, 0.0086, 0.0004, 0.0387])\n",
      "6 tensor([0.0375, 1.2106, 1.4445, 0.3473, 0.3249, 0.0759, 0.9790, 1.1147, 0.8699,\n",
      "        0.0353, 0.4766, 0.0249, 1.7773, 2.4786, 0.4621, 0.2032, 2.1389, 0.0340,\n",
      "        0.0974, 0.9014, 0.0087, 0.4163, 0.0000, 0.7525, 0.2431, 0.6361, 0.1478,\n",
      "        1.0117, 0.6046, 0.0401, 0.0000]) tensor([0.0562, 1.2318, 1.4682, 0.3789, 0.3420, 0.0949, 1.0035, 1.1389, 0.9157,\n",
      "        0.0530, 0.4950, 0.0285, 1.8055, 2.5206, 0.5041, 0.2235, 2.1728, 0.0679,\n",
      "        0.1169, 0.9162, 0.0175, 0.4302, 0.0269, 0.7794, 0.2566, 0.6588, 0.1602,\n",
      "        1.0288, 0.6248, 0.0446, 0.0476]) tensor([0.0161, 0.0051, 0.0049, 0.0304, 0.0090, 0.0122, 0.0120, 0.0183, 0.0139,\n",
      "        0.0067, 0.0040, 0.0015, 0.0141, 0.0130, 0.0152, 0.0062, 0.0170, 0.0243,\n",
      "        0.0147, 0.0008, 0.0026, 0.0021, 0.0213, 0.0149, 0.0167, 0.0019, 0.0027,\n",
      "        0.0028, 0.0015, 0.0019, 0.0141])\n",
      "7 tensor([0.0375, 1.2212, 1.4445, 0.3631, 0.3249, 0.0759, 0.9790, 1.1147, 0.8928,\n",
      "        0.0353, 0.4766, 0.0267, 1.7914, 2.4996, 0.4621, 0.2134, 2.1559, 0.0509,\n",
      "        0.1071, 0.9088, 0.0087, 0.4163, 0.0134, 0.7659, 0.2431, 0.6474, 0.1478,\n",
      "        1.0202, 0.6147, 0.0401, 0.0000]) tensor([0.0469, 1.2318, 1.4564, 0.3789, 0.3335, 0.0854, 0.9912, 1.1268, 0.9157,\n",
      "        0.0441, 0.4858, 0.0285, 1.8055, 2.5206, 0.4831, 0.2235, 2.1728, 0.0679,\n",
      "        0.1169, 0.9162, 0.0131, 0.4233, 0.0269, 0.7794, 0.2498, 0.6588, 0.1540,\n",
      "        1.0288, 0.6248, 0.0424, 0.0238]) tensor([0.0068, 0.0002, 0.0010, 0.0067, 0.0005, 0.0027, 0.0002, 0.0061, 0.0024,\n",
      "        0.0022, 0.0006, 0.0006, 0.0071, 0.0025, 0.0058, 0.0091, 0.0085, 0.0012,\n",
      "        0.0050, 0.0029, 0.0018, 0.0083, 0.0079, 0.0014, 0.0066, 0.0038, 0.0003,\n",
      "        0.0015, 0.0036, 0.0008, 0.0022])\n",
      "8 tensor([0.0375, 1.2212, 1.4504, 0.3710, 0.3249, 0.0759, 0.9851, 1.1147, 0.9042,\n",
      "        0.0397, 0.4812, 0.0276, 1.7985, 2.5101, 0.4726, 0.2134, 2.1643, 0.0509,\n",
      "        0.1120, 0.9088, 0.0109, 0.4198, 0.0202, 0.7727, 0.2431, 0.6474, 0.1509,\n",
      "        1.0202, 0.6147, 0.0401, 0.0000]) tensor([0.0422, 1.2265, 1.4564, 0.3789, 0.3292, 0.0807, 0.9912, 1.1207, 0.9157,\n",
      "        0.0441, 0.4858, 0.0285, 1.8055, 2.5206, 0.4831, 0.2185, 2.1728, 0.0594,\n",
      "        0.1169, 0.9125, 0.0131, 0.4233, 0.0269, 0.7794, 0.2465, 0.6531, 0.1540,\n",
      "        1.0245, 0.6197, 0.0412, 0.0119]) tensor([2.0874e-03, 2.4520e-03, 1.9766e-03, 5.0960e-03, 3.8105e-03, 2.0456e-03,\n",
      "        5.8790e-03, 9.0659e-05, 3.3016e-03, 2.2554e-03, 1.6915e-03, 1.7995e-04,\n",
      "        3.5264e-03, 2.7734e-03, 4.6787e-03, 1.4405e-03, 4.2439e-03, 1.1574e-02,\n",
      "        1.3363e-04, 1.0299e-03, 4.1902e-04, 3.0670e-03, 1.1418e-03, 5.2712e-03,\n",
      "        1.5176e-03, 9.3770e-04, 1.2093e-03, 6.1347e-04, 1.0467e-03, 2.0564e-04,\n",
      "        3.7215e-03])\n",
      "9 tensor([0.0375, 1.2238, 1.4504, 0.3710, 0.3270, 0.0783, 0.9851, 1.1147, 0.9042,\n",
      "        0.0397, 0.4812, 0.0281, 1.8020, 2.5101, 0.4726, 0.2134, 2.1686, 0.0552,\n",
      "        0.1144, 0.9088, 0.0109, 0.4215, 0.0235, 0.7727, 0.2431, 0.6474, 0.1509,\n",
      "        1.0224, 0.6147, 0.0401, 0.0060]) tensor([0.0398, 1.2265, 1.4534, 0.3749, 0.3292, 0.0807, 0.9882, 1.1177, 0.9100,\n",
      "        0.0419, 0.4835, 0.0285, 1.8055, 2.5154, 0.4779, 0.2159, 2.1728, 0.0594,\n",
      "        0.1169, 0.9107, 0.0120, 0.4233, 0.0269, 0.7760, 0.2448, 0.6503, 0.1525,\n",
      "        1.0245, 0.6172, 0.0407, 0.0119]) tensor([2.5535e-04, 1.1247e-03, 4.9665e-04, 8.2350e-04, 1.6729e-03, 3.2759e-04,\n",
      "        2.8197e-03, 2.9384e-03, 4.4012e-04, 4.7922e-05, 5.4568e-04, 4.2737e-05,\n",
      "        1.7632e-03, 1.4772e-04, 5.7268e-04, 2.3699e-03, 2.1219e-03, 5.2085e-03,\n",
      "        2.3013e-03, 1.0632e-04, 6.7168e-04, 4.6510e-04, 2.2175e-03, 1.9118e-03,\n",
      "        1.0146e-03, 4.8208e-04, 4.3929e-04, 4.5823e-04, 2.1291e-04, 7.3001e-05,\n",
      "        7.4625e-04])\n",
      "10 tensor([0.0387, 1.2252, 1.4504, 0.3729, 0.3281, 0.0783, 0.9851, 1.1162, 0.9042,\n",
      "        0.0397, 0.4812, 0.0281, 1.8038, 2.5101, 0.4752, 0.2147, 2.1707, 0.0573,\n",
      "        0.1144, 0.9088, 0.0115, 0.4224, 0.0235, 0.7727, 0.2439, 0.6488, 0.1509,\n",
      "        1.0224, 0.6159, 0.0404, 0.0089]) tensor([0.0398, 1.2265, 1.4519, 0.3749, 0.3292, 0.0795, 0.9866, 1.1177, 0.9071,\n",
      "        0.0408, 0.4824, 0.0283, 1.8055, 2.5128, 0.4779, 0.2159, 2.1728, 0.0594,\n",
      "        0.1157, 0.9098, 0.0120, 0.4233, 0.0252, 0.7743, 0.2448, 0.6503, 0.1517,\n",
      "        1.0234, 0.6172, 0.0407, 0.0119]) tensor([9.1600e-04, 4.6101e-04, 2.4340e-04, 2.1362e-03, 6.0415e-04, 8.5902e-04,\n",
      "        1.2900e-03, 1.4240e-03, 9.9087e-04, 1.0557e-03, 2.7180e-05, 6.8605e-05,\n",
      "        8.8167e-04, 1.1653e-03, 2.0528e-03, 4.6468e-04, 1.0610e-03, 2.0251e-03,\n",
      "        1.0837e-03, 3.5550e-04, 1.2630e-04, 8.3596e-04, 5.3787e-04, 2.3210e-04,\n",
      "        2.5153e-04, 2.2781e-04, 5.4240e-05, 7.7620e-05, 4.1687e-04, 6.6310e-05,\n",
      "        7.4148e-04])\n",
      "11 tensor([0.0387, 1.2258, 1.4512, 0.3729, 0.3286, 0.0789, 0.9851, 1.1169, 0.9057,\n",
      "        0.0403, 0.4818, 0.0282, 1.8046, 2.5114, 0.4752, 0.2153, 2.1718, 0.0584,\n",
      "        0.1144, 0.9093, 0.0117, 0.4224, 0.0235, 0.7727, 0.2439, 0.6488, 0.1509,\n",
      "        1.0229, 0.6159, 0.0404, 0.0089]) tensor([0.0392, 1.2265, 1.4519, 0.3739, 0.3292, 0.0795, 0.9859, 1.1177, 0.9071,\n",
      "        0.0408, 0.4824, 0.0283, 1.8055, 2.5128, 0.4766, 0.2159, 2.1728, 0.0594,\n",
      "        0.1151, 0.9098, 0.0120, 0.4228, 0.0244, 0.7735, 0.2444, 0.6496, 0.1513,\n",
      "        1.0234, 0.6166, 0.0405, 0.0104]) tensor([3.3033e-04, 1.2913e-04, 1.2662e-04, 6.5637e-04, 6.9797e-05, 2.6584e-04,\n",
      "        5.2518e-04, 6.6680e-04, 2.7537e-04, 5.0390e-04, 2.5922e-04, 1.2934e-05,\n",
      "        4.4084e-04, 5.0865e-04, 7.4005e-04, 4.8804e-04, 5.3048e-04, 4.3392e-04,\n",
      "        4.7505e-04, 1.2459e-04, 1.4639e-04, 1.8543e-04, 3.0208e-04, 6.0773e-04,\n",
      "        3.8159e-04, 1.2714e-04, 1.3822e-04, 1.9024e-04, 1.0198e-04, 3.3528e-06,\n",
      "        2.3842e-06])\n",
      "12 tensor([0.0387, 1.2262, 1.4512, 0.3729, 0.3289, 0.0792, 0.9851, 1.1173, 0.9064,\n",
      "        0.0406, 0.4818, 0.0282, 1.8051, 2.5121, 0.4752, 0.2153, 2.1723, 0.0589,\n",
      "        0.1144, 0.9095, 0.0117, 0.4224, 0.0239, 0.7731, 0.2441, 0.6492, 0.1511,\n",
      "        1.0229, 0.6159, 0.0404, 0.0089]) tensor([0.0389, 1.2265, 1.4515, 0.3734, 0.3292, 0.0795, 0.9855, 1.1177, 0.9071,\n",
      "        0.0408, 0.4821, 0.0283, 1.8055, 2.5128, 0.4759, 0.2156, 2.1728, 0.0594,\n",
      "        0.1148, 0.9098, 0.0119, 0.4226, 0.0244, 0.7735, 0.2444, 0.6496, 0.1513,\n",
      "        1.0232, 0.6163, 0.0405, 0.0104]) tensor([3.7432e-05, 3.6806e-05, 5.8390e-05, 8.3447e-05, 1.9741e-04, 3.0994e-05,\n",
      "        1.4275e-04, 2.8819e-04, 8.2254e-05, 2.2793e-04, 1.1599e-04, 1.4901e-05,\n",
      "        2.2042e-04, 1.8059e-04, 8.3923e-05, 1.1683e-05, 2.6512e-04, 3.6192e-04,\n",
      "        1.7059e-04, 9.1381e-06, 1.0014e-05, 1.3983e-04, 1.1778e-04, 1.8775e-04,\n",
      "        6.4969e-05, 5.0306e-05, 4.1962e-05, 5.6252e-05, 5.5432e-05, 3.3528e-06,\n",
      "        2.3842e-06])\n",
      "13 tensor([0.0387, 1.2262, 1.4514, 0.3732, 0.3289, 0.0792, 0.9851, 1.1175, 0.9064,\n",
      "        0.0407, 0.4818, 0.0282, 1.8053, 2.5124, 0.4752, 0.2153, 2.1726, 0.0589,\n",
      "        0.1144, 0.9095, 0.0117, 0.4225, 0.0239, 0.7733, 0.2443, 0.6492, 0.1512,\n",
      "        1.0229, 0.6161, 0.0404, 0.0089]) tensor([0.0388, 1.2263, 1.4515, 0.3734, 0.3290, 0.0794, 0.9853, 1.1177, 0.9067,\n",
      "        0.0408, 0.4819, 0.0283, 1.8055, 2.5128, 0.4756, 0.2154, 2.1728, 0.0591,\n",
      "        0.1146, 0.9098, 0.0118, 0.4226, 0.0241, 0.7735, 0.2444, 0.6494, 0.1513,\n",
      "        1.0230, 0.6163, 0.0405, 0.0104]) tensor([1.0896e-04, 4.6164e-05, 3.4116e-05, 2.8658e-04, 6.3837e-05, 1.1730e-04,\n",
      "        4.8459e-05, 9.8884e-05, 9.6560e-05, 9.0003e-05, 4.4405e-05, 9.8348e-07,\n",
      "        1.1015e-04, 1.6555e-05, 2.4462e-04, 2.2650e-04, 1.3256e-04, 3.5763e-05,\n",
      "        1.8477e-05, 9.1381e-06, 5.8174e-05, 2.2769e-05, 9.2030e-05, 2.2173e-05,\n",
      "        9.3222e-05, 3.8385e-05, 6.1393e-06, 1.0744e-05, 2.3246e-05, 3.3528e-06,\n",
      "        2.3842e-06])\n",
      "14 tensor([0.0387, 1.2263, 1.4514, 0.3732, 0.3289, 0.0793, 0.9852, 1.1176, 0.9066,\n",
      "        0.0408, 0.4818, 0.0282, 1.8054, 2.5126, 0.4754, 0.2154, 2.1727, 0.0590,\n",
      "        0.1144, 0.9095, 0.0118, 0.4225, 0.0240, 0.7733, 0.2443, 0.6493, 0.1512,\n",
      "        1.0230, 0.6161, 0.0404, 0.0089]) tensor([0.0388, 1.2263, 1.4515, 0.3733, 0.3290, 0.0794, 0.9853, 1.1177, 0.9067,\n",
      "        0.0408, 0.4819, 0.0283, 1.8055, 2.5128, 0.4756, 0.2154, 2.1728, 0.0591,\n",
      "        0.1145, 0.9098, 0.0118, 0.4226, 0.0241, 0.7734, 0.2443, 0.6494, 0.1513,\n",
      "        1.0230, 0.6162, 0.0405, 0.0104]) tensor([3.5763e-05, 4.6790e-06, 1.2137e-05, 1.0133e-04, 2.9206e-06, 4.3154e-05,\n",
      "        4.7147e-05, 3.9935e-06, 7.1526e-06, 2.0981e-05, 8.5831e-06, 9.8348e-07,\n",
      "        5.5075e-05, 6.5461e-05, 8.0109e-05, 1.0753e-04, 6.6280e-05, 1.6308e-04,\n",
      "        5.7578e-05, 9.1381e-06, 2.4080e-05, 5.8472e-05, 1.2875e-05, 8.2850e-05,\n",
      "        1.4067e-05, 5.9605e-06, 6.1393e-06, 2.2754e-05, 1.6093e-05, 3.3528e-06,\n",
      "        2.3842e-06])\n",
      "15 tensor([0.0388, 1.2263, 1.4514, 0.3732, 0.3289, 0.0793, 0.9852, 1.1176, 0.9066,\n",
      "        0.0408, 0.4818, 0.0282, 1.8055, 2.5126, 0.4755, 0.2154, 2.1728, 0.0590,\n",
      "        0.1145, 0.9095, 0.0118, 0.4225, 0.0240, 0.7733, 0.2443, 0.6493, 0.1512,\n",
      "        1.0230, 0.6161, 0.0404, 0.0089]) tensor([0.0388, 1.2263, 1.4515, 0.3732, 0.3290, 0.0794, 0.9852, 1.1177, 0.9067,\n",
      "        0.0408, 0.4819, 0.0283, 1.8055, 2.5127, 0.4756, 0.2154, 2.1728, 0.0591,\n",
      "        0.1145, 0.9098, 0.0118, 0.4226, 0.0241, 0.7734, 0.2443, 0.6494, 0.1513,\n",
      "        1.0230, 0.6162, 0.0405, 0.0104]) tensor([8.3447e-07, 4.6790e-06, 1.0990e-05, 9.0599e-06, 2.9206e-06, 6.1989e-06,\n",
      "        6.5565e-07, 3.9935e-06, 7.1526e-06, 1.3471e-05, 8.5831e-06, 9.8348e-07,\n",
      "        2.7537e-05, 2.4453e-05, 1.9073e-06, 4.7922e-05, 3.3140e-05, 6.3419e-05,\n",
      "        1.9550e-05, 9.1381e-06, 7.0333e-06, 1.7881e-05, 3.9577e-05, 3.0398e-05,\n",
      "        2.5511e-05, 5.9605e-06, 6.1393e-06, 5.9456e-06, 3.5763e-06, 3.3528e-06,\n",
      "        2.3842e-06])\n",
      "16 tensor([0.0388, 1.2263, 1.4514, 0.3732, 0.3289, 0.0793, 0.9852, 1.1176, 0.9066,\n",
      "        0.0408, 0.4818, 0.0282, 1.8055, 2.5126, 0.4755, 0.2154, 2.1728, 0.0590,\n",
      "        0.1145, 0.9095, 0.0118, 0.4225, 0.0241, 0.7734, 0.2443, 0.6493, 0.1512,\n",
      "        1.0230, 0.6161, 0.0404, 0.0089]) tensor([0.0388, 1.2263, 1.4514, 0.3732, 0.3290, 0.0794, 0.9852, 1.1177, 0.9067,\n",
      "        0.0408, 0.4819, 0.0283, 1.8055, 2.5126, 0.4756, 0.2154, 2.1728, 0.0590,\n",
      "        0.1145, 0.9098, 0.0118, 0.4226, 0.0241, 0.7734, 0.2443, 0.6494, 0.1513,\n",
      "        1.0230, 0.6162, 0.0405, 0.0104]) tensor([8.3447e-07, 4.6790e-06, 5.7369e-07, 9.0599e-06, 2.9206e-06, 6.1989e-06,\n",
      "        6.5565e-07, 3.9935e-06, 7.1526e-06, 3.8147e-06, 8.5831e-06, 9.8348e-07,\n",
      "        1.3828e-05, 3.9488e-06, 1.9073e-06, 1.8120e-05, 1.6689e-05, 1.3828e-05,\n",
      "        5.9605e-07, 9.1381e-06, 7.0333e-06, 2.5034e-06, 1.3351e-05, 4.1723e-06,\n",
      "        5.6028e-06, 5.9605e-06, 6.1393e-06, 5.9456e-06, 3.5763e-06, 3.3528e-06,\n",
      "        2.3842e-06])\n",
      "17 tensor([0.0388, 1.2263, 1.4514, 0.3732, 0.3289, 0.0793, 0.9852, 1.1176, 0.9066,\n",
      "        0.0408, 0.4818, 0.0282, 1.8055, 2.5126, 0.4755, 0.2154, 2.1728, 0.0590,\n",
      "        0.1145, 0.9095, 0.0118, 0.4225, 0.0241, 0.7734, 0.2443, 0.6493, 0.1512,\n",
      "        1.0230, 0.6161, 0.0404, 0.0089]) tensor([0.0388, 1.2263, 1.4514, 0.3732, 0.3290, 0.0794, 0.9852, 1.1177, 0.9067,\n",
      "        0.0408, 0.4819, 0.0283, 1.8055, 2.5126, 0.4756, 0.2154, 2.1728, 0.0590,\n",
      "        0.1145, 0.9098, 0.0118, 0.4226, 0.0241, 0.7734, 0.2443, 0.6494, 0.1513,\n",
      "        1.0230, 0.6162, 0.0405, 0.0104]) tensor([8.3447e-07, 4.6790e-06, 5.7369e-07, 9.0599e-06, 2.9206e-06, 6.1989e-06,\n",
      "        6.5565e-07, 3.9935e-06, 7.1526e-06, 3.8147e-06, 8.5831e-06, 9.8348e-07,\n",
      "        6.9141e-06, 3.9488e-06, 1.9073e-06, 3.0994e-06, 8.3447e-06, 1.0967e-05,\n",
      "        5.9605e-07, 9.1381e-06, 7.0333e-06, 2.5034e-06, 2.3842e-07, 4.1723e-06,\n",
      "        5.6028e-06, 5.9605e-06, 6.1393e-06, 5.9456e-06, 3.5763e-06, 3.3528e-06,\n",
      "        2.3842e-06])\n",
      "18 tensor([0.0388, 1.2263, 1.4514, 0.3732, 0.3289, 0.0793, 0.9852, 1.1176, 0.9066,\n",
      "        0.0408, 0.4818, 0.0282, 1.8055, 2.5126, 0.4755, 0.2154, 2.1728, 0.0590,\n",
      "        0.1145, 0.9095, 0.0118, 0.4225, 0.0241, 0.7734, 0.2443, 0.6493, 0.1512,\n",
      "        1.0230, 0.6161, 0.0404, 0.0089]) tensor([0.0388, 1.2263, 1.4514, 0.3732, 0.3290, 0.0794, 0.9852, 1.1177, 0.9067,\n",
      "        0.0408, 0.4819, 0.0283, 1.8055, 2.5126, 0.4756, 0.2154, 2.1728, 0.0590,\n",
      "        0.1145, 0.9098, 0.0118, 0.4226, 0.0241, 0.7734, 0.2443, 0.6494, 0.1513,\n",
      "        1.0230, 0.6162, 0.0405, 0.0104]) tensor([8.3447e-07, 4.6790e-06, 5.7369e-07, 9.0599e-06, 2.9206e-06, 6.1989e-06,\n",
      "        6.5565e-07, 3.9935e-06, 7.1526e-06, 3.8147e-06, 8.5831e-06, 9.8348e-07,\n",
      "        6.9141e-06, 3.9488e-06, 1.9073e-06, 3.0994e-06, 8.3447e-06, 1.4305e-06,\n",
      "        5.9605e-07, 9.1381e-06, 7.0333e-06, 2.5034e-06, 2.3842e-07, 4.1723e-06,\n",
      "        5.6028e-06, 5.9605e-06, 6.1393e-06, 5.9456e-06, 3.5763e-06, 3.3528e-06,\n",
      "        2.3842e-06])\n",
      "tensor([[-8.3447e-07, -7.5481e-01, -7.0244e+00, -1.4021e+00, -2.8461e+00],\n",
      "        [-1.6547e+00,  4.6790e-06, -1.4965e+00, -6.1568e+00, -4.9540e+00],\n",
      "        [-4.3722e+00,  5.7369e-07, -5.0638e-01, -5.0427e+00, -4.4425e+00],\n",
      "        [-8.3020e+00, -7.9336e-02, -3.1792e+00, -2.8010e+00, -3.2894e-01],\n",
      "        [-2.4881e-01,  0.0000e+00, -1.1639e-01, -7.6411e+00, -5.0024e+00],\n",
      "        [ 7.1526e-06, -1.3183e+00, -2.7326e+00, -1.6718e+00, -5.8877e-01],\n",
      "        [-4.2016e+00, -1.8324e+00,  3.8147e-06, -2.6977e+00, -3.3056e-01],\n",
      "        [-2.4895e+00, -3.3200e+00, -9.6365e-01, -4.4090e+00, -2.6515e+00],\n",
      "        [-2.8894e+00, -7.5235e+00, -3.3560e+00, -8.5831e-06, -9.8348e-07],\n",
      "        [ 0.0000e+00, -1.7607e-01, -2.5593e+00, -2.2131e+00, -3.5420e+00],\n",
      "        [-5.8853e-01, -2.5518e-01, -3.3200e+00, -4.9517e-01, -4.3943e+00],\n",
      "        [-4.8934e+00, -4.3292e-01,  0.0000e+00, -3.1885e+00, -2.1138e+00],\n",
      "        [-4.5791e+00, -8.9223e+00, -5.8619e+00,  5.9605e-07, -3.0000e+00],\n",
      "        [ 9.1381e-06, -1.1944e-01, -1.4099e+00, -2.5447e-01, -2.4446e+00],\n",
      "        [-5.8631e+00, -3.0699e+00, -7.7337e-01, -8.3246e+00, -2.4082e-02],\n",
      "        [-4.7133e-01, -7.1172e-01, -1.3204e+00, -2.5302e+00, -3.8653e+00],\n",
      "        [-5.2031e+00, -5.5754e+00, -1.4295e-01, -5.9605e-06, -2.2769e+00],\n",
      "        [-6.3720e-01, -1.3122e-01, -5.9456e-06, -3.5375e+00, -2.0687e+00],\n",
      "        [-5.7768e+00, -4.0000e+00, -3.5763e-06, -3.7934e+00, -2.4485e-01],\n",
      "        [-1.5524e+00, -4.7500e+00, -5.1674e-01,  2.3842e-06, -3.6622e+00]])\n"
     ]
    }
   ],
   "source": [
    "valencies = [4., 3., 2., 1.]\n",
    "proj_xs, proj_adjs = project (xs, adjs, DotDict({'constraint': 'Valency', 'params': [valencies]}))\n",
    "print (proj_adjs.sum(dim=2) - proj_xs @ torch.tensor(valencies, device=proj_xs.device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) tensor([2.5682, 1.3075, 1.8591, 2.2448, 1.5582, 0.0727, 1.2715, 1.2243, 0.6909,\n",
      "        1.7386, 1.3722, 2.4106, 0.4530, 0.7466, 1.0107, 1.2039, 1.3088, 1.1813,\n",
      "        0.6102, 1.9400, 1.4674, 1.8218, 2.4602, 1.0999, 1.0745, 0.0113, 0.0523,\n",
      "        1.7546, 2.2136, 0.0801, 1.0620, 1.5489, 1.2275, 1.1978]) tensor([1.0000, 1.0579, 0.9295, 1.0000, 1.2163, 0.0364, 0.8646, 0.9358, 0.9224,\n",
      "        0.8693, 1.1387, 0.8615, 0.2265, 0.3733, 1.1566, 1.1442, 1.0553, 0.5906,\n",
      "        0.3051, 0.9077, 1.0830, 0.9109, 0.5230, 0.5499, 1.2575, 0.0057, 0.0261,\n",
      "        1.1227, 0.7821, 0.0400, 0.8097, 0.6291, 0.6138, 0.4947])\n",
      "1 tensor([0.0000, 0.6537, 0.9295, 1.1224, 0.7791, 0.0364, 0.0000, 0.0000, 0.0000,\n",
      "        0.8693, 0.0000, 1.2053, 0.2265, 0.3733, 0.0000, 0.6020, 0.0000, 0.5906,\n",
      "        0.3051, 0.0000, 0.7337, 0.9109, 0.0000, 0.5499, 0.0000, 0.0057, 0.0261,\n",
      "        0.0000, 0.0000, 0.0400, 0.0000, 0.0000, 0.6138, 0.0000]) tensor([1.2841, 1.3075, 1.8591, 2.2448, 1.5582, 0.0727, 0.6357, 0.6121, 0.3454,\n",
      "        1.7386, 0.6861, 2.4106, 0.4530, 0.7466, 0.5053, 1.2039, 0.6544, 1.1813,\n",
      "        0.6102, 0.9700, 1.4674, 1.8218, 1.2301, 1.0999, 0.5373, 0.0113, 0.0523,\n",
      "        0.8773, 1.1068, 0.0801, 0.5310, 0.7744, 1.2275, 0.5989]) tensor([0.8576, 0.4041, 0.4648, 0.5612, 0.3970, 0.0182, 0.4069, 0.0176, 0.0588,\n",
      "        0.4347, 0.0877, 0.7412, 0.1133, 0.1867, 0.6513, 0.5422, 0.2716, 0.2953,\n",
      "        0.1526, 0.3542, 0.3669, 0.4555, 0.4331, 0.2750, 0.4516, 0.0028, 0.0131,\n",
      "        0.5703, 0.2287, 0.0200, 0.1379, 0.6319, 0.3069, 0.4036])\n",
      "2 tensor([0.0000, 0.9806, 1.3943, 1.6836, 1.1687, 0.0545, 0.3179, 0.0000, 0.0000,\n",
      "        1.3040, 0.0000, 1.2053, 0.3398, 0.5600, 0.0000, 0.9029, 0.0000, 0.8859,\n",
      "        0.4577, 0.0000, 1.1006, 1.3664, 0.6150, 0.8249, 0.0000, 0.0085, 0.0392,\n",
      "        0.0000, 0.0000, 0.0601, 0.2655, 0.3872, 0.9206, 0.2994]) tensor([0.6421, 1.3075, 1.8591, 2.2448, 1.5582, 0.0727, 0.6357, 0.3061, 0.1727,\n",
      "        1.7386, 0.3430, 1.8079, 0.4530, 0.7466, 0.2527, 1.2039, 0.3272, 1.1813,\n",
      "        0.6102, 0.4850, 1.4674, 1.8218, 1.2301, 1.0999, 0.2686, 0.0113, 0.0523,\n",
      "        0.4386, 0.5534, 0.0801, 0.5310, 0.7744, 1.2275, 0.5989]) tensor([0.5366, 0.1634, 0.2324, 0.2806, 0.1948, 0.0091, 0.2288, 0.3702, 0.3730,\n",
      "        0.2173, 0.4269, 0.1628, 0.0566, 0.0933, 0.2917, 0.2412, 0.0740, 0.1477,\n",
      "        0.0763, 0.1308, 0.1834, 0.2277, 0.0899, 0.1375, 0.0487, 0.0014, 0.0065,\n",
      "        0.3213, 0.0000, 0.0100, 0.3931, 0.0192, 0.1534, 0.0456])\n",
      "3 tensor([0.0000, 1.1441, 1.6267, 1.9642, 1.3634, 0.0636, 0.3179, 0.1530, 0.0864,\n",
      "        1.5213, 0.1715, 1.5066, 0.3964, 0.6533, 0.0000, 1.0534, 0.1636, 1.0336,\n",
      "        0.5339, 0.2425, 1.2840, 1.5941, 0.6150, 0.9624, 0.0000, 0.0099, 0.0457,\n",
      "        0.0000, 0.0000, 0.0701, 0.2655, 0.5808, 1.0741, 0.2994]) tensor([0.3210, 1.3075, 1.8591, 2.2448, 1.5582, 0.0727, 0.4768, 0.3061, 0.1727,\n",
      "        1.7386, 0.3430, 1.8079, 0.4530, 0.7466, 0.1263, 1.2039, 0.3272, 1.1813,\n",
      "        0.6102, 0.4850, 1.4674, 1.8218, 0.9226, 1.0999, 0.1343, 0.0113, 0.0523,\n",
      "        0.2193, 0.5534, 0.0801, 0.3982, 0.7744, 1.2275, 0.4492]) tensor([0.3195, 0.0817, 0.1162, 0.1403, 0.0974, 0.0045, 0.0890, 0.2119, 0.1571,\n",
      "        0.1087, 0.1696, 0.2892, 0.0283, 0.0467, 0.1022, 0.0907, 0.0896, 0.0738,\n",
      "        0.0381, 0.1117, 0.0917, 0.1139, 0.2177, 0.0687, 0.1454, 0.0007, 0.0033,\n",
      "        0.1020, 0.0000, 0.0050, 0.1276, 0.3387, 0.0767, 0.1790])\n",
      "4 tensor([0.0000, 1.2258, 1.7429, 2.1045, 1.4608, 0.0682, 0.3973, 0.2296, 0.1295,\n",
      "        1.6300, 0.2573, 1.5066, 0.4247, 0.7000, 0.0000, 1.1287, 0.1636, 1.1074,\n",
      "        0.5721, 0.2425, 1.3757, 1.7080, 0.7688, 1.0312, 0.0672, 0.0106, 0.0490,\n",
      "        0.0000, 0.0000, 0.0751, 0.2655, 0.5808, 1.1508, 0.3743]) tensor([0.1605, 1.3075, 1.8591, 2.2448, 1.5582, 0.0727, 0.4768, 0.3061, 0.1727,\n",
      "        1.7386, 0.3430, 1.6573, 0.4530, 0.7466, 0.0632, 1.2039, 0.2454, 1.1813,\n",
      "        0.6102, 0.3638, 1.4674, 1.8218, 0.9226, 1.0999, 0.1343, 0.0113, 0.0523,\n",
      "        0.1097, 0.5534, 0.0801, 0.3319, 0.6776, 1.2275, 0.4492]) tensor([0.1590, 0.0409, 0.0581, 0.0701, 0.0487, 0.0023, 0.0699, 0.0971, 0.0492,\n",
      "        0.0543, 0.0409, 0.0632, 0.0142, 0.0233, 0.0074, 0.0376, 0.0078, 0.0369,\n",
      "        0.0191, 0.0096, 0.0459, 0.0569, 0.0639, 0.0344, 0.0521, 0.0004, 0.0016,\n",
      "        0.0077, 0.0000, 0.0025, 0.0052, 0.1745, 0.0384, 0.0667])\n",
      "5 tensor([0.0000, 1.2666, 1.8010, 2.1746, 1.5095, 0.0705, 0.3973, 0.2678, 0.1511,\n",
      "        1.6843, 0.3002, 1.5066, 0.4389, 0.7233, 0.0000, 1.1663, 0.1636, 1.1443,\n",
      "        0.5911, 0.3031, 1.4216, 1.7649, 0.8457, 1.0655, 0.1007, 0.0110, 0.0506,\n",
      "        0.0548, 0.0000, 0.0776, 0.2987, 0.5808, 1.1892, 0.4117]) tensor([0.0803, 1.3075, 1.8591, 2.2448, 1.5582, 0.0727, 0.4371, 0.3061, 0.1727,\n",
      "        1.7386, 0.3430, 1.5819, 0.4530, 0.7466, 0.0316, 1.2039, 0.2045, 1.1813,\n",
      "        0.6102, 0.3638, 1.4674, 1.8218, 0.9226, 1.0999, 0.1343, 0.0113, 0.0523,\n",
      "        0.1097, 0.5534, 0.0801, 0.3319, 0.6292, 1.2275, 0.4492]) tensor([0.0787, 0.0204, 0.0290, 0.0351, 0.0243, 0.0011, 0.0096, 0.0397, 0.0048,\n",
      "        0.0272, 0.0234, 0.0498, 0.0071, 0.0117, 0.0400, 0.0188, 0.0331, 0.0185,\n",
      "        0.0095, 0.0511, 0.0229, 0.0285, 0.0130, 0.0172, 0.0017, 0.0002, 0.0008,\n",
      "        0.0471, 0.0000, 0.0013, 0.0612, 0.0776, 0.0192, 0.0106])\n",
      "6 tensor([0.0000, 1.2871, 1.8300, 2.2097, 1.5339, 0.0716, 0.4172, 0.2869, 0.1511,\n",
      "        1.7115, 0.3002, 1.5443, 0.4459, 0.7350, 0.0158, 1.1851, 0.1840, 1.1628,\n",
      "        0.6007, 0.3031, 1.4445, 1.7934, 0.8457, 1.0827, 0.1175, 0.0111, 0.0514,\n",
      "        0.0548, 0.0000, 0.0788, 0.2987, 0.5808, 1.2083, 0.4304]) tensor([0.0401, 1.3075, 1.8591, 2.2448, 1.5582, 0.0727, 0.4371, 0.3061, 0.1619,\n",
      "        1.7386, 0.3216, 1.5819, 0.4530, 0.7466, 0.0316, 1.2039, 0.2045, 1.1813,\n",
      "        0.6102, 0.3334, 1.4674, 1.8218, 0.8841, 1.0999, 0.1343, 0.0113, 0.0523,\n",
      "        0.0822, 0.5534, 0.0801, 0.3153, 0.6050, 1.2275, 0.4492]) tensor([3.8611e-02, 1.0215e-02, 1.4524e-02, 1.7537e-02, 1.2173e-02, 5.6819e-04,\n",
      "        3.0181e-02, 1.1052e-02, 2.2193e-02, 1.3583e-02, 8.7678e-03, 6.7012e-03,\n",
      "        3.5392e-03, 5.8331e-03, 1.6263e-02, 9.4057e-03, 1.2629e-02, 9.2286e-03,\n",
      "        4.7672e-03, 2.0749e-02, 1.1464e-02, 1.4233e-02, 2.5467e-02, 8.5930e-03,\n",
      "        2.3492e-02, 8.8390e-05, 4.0825e-04, 1.9713e-02, 0.0000e+00, 6.2569e-04,\n",
      "        2.8020e-02, 2.9247e-02, 9.5900e-03, 1.7500e-02])\n",
      "7 tensor([0.0000, 1.2973, 1.8446, 2.2272, 1.5460, 0.0722, 0.4172, 0.2965, 0.1565,\n",
      "        1.7251, 0.3109, 1.5443, 0.4495, 0.7408, 0.0237, 1.1945, 0.1943, 1.1720,\n",
      "        0.6054, 0.3031, 1.4560, 1.8076, 0.8649, 1.0913, 0.1175, 0.0112, 0.0518,\n",
      "        0.0548, 0.0000, 0.0795, 0.2987, 0.5808, 1.2179, 0.4304]) tensor([0.0201, 1.3075, 1.8591, 2.2448, 1.5582, 0.0727, 0.4271, 0.3061, 0.1619,\n",
      "        1.7386, 0.3216, 1.5631, 0.4530, 0.7466, 0.0316, 1.2039, 0.2045, 1.1813,\n",
      "        0.6102, 0.3183, 1.4674, 1.8218, 0.8841, 1.0999, 0.1259, 0.0113, 0.0523,\n",
      "        0.0685, 0.5534, 0.0801, 0.3070, 0.5929, 1.2275, 0.4398]) tensor([1.8548e-02, 5.1074e-03, 7.2621e-03, 8.7688e-03, 6.0867e-03, 2.8409e-04,\n",
      "        1.0314e-02, 3.2945e-03, 8.6997e-03, 6.7915e-03, 7.3123e-03, 2.1548e-02,\n",
      "        1.7696e-03, 2.9166e-03, 4.4186e-03, 4.7029e-03, 2.4037e-03, 4.6144e-03,\n",
      "        2.3836e-03, 5.5926e-03, 5.7322e-03, 7.1167e-03, 6.2468e-03, 4.2965e-03,\n",
      "        1.0900e-02, 4.4195e-05, 2.0412e-04, 6.0053e-03, 0.0000e+00, 3.1284e-04,\n",
      "        1.1426e-02, 5.0460e-03, 4.7950e-03, 3.4637e-03])\n",
      "8 tensor([0.0000, 1.3024, 1.8518, 2.2360, 1.5521, 0.0724, 0.4172, 0.2965, 0.1592,\n",
      "        1.7318, 0.3109, 1.5537, 0.4512, 0.7437, 0.0276, 1.1992, 0.1994, 1.1766,\n",
      "        0.6078, 0.3031, 1.4617, 1.8147, 0.8745, 1.0956, 0.1175, 0.0113, 0.0521,\n",
      "        0.0548, 0.0000, 0.0798, 0.2987, 0.5808, 1.2227, 0.4304]) tensor([0.0100, 1.3075, 1.8591, 2.2448, 1.5582, 0.0727, 0.4222, 0.3013, 0.1619,\n",
      "        1.7386, 0.3162, 1.5631, 0.4530, 0.7466, 0.0316, 1.2039, 0.2045, 1.1813,\n",
      "        0.6102, 0.3107, 1.4674, 1.8218, 0.8841, 1.0999, 0.1217, 0.0113, 0.0523,\n",
      "        0.0617, 0.5534, 0.0801, 0.3028, 0.5869, 1.2275, 0.4351]) tensor([8.5154e-03, 2.5537e-03, 3.6310e-03, 4.3843e-03, 3.0434e-03, 1.4205e-04,\n",
      "        3.8111e-04, 3.8791e-03, 1.9529e-03, 3.3957e-03, 7.2789e-04, 7.4234e-03,\n",
      "        8.8480e-04, 1.4583e-03, 1.5033e-03, 2.3514e-03, 2.7086e-03, 2.3072e-03,\n",
      "        1.1918e-03, 1.9858e-03, 2.8661e-03, 3.5583e-03, 3.3634e-03, 2.1483e-03,\n",
      "        4.6037e-03, 2.2097e-05, 1.0206e-04, 8.4829e-04, 0.0000e+00, 1.5642e-04,\n",
      "        3.1297e-03, 7.0546e-03, 2.3974e-03, 3.5543e-03])\n",
      "9 tensor([0.0000, 1.3049, 1.8555, 2.2404, 1.5552, 0.0726, 0.4172, 0.2989, 0.1606,\n",
      "        1.7352, 0.3136, 1.5584, 0.4521, 0.7452, 0.0276, 1.2016, 0.1994, 1.1790,\n",
      "        0.6090, 0.3069, 1.4646, 1.8183, 0.8745, 1.0977, 0.1175, 0.0113, 0.0522,\n",
      "        0.0583, 0.0000, 0.0799, 0.2987, 0.5839, 1.2251, 0.4328]) tensor([0.0050, 1.3075, 1.8591, 2.2448, 1.5582, 0.0727, 0.4197, 0.3013, 0.1619,\n",
      "        1.7386, 0.3162, 1.5631, 0.4530, 0.7466, 0.0296, 1.2039, 0.2019, 1.1813,\n",
      "        0.6102, 0.3107, 1.4674, 1.8218, 0.8793, 1.0999, 0.1196, 0.0113, 0.0523,\n",
      "        0.0617, 0.5534, 0.0801, 0.3008, 0.5869, 1.2275, 0.4351]) tensor([3.4995e-03, 1.2769e-03, 1.8154e-03, 2.1923e-03, 1.5217e-03, 7.1019e-05,\n",
      "        4.5855e-03, 2.9230e-04, 1.4207e-03, 1.6979e-03, 3.2923e-03, 3.6120e-04,\n",
      "        4.4242e-04, 7.2914e-04, 1.4577e-03, 1.1758e-03, 1.5235e-04, 1.1536e-03,\n",
      "        5.9593e-04, 1.8034e-03, 1.4330e-03, 1.7792e-03, 1.4417e-03, 1.0741e-03,\n",
      "        1.4558e-03, 1.1049e-05, 5.1029e-05, 2.5785e-03, 0.0000e+00, 7.8209e-05,\n",
      "        1.0188e-03, 1.0045e-03, 1.1986e-03, 4.5300e-05])\n",
      "10 tensor([0.0000, 1.3062, 1.8573, 2.2426, 1.5567, 0.0727, 0.4184, 0.3001, 0.1606,\n",
      "        1.7369, 0.3136, 1.5608, 0.4526, 0.7459, 0.0286, 1.2028, 0.1994, 1.1801,\n",
      "        0.6096, 0.3069, 1.4660, 1.8201, 0.8769, 1.0988, 0.1175, 0.0113, 0.0522,\n",
      "        0.0583, 0.0000, 0.0800, 0.2997, 0.5854, 1.2263, 0.4340]) tensor([0.0025, 1.3075, 1.8591, 2.2448, 1.5582, 0.0727, 0.4197, 0.3013, 0.1612,\n",
      "        1.7386, 0.3149, 1.5631, 0.4530, 0.7466, 0.0296, 1.2039, 0.2007, 1.1813,\n",
      "        0.6102, 0.3088, 1.4674, 1.8218, 0.8793, 1.0999, 0.1186, 0.0113, 0.0523,\n",
      "        0.0600, 0.5534, 0.0801, 0.3008, 0.5869, 1.2275, 0.4351]) tensor([9.9134e-04, 6.3848e-04, 9.0778e-04, 1.0960e-03, 7.6079e-04, 3.5509e-05,\n",
      "        2.1021e-03, 1.5012e-03, 2.6608e-04, 8.4889e-04, 1.2822e-03, 3.1700e-03,\n",
      "        2.2122e-04, 3.6454e-04, 2.2769e-05, 5.8782e-04, 1.1258e-03, 5.7673e-04,\n",
      "        2.9796e-04, 9.1314e-05, 7.1657e-04, 8.8966e-04, 9.6083e-04, 5.3704e-04,\n",
      "        1.1826e-04, 5.5246e-06, 2.5515e-05, 8.6498e-04, 0.0000e+00, 3.9108e-05,\n",
      "        1.0555e-03, 2.0206e-03, 5.9927e-04, 1.7092e-03])\n",
      "11 tensor([0.0000, 1.3069, 1.8582, 2.2437, 1.5574, 0.0727, 0.4191, 0.3001, 0.1609,\n",
      "        1.7378, 0.3136, 1.5608, 0.4528, 0.7463, 0.0286, 1.2033, 0.2000, 1.1807,\n",
      "        0.6099, 0.3079, 1.4667, 1.8210, 0.8769, 1.0994, 0.1180, 0.0113, 0.0522,\n",
      "        0.0583, 0.0000, 0.0800, 0.2997, 0.5854, 1.2269, 0.4340]) tensor([1.2540e-03, 1.3075e+00, 1.8591e+00, 2.2448e+00, 1.5582e+00, 7.2728e-02,\n",
      "        4.1968e-01, 3.0069e-01, 1.6125e-01, 1.7386e+00, 3.1423e-01, 1.5619e+00,\n",
      "        4.5302e-01, 7.4664e-01, 2.9117e-02, 1.2039e+00, 2.0067e-01, 1.1813e+00,\n",
      "        6.1020e-01, 3.0882e-01, 1.4674e+00, 1.8218e+00, 8.7813e-01, 1.0999e+00,\n",
      "        1.1857e-01, 1.1314e-02, 5.2256e-02, 5.9114e-02, 5.5340e-01, 8.0088e-02,\n",
      "        3.0024e-01, 5.8612e-01, 1.2275e+00, 4.3454e-01]) tensor([2.6274e-04, 3.1924e-04, 4.5383e-04, 5.4812e-04, 3.8040e-04, 1.7755e-05,\n",
      "        8.6045e-04, 6.0439e-04, 5.7721e-04, 4.2450e-04, 2.7716e-04, 1.4044e-03,\n",
      "        1.1060e-04, 1.8227e-04, 7.1740e-04, 2.9385e-04, 4.8685e-04, 2.8837e-04,\n",
      "        1.4901e-04, 8.5604e-04, 3.5822e-04, 4.4477e-04, 2.4033e-04, 2.6846e-04,\n",
      "        6.6876e-04, 5.5246e-06, 1.2755e-05, 8.3447e-06, 0.0000e+00, 1.9558e-05,\n",
      "        1.8358e-05, 5.0807e-04, 2.9957e-04, 8.3208e-04])\n",
      "12 tensor([6.2700e-04, 1.3072e+00, 1.8586e+00, 2.2442e+00, 1.5578e+00, 7.2711e-02,\n",
      "        4.1937e-01, 3.0009e-01, 1.6091e-01, 1.7382e+00, 3.1356e-01, 1.5608e+00,\n",
      "        4.5291e-01, 7.4645e-01, 2.8870e-02, 1.2036e+00, 2.0035e-01, 1.1810e+00,\n",
      "        6.1005e-01, 3.0787e-01, 1.4671e+00, 1.8214e+00, 8.7753e-01, 1.0996e+00,\n",
      "        1.1805e-01, 1.1303e-02, 5.2243e-02, 5.8257e-02, 0.0000e+00, 8.0069e-02,\n",
      "        2.9972e-01, 5.8537e-01, 1.2272e+00, 4.3395e-01]) tensor([1.2540e-03, 1.3075e+00, 1.8591e+00, 2.2448e+00, 1.5582e+00, 7.2728e-02,\n",
      "        4.1968e-01, 3.0039e-01, 1.6108e-01, 1.7386e+00, 3.1390e-01, 1.5613e+00,\n",
      "        4.5302e-01, 7.4664e-01, 2.9117e-02, 1.2039e+00, 2.0067e-01, 1.1813e+00,\n",
      "        6.1020e-01, 3.0834e-01, 1.4674e+00, 1.8218e+00, 8.7813e-01, 1.0999e+00,\n",
      "        1.1831e-01, 1.1314e-02, 5.2256e-02, 5.9114e-02, 5.5340e-01, 8.0088e-02,\n",
      "        2.9998e-01, 5.8574e-01, 1.2275e+00, 4.3424e-01]) tensor([3.6430e-04, 1.5962e-04, 2.2686e-04, 2.7394e-04, 1.9026e-04, 8.8811e-06,\n",
      "        2.3961e-04, 1.5616e-04, 1.5569e-04, 2.1231e-04, 2.2531e-04, 5.2142e-04,\n",
      "        5.5283e-05, 9.1136e-05, 3.4738e-04, 1.4699e-04, 1.6737e-04, 1.4424e-04,\n",
      "        7.4506e-05, 3.8242e-04, 1.7917e-04, 2.2233e-04, 3.6025e-04, 1.3423e-04,\n",
      "        2.7514e-04, 5.5246e-06, 6.3777e-06, 8.3447e-06, 0.0000e+00, 9.7826e-06,\n",
      "        5.0020e-04, 2.4819e-04, 1.4985e-04, 3.9327e-04])\n",
      "13 tensor([6.2700e-04, 1.3073e+00, 1.8589e+00, 2.2445e+00, 1.5580e+00, 7.2711e-02,\n",
      "        4.1953e-01, 3.0009e-01, 1.6091e-01, 1.7384e+00, 3.1373e-01, 1.5608e+00,\n",
      "        4.5296e-01, 7.4655e-01, 2.8993e-02, 1.2038e+00, 2.0051e-01, 1.1811e+00,\n",
      "        6.1013e-01, 3.0787e-01, 1.4673e+00, 1.8216e+00, 8.7753e-01, 1.0998e+00,\n",
      "        1.1805e-01, 1.1303e-02, 5.2243e-02, 5.8257e-02, 0.0000e+00, 8.0069e-02,\n",
      "        2.9985e-01, 5.8555e-01, 1.2274e+00, 4.3395e-01]) tensor([9.4050e-04, 1.3075e+00, 1.8591e+00, 2.2448e+00, 1.5582e+00, 7.2728e-02,\n",
      "        4.1968e-01, 3.0024e-01, 1.6100e-01, 1.7386e+00, 3.1390e-01, 1.5610e+00,\n",
      "        4.5302e-01, 7.4664e-01, 2.9117e-02, 1.2039e+00, 2.0067e-01, 1.1813e+00,\n",
      "        6.1020e-01, 3.0811e-01, 1.4674e+00, 1.8218e+00, 8.7783e-01, 1.0999e+00,\n",
      "        1.1818e-01, 1.1314e-02, 5.2256e-02, 5.9114e-02, 5.5340e-01, 8.0088e-02,\n",
      "        2.9998e-01, 5.8574e-01, 1.2275e+00, 4.3410e-01]) tensor([5.1022e-05, 7.9870e-05, 1.1337e-04, 1.3709e-04, 9.5129e-05, 8.8811e-06,\n",
      "        7.0810e-05, 6.8188e-05, 5.5313e-05, 1.0622e-04, 2.5749e-05, 8.0109e-05,\n",
      "        2.7627e-05, 4.5538e-05, 1.6212e-04, 7.3552e-05, 7.3910e-06, 7.2122e-05,\n",
      "        3.7253e-05, 1.4555e-04, 8.9526e-05, 1.1122e-04, 5.9843e-05, 6.7115e-05,\n",
      "        7.8440e-05, 5.5246e-06, 6.3777e-06, 8.3447e-06, 0.0000e+00, 9.7826e-06,\n",
      "        2.4104e-04, 1.3018e-04, 7.4983e-05, 1.7405e-04])\n",
      "14 tensor([6.2700e-04, 1.3074e+00, 1.8590e+00, 2.2446e+00, 1.5581e+00, 7.2711e-02,\n",
      "        4.1953e-01, 3.0016e-01, 1.6095e-01, 1.7385e+00, 3.1373e-01, 1.5608e+00,\n",
      "        4.5299e-01, 7.4659e-01, 2.9055e-02, 1.2039e+00, 2.0051e-01, 1.1812e+00,\n",
      "        6.1016e-01, 3.0787e-01, 1.4674e+00, 1.8217e+00, 8.7753e-01, 1.0998e+00,\n",
      "        1.1805e-01, 1.1303e-02, 5.2243e-02, 5.8257e-02, 0.0000e+00, 8.0069e-02,\n",
      "        2.9991e-01, 5.8555e-01, 1.2274e+00, 4.3395e-01]) tensor([7.8375e-04, 1.3075e+00, 1.8591e+00, 2.2448e+00, 1.5582e+00, 7.2728e-02,\n",
      "        4.1961e-01, 3.0024e-01, 1.6100e-01, 1.7386e+00, 3.1381e-01, 1.5609e+00,\n",
      "        4.5302e-01, 7.4664e-01, 2.9117e-02, 1.2039e+00, 2.0067e-01, 1.1813e+00,\n",
      "        6.1020e-01, 3.0799e-01, 1.4674e+00, 1.8218e+00, 8.7768e-01, 1.0999e+00,\n",
      "        1.1811e-01, 1.1314e-02, 5.2256e-02, 5.9114e-02, 5.5340e-01, 8.0088e-02,\n",
      "        2.9998e-01, 5.8565e-01, 1.2275e+00, 4.3402e-01]) tensor([1.0586e-04, 3.9935e-05, 5.6624e-05, 6.8426e-05, 4.7565e-05, 8.8811e-06,\n",
      "        8.4400e-05, 4.3869e-05, 5.0306e-05, 5.3048e-05, 9.9659e-05, 1.4067e-04,\n",
      "        1.3798e-05, 2.2769e-05, 6.9857e-05, 3.6836e-05, 7.3910e-06, 3.6001e-05,\n",
      "        1.8597e-05, 2.7180e-05, 4.4703e-05, 5.5671e-05, 9.0361e-05, 3.3617e-05,\n",
      "        1.9789e-05, 5.5246e-06, 6.3777e-06, 8.3447e-06, 0.0000e+00, 9.7826e-06,\n",
      "        1.1134e-04, 5.9128e-05, 3.7551e-05, 6.4254e-05])\n",
      "15 tensor([7.0538e-04, 1.3075e+00, 1.8590e+00, 2.2447e+00, 1.5582e+00, 7.2711e-02,\n",
      "        4.1957e-01, 3.0016e-01, 1.6095e-01, 1.7386e+00, 3.1377e-01, 1.5608e+00,\n",
      "        4.5300e-01, 7.4661e-01, 2.9086e-02, 1.2039e+00, 2.0051e-01, 1.1812e+00,\n",
      "        6.1018e-01, 3.0787e-01, 1.4674e+00, 1.8218e+00, 8.7760e-01, 1.0999e+00,\n",
      "        1.1808e-01, 1.1303e-02, 5.2243e-02, 5.8257e-02, 0.0000e+00, 8.0069e-02,\n",
      "        2.9994e-01, 5.8560e-01, 1.2275e+00, 4.3395e-01]) tensor([7.8375e-04, 1.3075e+00, 1.8591e+00, 2.2448e+00, 1.5582e+00, 7.2728e-02,\n",
      "        4.1961e-01, 3.0020e-01, 1.6097e-01, 1.7386e+00, 3.1381e-01, 1.5609e+00,\n",
      "        4.5302e-01, 7.4664e-01, 2.9117e-02, 1.2039e+00, 2.0067e-01, 1.1813e+00,\n",
      "        6.1020e-01, 3.0793e-01, 1.4674e+00, 1.8218e+00, 8.7768e-01, 1.0999e+00,\n",
      "        1.1811e-01, 1.1314e-02, 5.2256e-02, 5.9114e-02, 5.5340e-01, 8.0088e-02,\n",
      "        2.9998e-01, 5.8565e-01, 1.2275e+00, 4.3399e-01]) tensor([2.7418e-05, 2.0027e-05, 2.8253e-05, 3.4094e-05, 2.3842e-05, 8.8811e-06,\n",
      "        6.9141e-06, 1.1921e-05, 2.6226e-06, 2.6584e-05, 3.6955e-05, 3.0041e-05,\n",
      "        6.8843e-06, 1.1384e-05, 2.3365e-05, 1.8477e-05, 7.3910e-06, 1.8001e-05,\n",
      "        9.2983e-06, 3.1948e-05, 2.2292e-05, 2.7776e-05, 1.5259e-05, 1.6809e-05,\n",
      "        2.9325e-05, 5.5246e-06, 6.3777e-06, 8.3447e-06, 0.0000e+00, 9.7826e-06,\n",
      "        4.6492e-05, 3.5763e-05, 1.8716e-05, 9.5367e-06])\n",
      "16 tensor([7.4456e-04, 1.3075e+00, 1.8591e+00, 2.2447e+00, 1.5582e+00, 7.2711e-02,\n",
      "        4.1957e-01, 3.0018e-01, 1.6095e-01, 1.7386e+00, 3.1379e-01, 1.5609e+00,\n",
      "        4.5300e-01, 7.4663e-01, 2.9101e-02, 1.2039e+00, 2.0051e-01, 1.1812e+00,\n",
      "        6.1018e-01, 3.0790e-01, 1.4674e+00, 1.8218e+00, 8.7764e-01, 1.0999e+00,\n",
      "        1.1808e-01, 1.1303e-02, 5.2243e-02, 5.8257e-02, 0.0000e+00, 8.0069e-02,\n",
      "        2.9996e-01, 5.8560e-01, 1.2275e+00, 4.3395e-01]) tensor([7.8375e-04, 1.3075e+00, 1.8591e+00, 2.2448e+00, 1.5582e+00, 7.2728e-02,\n",
      "        4.1961e-01, 3.0020e-01, 1.6097e-01, 1.7386e+00, 3.1381e-01, 1.5609e+00,\n",
      "        4.5302e-01, 7.4664e-01, 2.9117e-02, 1.2039e+00, 2.0067e-01, 1.1813e+00,\n",
      "        6.1020e-01, 3.0793e-01, 1.4674e+00, 1.8218e+00, 8.7768e-01, 1.0999e+00,\n",
      "        1.1810e-01, 1.1314e-02, 5.2256e-02, 5.9114e-02, 5.5340e-01, 8.0088e-02,\n",
      "        2.9998e-01, 5.8563e-01, 1.2275e+00, 4.3399e-01]) tensor([1.1683e-05, 1.0014e-05, 1.4186e-05, 1.6928e-05, 1.1921e-05, 8.8811e-06,\n",
      "        6.9141e-06, 1.5974e-05, 2.6226e-06, 1.3232e-05, 5.4836e-06, 2.5034e-05,\n",
      "        6.8843e-06, 5.7220e-06, 4.7684e-07, 9.1791e-06, 7.3910e-06, 9.0599e-06,\n",
      "        9.2983e-06, 2.3842e-06, 1.1086e-05, 1.3947e-05, 2.2411e-05, 8.3447e-06,\n",
      "        4.6492e-06, 5.5246e-06, 6.3777e-06, 8.3447e-06, 0.0000e+00, 9.7826e-06,\n",
      "        1.4067e-05, 1.1683e-05, 9.4175e-06, 9.5367e-06])\n",
      "17 tensor([7.4456e-04, 1.3075e+00, 1.8591e+00, 2.2448e+00, 1.5582e+00, 7.2711e-02,\n",
      "        4.1957e-01, 3.0018e-01, 1.6095e-01, 1.7386e+00, 3.1379e-01, 1.5609e+00,\n",
      "        4.5300e-01, 7.4663e-01, 2.9101e-02, 1.2039e+00, 2.0051e-01, 1.1812e+00,\n",
      "        6.1018e-01, 3.0790e-01, 1.4674e+00, 1.8218e+00, 8.7764e-01, 1.0999e+00,\n",
      "        1.1808e-01, 1.1303e-02, 5.2243e-02, 5.8257e-02, 0.0000e+00, 8.0069e-02,\n",
      "        2.9997e-01, 5.8561e-01, 1.2275e+00, 4.3395e-01]) tensor([7.6416e-04, 1.3075e+00, 1.8591e+00, 2.2448e+00, 1.5582e+00, 7.2728e-02,\n",
      "        4.1961e-01, 3.0019e-01, 1.6097e-01, 1.7386e+00, 3.1381e-01, 1.5609e+00,\n",
      "        4.5302e-01, 7.4664e-01, 2.9117e-02, 1.2039e+00, 2.0067e-01, 1.1813e+00,\n",
      "        6.1020e-01, 3.0793e-01, 1.4674e+00, 1.8218e+00, 8.7766e-01, 1.0999e+00,\n",
      "        1.1810e-01, 1.1314e-02, 5.2256e-02, 5.9114e-02, 5.5340e-01, 8.0088e-02,\n",
      "        2.9998e-01, 5.8563e-01, 1.2275e+00, 4.3399e-01]) tensor([7.8678e-06, 5.0068e-06, 7.0333e-06, 8.3447e-06, 5.9605e-06, 8.8811e-06,\n",
      "        6.9141e-06, 2.0266e-06, 2.6226e-06, 6.5565e-06, 5.4836e-06, 2.3842e-06,\n",
      "        6.8843e-06, 5.7220e-06, 4.7684e-07, 9.1791e-06, 7.3910e-06, 9.0599e-06,\n",
      "        9.2983e-06, 2.3842e-06, 5.6028e-06, 7.0333e-06, 3.5763e-06, 8.3447e-06,\n",
      "        4.6492e-06, 5.5246e-06, 6.3777e-06, 8.3447e-06, 0.0000e+00, 9.7826e-06,\n",
      "        2.1458e-06, 1.2040e-05, 9.4175e-06, 9.5367e-06])\n",
      "18 tensor([7.4456e-04, 1.3075e+00, 1.8591e+00, 2.2448e+00, 1.5582e+00, 7.2711e-02,\n",
      "        4.1957e-01, 3.0018e-01, 1.6095e-01, 1.7386e+00, 3.1379e-01, 1.5609e+00,\n",
      "        4.5300e-01, 7.4663e-01, 2.9101e-02, 1.2039e+00, 2.0051e-01, 1.1812e+00,\n",
      "        6.1018e-01, 3.0790e-01, 1.4674e+00, 1.8218e+00, 8.7764e-01, 1.0999e+00,\n",
      "        1.1808e-01, 1.1303e-02, 5.2243e-02, 5.8257e-02, 0.0000e+00, 8.0069e-02,\n",
      "        2.9997e-01, 5.8561e-01, 1.2275e+00, 4.3395e-01]) tensor([7.6416e-04, 1.3075e+00, 1.8591e+00, 2.2448e+00, 1.5582e+00, 7.2728e-02,\n",
      "        4.1961e-01, 3.0019e-01, 1.6097e-01, 1.7386e+00, 3.1381e-01, 1.5609e+00,\n",
      "        4.5302e-01, 7.4664e-01, 2.9117e-02, 1.2039e+00, 2.0067e-01, 1.1813e+00,\n",
      "        6.1020e-01, 3.0793e-01, 1.4674e+00, 1.8218e+00, 8.7766e-01, 1.0999e+00,\n",
      "        1.1810e-01, 1.1314e-02, 5.2256e-02, 5.9114e-02, 5.5340e-01, 8.0088e-02,\n",
      "        2.9998e-01, 5.8562e-01, 1.2275e+00, 4.3399e-01]) tensor([7.8678e-06, 5.0068e-06, 7.0333e-06, 8.3447e-06, 5.9605e-06, 8.8811e-06,\n",
      "        6.9141e-06, 2.0266e-06, 2.6226e-06, 6.5565e-06, 5.4836e-06, 2.3842e-06,\n",
      "        6.8843e-06, 5.7220e-06, 4.7684e-07, 9.1791e-06, 7.3910e-06, 9.0599e-06,\n",
      "        9.2983e-06, 2.3842e-06, 5.6028e-06, 7.0333e-06, 3.5763e-06, 8.3447e-06,\n",
      "        4.6492e-06, 5.5246e-06, 6.3777e-06, 8.3447e-06, 0.0000e+00, 9.7826e-06,\n",
      "        2.1458e-06, 4.7684e-07, 9.4175e-06, 9.5367e-06])\n",
      "tensor([ 7.8678e-06, -1.8976e+00, -2.0000e+00,  5.0068e-06, -7.3287e-01,\n",
      "        -1.0912e+00, -3.9902e-01,  7.0333e-06, -3.7359e-01, -1.4277e-02,\n",
      "        -1.8847e-02,  8.3447e-06, -1.8495e+00, -6.0888e-01, -1.1156e+00,\n",
      "         5.9605e-06, -1.0130e+00, -1.1877e-01, -2.0000e+00,  8.8811e-06,\n",
      "         6.9141e-06, -2.0266e-06,  2.6226e-06,  6.5565e-06,  5.4836e-06,\n",
      "        -1.5141e+00,  2.3842e-06,  6.8843e-06, -5.1215e-01, -2.7051e-02,\n",
      "        -8.2626e-01,  5.7220e-06, -2.6121e-01, -1.8745e+00,  4.7684e-07,\n",
      "         9.1791e-06,  7.3910e-06, -1.2301e+00, -2.5540e-01,  9.0599e-06,\n",
      "        -8.1847e-01, -3.5109e-01, -1.0000e+00,  9.2983e-06,  2.3842e-06,\n",
      "        -1.5314e+00, -2.4745e-01,  5.6028e-06, -6.7034e-01, -1.0796e+00,\n",
      "        -4.6549e-01,  7.0333e-06,  0.0000e+00, -6.0876e-01, -3.5763e-06,\n",
      "         8.3447e-06, -1.5064e+00, -4.6492e-06, -1.0000e+00,  5.5246e-06,\n",
      "        -8.3817e-01,  0.0000e+00, -1.6740e+00,  6.3777e-06, -1.1792e+00,\n",
      "        -8.3447e-06,  0.0000e+00,  9.7826e-06, -2.1458e-06, -1.2309e+00,\n",
      "        -1.6671e+00,  0.0000e+00, -4.7684e-07, -1.8815e+00, -1.3329e+00,\n",
      "         9.4175e-06, -9.6014e-03, -9.5367e-06, -1.7335e+00,  0.0000e+00])\n",
      "tensor([[ 7.8678e-06, -1.8976e+00, -2.0000e+00,  5.0068e-06],\n",
      "        [-7.3287e-01, -1.0912e+00, -3.9902e-01,  7.0333e-06],\n",
      "        [-3.7359e-01, -1.4277e-02, -1.8847e-02,  8.3447e-06],\n",
      "        [-1.8495e+00, -6.0888e-01, -1.1156e+00,  5.9605e-06],\n",
      "        [-1.0130e+00, -1.1877e-01, -2.0000e+00,  8.8811e-06],\n",
      "        [ 6.9141e-06, -2.0266e-06,  2.6226e-06,  6.5565e-06],\n",
      "        [ 5.4836e-06, -1.5141e+00,  2.3842e-06,  6.8843e-06],\n",
      "        [-5.1215e-01, -2.7051e-02, -8.2626e-01,  5.7220e-06],\n",
      "        [-2.6121e-01, -1.8745e+00,  4.7684e-07,  9.1791e-06],\n",
      "        [ 7.3910e-06, -1.2301e+00, -2.5540e-01,  9.0599e-06],\n",
      "        [-8.1847e-01, -3.5109e-01, -1.0000e+00,  9.2983e-06],\n",
      "        [ 2.3842e-06, -1.5314e+00, -2.4745e-01,  5.6028e-06],\n",
      "        [-6.7034e-01, -1.0796e+00, -4.6549e-01,  7.0333e-06],\n",
      "        [ 0.0000e+00, -6.0876e-01, -3.5763e-06,  8.3447e-06],\n",
      "        [-1.5064e+00, -4.6492e-06, -1.0000e+00,  5.5246e-06],\n",
      "        [-8.3817e-01,  0.0000e+00, -1.6740e+00,  6.3777e-06],\n",
      "        [-1.1792e+00, -8.3447e-06,  0.0000e+00,  9.7826e-06],\n",
      "        [-2.1458e-06, -1.2309e+00, -1.6671e+00,  0.0000e+00],\n",
      "        [-4.7684e-07, -1.8815e+00, -1.3329e+00,  9.4175e-06],\n",
      "        [-9.6014e-03, -9.5367e-06, -1.7335e+00,  0.0000e+00]])\n"
     ]
    }
   ],
   "source": [
    "atomCounts = [2, 2, 2, 0]\n",
    "proj_xs, proj_adjs = project (xs, adjs, DotDict({'constraint': 'Atom-Count', 'params': [atomCounts]}))\n",
    "print (proj_xs.sum(dim=1) - torch.tensor(atomCounts, dtype=proj_xs.dtype, device=proj_xs.device)[None,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) tensor([0.2140, 0.1549, 0.1871, 0.1456, 0.1939, 0.1449, 0.2009, 0.0956, 0.1778,\n",
      "        0.1440, 0.1344, 0.1617, 0.1518, 0.2050, 0.0895, 0.1718, 0.1845, 0.0885,\n",
      "        0.1291, 0.1103]) tensor([ 8.0000,  7.8529,  9.2839,  5.3481,  6.4300,  2.8860, 11.7861,  3.7173,\n",
      "         8.0000,  4.4800,  9.5213,  6.6687, 10.2338,  1.7180,  7.9990,  2.8190,\n",
      "         1.7080,  4.2253,  2.3633,  1.7295])\n",
      "1 tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1004, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0859, 0.0000, 0.0000,\n",
      "        0.0000, 0.0552]) tensor([0.1070, 0.0775, 0.0935, 0.0728, 0.0969, 0.0724, 0.2009, 0.0478, 0.0889,\n",
      "        0.0720, 0.0672, 0.0808, 0.0759, 0.1025, 0.0448, 0.1718, 0.0922, 0.0442,\n",
      "        0.0645, 0.1103]) tensor([ 1.6143, 21.3637, 25.1861, 27.4144,  4.4924, 42.8525, 20.0000, 23.3181,\n",
      "        12.3659, 27.2807, 10.6199, 29.0407, 13.5321, 35.8172, 13.0168, 14.0186,\n",
      "        23.9522, 11.4819, 25.6369, 15.4801])\n",
      "2 tensor([0.0535, 0.0387, 0.0468, 0.0364, 0.0000, 0.0362, 0.1004, 0.0239, 0.0445,\n",
      "        0.0360, 0.0336, 0.0404, 0.0380, 0.0513, 0.0224, 0.0859, 0.0461, 0.0221,\n",
      "        0.0323, 0.0552]) tensor([0.1070, 0.0775, 0.0935, 0.0728, 0.0485, 0.0724, 0.1507, 0.0478, 0.0889,\n",
      "        0.0720, 0.0672, 0.0808, 0.0759, 0.1025, 0.0448, 0.1289, 0.0922, 0.0442,\n",
      "        0.0645, 0.0828]) tensor([ 8.0000,  6.4622,  5.4338,  7.7680,  1.8353, 17.2719,  9.5011,  8.2647,\n",
      "         7.9332,  9.0772,  0.4023,  5.8556,  2.2882, 14.5474,  1.9583,  5.5998,\n",
      "        10.4912,  2.9418, 10.6429,  8.3673])\n",
      "3 tensor([0.0535, 0.0581, 0.0701, 0.0546, 0.0242, 0.0543, 0.1004, 0.0358, 0.0445,\n",
      "        0.0540, 0.0336, 0.0606, 0.0569, 0.0769, 0.0336, 0.0859, 0.0692, 0.0332,\n",
      "        0.0484, 0.0552]) tensor([0.0803, 0.0775, 0.0935, 0.0728, 0.0485, 0.0724, 0.1256, 0.0478, 0.0667,\n",
      "        0.0720, 0.0504, 0.0808, 0.0759, 0.1025, 0.0448, 0.1074, 0.0922, 0.0442,\n",
      "        0.0645, 0.0690]) tensor([6.1343, 1.8851, 3.0309, 0.0856, 2.7477, 6.6155, 0.3358, 1.0138, 0.2057,\n",
      "        1.2638, 5.0774, 1.4600, 4.5817, 7.5769, 3.5709, 1.3904, 4.0374, 0.7518,\n",
      "        3.7354, 3.3189])\n",
      "4 tensor([0.0535, 0.0581, 0.0701, 0.0546, 0.0242, 0.0634, 0.1130, 0.0418, 0.0556,\n",
      "        0.0630, 0.0420, 0.0606, 0.0569, 0.0897, 0.0336, 0.0859, 0.0807, 0.0332,\n",
      "        0.0565, 0.0552]) tensor([0.0669, 0.0678, 0.0818, 0.0637, 0.0363, 0.0724, 0.1256, 0.0478, 0.0667,\n",
      "        0.0720, 0.0504, 0.0707, 0.0664, 0.1025, 0.0392, 0.0966, 0.0922, 0.0387,\n",
      "        0.0645, 0.0621]) tensor([3.2221, 2.0081, 1.2015, 3.7548, 0.7812, 1.6598, 4.6799, 1.3518, 4.8175,\n",
      "        2.0316, 2.1014, 2.1978, 1.1468, 2.8527, 0.8063, 0.7143, 0.8976, 0.9850,\n",
      "        0.6860, 0.7948])\n",
      "5 tensor([0.0535, 0.0629, 0.0760, 0.0592, 0.0242, 0.0679, 0.1130, 0.0418, 0.0556,\n",
      "        0.0630, 0.0462, 0.0657, 0.0569, 0.0961, 0.0336, 0.0913, 0.0865, 0.0360,\n",
      "        0.0605, 0.0552]) tensor([0.0602, 0.0678, 0.0818, 0.0637, 0.0303, 0.0724, 0.1193, 0.0448, 0.0611,\n",
      "        0.0675, 0.0504, 0.0707, 0.0617, 0.1025, 0.0364, 0.0966, 0.0922, 0.0387,\n",
      "        0.0645, 0.0586]) tensor([1.0342, 0.1325, 0.9147, 1.8346, 0.2487, 0.6131, 2.2694, 0.1690, 2.3059,\n",
      "        0.6896, 0.8496, 0.3689, 0.5707, 0.3835, 0.5760, 0.3381, 0.4052, 0.1166,\n",
      "        0.8386, 0.4674])\n",
      "6 tensor([0.0535, 0.0629, 0.0760, 0.0614, 0.0273, 0.0679, 0.1130, 0.0418, 0.0556,\n",
      "        0.0630, 0.0483, 0.0682, 0.0593, 0.0993, 0.0350, 0.0913, 0.0865, 0.0373,\n",
      "        0.0605, 0.0569]) tensor([0.0568, 0.0654, 0.0789, 0.0637, 0.0303, 0.0702, 0.1161, 0.0433, 0.0583,\n",
      "        0.0653, 0.0504, 0.0707, 0.0617, 0.1025, 0.0364, 0.0940, 0.0894, 0.0387,\n",
      "        0.0625, 0.0586]) tensor([0.2901, 0.8946, 0.1434, 0.8745, 0.2662, 0.5234, 1.0641, 0.4224, 1.0501,\n",
      "        0.2871, 0.2237, 0.5455, 0.2880, 0.6673, 0.1151, 0.1881, 0.2462, 0.3176,\n",
      "        0.0763, 0.1637])\n",
      "7 tensor([0.0552, 0.0641, 0.0775, 0.0626, 0.0273, 0.0690, 0.1130, 0.0426, 0.0556,\n",
      "        0.0641, 0.0494, 0.0682, 0.0593, 0.0993, 0.0350, 0.0926, 0.0879, 0.0373,\n",
      "        0.0605, 0.0569]) tensor([0.0568, 0.0654, 0.0789, 0.0637, 0.0288, 0.0702, 0.1146, 0.0433, 0.0570,\n",
      "        0.0653, 0.0504, 0.0695, 0.0605, 0.1009, 0.0357, 0.0940, 0.0894, 0.0380,\n",
      "        0.0615, 0.0578]) tensor([0.3721, 0.3379, 0.3857, 0.3944, 0.0088, 0.0448, 0.3798, 0.1267, 0.4222,\n",
      "        0.2012, 0.0893, 0.0883, 0.1413, 0.1419, 0.2305, 0.0750, 0.0795, 0.1005,\n",
      "        0.3049, 0.1518])\n",
      "8 tensor([0.0552, 0.0648, 0.0775, 0.0631, 0.0273, 0.0690, 0.1130, 0.0429, 0.0556,\n",
      "        0.0641, 0.0494, 0.0682, 0.0599, 0.0993, 0.0353, 0.0926, 0.0879, 0.0373,\n",
      "        0.0610, 0.0573]) tensor([0.0560, 0.0654, 0.0782, 0.0637, 0.0280, 0.0696, 0.1138, 0.0433, 0.0563,\n",
      "        0.0647, 0.0499, 0.0688, 0.0605, 0.1001, 0.0357, 0.0933, 0.0886, 0.0377,\n",
      "        0.0615, 0.0578]) tensor([0.0410, 0.0866, 0.1211, 0.1544, 0.1200, 0.2393, 0.0220, 0.0212, 0.1083,\n",
      "        0.0429, 0.0672, 0.1403, 0.0734, 0.1208, 0.0577, 0.0566, 0.0834, 0.0080,\n",
      "        0.1143, 0.0059])\n",
      "9 tensor([0.0552, 0.0651, 0.0775, 0.0634, 0.0276, 0.0693, 0.1130, 0.0429, 0.0556,\n",
      "        0.0644, 0.0496, 0.0685, 0.0599, 0.0997, 0.0355, 0.0930, 0.0883, 0.0375,\n",
      "        0.0613, 0.0573]) tensor([0.0556, 0.0654, 0.0778, 0.0637, 0.0280, 0.0696, 0.1134, 0.0431, 0.0559,\n",
      "        0.0647, 0.0499, 0.0688, 0.0602, 0.1001, 0.0357, 0.0933, 0.0886, 0.0377,\n",
      "        0.0615, 0.0575]) tensor([0.1245, 0.0230, 0.0111, 0.0344, 0.0556, 0.0972, 0.1569, 0.0528, 0.0487,\n",
      "        0.0791, 0.0111, 0.0260, 0.0340, 0.0106, 0.0287, 0.0092, 0.0019, 0.0462,\n",
      "        0.0190, 0.0729])\n",
      "10 tensor([0.0554, 0.0651, 0.0776, 0.0636, 0.0278, 0.0695, 0.1132, 0.0430, 0.0557,\n",
      "        0.0644, 0.0496, 0.0687, 0.0600, 0.0997, 0.0355, 0.0930, 0.0885, 0.0375,\n",
      "        0.0614, 0.0574]) tensor([0.0556, 0.0652, 0.0778, 0.0637, 0.0280, 0.0696, 0.1134, 0.0431, 0.0559,\n",
      "        0.0646, 0.0498, 0.0688, 0.0602, 0.0999, 0.0356, 0.0931, 0.0886, 0.0376,\n",
      "        0.0615, 0.0575]) tensor([0.0418, 0.0318, 0.0550, 0.0256, 0.0234, 0.0262, 0.0675, 0.0158, 0.0298,\n",
      "        0.0181, 0.0281, 0.0312, 0.0197, 0.0551, 0.0145, 0.0237, 0.0388, 0.0191,\n",
      "        0.0287, 0.0335])\n",
      "11 tensor([0.0555, 0.0651, 0.0776, 0.0636, 0.0279, 0.0695, 0.1133, 0.0431, 0.0557,\n",
      "        0.0644, 0.0497, 0.0687, 0.0600, 0.0998, 0.0355, 0.0930, 0.0885, 0.0375,\n",
      "        0.0614, 0.0575]) tensor([0.0556, 0.0652, 0.0777, 0.0636, 0.0280, 0.0696, 0.1134, 0.0431, 0.0558,\n",
      "        0.0645, 0.0498, 0.0688, 0.0601, 0.0999, 0.0356, 0.0931, 0.0885, 0.0376,\n",
      "        0.0614, 0.0575]) tensor([0.0004, 0.0044, 0.0219, 0.0044, 0.0073, 0.0093, 0.0227, 0.0027, 0.0094,\n",
      "        0.0124, 0.0085, 0.0026, 0.0071, 0.0223, 0.0071, 0.0072, 0.0184, 0.0055,\n",
      "        0.0048, 0.0138])\n",
      "12 tensor([0.0555, 0.0652, 0.0776, 0.0636, 0.0280, 0.0695, 0.1133, 0.0431, 0.0558,\n",
      "        0.0645, 0.0497, 0.0687, 0.0601, 0.0999, 0.0355, 0.0931, 0.0885, 0.0375,\n",
      "        0.0614, 0.0575]) tensor([0.0556, 0.0652, 0.0777, 0.0636, 0.0280, 0.0696, 0.1134, 0.0431, 0.0558,\n",
      "        0.0645, 0.0498, 0.0687, 0.0601, 0.0999, 0.0356, 0.0931, 0.0885, 0.0375,\n",
      "        0.0614, 0.0575]) tensor([0.0203, 0.0093, 0.0054, 0.0106, 0.0007, 0.0084, 0.0004, 0.0066, 0.0102,\n",
      "        0.0028, 0.0013, 0.0117, 0.0063, 0.0059, 0.0037, 0.0010, 0.0082, 0.0013,\n",
      "        0.0071, 0.0039])\n",
      "13 tensor([0.0555, 0.0652, 0.0776, 0.0636, 0.0280, 0.0696, 0.1134, 0.0431, 0.0558,\n",
      "        0.0645, 0.0497, 0.0687, 0.0601, 0.0999, 0.0356, 0.0931, 0.0885, 0.0375,\n",
      "        0.0614, 0.0575]) tensor([0.0556, 0.0652, 0.0777, 0.0636, 0.0280, 0.0696, 0.1134, 0.0431, 0.0558,\n",
      "        0.0645, 0.0497, 0.0687, 0.0601, 0.0999, 0.0356, 0.0931, 0.0885, 0.0375,\n",
      "        0.0614, 0.0575]) tensor([0.0100, 0.0024, 0.0028, 0.0031, 0.0033, 0.0004, 0.0108, 0.0019, 0.0004,\n",
      "        0.0048, 0.0036, 0.0046, 0.0004, 0.0023, 0.0017, 0.0031, 0.0031, 0.0021,\n",
      "        0.0011, 0.0010])\n",
      "14 tensor([0.0555, 0.0652, 0.0777, 0.0636, 0.0280, 0.0696, 0.1134, 0.0431, 0.0558,\n",
      "        0.0645, 0.0497, 0.0687, 0.0601, 0.0999, 0.0356, 0.0931, 0.0885, 0.0375,\n",
      "        0.0614, 0.0575]) tensor([0.0556, 0.0652, 0.0777, 0.0636, 0.0280, 0.0696, 0.1134, 0.0431, 0.0558,\n",
      "        0.0645, 0.0497, 0.0687, 0.0601, 0.0999, 0.0356, 0.0931, 0.0885, 0.0375,\n",
      "        0.0614, 0.0575]) tensor([0.0048, 0.0010, 0.0013, 0.0006, 0.0013, 0.0040, 0.0052, 0.0004, 0.0045,\n",
      "        0.0010, 0.0012, 0.0010, 0.0029, 0.0018, 0.0010, 0.0011, 0.0006, 0.0004,\n",
      "        0.0019, 0.0015])\n",
      "15 tensor([0.0555, 0.0652, 0.0777, 0.0636, 0.0280, 0.0696, 0.1134, 0.0431, 0.0558,\n",
      "        0.0645, 0.0497, 0.0687, 0.0601, 0.0999, 0.0356, 0.0931, 0.0885, 0.0375,\n",
      "        0.0614, 0.0575]) tensor([0.0555, 0.0652, 0.0777, 0.0636, 0.0280, 0.0696, 0.1134, 0.0431, 0.0558,\n",
      "        0.0645, 0.0497, 0.0687, 0.0601, 0.0999, 0.0356, 0.0931, 0.0885, 0.0375,\n",
      "        0.0614, 0.0575]) tensor([2.1954e-03, 9.9564e-04, 7.8201e-04, 6.4659e-04, 2.7466e-04, 4.0016e-03,\n",
      "        2.4071e-03, 3.7193e-04, 2.0714e-03, 9.7466e-04, 1.1654e-03, 9.8038e-04,\n",
      "        2.9221e-03, 2.9755e-04, 9.6893e-04, 5.3406e-05, 6.7520e-04, 4.4632e-04,\n",
      "        1.8520e-03, 1.4553e-03])\n",
      "tensor([-2.1954e-03,  9.9564e-04,  7.8201e-04,  6.4659e-04,  2.7466e-04,\n",
      "         4.0016e-03, -2.4071e-03, -3.7193e-04,  2.0714e-03,  9.7466e-04,\n",
      "         1.1654e-03,  9.8038e-04, -2.9221e-03, -2.9755e-04,  9.6893e-04,\n",
      "         5.3406e-05,  6.7520e-04, -4.4632e-04, -1.8520e-03,  1.4553e-03])\n"
     ]
    }
   ],
   "source": [
    "atomWeights = [12, 14, 16, 18]\n",
    "max_weight = 20\n",
    "proj_xs, proj_adjs = project (xs, adjs, DotDict({'constraint': 'Mol-Weight', 'params': [atomWeights, max_weight]}))\n",
    "print ((proj_xs @ torch.tensor(atomWeights, dtype=xs.dtype, device=xs.device)).sum(dim=1) - max_weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,\n",
      "        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 1.0000]) tensor([0.2023, 0.7829, 1.4074, 0.6526, 1.3452, 0.6612, 2.3302, 2.6946, 3.5294,\n",
      "        2.1336, 2.9083, 1.3164, 0.6407, 2.7937, 2.2059, 0.7731])\n",
      "1 tensor([0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.5000]) tensor([0.5000, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,\n",
      "        0.2500, 0.5000, 0.5000, 0.2500, 0.2500, 0.2500, 1.0000]) tensor([0.5536, 1.0475, 0.8046, 1.1432, 0.5674, 1.5004, 0.4097, 0.5441, 1.3336,\n",
      "        0.8318, 0.0945, 0.8176, 1.6054, 0.0434, 0.2141, 3.3413])\n",
      "2 tensor([0.2500, 0.1250, 0.1250, 0.1250, 0.0000, 0.1250, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.3750, 0.2500, 0.1250, 0.0000, 0.0000, 0.5000]) tensor([0.3750, 0.2500, 0.2500, 0.2500, 0.1250, 0.2500, 0.1250, 0.1250, 0.1250,\n",
      "        0.1250, 0.5000, 0.3750, 0.2500, 0.1250, 0.1250, 0.7500]) tensor([0.1427, 0.1468, 0.3810, 0.2424, 0.1514, 0.4135, 0.5848, 0.5322, 0.2562,\n",
      "        0.1863, 1.2841, 0.2958, 0.4802, 1.3396, 0.8469, 1.1630])\n",
      "3 tensor([0.2500, 0.1875, 0.1250, 0.1875, 0.0000, 0.1875, 0.0625, 0.0625, 0.0000,\n",
      "        0.0000, 0.3750, 0.3125, 0.1875, 0.0625, 0.0625, 0.5000]) tensor([0.3125, 0.2500, 0.1875, 0.2500, 0.0625, 0.2500, 0.1250, 0.1250, 0.0625,\n",
      "        0.0625, 0.4375, 0.3750, 0.2500, 0.1250, 0.1250, 0.6250]) tensor([0.0267, 0.3181, 0.1868, 0.2080, 0.0556, 0.1243, 0.0704, 0.0060, 0.2825,\n",
      "        0.1365, 0.5997, 0.2606, 0.0804, 0.6481, 0.3111, 0.0948])\n",
      "4 tensor([0.2812, 0.1875, 0.1562, 0.1875, 0.0312, 0.1875, 0.0938, 0.0625, 0.0312,\n",
      "        0.0312, 0.3750, 0.3125, 0.1875, 0.0938, 0.0938, 0.5000]) tensor([0.3125, 0.2188, 0.1875, 0.2188, 0.0625, 0.2188, 0.1250, 0.0938, 0.0625,\n",
      "        0.0625, 0.4062, 0.3438, 0.2188, 0.1250, 0.1250, 0.5625]) tensor([0.0580, 0.0857, 0.1163, 0.0172, 0.0474, 0.1442, 0.1697, 0.2631, 0.0132,\n",
      "        0.0249, 0.2540, 0.0176, 0.1997, 0.3024, 0.0432, 0.3881])\n",
      "5 tensor([0.2812, 0.1875, 0.1562, 0.2031, 0.0312, 0.2031, 0.0938, 0.0781, 0.0469,\n",
      "        0.0312, 0.3750, 0.3281, 0.2031, 0.1094, 0.1094, 0.5312]) tensor([0.2969, 0.2031, 0.1719, 0.2188, 0.0469, 0.2188, 0.1094, 0.0938, 0.0625,\n",
      "        0.0469, 0.3906, 0.3438, 0.2188, 0.1250, 0.1250, 0.5625]) tensor([0.0156, 0.0306, 0.0324, 0.0954, 0.0044, 0.0099, 0.0496, 0.1286, 0.1215,\n",
      "        0.0558, 0.0797, 0.1215, 0.0596, 0.1295, 0.0896, 0.1723])\n",
      "6 tensor([0.2812, 0.1953, 0.1641, 0.2031, 0.0391, 0.2109, 0.0938, 0.0859, 0.0469,\n",
      "        0.0391, 0.3750, 0.3281, 0.2109, 0.1172, 0.1094, 0.5469]) tensor([0.2891, 0.2031, 0.1719, 0.2109, 0.0469, 0.2188, 0.1016, 0.0938, 0.0547,\n",
      "        0.0469, 0.3828, 0.3359, 0.2188, 0.1250, 0.1172, 0.5625]) tensor([0.0055, 0.0275, 0.0449, 0.0391, 0.0214, 0.0572, 0.0104, 0.0613, 0.0542,\n",
      "        0.0155, 0.0074, 0.0520, 0.0104, 0.0431, 0.0238, 0.0388])\n",
      "7 tensor([0.2852, 0.1953, 0.1641, 0.2031, 0.0391, 0.2109, 0.0977, 0.0898, 0.0469,\n",
      "        0.0430, 0.3789, 0.3281, 0.2109, 0.1211, 0.1094, 0.5547]) tensor([0.2891, 0.1992, 0.1680, 0.2070, 0.0430, 0.2148, 0.1016, 0.0938, 0.0508,\n",
      "        0.0469, 0.3828, 0.3320, 0.2148, 0.1250, 0.1133, 0.5625]) tensor([0.0050, 0.0015, 0.0062, 0.0109, 0.0084, 0.0236, 0.0196, 0.0277, 0.0205,\n",
      "        0.0047, 0.0362, 0.0172, 0.0246, 0.0002, 0.0097, 0.0280])\n",
      "8 tensor([0.2852, 0.1973, 0.1641, 0.2031, 0.0391, 0.2109, 0.0977, 0.0918, 0.0469,\n",
      "        0.0430, 0.3789, 0.3281, 0.2129, 0.1211, 0.1113, 0.5547]) tensor([0.2871, 0.1992, 0.1660, 0.2051, 0.0410, 0.2129, 0.0996, 0.0938, 0.0488,\n",
      "        0.0449, 0.3809, 0.3301, 0.2148, 0.1230, 0.1133, 0.5586]) tensor([0.0003, 0.0130, 0.0131, 0.0032, 0.0020, 0.0068, 0.0046, 0.0109, 0.0037,\n",
      "        0.0054, 0.0144, 0.0002, 0.0071, 0.0214, 0.0070, 0.0054])\n",
      "9 tensor([0.2861, 0.1973, 0.1650, 0.2041, 0.0391, 0.2109, 0.0977, 0.0928, 0.0469,\n",
      "        0.0439, 0.3789, 0.3291, 0.2139, 0.1221, 0.1113, 0.5566]) tensor([0.2871, 0.1982, 0.1660, 0.2051, 0.0400, 0.2119, 0.0986, 0.0938, 0.0479,\n",
      "        0.0449, 0.3799, 0.3301, 0.2148, 0.1230, 0.1123, 0.5586]) tensor([0.0024, 0.0058, 0.0034, 0.0039, 0.0012, 0.0015, 0.0029, 0.0025, 0.0047,\n",
      "        0.0003, 0.0035, 0.0085, 0.0017, 0.0106, 0.0014, 0.0113])\n",
      "10 tensor([0.2861, 0.1973, 0.1655, 0.2041, 0.0396, 0.2114, 0.0981, 0.0933, 0.0474,\n",
      "        0.0444, 0.3789, 0.3291, 0.2139, 0.1226, 0.1118, 0.5566]) tensor([0.2866, 0.1978, 0.1660, 0.2046, 0.0400, 0.2119, 0.0986, 0.0938, 0.0479,\n",
      "        0.0449, 0.3794, 0.3296, 0.2144, 0.1230, 0.1123, 0.5576]) tensor([0.0011, 0.0021, 0.0014, 0.0004, 0.0004, 0.0027, 0.0009, 0.0017, 0.0005,\n",
      "        0.0022, 0.0020, 0.0041, 0.0027, 0.0052, 0.0028, 0.0030])\n",
      "11 tensor([0.2861, 0.1973, 0.1655, 0.2041, 0.0396, 0.2114, 0.0981, 0.0933, 0.0476,\n",
      "        0.0444, 0.3792, 0.3291, 0.2141, 0.1228, 0.1118, 0.5566]) tensor([0.2864, 0.1975, 0.1658, 0.2043, 0.0398, 0.2117, 0.0984, 0.0935, 0.0479,\n",
      "        0.0447, 0.3794, 0.3293, 0.2144, 0.1230, 0.1121, 0.5571]) tensor([0.0004, 0.0003, 0.0010, 0.0014, 0.0004, 0.0006, 0.0010, 0.0004, 0.0016,\n",
      "        0.0009, 0.0008, 0.0020, 0.0005, 0.0025, 0.0007, 0.0012])\n",
      "12 tensor([0.2861, 0.1973, 0.1656, 0.2042, 0.0397, 0.2114, 0.0983, 0.0934, 0.0476,\n",
      "        0.0444, 0.3792, 0.3291, 0.2142, 0.1229, 0.1118, 0.5569]) tensor([0.2863, 0.1974, 0.1658, 0.2043, 0.0398, 0.2115, 0.0984, 0.0935, 0.0477,\n",
      "        0.0446, 0.3793, 0.3292, 0.2144, 0.1230, 0.1119, 0.5571]) tensor([8.0377e-05, 5.9745e-04, 2.1031e-04, 5.1287e-04, 4.2707e-05, 4.9284e-04,\n",
      "        6.1184e-05, 6.9669e-04, 5.2795e-04, 3.0747e-04, 5.9479e-04, 8.8063e-04,\n",
      "        5.6377e-04, 1.1841e-03, 3.0968e-04, 8.8444e-04])\n",
      "13 tensor([0.2861, 0.1973, 0.1656, 0.2043, 0.0397, 0.2115, 0.0983, 0.0934, 0.0476,\n",
      "        0.0444, 0.3792, 0.3291, 0.2142, 0.1230, 0.1119, 0.5569]) tensor([0.2862, 0.1974, 0.1657, 0.2043, 0.0397, 0.2115, 0.0984, 0.0934, 0.0477,\n",
      "        0.0445, 0.3793, 0.3292, 0.2143, 0.1230, 0.1119, 0.5570]) tensor([8.4966e-05, 1.4338e-04, 3.9276e-04, 7.2926e-05, 1.5733e-04, 3.1561e-05,\n",
      "        4.0779e-04, 1.7107e-04, 1.9968e-06, 7.6592e-06, 8.6039e-05, 3.3751e-04,\n",
      "        1.6838e-05, 5.0887e-04, 2.1377e-04, 1.5900e-04])\n",
      "14 tensor([0.2862, 0.1974, 0.1657, 0.2043, 0.0397, 0.2115, 0.0983, 0.0934, 0.0476,\n",
      "        0.0444, 0.3792, 0.3291, 0.2142, 0.1230, 0.1119, 0.5569]) tensor([0.2862, 0.1974, 0.1657, 0.2043, 0.0397, 0.2115, 0.0984, 0.0934, 0.0477,\n",
      "        0.0445, 0.3792, 0.3291, 0.2143, 0.1230, 0.1119, 0.5570]) tensor([2.4736e-06, 8.3715e-05, 9.1404e-05, 1.4684e-04, 5.7310e-05, 2.3082e-04,\n",
      "        1.7342e-04, 9.1732e-05, 1.9968e-06, 7.6592e-06, 2.5445e-04, 6.5714e-05,\n",
      "        2.5734e-04, 1.7139e-04, 4.8012e-05, 3.6249e-04])\n",
      "15 tensor([0.2862, 0.1974, 0.1657, 0.2043, 0.0397, 0.2115, 0.0983, 0.0934, 0.0476,\n",
      "        0.0444, 0.3792, 0.3291, 0.2142, 0.1230, 0.1119, 0.5569]) tensor([0.2862, 0.1974, 0.1657, 0.2043, 0.0397, 0.2115, 0.0983, 0.0934, 0.0477,\n",
      "        0.0445, 0.3792, 0.3291, 0.2143, 0.1230, 0.1119, 0.5570]) tensor([2.4736e-06, 3.0130e-05, 5.9396e-05, 3.6806e-05, 7.3314e-06, 9.9510e-05,\n",
      "        5.6118e-05, 3.9756e-05, 1.9968e-06, 7.6592e-06, 8.4043e-05, 7.0184e-05,\n",
      "        1.2001e-04, 2.2352e-06, 8.2880e-05, 1.0172e-04])\n",
      "16 tensor([0.2862, 0.1974, 0.1657, 0.2043, 0.0397, 0.2115, 0.0983, 0.0934, 0.0476,\n",
      "        0.0444, 0.3792, 0.3291, 0.2143, 0.1230, 0.1119, 0.5569]) tensor([0.2862, 0.1974, 0.1657, 0.2043, 0.0397, 0.2115, 0.0983, 0.0934, 0.0477,\n",
      "        0.0445, 0.3792, 0.3291, 0.2143, 0.1230, 0.1119, 0.5570]) tensor([2.4736e-06, 2.7090e-05, 1.5825e-05, 1.8030e-05, 7.3314e-06, 3.3945e-05,\n",
      "        2.2352e-06, 2.5839e-05, 1.9968e-06, 7.6592e-06, 1.1027e-06, 2.2352e-06,\n",
      "        5.2065e-05, 2.2352e-06, 1.7554e-05, 2.8402e-05])\n",
      "17 tensor([0.2862, 0.1974, 0.1657, 0.2043, 0.0397, 0.2115, 0.0983, 0.0934, 0.0476,\n",
      "        0.0444, 0.3792, 0.3291, 0.2143, 0.1230, 0.1119, 0.5570]) tensor([0.2862, 0.1974, 0.1657, 0.2043, 0.0397, 0.2115, 0.0983, 0.0934, 0.0477,\n",
      "        0.0445, 0.3792, 0.3291, 0.2143, 0.1230, 0.1119, 0.5570]) tensor([2.4736e-06, 2.7090e-05, 1.5825e-05, 1.8030e-05, 7.3314e-06, 3.3945e-05,\n",
      "        2.2352e-06, 2.5839e-05, 1.9968e-06, 7.6592e-06, 1.1027e-06, 2.2352e-06,\n",
      "        5.2065e-05, 2.2352e-06, 1.7554e-05, 3.6567e-05])\n",
      "tensor([ 2.4736e-06, -2.7090e-05,  1.5825e-05,  1.8030e-05,  7.3314e-06,\n",
      "         3.3945e-05, -2.7278e+00, -6.9747e-02,  2.2352e-06,  2.5839e-05,\n",
      "        -1.9968e-06,  7.6592e-06, -1.1027e-06,  2.2352e-06, -1.2333e+00,\n",
      "         5.2065e-05,  2.2352e-06, -1.7554e-05, -3.6567e-05, -4.9371e+00])\n"
     ]
    }
   ],
   "source": [
    "xtheta_params = torch.randn_like (xs[0].reshape(-1))\n",
    "atheta_params = torch.randn_like (adjs[0].reshape(-1))\n",
    "torch.save(xtheta_params, \"xtheta_params.pt\")\n",
    "torch.save(atheta_params, \"atheta_params.pt\")\n",
    "prop = torch.randn(1).item()\n",
    "\n",
    "proj_xs, proj_adjs = project (xs, adjs, DotDict({'constraint': 'Property-MLP', 'params': [\"xtheta_params.pt\", \"atheta_params.pt\", prop]}))\n",
    "print (xtheta_params @ proj_xs.reshape(proj_xs.shape[0], -1).T + atheta_params @ proj_adjs.reshape(proj_adjs.shape[0], -1).T - prop)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(20)\n",
      "tensor([-0.9819, -3.4398, -3.0221, -2.2747, -4.0857, -4.4855, -0.2926, -3.2032,\n",
      "        -1.3910, -3.9932, -1.4100, -1.0274, -3.1581, -1.9403, -1.7106, -3.6329,\n",
      "        -3.7828, -2.0122, -4.6908, -1.0210])\n"
     ]
    }
   ],
   "source": [
    "theta_params = torch.randn (xs.shape[-1], device=xs.device, dtype=xs.dtype)\n",
    "torch.save(theta_params, \"config/constraints/regmodels/sgc2_temp.pt\")\n",
    "prop = torch.randn(1).item()\n",
    "nlayers = 2\n",
    "torch.save(theta_params, \"config/constraints/regmodels/sgc2_temp.pt\")\n",
    "proj_xs, proj_adjs = project (xs, adjs, DotDict({'constraint': 'Property-SGC', 'params': [\"sgc2_temp.pt\", prop]}))\n",
    "# \n",
    "# adjs_norm = proj_adjs.clone()\n",
    "# n = proj_adjs.shape[1]\n",
    "# adjs_norm[:, torch.arange(n), torch.arange(n)] = 1\n",
    "# degs_norm = adjs_norm.sum(dim=2)\n",
    "# degs_norm = (degs_norm[:, :, None] @ degs_norm[:, None, :])**0.5\n",
    "# adjs_norm = adjs_norm / degs_norm\n",
    "# adjs_norm_k = torch.matrix_power(adjs_norm, nlayers)\n",
    "# print ((adjs_norm_k @ proj_xs @ theta_params[None, :, None]).squeeze().sum(dim=1) - prop)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from reg_models import SGCReg\n",
    "\n",
    "model = SGCReg (xs.shape[-1], 2).float()\n",
    "torch.save (model.state_dict(), \"config/constraints/regmodels/sgc2_temp.pt\")\n",
    "\n",
    "prop2 = torch.rand(1).item()\n",
    "prop1 = torch.rand(1).item()/prop2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-0.5979, -0.8385, -1.0902, -0.7480, -0.6076, -1.0629, -0.7142, -1.1087,\n",
      "        -0.9371, -0.7992, -0.8625, -0.7978, -0.8470, -0.8183, -0.8629, -0.9081,\n",
      "        -0.7405, -0.8646, -0.8736, -0.6982])\n",
      "tensor([ 3.1888e-06,  2.4438e-06,  6.6459e-06,  2.5332e-06,  9.0599e-06,\n",
      "         3.4273e-06, -6.5565e-06,  2.7120e-06, -9.3877e-06,  5.3644e-07,\n",
      "         8.1062e-06, -6.3479e-06, -2.3842e-06,  8.3745e-06,  2.3842e-06,\n",
      "         3.1888e-06,  1.7583e-06,  8.8215e-06, -2.8908e-06,  2.0862e-07])\n",
      "tensor([-0.2906, -0.2906, -0.2906, -0.2906, -0.2906, -0.2906, -0.2906, -0.2906,\n",
      "        -0.2905, -0.2906, -0.2906, -0.2906, -0.2906, -0.2906, -0.2906, -0.2906,\n",
      "        -0.2906, -0.2906, -0.2906, -0.2906])\n",
      "tensor([0.3073, 0.5480, 0.7996, 0.4575, 0.3171, 0.7724, 0.4236, 0.8182, 0.6465,\n",
      "        0.5087, 0.5719, 0.5072, 0.5564, 0.5278, 0.5724, 0.6175, 0.4500, 0.5740,\n",
      "        0.5830, 0.4077])\n"
     ]
    }
   ],
   "source": [
    "nlayers = 2\n",
    "\n",
    "proj_xs, proj_adjs = project (xs, adjs, DotDict({'constraint': 'Property-SGC-In', 'params': [\"sgc2_temp.pt\", prop1, prop2]}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "theta_params = torch.load (f\"config/constraints/regmodels/sgc2_temp.pt\", map_location=xs.device)\n",
    "theta = theta_params['conv.lin.weight'].squeeze().type(xs.dtype).to(xs.device)\n",
    "bias = theta_params['conv.lin.bias'].item()\n",
    "prop = torch.rand(1).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs.requires_grad = True\n",
    "adjs_norm.requires_grad = True\n",
    "adjs_norm_I = adjs_norm + torch.stack([torch.eye(adjs.shape[1], device=adjs_norm.device, dtype=adjs.dtype, requires_grad=True) \n",
    "                                        for _ in range(adjs.shape[0])])\n",
    "degs_norm = adjs_norm_I.sum(dim=2)\n",
    "degs_norm = ((degs_norm[:, :, None] @ degs_norm[:, None, :]))**0.5\n",
    "adjs_norm_Ideg = adjs_norm_I / degs_norm\n",
    "adjs_norm_k = torch.matrix_power(adjs_norm_Ideg, nlayers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([     nan,      nan,  -2.2476,  -0.3318,   0.5256,  10.1121,      nan,\n",
       "             nan,      nan,      nan,      nan, 200.2227,      nan,   1.4147,\n",
       "             nan,      nan,      nan,      nan,      nan,      nan],\n",
       "       grad_fn=<SumBackward1>)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(adjs_norm_k @ xs @ theta[None, :, None]).squeeze().sum(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "degs_norm = adjs_norm_I.sum(dim=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "grad can be implicitly created only for scalar outputs",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_4109364/2742720289.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0madjs_norm_k\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatrix_power\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0madjs_norm_Ideg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnlayers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0merror\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0madjs_norm_k\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mxs\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mtheta\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mprop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mgrad_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_adjs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merror\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merror\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madjs_norm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/home/ksharma323/miniconda3/envs/moltemp/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mgrad\u001b[0;34m(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused)\u001b[0m\n\u001b[1;32m    216\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    217\u001b[0m     \u001b[0mgrad_outputs_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_tensor_or_tensors_to_tuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrad_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 218\u001b[0;31m     \u001b[0mgrad_outputs_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_make_grads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_outputs_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    219\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    220\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/ksharma323/miniconda3/envs/moltemp/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36m_make_grads\u001b[0;34m(outputs, grads)\u001b[0m\n\u001b[1;32m     48\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_grad\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     49\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 50\u001b[0;31m                     \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"grad can be implicitly created only for scalar outputs\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     51\u001b[0m                 \u001b[0mnew_grads\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemory_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpreserve_format\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     52\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: grad can be implicitly created only for scalar outputs"
     ]
    }
   ],
   "source": [
    "nlayers = 2\n",
    "adjs_norm = adjs.clone().float()\n",
    "xs = xs.clone()\n",
    "xs.requires_grad = True\n",
    "adjs_norm.requires_grad = True\n",
    "adjs_norm_I = adjs_norm + torch.stack([torch.eye(adjs.shape[1], device=adjs_norm.device, dtype=adjs.dtype, requires_grad=True) \n",
    "                                        for _ in range(adjs.shape[0])])\n",
    "degs_norm = adjs_norm_I.sum(dim=2)\n",
    "degs_norm = ((degs_norm[:, :, None] @ degs_norm[:, None, :]))**0.5\n",
    "adjs_norm_Ideg = adjs_norm_I / degs_norm\n",
    "adjs_norm_k = torch.matrix_power(adjs_norm_Ideg, nlayers)\n",
    "error = ((adjs_norm_k @ xs @ theta[None, :, None]).squeeze().sum(dim=1) + bias - prop)**2\n",
    "grad_x, grad_adjs = torch.autograd.grad(error, xs), torch.autograd.grad(error, adjs_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2041,\n",
      "        0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040,\n",
      "        0.2040, 0.2040]) tensor([-0.2906, -0.2906, -0.2906, -0.2906, -0.2906, -0.2906, -0.2906, -0.2906,\n",
      "        -0.2905, -0.2906, -0.2906, -0.2906, -0.2906, -0.2906, -0.2906, -0.2906,\n",
      "        -0.2906, -0.2906, -0.2906, -0.2906]) tensor([ 3.1888e-06,  2.4438e-06,  6.6459e-06,  2.5332e-06,  9.0599e-06,\n",
      "         3.4273e-06, -6.5565e-06,  2.7120e-06, -9.3877e-06,  5.3644e-07,\n",
      "         8.1062e-06, -6.3479e-06, -2.3842e-06,  8.3745e-06,  2.3842e-06,\n",
      "         3.1888e-06,  1.7583e-06,  8.8215e-06, -2.8908e-06,  2.0862e-07])\n"
     ]
    }
   ],
   "source": [
    "nlayers = int(constraint_config.params[0].split(\"_\")[0][3:])\n",
    "prop_low = constraint_config.params[1]\n",
    "prop_high = constraint_config.params[2]\n",
    "prop = (prop_low + prop_high) / 2.\n",
    "adjs_norm = adjs.clone().float()\n",
    "xs = xs.clone()\n",
    "xs.requires_grad = True\n",
    "adjs_norm.requires_grad = True\n",
    "adjs_norm_I = adjs_norm + torch.stack([torch.eye(adjs.shape[1], device=adjs_norm.device, dtype=adjs.dtype, requires_grad=True) \n",
    "                                        for _ in range(adjs.shape[0])])\n",
    "degs_norm = adjs_norm_I.sum(dim=2)\n",
    "degs_norm = ((degs_norm[:, :, None] @ degs_norm[:, None, :]))**0.5\n",
    "adjs_norm_Ideg = adjs_norm_I / degs_norm\n",
    "adjs_norm_k = torch.matrix_power(adjs_norm_Ideg, nlayers)\n",
    "error = ((adjs_norm_k @ xs @ theta[None, :, None]).squeeze().sum(dim=1) + bias - prop)**2\n",
    "grad_x, grad_adjs = torch.autograd.grad(error, xs), torch.autograd.grad(error, adjs_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils.sparse import dense_to_sparse\n",
    "from torch_geometric.data import Batch, Data\n",
    "\n",
    "batches = []\n",
    "data_list = []\n",
    "for i in range (adjs.shape[0]):\n",
    "    edge_ind, edge_attr = dense_to_sparse (proj_adjs[i])\n",
    "    data_list.append (Data (x=proj_xs[i], edge_index=edge_ind, edge_attr=edge_attr))\n",
    "\n",
    "data_batch = Batch.from_data_list(data_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2041,\n",
       "        0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040, 0.2040,\n",
       "        0.2040, 0.2040], grad_fn=<SqueezeBackward0>)"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model (data_batch.x, data_batch.edge_index, data_batch.edge_attr, data_batch.batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(nan, grad_fn=<SubBackward0>) tensor(nan, grad_fn=<RsubBackward1>)\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.utils.sparse import dense_to_sparse\n",
    "\n",
    "edge_ind, edge_attr = dense_to_sparse (adjs)\n",
    "x = xs[0]\n",
    "batch = torch.zeros_like (x[:, 0]).long()\n",
    "\n",
    "\n",
    "print (r - prop2, prop1 - r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[False, False,  True,  True],\n",
       "         [False, False,  True,  True],\n",
       "         [ True, False,  True, False],\n",
       "         [False, False,  True, False],\n",
       "         [False, False,  True,  True]]),\n",
       " tensor([[ True,  True, False, False],\n",
       "         [ True,  True, False, False],\n",
       "         [False,  True, False,  True],\n",
       "         [ True,  True, False,  True],\n",
       "         [ True,  True, False, False]]),\n",
       " tensor([[True, True, True, True],\n",
       "         [True, True, True, True],\n",
       "         [True, True, True, True],\n",
       "         [True, True, True, True],\n",
       "         [True, True, True, True]]))"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xs[0] <= 0.5, xs[0] > 0.5, (xs[0] <= 0.5) | (xs[0] > 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "atoms_exist = torch.any(Xsnorm > 0.5, dim=2, keepdim=True).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj_rowvecs = adjs.reshape(-1, adjs.shape[1])[atoms_exist.reshape(-1), :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "mus_lower = adj_rowvecs.reshape (adj_rowvecs.shape[0], -1).min(dim=1)[0] - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "mus = torch.rand (xs.shape[0])\n",
    "atomic_weights = torch.arange (1, xs.shape[2]+1, dtype=mus.dtype)\n",
    "\n",
    "mu_weights = mus[:, None] @ atomic_weights[None, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj_rowvecs[~atoms_exist.repeat(1, atoms_exist.shape[1]).reshape(atoms_exist.shape[1]*atoms_exist.shape[0], atoms_exist.shape[1])[atoms_exist.reshape(-1), :]] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_3090217/347414840.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mproj_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mAs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmus\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mP01\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mAs\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mmus\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatrix_power\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mAs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmus\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mP01\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mAs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mmus_lower\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0madjs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mfind_muUpper\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0madjs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproj_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconstr_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmus_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/tmp/ipykernel_3090217/1573150747.py\u001b[0m in \u001b[0;36mfind_muUpper\u001b[0;34m(inputs, proj_fn, constr_fn, mus_lower, params)\u001b[0m\n\u001b[1;32m      7\u001b[0m         still_find.scatter_ (dim=0, index=torch.where(still_find)[0], \n\u001b[1;32m      8\u001b[0m                              src=((constr_fn (proj_fn (inputs[still_find], mus=mus_lower[still_find]), params=params) *\n\u001b[0;32m----> 9\u001b[0;31m                                    constr_fn (proj_fn (inputs[still_find], mus=mus_upper[still_find]), params=params)) > 0))\n\u001b[0m\u001b[1;32m     10\u001b[0m         \u001b[0mmus_upper\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstill_find\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mstep_size\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mmus_upper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "P01 = lambda x: torch.clamp (x, 0, 1)\n",
    "constr_fn = lambda As, params: (1/6 * torch.diagonal(torch.matrix_power(As, 3), dim1=1, dim2=2).sum(dim=1) - params)\n",
    "proj_fn = lambda As, mus=None: P01 (As - mus[:, None, None] * torch.matrix_power(As, 2)) if mus is not None else P01 (As)\n",
    "mus_lower = torch.zeros (adjs.shape[0])\n",
    "find_muUpper (adjs, proj_fn, constr_fn, mus_lower, params=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0000e+00, -8.0876e-01,  2.3043e-01, -5.9609e-01,  9.7901e-01],\n",
       "        [-8.0876e-01,  0.0000e+00,  1.8443e+00,  1.7953e+00, -7.3648e-01],\n",
       "        [ 2.3043e-01,  1.8443e+00,  0.0000e+00, -1.1911e+00,  5.3133e-01],\n",
       "        [-5.9609e-01,  1.7953e+00, -1.1911e+00,  0.0000e+00,  1.5411e-01],\n",
       "        [ 9.7901e-01, -7.3648e-01,  5.3133e-01,  1.5411e-01,  0.0000e+00],\n",
       "        [ 0.0000e+00,  0.0000e+00,  1.0390e+00,  7.9877e-01,  1.0261e-01],\n",
       "        [ 1.0390e+00,  0.0000e+00,  0.0000e+00, -9.8562e-01,  2.9623e-01],\n",
       "        [ 7.9877e-01,  0.0000e+00, -9.8562e-01,  0.0000e+00,  7.1313e-01],\n",
       "        [ 1.0261e-01,  0.0000e+00,  2.9623e-01,  7.1313e-01,  0.0000e+00],\n",
       "        [ 0.0000e+00,  1.9031e-01,  8.8241e-01,  1.1338e+00,  7.8011e-01],\n",
       "        [ 1.9031e-01,  0.0000e+00,  1.9461e-01,  9.2036e-01, -1.5031e+00],\n",
       "        [ 8.8241e-01,  1.9461e-01,  0.0000e+00,  4.4297e-01,  3.9634e-01],\n",
       "        [ 1.1338e+00,  9.2036e-01,  4.4297e-01,  0.0000e+00,  3.3096e-02],\n",
       "        [ 7.8011e-01, -1.5031e+00,  3.9634e-01,  3.3096e-02,  0.0000e+00],\n",
       "        [ 0.0000e+00,  0.0000e+00, -3.2957e-01,  0.0000e+00,  1.9927e-01],\n",
       "        [ 0.0000e+00, -3.2957e-01,  0.0000e+00,  0.0000e+00, -2.7309e+00],\n",
       "        [ 0.0000e+00,  1.9927e-01, -2.7309e+00,  0.0000e+00,  0.0000e+00],\n",
       "        [ 0.0000e+00,  1.1278e+00, -3.1702e-01,  5.4876e-01,  9.3097e-01],\n",
       "        [ 1.1278e+00,  0.0000e+00,  2.8105e-01,  1.8529e-01,  7.4979e-01],\n",
       "        [-3.1702e-01,  2.8105e-01,  0.0000e+00,  3.3120e-01,  6.1879e-01],\n",
       "        [ 5.4876e-01,  1.8529e-01,  3.3120e-01,  0.0000e+00, -7.9655e-01],\n",
       "        [ 9.3097e-01,  7.4979e-01,  6.1879e-01, -7.9655e-01,  0.0000e+00],\n",
       "        [ 0.0000e+00,  1.4267e-01, -4.0317e-01, -5.5347e-01, -1.7622e-01],\n",
       "        [ 1.4267e-01,  0.0000e+00,  7.2923e-01, -1.1847e+00, -3.9920e-01],\n",
       "        [-4.0317e-01,  7.2923e-01,  0.0000e+00,  7.3667e-01,  6.4860e-01],\n",
       "        [-5.5347e-01, -1.1847e+00,  7.3667e-01,  0.0000e+00,  2.1438e-01],\n",
       "        [-1.7622e-01, -3.9920e-01,  6.4860e-01,  2.1438e-01,  0.0000e+00],\n",
       "        [ 0.0000e+00, -1.3665e+00, -8.4168e-01, -4.2972e-01, -1.8949e+00],\n",
       "        [-1.3665e+00,  0.0000e+00, -5.6269e-01,  2.2433e-01,  9.2611e-02],\n",
       "        [-8.4168e-01, -5.6269e-01,  0.0000e+00,  5.5144e-01,  6.4422e-01],\n",
       "        [-4.2972e-01,  2.2433e-01,  5.5144e-01,  0.0000e+00,  5.0004e-01],\n",
       "        [-1.8949e+00,  9.2611e-02,  6.4422e-01,  5.0004e-01,  0.0000e+00],\n",
       "        [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -1.5691e+00, -3.8954e-01],\n",
       "        [-1.5691e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.2742e+00],\n",
       "        [-3.8954e-01,  0.0000e+00,  0.0000e+00,  1.2742e+00,  0.0000e+00],\n",
       "        [ 0.0000e+00, -1.0588e+00, -1.1280e+00,  1.3822e+00,  1.8109e+00],\n",
       "        [-1.0588e+00,  0.0000e+00,  1.5893e+00,  2.2412e+00,  9.1662e-02],\n",
       "        [-1.1280e+00,  1.5893e+00,  0.0000e+00, -7.9173e-01, -1.2564e-01],\n",
       "        [ 1.3822e+00,  2.2412e+00, -7.9173e-01,  0.0000e+00,  1.6212e+00],\n",
       "        [ 1.8109e+00,  9.1662e-02, -1.2564e-01,  1.6212e+00,  0.0000e+00],\n",
       "        [ 0.0000e+00, -4.9859e-01,  1.1785e+00, -2.3531e+00,  2.1405e+00],\n",
       "        [-4.9859e-01,  0.0000e+00,  5.2948e-01, -5.8967e-02, -4.6306e-01],\n",
       "        [ 1.1785e+00,  5.2948e-01,  0.0000e+00, -8.2800e-01, -6.7524e-01],\n",
       "        [-2.3531e+00, -5.8967e-02, -8.2800e-01,  0.0000e+00,  4.4071e-01],\n",
       "        [ 2.1405e+00, -4.6306e-01, -6.7524e-01,  4.4071e-01,  0.0000e+00],\n",
       "        [ 0.0000e+00, -1.1352e+00,  7.7235e-01,  1.1880e+00,  1.3717e-01],\n",
       "        [-1.1352e+00,  0.0000e+00,  4.1249e-01,  1.3453e+00, -1.9925e-01],\n",
       "        [ 7.7235e-01,  4.1249e-01,  0.0000e+00, -1.6082e-01, -1.2524e-01],\n",
       "        [ 1.1880e+00,  1.3453e+00, -1.6082e-01,  0.0000e+00, -9.1411e-01],\n",
       "        [ 1.3717e-01, -1.9925e-01, -1.2524e-01, -9.1411e-01,  0.0000e+00],\n",
       "        [ 0.0000e+00, -1.2588e+00, -8.6566e-01,  7.2765e-01, -3.1021e-01],\n",
       "        [-1.2588e+00,  0.0000e+00,  3.1239e-01, -5.5908e-02, -1.1730e+00],\n",
       "        [-8.6566e-01,  3.1239e-01,  0.0000e+00,  3.6441e-01, -6.6327e-01],\n",
       "        [ 7.2765e-01, -5.5908e-02,  3.6441e-01,  0.0000e+00, -4.7950e-01],\n",
       "        [-3.1021e-01, -1.1730e+00, -6.6327e-01, -4.7950e-01,  0.0000e+00],\n",
       "        [ 0.0000e+00,  0.0000e+00, -1.3832e+00,  1.3584e+00, -2.6976e-01],\n",
       "        [ 0.0000e+00, -1.3832e+00,  0.0000e+00, -1.1158e+00, -8.9250e-01],\n",
       "        [ 0.0000e+00,  1.3584e+00, -1.1158e+00,  0.0000e+00, -2.1357e+00],\n",
       "        [ 0.0000e+00, -2.6976e-01, -8.9250e-01, -2.1357e+00,  0.0000e+00],\n",
       "        [ 0.0000e+00, -1.3976e-01, -1.3861e-01,  2.0236e-02,  8.9989e-01],\n",
       "        [-1.3976e-01,  0.0000e+00, -5.8457e-01, -6.9294e-01, -2.4928e-01],\n",
       "        [-1.3861e-01, -5.8457e-01,  0.0000e+00, -4.9915e-01, -8.1480e-01],\n",
       "        [ 2.0236e-02, -6.9294e-01, -4.9915e-01,  0.0000e+00, -1.4321e+00],\n",
       "        [ 8.9989e-01, -2.4928e-01, -8.1480e-01, -1.4321e+00,  0.0000e+00],\n",
       "        [ 0.0000e+00, -8.4583e-01,  1.2539e+00,  1.8076e-01, -4.2165e-01],\n",
       "        [-8.4583e-01,  0.0000e+00,  9.0766e-02,  2.4930e+00,  6.0680e-01],\n",
       "        [ 1.2539e+00,  9.0766e-02,  0.0000e+00, -3.7567e-01,  2.0802e+00],\n",
       "        [ 1.8076e-01,  2.4930e+00, -3.7567e-01,  0.0000e+00, -2.3600e+00],\n",
       "        [-4.2165e-01,  6.0680e-01,  2.0802e+00, -2.3600e+00,  0.0000e+00],\n",
       "        [ 0.0000e+00, -6.9941e-01, -1.7597e+00,  0.0000e+00, -1.7529e+00],\n",
       "        [-6.9941e-01,  0.0000e+00, -1.6523e+00,  0.0000e+00,  1.4579e-01],\n",
       "        [-1.7597e+00, -1.6523e+00,  0.0000e+00,  0.0000e+00,  6.0141e-01],\n",
       "        [-1.7529e+00,  1.4579e-01,  6.0141e-01,  0.0000e+00,  0.0000e+00],\n",
       "        [ 0.0000e+00,  7.5250e-01,  6.7383e-02, -6.9153e-02, -2.7389e-01],\n",
       "        [ 7.5250e-01,  0.0000e+00, -2.8165e-02, -1.1969e+00,  7.7850e-01],\n",
       "        [ 6.7383e-02, -2.8165e-02,  0.0000e+00,  1.4493e+00,  1.3870e+00],\n",
       "        [-6.9153e-02, -1.1969e+00,  1.4493e+00,  0.0000e+00, -4.4043e-01],\n",
       "        [-2.7389e-01,  7.7850e-01,  1.3870e+00, -4.4043e-01,  0.0000e+00],\n",
       "        [ 0.0000e+00, -7.2092e-01, -3.1775e-01,  8.7612e-01,  6.1657e-02],\n",
       "        [-7.2092e-01,  0.0000e+00, -1.7527e+00,  6.6510e-01, -1.1136e+00],\n",
       "        [-3.1775e-01, -1.7527e+00,  0.0000e+00,  5.9526e-01, -6.3461e-01],\n",
       "        [ 8.7612e-01,  6.6510e-01,  5.9526e-01,  0.0000e+00, -8.9615e-01],\n",
       "        [ 6.1657e-02, -1.1136e+00, -6.3461e-01, -8.9615e-01,  0.0000e+00],\n",
       "        [ 0.0000e+00, -6.3965e-01, -7.3542e-01,  5.4220e-01, -5.9634e-01],\n",
       "        [-6.3965e-01,  0.0000e+00, -7.4788e-01, -5.6646e-01, -2.0134e-03],\n",
       "        [-7.3542e-01, -7.4788e-01,  0.0000e+00, -2.2671e-02, -2.3432e+00],\n",
       "        [ 5.4220e-01, -5.6646e-01, -2.2671e-02,  0.0000e+00,  2.4476e-01],\n",
       "        [-5.9634e-01, -2.0134e-03, -2.3432e+00,  2.4476e-01,  0.0000e+00],\n",
       "        [ 0.0000e+00,  6.7541e-01,  0.0000e+00, -6.6902e-01,  1.2297e+00],\n",
       "        [ 6.7541e-01,  0.0000e+00,  0.0000e+00,  5.1986e-01, -1.4858e-01],\n",
       "        [-6.6902e-01,  5.1986e-01,  0.0000e+00,  0.0000e+00,  1.2099e+00],\n",
       "        [ 1.2297e+00, -1.4858e-01,  0.0000e+00,  1.2099e+00,  0.0000e+00]])"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "adj_rowvecs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "a_atoms_exist = torch.einsum ('ij,ik->ijk', atoms_exist, atoms_exist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_atoms_exist = torch.einsum ('ij,ik->ijk', atoms_exist, torch.ones_like(xs[:,0,:], dtype=bool))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([20, 5, 5])"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a_atoms_exist.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj_rowvecs = adjs.reshape(-1, adjs.shape[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([20, 5])"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "atoms_exist.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([100, 5])"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adj_rowvecs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) tensor([2.0879, 2.0879, 2.0848, 1.6536, 2.0848, 1.1269, 1.1016, 1.2861, 1.2478,\n",
      "        0.7816, 1.2202, 1.6040, 1.0545, 1.4072, 1.2967, 0.9913, 0.9699, 1.3165]) tensor([0.6549, 1.0000, 0.6534, 0.9916, 0.3888, 0.8984, 0.9917, 0.8140, 0.9895,\n",
      "        0.9572, 1.1343, 1.0022, 0.5926, 0.5956, 1.3517, 1.2255, 1.1950, 0.6460])\n",
      "1 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) tensor([1.0440, 1.0440, 1.0424, 0.8268, 1.0424, 0.5634, 0.5508, 0.6430, 0.6239,\n",
      "        0.3908, 0.6101, 0.8020, 0.5273, 0.7036, 0.6483, 0.4957, 0.4849, 0.6582]) tensor([0.1329, 0.3424, 0.1322, 0.4050, 0.1719, 0.3350, 0.3076, 0.1710, 0.2199,\n",
      "        0.3710, 0.2767, 0.4032, 0.1983, 0.2653, 0.6077, 0.4820, 0.4676, 0.3414])\n",
      "2 tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.5212, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "        0.0000, 0.0000, 0.0000, 0.2636, 0.3518, 0.0000, 0.0000, 0.0000, 0.3291]) tensor([0.5220, 0.5220, 0.5212, 0.4134, 1.0424, 0.2817, 0.2754, 0.3215, 0.3119,\n",
      "        0.1954, 0.3050, 0.4010, 0.5273, 0.7036, 0.3242, 0.2478, 0.2425, 0.6582]) tensor([0.0000, 0.1796, 0.0000, 0.1983, 0.1282, 0.0532, 0.2432, 0.0623, 0.1563,\n",
      "        0.0779, 0.1132, 0.1956, 0.1971, 0.1545, 0.2561, 0.1103, 0.1039, 0.1523])\n",
      "3 tensor([0.0000, 0.2610, 0.0000, 0.0000, 0.5212, 0.0000, 0.1377, 0.1608, 0.1560,\n",
      "        0.0000, 0.1525, 0.0000, 0.2636, 0.3518, 0.0000, 0.0000, 0.0000, 0.3291]) tensor([0.5220, 0.5220, 0.5212, 0.2067, 0.7818, 0.1409, 0.2754, 0.3215, 0.3119,\n",
      "        0.0977, 0.3050, 0.2005, 0.3954, 0.5277, 0.1621, 0.1239, 0.1212, 0.4937]) tensor([0.0000, 0.0814, 0.0000, 0.0949, 0.0416, 0.0000, 0.0322, 0.0551, 0.0003,\n",
      "        0.0687, 0.0479, 0.0049, 0.0006, 0.1094, 0.0940, 0.0756, 0.0780, 0.0946])\n",
      "4 tensor([0.0000, 0.2610, 0.0000, 0.0000, 0.6515, 0.0000, 0.1377, 0.1608, 0.2340,\n",
      "        0.0489, 0.1525, 0.1003, 0.3295, 0.4398, 0.0000, 0.0620, 0.0606, 0.4114]) tensor([0.5220, 0.3915, 0.5212, 0.1033, 0.7818, 0.1409, 0.2066, 0.2411, 0.3119,\n",
      "        0.0977, 0.2288, 0.2005, 0.3954, 0.5277, 0.0810, 0.1239, 0.1212, 0.4937]) tensor([0.0000, 0.0491, 0.0000, 0.0433, 0.0631, 0.0000, 0.1055, 0.0072, 0.1029,\n",
      "        0.0046, 0.0369, 0.0954, 0.0983, 0.0226, 0.0129, 0.0173, 0.0130, 0.0289])\n",
      "5 tensor([0.0000, 0.3262, 0.0000, 0.0000, 0.6515, 0.0000, 0.1721, 0.2010, 0.2340,\n",
      "        0.0489, 0.1907, 0.1003, 0.3295, 0.4398, 0.0000, 0.0620, 0.0606, 0.4114]) tensor([0.5220, 0.3915, 0.5212, 0.0517, 0.7167, 0.1409, 0.2066, 0.2411, 0.2730,\n",
      "        0.0733, 0.2288, 0.1504, 0.3625, 0.4837, 0.0405, 0.0929, 0.0909, 0.4525]) tensor([0.0000, 0.0162, 0.0000, 0.0174, 0.0214, 0.0000, 0.0367, 0.0330, 0.0444,\n",
      "        0.0320, 0.0012, 0.0452, 0.0488, 0.0434, 0.0276, 0.0291, 0.0325, 0.0329])\n",
      "6 tensor([0.0000, 0.3262, 0.0000, 0.0000, 0.6515, 0.0000, 0.1893, 0.2010, 0.2340,\n",
      "        0.0611, 0.1907, 0.1003, 0.3295, 0.4617, 0.0203, 0.0774, 0.0758, 0.4320]) tensor([0.5220, 0.3589, 0.5212, 0.0258, 0.6841, 0.1409, 0.2066, 0.2210, 0.2535,\n",
      "        0.0733, 0.2097, 0.1253, 0.3460, 0.4837, 0.0405, 0.0929, 0.0909, 0.4525]) tensor([0.0000, 0.0165, 0.0000, 0.0020, 0.0111, 0.0000, 0.0022, 0.0129, 0.0192,\n",
      "        0.0137, 0.0179, 0.0202, 0.0241, 0.0104, 0.0073, 0.0059, 0.0098, 0.0020])\n",
      "7 tensor([0.0000, 0.3425, 0.0000, 0.0129, 0.6678, 0.0000, 0.1979, 0.2010, 0.2340,\n",
      "        0.0672, 0.2002, 0.1003, 0.3295, 0.4727, 0.0304, 0.0852, 0.0833, 0.4423]) tensor([0.5220, 0.3589, 0.5212, 0.0258, 0.6841, 0.1409, 0.2066, 0.2110, 0.2437,\n",
      "        0.0733, 0.2097, 0.1128, 0.3378, 0.4837, 0.0405, 0.0929, 0.0909, 0.4525]) tensor([0.0000, 0.0002, 0.0000, 0.0110, 0.0052, 0.0000, 0.0150, 0.0029, 0.0095,\n",
      "        0.0046, 0.0083, 0.0076, 0.0118, 0.0061, 0.0028, 0.0057, 0.0016, 0.0134])\n",
      "8 tensor([0.0000, 0.3507, 0.0000, 0.0129, 0.6678, 0.0000, 0.1979, 0.2010, 0.2340,\n",
      "        0.0702, 0.2049, 0.1003, 0.3295, 0.4727, 0.0304, 0.0852, 0.0833, 0.4423]) tensor([0.5220, 0.3589, 0.5212, 0.0194, 0.6759, 0.1409, 0.2022, 0.2060, 0.2388,\n",
      "        0.0733, 0.2097, 0.1065, 0.3337, 0.4782, 0.0355, 0.0891, 0.0871, 0.4474]) tensor([0.0000e+00, 8.0011e-03, 0.0000e+00, 4.4959e-03, 2.9902e-03, 0.0000e+00,\n",
      "        6.3684e-03, 2.1694e-03, 4.5873e-03, 1.1563e-05, 3.5832e-03, 1.3748e-03,\n",
      "        5.5742e-03, 2.1822e-03, 2.2690e-03, 8.3447e-05, 4.0770e-03, 5.7096e-03])\n",
      "9 tensor([0.0000, 0.3507, 0.0000, 0.0129, 0.6719, 0.0000, 0.1979, 0.2035, 0.2340,\n",
      "        0.0702, 0.2073, 0.1003, 0.3295, 0.4755, 0.0329, 0.0871, 0.0852, 0.4423]) tensor([0.5220, 0.3548, 0.5212, 0.0161, 0.6759, 0.1409, 0.2001, 0.2060, 0.2364,\n",
      "        0.0718, 0.2097, 0.1034, 0.3316, 0.4782, 0.0355, 0.0891, 0.0871, 0.4448]) tensor([0.0000, 0.0039, 0.0000, 0.0013, 0.0011, 0.0000, 0.0021, 0.0003, 0.0022,\n",
      "        0.0023, 0.0012, 0.0018, 0.0025, 0.0019, 0.0003, 0.0028, 0.0012, 0.0019])\n",
      "10 tensor([0.0000, 0.3507, 0.0000, 0.0129, 0.6719, 0.0000, 0.1979, 0.2035, 0.2340,\n",
      "        0.0710, 0.2085, 0.1018, 0.3295, 0.4755, 0.0329, 0.0871, 0.0862, 0.4423]) tensor([0.5220, 0.3527, 0.5212, 0.0145, 0.6739, 0.1409, 0.1990, 0.2047, 0.2352,\n",
      "        0.0718, 0.2097, 0.1034, 0.3306, 0.4769, 0.0342, 0.0881, 0.0871, 0.4435]) tensor([0.0000e+00, 1.8842e-03, 0.0000e+00, 3.4857e-04, 9.5415e-04, 0.0000e+00,\n",
      "        8.6546e-05, 9.1338e-04, 9.3174e-04, 1.1332e-03, 8.5831e-06, 1.9169e-04,\n",
      "        9.4032e-04, 1.2088e-04, 1.0026e-03, 1.3688e-03, 1.8501e-04, 7.5817e-05])\n",
      "11 tensor([0.0000, 0.3507, 0.0000, 0.0137, 0.6729, 0.0000, 0.1985, 0.2041, 0.2340,\n",
      "        0.0714, 0.2085, 0.1026, 0.3295, 0.4762, 0.0336, 0.0871, 0.0862, 0.4429]) tensor([0.5220, 0.3517, 0.5212, 0.0145, 0.6739, 0.1409, 0.1990, 0.2047, 0.2346,\n",
      "        0.0718, 0.2097, 0.1034, 0.3300, 0.4769, 0.0342, 0.0876, 0.0867, 0.4435]) tensor([0.0000e+00, 8.6474e-04, 0.0000e+00, 4.5896e-04, 6.3777e-05, 0.0000e+00,\n",
      "        9.8932e-04, 2.8539e-04, 3.2246e-04, 5.6100e-04, 8.5831e-06, 5.9164e-04,\n",
      "        1.6797e-04, 9.0981e-04, 3.6931e-04, 6.4278e-04, 5.2547e-04, 8.8847e-04])\n",
      "12 tensor([0.0000, 0.3507, 0.0000, 0.0137, 0.6729, 0.0000, 0.1985, 0.2044, 0.2340,\n",
      "        0.0716, 0.2085, 0.1026, 0.3295, 0.4762, 0.0339, 0.0871, 0.0864, 0.4429]) tensor([0.5220, 0.3512, 0.5212, 0.0141, 0.6734, 0.1409, 0.1988, 0.2047, 0.2343,\n",
      "        0.0718, 0.2097, 0.1030, 0.3298, 0.4765, 0.0342, 0.0874, 0.0867, 0.4432]) tensor([0.0000e+00, 3.5501e-04, 0.0000e+00, 5.5194e-05, 4.4537e-04, 0.0000e+00,\n",
      "        4.5145e-04, 2.8610e-05, 1.7881e-05, 2.7466e-04, 8.5831e-06, 2.0003e-04,\n",
      "        2.1839e-04, 3.9458e-04, 5.2929e-05, 2.7966e-04, 1.7023e-04, 4.0638e-04])\n",
      "13 tensor([0.0000, 0.3507, 0.0000, 0.0137, 0.6731, 0.0000, 0.1985, 0.2044, 0.2340,\n",
      "        0.0717, 0.2085, 0.1026, 0.3297, 0.4762, 0.0340, 0.0871, 0.0865, 0.4429]) tensor([0.5220, 0.3510, 0.5212, 0.0139, 0.6734, 0.1409, 0.1986, 0.2046, 0.2341,\n",
      "        0.0718, 0.2097, 0.1028, 0.3298, 0.4763, 0.0342, 0.0872, 0.0867, 0.4431]) tensor([0.0000e+00, 1.0014e-04, 0.0000e+00, 1.4663e-04, 1.9073e-04, 0.0000e+00,\n",
      "        1.8251e-04, 1.2827e-04, 1.3447e-04, 1.3161e-04, 8.5831e-06, 4.2915e-06,\n",
      "        2.5272e-05, 1.3685e-04, 1.0538e-04, 9.8228e-05, 7.3910e-06, 1.6534e-04])\n",
      "14 tensor([0.0000, 0.3507, 0.0000, 0.0138, 0.6733, 0.0000, 0.1985, 0.2045, 0.2340,\n",
      "        0.0717, 0.2085, 0.1026, 0.3297, 0.4762, 0.0340, 0.0871, 0.0865, 0.4429]) tensor([0.5220, 0.3508, 0.5212, 0.0139, 0.6734, 0.1409, 0.1985, 0.2046, 0.2341,\n",
      "        0.0718, 0.2097, 0.1028, 0.3298, 0.4763, 0.0341, 0.0872, 0.0867, 0.4430]) tensor([0.0000e+00, 2.7180e-05, 0.0000e+00, 4.5776e-05, 6.3419e-05, 0.0000e+00,\n",
      "        4.8041e-05, 4.9829e-05, 5.8413e-05, 6.0081e-05, 8.5831e-06, 4.2915e-06,\n",
      "        7.1406e-05, 8.1062e-06, 2.6226e-05, 7.3910e-06, 7.3910e-06, 4.4823e-05])\n",
      "15 tensor([0.0000, 0.3508, 0.0000, 0.0139, 0.6733, 0.0000, 0.1985, 0.2045, 0.2341,\n",
      "        0.0717, 0.2085, 0.1026, 0.3297, 0.4762, 0.0340, 0.0871, 0.0865, 0.4429]) tensor([0.5220, 0.3508, 0.5212, 0.0139, 0.6734, 0.1409, 0.1985, 0.2046, 0.2341,\n",
      "        0.0718, 0.2097, 0.1028, 0.3298, 0.4763, 0.0341, 0.0872, 0.0867, 0.4429]) tensor([0.0000e+00, 3.6478e-05, 0.0000e+00, 4.7684e-06, 0.0000e+00, 0.0000e+00,\n",
      "        1.9312e-05, 1.0490e-05, 2.0266e-05, 2.4319e-05, 8.5831e-06, 4.2915e-06,\n",
      "        2.3127e-05, 8.1062e-06, 1.3351e-05, 7.3910e-06, 7.3910e-06, 1.5497e-05])\n",
      "16 tensor([0.0000, 0.3508, 0.0000, 0.0139, 0.6733, 0.0000, 0.1985, 0.2045, 0.2341,\n",
      "        0.0717, 0.2085, 0.1026, 0.3297, 0.4762, 0.0341, 0.0871, 0.0865, 0.4429]) tensor([0.5220, 0.3508, 0.5212, 0.0139, 0.6734, 0.1409, 0.1985, 0.2046, 0.2341,\n",
      "        0.0718, 0.2097, 0.1028, 0.3297, 0.4763, 0.0341, 0.0872, 0.0867, 0.4429]) tensor([0.0000e+00, 4.5300e-06, 0.0000e+00, 4.7684e-06, 0.0000e+00, 0.0000e+00,\n",
      "        1.4424e-05, 9.0599e-06, 1.1921e-06, 6.1989e-06, 8.5831e-06, 4.2915e-06,\n",
      "        1.1921e-06, 8.1062e-06, 6.4373e-06, 7.3910e-06, 7.3910e-06, 1.4663e-05])\n",
      "17 tensor([0.0000, 0.3508, 0.0000, 0.0139, 0.6733, 0.0000, 0.1985, 0.2045, 0.2341,\n",
      "        0.0717, 0.2085, 0.1026, 0.3297, 0.4762, 0.0341, 0.0871, 0.0865, 0.4429]) tensor([0.5220, 0.3508, 0.5212, 0.0139, 0.6734, 0.1409, 0.1985, 0.2046, 0.2341,\n",
      "        0.0718, 0.2097, 0.1028, 0.3297, 0.4763, 0.0341, 0.0872, 0.0867, 0.4429]) tensor([0.0000e+00, 4.5300e-06, 0.0000e+00, 4.7684e-06, 0.0000e+00, 0.0000e+00,\n",
      "        2.3842e-06, 9.0599e-06, 1.1921e-06, 6.1989e-06, 8.5831e-06, 4.2915e-06,\n",
      "        1.1921e-06, 8.1062e-06, 6.4373e-06, 7.3910e-06, 7.3910e-06, 4.7684e-07])\n"
     ]
    }
   ],
   "source": [
    "_, proj_adjs = project (None, adjs, constraint_config=DotDict({'params': [[4, 3, 2, 1]], 'constraint': 'valency'}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0000, 1.4077, 2.5222, 1.0000, 1.9299],\n",
       "        [1.0000, 1.2206, 1.9128, 2.0000, 2.1334],\n",
       "        [0.0000, 1.4648, 1.0000, 0.0000, 0.4648],\n",
       "        [0.6544, 1.5769, 1.1963, 0.2738, 0.0000],\n",
       "        [0.8340, 2.1839, 0.6377, 2.0565, 0.7687],\n",
       "        [1.6457, 0.1503, 0.0000, 0.7960, 1.0000],\n",
       "        [1.7565, 1.5525, 0.6172, 0.3080, 2.1052],\n",
       "        [1.8858, 1.7714, 1.1846, 1.1342, 0.8697],\n",
       "        [0.0000, 2.0000, 0.6821, 1.0000, 1.6821],\n",
       "        [1.6505, 0.7307, 0.0000, 1.3752, 1.0061],\n",
       "        [1.3241, 0.8758, 1.0077, 2.1858, 0.9936],\n",
       "        [1.8030, 2.1756, 1.7707, 1.6018, 1.0000],\n",
       "        [1.4729, 2.2983, 2.4520, 1.8068, 1.4705],\n",
       "        [0.0000, 1.0000, 0.6002, 0.6002, 1.0000],\n",
       "        [2.1086, 1.3598, 0.9680, 0.4348, 1.7843],\n",
       "        [0.0000, 1.0000, 1.0000, 1.6557, 1.6557],\n",
       "        [1.1886, 2.0000, 0.1342, 1.3228, 0.0000],\n",
       "        [1.2496, 0.0299, 1.3856, 2.0000, 0.1659],\n",
       "        [0.7962, 1.7962, 0.9184, 1.8729, 2.0454],\n",
       "        [1.0000, 1.8739, 0.3559, 0.2722, 0.2457]])"
      ]
     },
     "execution_count": 121,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "proj_adjs.sum(dim=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 2.6588,  0.8333, -2.0000, -2.0000,  3.7330, -2.0000,  2.8666,  1.0321,\n",
       "        -2.0000, -2.0000, -2.0000,  4.1018, -2.0000, -2.0000,  3.8333, -2.0000,\n",
       "         3.0354,  3.8333,  2.8145, -2.0000])"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(1/6 * torch.diagonal(torch.matrix_power(proj_adjs, 3), dim1=1, dim2=2).sum(dim=1) - 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 0.],\n",
       "         [0., 1.]],\n",
       "\n",
       "        [[1., 0.],\n",
       "         [0., 1.]]], dtype=torch.float64)"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor (\n",
    "    [[[0, 1], \n",
    "      [1, 0]],\n",
    "     [[1, 0], \n",
    "      [0, 1]]], dtype=float)\n",
    "torch.matrix_power (x, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-4.7987, -1.5312, -2.5639, -2.0592, -0.6620],\n",
       "        [-1.9561, -1.9998, -0.1296,  1.8105,  2.7468],\n",
       "        [-2.7406, -2.7272,  2.6001,  2.9406,  2.0861],\n",
       "        [ 1.2232, -2.4412, -2.5277, -2.4936, -1.1136],\n",
       "        [ 0.6099,  0.2935,  2.0340,  0.0920,  1.8067],\n",
       "        [-0.9919, -3.6611, -0.9305,  2.1824,  0.8618],\n",
       "        [ 2.5044,  5.2939,  3.5612, -4.8793, -5.4360],\n",
       "        [ 3.5309,  1.8916,  1.0537,  0.5965, -3.1797],\n",
       "        [ 3.1252, -2.3439,  2.3019,  1.9706, -0.0110],\n",
       "        [-3.6103,  0.2631,  3.4328, -1.9763,  1.2011],\n",
       "        [ 0.7022, -1.9996, -0.7116, -0.6160, -0.7453],\n",
       "        [-1.2445, -2.1483,  3.4769,  0.6923,  1.8911],\n",
       "        [-0.2467, -2.3674, -2.6322, -1.6952,  2.1274],\n",
       "        [ 2.4438, -1.0605,  1.2864,  4.9133,  3.9178],\n",
       "        [-0.7750, -2.3871, -2.7681,  3.7462,  1.4856],\n",
       "        [-2.2840, -2.9864, -0.6772, -0.9919,  1.0436],\n",
       "        [-1.0775,  2.8517, -1.0981, -2.6815,  0.3869],\n",
       "        [ 0.2351, -1.5022, -1.8544, -1.4144,  1.9526],\n",
       "        [-0.9523, -0.4195, -0.3708, -0.6312, -0.0285],\n",
       "        [ 4.2561,  1.5670,  1.3971,  4.4840, -2.2366]])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs.sum(dim=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-4.7987, -1.5312, -2.5639, -2.0592, -0.6620],\n",
       "        [-1.9561, -1.9998, -0.1296,  1.8105,  2.7468],\n",
       "        [-2.7406, -2.7272,  2.6001,  2.9406,  2.0861],\n",
       "        [ 1.2232, -2.4412, -2.5277, -2.4936, -1.1136],\n",
       "        [ 0.6099,  0.2935,  2.0340,  0.0920,  1.8067],\n",
       "        [-0.9919, -3.6611, -0.9305,  2.1824,  0.8618],\n",
       "        [ 2.5044,  5.2939,  3.5612, -4.8793, -5.4360],\n",
       "        [ 3.5309,  1.8916,  1.0537,  0.5965, -3.1797],\n",
       "        [ 3.1252, -2.3439,  2.3019,  1.9706, -0.0110],\n",
       "        [-3.6103,  0.2631,  3.4328, -1.9763,  1.2011],\n",
       "        [ 0.7022, -1.9996, -0.7116, -0.6160, -0.7453],\n",
       "        [-1.2445, -2.1483,  3.4769,  0.6923,  1.8911],\n",
       "        [-0.2467, -2.3674, -2.6322, -1.6952,  2.1274],\n",
       "        [ 2.4438, -1.0605,  1.2864,  4.9133,  3.9178],\n",
       "        [-0.7750, -2.3871, -2.7681,  3.7462,  1.4856],\n",
       "        [-2.2840, -2.9864, -0.6772, -0.9919,  1.0436],\n",
       "        [-1.0775,  2.8517, -1.0981, -2.6815,  0.3869],\n",
       "        [ 0.2351, -1.5022, -1.8544, -1.4144,  1.9526],\n",
       "        [-0.9523, -0.4195, -0.3708, -0.6312, -0.0285],\n",
       "        [ 4.2561,  1.5670,  1.3971,  4.4840, -2.2366]])"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs.reshape(-1, 5).sum(dim=1).reshape(20, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.0000, 0.0000, 0.0000, 0.6857, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 1.0000, 1.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 1.0000, 0.0000],\n",
       "         [0.6857, 1.0000, 1.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 1.0000, 0.0000, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0000, 0.0000, 0.2432, 0.0000],\n",
       "         [0.0000, 0.0000, 0.5205, 0.0000, 0.0000],\n",
       "         [0.0000, 0.5205, 0.0000, 0.0000, 0.6612],\n",
       "         [0.2432, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.6612, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.3518, 0.3200, 0.0000, 0.0000],\n",
       "         [0.3518, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.3200, 0.0000, 0.0000, 0.5105, 1.0000],\n",
       "         [0.0000, 0.0000, 0.5105, 0.0000, 0.8918],\n",
       "         [0.0000, 0.0000, 1.0000, 0.8918, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0000, 0.6362, 0.0000, 0.3154],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.1098],\n",
       "         [0.6362, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.3154, 0.1098, 0.0000, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0000, 1.0000, 0.0000, 0.1220],\n",
       "         [0.0000, 0.0000, 0.0000, 0.8544, 0.1341],\n",
       "         [1.0000, 0.0000, 0.0000, 0.0510, 0.0000],\n",
       "         [0.0000, 0.8544, 0.0510, 0.0000, 0.0000],\n",
       "         [0.1220, 0.1341, 0.0000, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 1.0000, 0.7205, 0.0000, 0.0000],\n",
       "         [1.0000, 0.0000, 0.1282, 0.0000, 0.9125],\n",
       "         [0.7205, 0.1282, 0.0000, 1.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 1.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.9125, 0.0000, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0000, 0.3349, 0.0000, 1.0000],\n",
       "         [0.0000, 0.0000, 1.0000, 0.7150, 0.4718],\n",
       "         [0.3349, 1.0000, 0.0000, 1.0000, 0.0000],\n",
       "         [0.0000, 0.7150, 1.0000, 0.0000, 0.0115],\n",
       "         [1.0000, 0.4718, 0.0000, 0.0115, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.6425, 1.0000, 0.5154, 0.1638],\n",
       "         [0.6425, 0.0000, 1.0000, 0.0000, 0.8106],\n",
       "         [1.0000, 1.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.5154, 0.0000, 0.0000, 0.0000, 0.4239],\n",
       "         [0.1638, 0.8106, 0.0000, 0.4239, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.9066, 0.8572, 0.0000, 0.7639],\n",
       "         [0.9066, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.8572, 0.0000, 0.0000, 1.0000, 0.9325],\n",
       "         [0.0000, 0.0000, 1.0000, 0.0000, 0.0000],\n",
       "         [0.7639, 0.0000, 0.9325, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.3334, 0.0000],\n",
       "         [0.0000, 0.0000, 0.3334, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 1.0000, 0.0000, 1.0000, 0.0000],\n",
       "         [1.0000, 0.0000, 0.2594, 0.5394, 0.0000],\n",
       "         [0.0000, 0.2594, 0.0000, 0.0000, 0.0000],\n",
       "         [1.0000, 0.5394, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.4416, 0.9391, 0.0000, 0.0000],\n",
       "         [0.4416, 0.0000, 0.9380, 0.0000, 0.0000],\n",
       "         [0.9391, 0.9380, 0.0000, 0.0000, 1.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.3543],\n",
       "         [0.0000, 0.0000, 1.0000, 0.3543, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.2802, 0.0000, 0.0000],\n",
       "         [0.0000, 0.2802, 0.0000, 0.0000, 0.5133],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0958],\n",
       "         [0.0000, 0.0000, 0.5133, 0.0958, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0148, 0.0387, 1.0000, 0.5631],\n",
       "         [0.0148, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0387, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [1.0000, 0.0000, 0.0000, 0.0000, 1.0000],\n",
       "         [0.5631, 0.0000, 0.0000, 1.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0000, 0.9372, 0.5778, 0.0000],\n",
       "         [0.0000, 0.0000, 1.0000, 0.0000, 1.0000],\n",
       "         [0.9372, 1.0000, 0.0000, 0.1787, 0.0000],\n",
       "         [0.5778, 0.0000, 0.1787, 0.0000, 1.0000],\n",
       "         [0.0000, 1.0000, 0.0000, 1.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.4880, 0.0000, 0.0000, 0.2085],\n",
       "         [0.4880, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.2085, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.3097, 0.0000, 0.0790, 0.0000],\n",
       "         [0.3097, 0.0000, 0.9157, 0.4220, 0.4603],\n",
       "         [0.0000, 0.9157, 0.0000, 0.0000, 0.2338],\n",
       "         [0.0790, 0.4220, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.4603, 0.2338, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0000, 0.0000, 1.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.1588, 0.2455],\n",
       "         [0.0000, 0.0000, 0.0000, 0.1626, 0.0000],\n",
       "         [1.0000, 0.1588, 0.1626, 0.0000, 0.4468],\n",
       "         [0.0000, 0.2455, 0.0000, 0.4468, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0000, 0.0000, 0.0000, 0.4343],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.0000, 0.0000, 0.0000, 0.0383, 1.0000],\n",
       "         [0.0000, 0.0000, 0.0383, 0.0000, 0.0000],\n",
       "         [0.4343, 0.0000, 1.0000, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.0000, 0.0000, 0.0000, 0.6867, 1.0000],\n",
       "         [0.0000, 0.0000, 0.7777, 0.0000, 1.0000],\n",
       "         [0.0000, 0.7777, 0.0000, 0.0000, 0.0000],\n",
       "         [0.6867, 0.0000, 0.0000, 0.0000, 1.0000],\n",
       "         [1.0000, 1.0000, 0.0000, 1.0000, 0.0000]]])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "proj_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RDKit WARNING: [17:59:27] Enabling RDKit 2019.09.3 jupyter extensions\n",
      "[17:59:27] Enabling RDKit 2019.09.3 jupyter extensions\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import pickle\n",
    "import math\n",
    "import torch\n",
    "\n",
    "from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log\n",
    "from utils.loader import load_ckpt, load_data, load_seed, load_device, load_model_from_ckpt, \\\n",
    "                         load_ema_from_ckpt, load_sampling_fn, load_eval_settings\n",
    "from utils.graph_utils import adjs_to_graphs, init_flags, quantize, quantize_mol\n",
    "from utils.plot import save_graph_list, plot_graphs_list\n",
    "from evaluation.stats import eval_graph_list\n",
    "from utils.mol_utils import gen_mol, mols_to_smiles, load_smiles, canonicalize_smiles, mols_to_nx\n",
    "import projop\n",
    "import networkx as nx\n",
    "from parsers.config import get_config, get_constraint_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_config = \"sample_community_small\"\n",
    "constr_config = \"numtriangles\"\n",
    "seed = 42\n",
    "device = 'cuda:1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = get_config (dataset_config, seed=42)\n",
    "constraint_config = get_constraint_config (constr_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "\n",
    "master_constr = yaml.load(open(f\"config/constraints/master_{constr_config}.yaml\", 'r'), Loader=yaml.FullLoader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint_config['add_diff_step'] = -998\n",
    "constraint_config['params'] = [master_constr[dataset_config[len('sample_'):]][1]]\n",
    "constraint_config['schedule']['params'] = [1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./checkpoints/community_small/gdss_community_small.pth loaded\n"
     ]
    }
   ],
   "source": [
    "from sampler import Sampler, Sampler_mol\n",
    "sampler = Sampler(config, constraint_config, device='cuda:1') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.loader import load_sde\n",
    "\n",
    "max_node_num  = sampler.configt.data.max_node_num\n",
    "\n",
    "sde_x = load_sde(sampler.configt.sde.x)\n",
    "sde_adj = load_sde(sampler.configt.sde.adj)\n",
    "\n",
    "if sampler.configt.data.data in ['QM9', 'ZINC250k']:\n",
    "    shape_x = (10000, max_node_num, sampler.configt.data.max_feat_num)\n",
    "    shape_adj = (10000, max_node_num, max_node_num)\n",
    "else:\n",
    "    shape_x = (sampler.configt.data.batch_size, max_node_num, sampler.configt.data.max_feat_num)\n",
    "    shape_adj = (sampler.configt.data.batch_size, max_node_num, max_node_num)\n",
    "\n",
    "x = sde_x.prior_sampling(shape_x).to(device) \n",
    "adj = sde_adj.prior_sampling_sym(shape_adj).to(device) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint_config.params = [10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 tensor(11979.3105, device='cuda:1')\n",
      "1 tensor(0., device='cuda:1') tensor(235.0992, device='cuda:1')\n",
      "2 tensor(11963.4014, device='cuda:1')\n",
      "2 tensor(0., device='cuda:1') tensor(57.7800, device='cuda:1')\n",
      "3 tensor(12024.3701, device='cuda:1')\n",
      "3 tensor(0., device='cuda:1') tensor(55.6578, device='cuda:1')\n",
      "4 tensor(12461.9170, device='cuda:1')\n",
      "4 tensor(0., device='cuda:1') tensor(57.9860, device='cuda:1')\n",
      "tensor([ 82.2185, 156.1621, 152.2067, 126.3166,  58.1924,  61.2550, 334.1614,\n",
      "         38.4233,  37.9830, 159.7070, 529.2715, 291.0439, 144.4437, 432.9718,\n",
      "        159.3355, 183.1663, 153.5680, 148.4521,  54.8694,  59.0382,  14.7854,\n",
      "         71.6933, 264.0844, 226.5238, 162.7674, 561.4009, 195.6279, 278.6697,\n",
      "        345.9526, 104.0098, 318.0472, 163.0284, 407.3328,  97.5226, 111.1300,\n",
      "        319.5346,  26.0094, 201.3893, 368.2536, 244.3105, 157.9790, 148.7738,\n",
      "         66.5313, 104.1375, 364.9820, 448.8322, 116.6736,  66.5017, 444.6759,\n",
      "        274.3890, 182.3180,  37.6162, 485.3823, 210.4088,  89.9142, 194.5498,\n",
      "        232.6603, 410.5986,  85.3284, 360.8667, 194.6328,  75.6328, 259.3267,\n",
      "         10.9616,  37.0351, 226.4734, 126.5987, 200.3584, 327.3850, 371.2053,\n",
      "        235.6426, 362.5858, 351.8002,  59.5653, 151.5585, 290.7868, 352.8215,\n",
      "        296.5204, 191.5351, 232.0007, 401.2167, 188.8650, 322.7955,  45.1182,\n",
      "         30.5101, 330.5652,  46.8660, 341.8208, 307.3408, 151.4611, 270.8486,\n",
      "        117.2805, 639.2899,  50.2729, 149.9730, 366.1269, 313.6704, 287.3549,\n",
      "        153.1784, 473.6024, 101.9081, 146.0165,  74.4970, 450.5396, 156.9429,\n",
      "        153.3126, 145.2298, 154.7684, 340.2103,  70.8153,  23.1570, 172.0048,\n",
      "        243.1979, 247.9174, 343.2309, 109.4863, 556.6619, 526.8278, 217.8053,\n",
      "         99.4630, 280.8288,  46.5637,  81.4293, 654.2053, 171.5161, 145.7405,\n",
      "        537.3351, 189.0386], device='cuda:1')\n"
     ]
    }
   ],
   "source": [
    "import projop\n",
    "\n",
    "pxs, padjs = projop.drift_transformProject(x, adj, constraint_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000,\n",
       "        185.0001, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000,\n",
       "        185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000,\n",
       "        185.0000, 185.0000, 185.0000, 185.0000, 185.0001, 185.0000, 185.0000,\n",
       "        185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000,\n",
       "        184.9999, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 184.9999,\n",
       "        185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000,\n",
       "        185.0000, 184.9999, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000,\n",
       "        185.0000, 185.0000, 184.9999, 185.0000, 185.0000, 185.0000, 185.0000,\n",
       "        185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000,\n",
       "        185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0002, 185.0000,\n",
       "        185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000,\n",
       "        185.0000, 185.0000, 185.0000, 184.9999, 185.0000, 185.0000, 185.0000,\n",
       "        184.9998, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000,\n",
       "        185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000,\n",
       "        185.0000, 185.0001, 185.0000, 185.0000, 185.0000, 185.0001, 185.0000,\n",
       "        185.0000, 185.0000, 185.0000, 185.0000, 185.0002, 185.0000, 185.0000,\n",
       "        185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000, 185.0000,\n",
       "        185.0000, 185.0000], device='cuda:1')"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.diagonal(torch.matrix_power(padjs, 3), dim1=1, dim2=2).sum(dim=1)/6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.graph_utils import *\n",
    "import torch\n",
    "\n",
    "adjs = torch.rand (10, 10)\n",
    "xs = torch.rand(10, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from projop.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjs = torch.tensor([[0, 1, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0], [0, 0, 0, 1, 0, 1], [0, 0, 0, 0, 1, 0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs_mask, adjs_mask = implicitConstr_transform (xs, adjs, \"QM9\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 1, 0, 0, 0, 0],\n",
       "        [1, 0, 1, 0, 0, 0],\n",
       "        [0, 1, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0]])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adjs_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjs_int = quantize (adjs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-2.0062, -1.7219, -1.5068, -1.1191, -0.3437,  0.0168,  1.3352,  1.9200,\n",
       "         2.2416,  5.1841])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.linalg.eigvalsh(adjs_int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "linalg_eigvalsh(): argument 'input' (position 1) must be Tensor, not matrix",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_1348331/588399732.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meigvalsh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madjacency_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnxgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtodense\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mnxgraph\u001b[0m \u001b[0;32min\u001b[0m \u001b[0madjs_to_graphs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0madjs_int\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/tmp/ipykernel_1348331/588399732.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meigvalsh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madjacency_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnxgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtodense\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mnxgraph\u001b[0m \u001b[0;32min\u001b[0m \u001b[0madjs_to_graphs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0madjs_int\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m: linalg_eigvalsh(): argument 'input' (position 1) must be Tensor, not matrix"
     ]
    }
   ],
   "source": [
    "[torch.linalg.eigvalsh(torch.tensor(nx.adjacency_matrix(nxgraph).todense())) for nxgraph in adjs_to_graphs([adjs_int.numpy()])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = get_config (dataset_config, seed=42)\n",
    "constraint_config = get_constraint_config (constr_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint_config['params'] = [6]\n",
    "constraint_config['add_diff_step'] = -998"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./checkpoints/community_small/gdss_community_small.pth loaded\n"
     ]
    }
   ],
   "source": [
    "from sampler import Sampler, Sampler_mol\n",
    "sampler = Sampler(config, constraint_config, device='cuda:1') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------------------------------------------\n",
      "Make Directory community_small/test/Rank in Logs\n",
      "gdss_community_small-proj-1p0-fixed1-0--998-360p0-none\n",
      "----------------------------------------------------------------------------------------------------\n",
      "[community_small]   init=deg (10)   seed=12   batch_size=128\n",
      "----------------------------------------------------------------------------------------------------\n",
      "lr=0.01 schedule=True ema=0.999 epochs=5000 reduce=False eps=1e-05\n",
      "(ScoreNetworkX)+(ScoreNetworkA=GCN,4)   : depth=3 adim=32 nhid=32 layers=5 linears=2 c=(2 8 4)\n",
      "(x:VP)=(0.10, 1.00) N=1000 (adj:VP)=(0.10, 1.00) N=1000\n",
      "----------------------------------------------------------------------------------------------------\n",
      "(Euler)+(Langevin): eps=0.0001 denoise=True ema=False || snr=0.05 seps=0.7 n_steps=10 \n",
      "----------------------------------------------------------------------------------------------------\n",
      "GEN SEED: 42\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": []
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Constraint Validity before round: 1.0\n",
      "Round 0 : 6.82s\n",
      "Constraint Validity: 1.0\n",
      "Constraint Validity: 1.0\n",
      "\u001b[91mdegree   \u001b[0m : \u001b[94m0.166190\u001b[0m\n",
      "\u001b[91mcluster  \u001b[0m : \u001b[94m0.760776\u001b[0m\n",
      "\u001b[91morbit    \u001b[0m : \u001b[94m0.354999\u001b[0m\n",
      "\u001b[91mspectral \u001b[0m : \u001b[94m0.220052\u001b[0m\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "sampler.sample()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RDKit WARNING: [16:26:47] Enabling RDKit 2019.09.3 jupyter extensions\n",
      "[16:26:47] Enabling RDKit 2019.09.3 jupyter extensions\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import pickle\n",
    "import math\n",
    "import torch\n",
    "\n",
    "from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log\n",
    "from utils.loader import load_ckpt, load_data, load_seed, load_device, load_model_from_ckpt, \\\n",
    "                         load_ema_from_ckpt, load_sampling_fn, load_eval_settings\n",
    "from utils.graph_utils import adjs_to_graphs, init_flags, quantize, quantize_mol\n",
    "from utils.plot import save_graph_list, plot_graphs_list\n",
    "from evaluation.stats import eval_graph_list\n",
    "from utils.mol_utils import gen_mol, mols_to_smiles, load_smiles, canonicalize_smiles, mols_to_nx\n",
    "import projop\n",
    "import networkx as nx\n",
    "from parsers.config import get_config, get_constraint_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.load('temp_x.pt')\n",
    "samples = torch.load('temp_samples.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ksharma323/miniconda3/envs/moltemp/lib/python3.7/site-packages/ipykernel_launcher.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  \"\"\"Entry point for launching an IPython kernel.\n"
     ]
    }
   ],
   "source": [
    "adj = torch.nn.functional.one_hot(torch.tensor(samples), num_classes=4).permute(0, 3, 1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_mols = pickle.load(open(\"gen_mols.pkl\", 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_smiles = []\n",
    "with open (\"logs_sample/QM9/test/Valency/gdss_qm9-proj-0p0-0-[4, 3, 2, 1]-none.txt\", 'r') as f:\n",
    "    for line in f:\n",
    "        gen_smiles.append(line[:-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.mol_utils import *\n",
    "gen_mols = smiles_to_mols(gen_smiles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "nx_graphs = [mol_to_nx(mol) for mol in gen_mols]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "nx_graphs = mols_to_nx (gen_mols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_mols = len(gen_mols)\n",
    "xs = torch.zeros (num_mols, 20, 4)\n",
    "atom_id_map = {'C': 0, 'N': 1, 'O': 2, 'F': 3, 'P': 4, 'S': 5, 'Cl': 6, 'Br': 7, 'I': 8}\n",
    "for i, G in enumerate(nx_graphs):\n",
    "    xs[i, torch.arange(len(G.nodes)), [atom_id_map[x['label']] for x in G.nodes().values()]] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjs = torch.zeros(len(nx_graphs), 20, 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = nx_graphs[0]\n",
    "nG = G.number_of_nodes()\n",
    "adjs[0, :nG, :nG] = torch.tensor(nx.adjacency_matrix(G).todense())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 1., 0., 0., 0., 0., 0., 0.],\n",
       "        [1., 0., 1., 0., 0., 0., 0., 0.],\n",
       "        [0., 1., 0., 1., 0., 0., 0., 0.],\n",
       "        [0., 0., 1., 0., 1., 1., 0., 0.],\n",
       "        [0., 0., 0., 1., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 1., 0., 0., 1., 1.],\n",
       "        [0., 0., 0., 0., 0., 1., 0., 0.],\n",
       "        [0., 0., 0., 0., 0., 1., 0., 0.]])"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adjs[0, :nG, :nG]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 1, 2, 3, 4, 5, 6, 7]"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(nx_graphs[0].nodes())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'QM9'\n",
    "atom_id_map = {'QM9': {'C': 0, 'N': 1, 'O': 2, 'F': 3}}\n",
    "x = np.zeros((nx_graphs[1].nodes()len(atom_id_map[dataset]))\n",
    "for n in nx_graphs[1].nodes().values():\n",
    "    x[atom_id_map[dataset][n['label']]] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1., 1., 1., 0.])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "matrix([[0, 1, 0, 0, 0, 0, 0],\n",
       "        [1, 0, 2, 0, 0, 0, 0],\n",
       "        [0, 2, 0, 1, 1, 0, 0],\n",
       "        [0, 0, 1, 0, 0, 0, 0],\n",
       "        [0, 0, 1, 0, 0, 1, 0],\n",
       "        [0, 0, 0, 0, 1, 0, 2],\n",
       "        [0, 0, 0, 0, 0, 2, 0]], dtype=int64)"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nx.adjacency_matrix(nx_graphs[1], weight='label').todense()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['C', 'O', 'C', 'N', 'N', 'C', 'N']"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[x['atom_symbol'] for x in nx_graphs[1].nodes().values()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'bond_type': rdkit.Chem.rdchem.BondType.SINGLE}"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nx_graphs[1].edges()[(0, 1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['C', 'O', 'C', 'C', 'C', 'C', 'C', 'N']"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[x['atom_symbol'] for x in nx_graphs[0].nodes().values()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<rdkit.Chem.rdchem.Mol at 0x7f0eb232a3f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232a4b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232acf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232af30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232a8f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232aaf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232a3b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232ad30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232a970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232a130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232a930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232abf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232adf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232a870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232ae30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232a070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232acb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232a730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232aeb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232abb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232a2b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232ac70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232ac30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232a030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232cbf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232cdb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232cef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232cf30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232cc30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c0f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232cb70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232cc70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232cbb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232ccb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c9b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c9f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c1b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c1f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c5f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232ca70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c5b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232ca30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c3f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232ccf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232cfb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232c7b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d1b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232dd30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d9f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d8b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232dfb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232dbb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d7b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232de30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232def0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232de70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232df70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d7f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232deb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232daf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232ddf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232dc30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d3b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232d4f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0f4016dc30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0f4016dd70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0f4016dc70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0f4016daf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb97ab0f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0f4014ebb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0f4014eb70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0f4014ee70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0f4014e830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0f4014e8f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0f4014edb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0f4014ebf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b3b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b0f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b8f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232bc70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232bcb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232be70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b9f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232bc30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b7b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232bab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232bcf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232bdb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b4f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232baf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232ba30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232bf70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233cf70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c1b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c0f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c4f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233cdb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233cfb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233cef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c7f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c3f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233cc30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233ccb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233cb30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233ca70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c8f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233cbf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233caf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233cab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233ca30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233c470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb233ccf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23372f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23375f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23375b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337df0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337d70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337ab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337cb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23377f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23371f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23370b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337eb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23378f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337bf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23370f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23377b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23373b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337e30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23374b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337f30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23371b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23376b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337bb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2337230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d00f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d01f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d02f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d03f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d04f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d05f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d06f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d07f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d08f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d09f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0a70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0af0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0b70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0bf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0c70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0cf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0d70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0df0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0e70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0ef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d0f70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d70b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d71b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d72b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d73b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d74b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d75b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d76b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d77b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d78b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d79b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7a30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7ab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7b30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7bb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7c30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7cb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7d30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7db0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7e30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7eb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7f30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22d7fb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e50f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e51f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e52f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e53f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e54f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e55f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e56f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e57f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e58f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e59f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5a70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5af0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5b70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5bf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5c70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5cf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5d70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5df0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5e70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5ef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22e5f70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db0b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db1b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db2b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db3b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db4b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db5b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db6b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db7b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db8b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22db9b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22dba30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22dbab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22dbb30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22dbbb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22dbc30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22dbcb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22dbd30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22dbdb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22dbe30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22dbeb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22dbf30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22dbfb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee0f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee1f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee2f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee3f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee4f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee5f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee6f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee7f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee8f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ee9f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22eea70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22eeaf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22eeb70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22eebf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22eec70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22eecf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22eed70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22eedf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22eee70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22eeef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22eef70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f10b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f11b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f12b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f13b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f14b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f15b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f16b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f17b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f18b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f19b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1a30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1ab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1b30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1bb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1c30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1cb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1d30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1db0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1e30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1eb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1f30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f1fb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f50f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f51f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f52f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f53f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f54f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f55f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f56f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f57f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f58f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f59f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5a70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5af0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5b70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5bf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5c70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5cf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5d70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5df0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5e70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5ef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f5f70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f70b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f71b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f72b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f73b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f74b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f75b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f76b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f77b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f78b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f79b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7a30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7ab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7b30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7bb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7c30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7cb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7d30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7db0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7e30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7eb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7f30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22f7fb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa0f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa1f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa2f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa3f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa4f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa5f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa6f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa7f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa8f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fa9f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22faa70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22faaf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fab70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fabf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fac70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22facf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fad70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fadf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fae70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22faef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22faf70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb0b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb1b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb2b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb3b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb4b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb5b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb6b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb7b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb8b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fb9b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fba30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fbab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fbb30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fbbb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fbc30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fbcb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fbd30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fbdb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fbe30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fbeb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fbf30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22fbfb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23010f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23011f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23012f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23013f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23014f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23015f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23016f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23017f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23018f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23019f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301a70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301af0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301b70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301bf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301c70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301cf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301d70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301df0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301e70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301ef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2301f70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23070b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23071b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23072b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23073b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23074b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23075b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23076b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23077b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23078b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb23079b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307a30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307ab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307b30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307bb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307c30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307cb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307d30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307db0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307e30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307eb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307f30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2307fb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a0f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a1f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a2f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a3f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a4f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a5f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a6f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a7f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a8f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230a9f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230aa70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230aaf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230ab70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230abf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230ac70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230acf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230ad70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230adf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230ae70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230aef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb230af70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d0b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d1b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d2b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d3b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d4b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d5b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d6b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d7b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d8b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228d9b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228da30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228dab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231dc30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231d2f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231d630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231d730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231d5b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231dcf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231deb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231d230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231db70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231d7f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231de30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231d870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231dbf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231d5f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231daf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231ddb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231dab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb231d4b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b8b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b6f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b5b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb232b130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228db30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228dbb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228dc30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228dcb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228dd30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228ddb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228de30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228deb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228df30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb228dfb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22920f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22921f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22922f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22923f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22924f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22925f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22926f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22927f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22928f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22929f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292a70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292af0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292b70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292bf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292c70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292cf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292d70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292df0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292e70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292ef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2292f70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22960b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22961b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22962b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22963b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22964b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22965b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22966b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22967b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22968b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22969b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296a30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296ab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296b30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296bb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296c30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296cb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296d30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296db0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296e30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296eb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296f30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb2296fb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a0f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a1f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a2f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a3f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a4f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a5f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a6f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a7f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a8f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229a9f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229aa70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229aaf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229ab70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229abf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229ac70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229acf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229ad70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229adf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229ae70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229aef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb229af70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a00b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a01b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a02b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a03b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a04b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a05b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a06b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a07b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a08b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a09b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0a30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0ab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0b30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0bb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0c30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0cb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0d30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0db0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0e30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0eb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0f30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a0fb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a30f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a31f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a32f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a33f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a34f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a35f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a36f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a37f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a38f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a39f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3a70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3af0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3b70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3bf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3c70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3cf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3d70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3df0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3e70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3ef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a3f70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a40b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a41b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a42b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a43b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a44b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a45b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a46b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a47b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a48b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a49b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4a30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4ab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4b30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4bb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4c30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4cb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4d30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4db0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4e30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4eb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4f30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a4fb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a70f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a71f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a72f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a73f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a74f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a75f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a76f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a77f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a78f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a79f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7a70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7af0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7b70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7bf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7c70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7cf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7d70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7df0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7e70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7ef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22a7f70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad0b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad1b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad2b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad3b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad4b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad5b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad6b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad7b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad8b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ad9b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ada30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22adab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22adb30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22adbb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22adc30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22adcb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22add30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22addb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ade30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22adeb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22adf30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22adfb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af0f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af1f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af2f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af3f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af4f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af5f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af6f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af7f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af8f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22af9f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22afa70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22afaf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22afb70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22afbf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22afc70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22afcf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22afd70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22afdf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22afe70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22afef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22aff70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b30b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b31b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b32b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b33b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b34b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3530>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b35b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3630>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b36b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3730>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b37b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3830>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b38b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3930>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b39b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3a30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3ab0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3b30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3bb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3c30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3cb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3d30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3db0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3e30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3eb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3f30>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b3fb0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7070>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b70f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7170>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b71f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7270>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b72f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7370>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b73f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7470>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b74f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7570>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b75f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7670>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b76f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7770>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b77f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7870>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b78f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7970>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b79f0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7a70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7af0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7b70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7bf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7c70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7cf0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7d70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7df0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7e70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7ef0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22b7f70>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ba030>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ba0b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ba130>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ba1b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ba230>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ba2b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ba330>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ba3b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ba430>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ba4b0>,\n",
       " <rdkit.Chem.rdchem.Mol at 0x7f0eb22ba530>,\n",
       " ...]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open(os.path.join(self.log_dir, f'{self.log_name}.txt'), 'a') as f:\n",
    "    for smiles in gen_smiles:\n",
    "        f.write(f'{smiles}\\n')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_name = \"sample_community_small\"\n",
    "constr_config_name = \"numtriangles\"\n",
    "seed = 42\n",
    "device = 'cuda:1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./checkpoints/community_small/gdss_community_small.pth loaded\n"
     ]
    }
   ],
   "source": [
    "config = get_config(config_name, seed)\n",
    "constr_config = get_constraint_config(constr_config_name)\n",
    "ckpt_dict = load_ckpt(config, device)\n",
    "configt = ckpt_dict['config']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "constr_config['add_diff_step'] = -998"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_graph_list, test_graph_list = load_data(configt, get_graph_list=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "As = torch.stack([torch.tensor(nx.adjacency_matrix(graph).todense()).float() for graph in train_graph_list if graph.number_of_nodes() == 20])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigs, Qs = torch.linalg.eigh (As)\n",
    "eigs[torch.abs(eigs) <= 1e-5] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from projop.halfspace import hyperplane_projection_multiple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "proj_eigs3 = hyperplane_projection_multiple (torch.pow(eigs, 3), torch.ones_like(eigs[0]), 6*0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "proj_eigs = torch.sign(proj_eigs3) * torch.pow (torch.abs(proj_eigs3), 1/3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([377.9997, 377.9998, 378.0000, 378.0000, 377.9999, 378.0002, 378.0002,\n",
       "         377.9998, 378.0005, 378.0005]),\n",
       " tensor([378., 378., 378., 378., 378., 378., 378., 378., 378., 378.]))"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(eigs**3).sum(dim=1), torch.diagonal(torch.matrix_power(As, 3), dim1=1, dim2=2).sum(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------------------------------------------\n",
      "Make Directory community_small/test/L1-adj in Logs\n"
     ]
    }
   ],
   "source": [
    "log_folder_name, log_dir, _ = set_log(configt, constraint=constr_config.constraint, is_train=False)\n",
    "log_name = f\"{config.ckpt}-{constr_config.method.op}-{constr_config.method.gamma}\"\n",
    "# \"\"\"Just for testing \"\"\"\n",
    "# log_name += f\"-{constr_config.method.solve_order}\"\n",
    "log_name += f\"-{constr_config.schedule.gamma}{','.join(map(str, constr_config.schedule.params))}\"\n",
    "log_name += f\"-{constr_config.burnin}\"\n",
    "log_name += f\"-{constr_config.add_diff_step}\"\n",
    "param_vals = map (str, constr_config.params)\n",
    "log_name += f\"-{','.join(param_vals)}-{constr_config.rounding}\"\n",
    "log_name = log_name.replace(\".\", \"p\")\n",
    "save_dir = './samples/pkl/{}/{}.pkl'.format(log_folder_name, log_name)\n",
    "with open(save_dir, 'rb') as f:\n",
    "    gen_graph_list = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[91mdegree   \u001b[0m : \u001b[94m0.170196\u001b[0m\n",
      "\u001b[91mcluster  \u001b[0m : \u001b[94m0.089773\u001b[0m\n",
      "\u001b[91morbit    \u001b[0m : \u001b[94m0.078871\u001b[0m\n",
      "\u001b[91mspectral \u001b[0m : \u001b[94m0.218928\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "methods, kernels = load_eval_settings(config.data.data)\n",
    "result_dict = eval_graph_list(test_graph_list, gen_graph_list, methods=methods, kernels=kernels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'degree': 0.170196,\n",
       " 'cluster': 0.089773,\n",
       " 'orbit': 0.078871,\n",
       " 'spectral': 0.218928}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigs, _ = torch.rand (10, 5).sort(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[4.4905e-01, 4.5599e-01, 5.9510e-01, 6.9206e-01, 7.3667e-01],\n",
       "        [1.2223e-02, 4.5413e-01, 7.9859e-01, 8.7098e-01, 9.1767e-01],\n",
       "        [8.1590e-02, 4.1886e-01, 4.2031e-01, 5.4888e-01, 5.8298e-01],\n",
       "        [1.6383e-01, 1.9348e-01, 2.2915e-01, 2.3113e-01, 5.8986e-01],\n",
       "        [8.7901e-02, 1.6594e-01, 2.8093e-01, 6.1142e-01, 9.2036e-01],\n",
       "        [6.2805e-01, 7.3120e-01, 8.9768e-01, 9.1984e-01, 9.7995e-01],\n",
       "        [9.8892e-02, 2.9547e-01, 4.5890e-01, 5.0080e-01, 7.9806e-01],\n",
       "        [2.8898e-01, 4.2522e-01, 5.6801e-01, 6.1271e-01, 9.5051e-01],\n",
       "        [4.8625e-04, 1.9810e-01, 4.2217e-01, 6.0271e-01, 8.9573e-01],\n",
       "        [2.0517e-01, 2.0530e-01, 2.4582e-01, 6.5316e-01, 8.0154e-01]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nnz_inds = eigs > 0.5\n",
    "nnz_srtd = nnz_inds * torch.arange(start=eigs.shape[1], end=0, step=-1, device=eigs.device, dtype=eigs.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "atoms = [\"C\", \"N\", \"O\", \"F\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'C4N3O2F1'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "''.join([(x + y) for x, y in zip(atoms, \"4v3v2v1\".split(\"v\"))])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ruamel.yaml\n",
    "import itertools\n",
    "import subprocess\n",
    "constraint = \"cheeger\"\n",
    "dataset = \"community_small\"\n",
    "constraint_config = \"cheeger\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "yaml = ruamel.yaml.YAML()\n",
    "constraint_vals = yaml.load(open(f\"config/constraints/master_{constraint}.yaml\", 'r'))\n",
    "all_params = constraint_vals[f\"{dataset}\"]\n",
    "constr_config = yaml.load(open(f\"config/constraints/{constraint_config}.yaml\", 'r'))\n",
    "if \"method\" not in constr_config:\n",
    "    constr_config[\"method\"] = {'op':'proj', 'gamma': 0}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0.05, 10.0]]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param_vals = all_params[0]\n",
    "\n",
    "[param_vals] if type(param_vals) is not list else param_vals"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Combining results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import argparse\n",
    "from parsers.config import get_config, get_constraint_config\n",
    "import ast\n",
    "import re\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "from projop.lp_balls import l0_ball_vec\n",
    "import torch\n",
    "from projop.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.randint(2, (100,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(False)"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm (a.float(), p=0) <= 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "vmin, vmax = 0, torch.max(torch.abs(v))**2 /2 + 1\n",
    "bound = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "func = lambda mu: torch.sum(torch.abs(v) >= (2*mu)**0.5) - bound"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.5000)"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vmin, vmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(99), tensor(-1))"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "func(vmin), func(vmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "a, b = 0, torch.max(torch.abs(v)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "func=lambda mu: plus_fn(torch.abs(v) - mu).sum() - bound"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(50.), tensor(-1.))"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "func(a), func(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.9804)"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bisection(v, func=lambda mu: plus_fn(torch.abs(v) - mu).sum() - bound, a=a, b=b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-1.)"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "func(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0, tensor(1.5000))"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def bisection(v, func, a=0, b=None, epsilon=1e-5, iter_max=1e5):\n",
    "    miu = a\n",
    "    for _ in range(int(iter_max)):\n",
    "        miu = (a + b) / 2\n",
    "        # print (epsilon, func(miu), func(a), a, b, b-a, miu)\n",
    "        # return miu\n",
    "        # Check if middle point is root\n",
    "        if (func(miu) == 0.0):\n",
    "            break\n",
    "        # Decide the side to repeat the steps\n",
    "        if (func(miu) * func(a) < 0):\n",
    "            b = miu\n",
    "        else:\n",
    "            a = miu\n",
    "        if ((b - a) <= epsilon):\n",
    "            break\n",
    "    return mi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.5000)"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bisection (v, func=func, a=vmin, b=vmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(51)"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.sum(torch.abs(v) >= (2*0.5)**0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(99), tensor(-1))"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "func(vmin), func(vmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0.])"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v = a.float()\n",
    "torch.sum(torch.abs(v) >= (2*0)**0.5) - bound"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = get_config (dataset_config, seed=42)\n",
    "constraint_config = get_constraint_config (constr_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint_config['params'] = [6]\n",
    "constraint_config['add_diff_step'] = -998"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./checkpoints/community_small/gdss_community_small.pth loaded\n"
     ]
    }
   ],
   "source": [
    "from sampler import Sampler, Sampler_mol\n",
    "sampler = Sampler(config, constraint_config, device='cuda:1') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------------------------------------------\n",
      "Make Directory community_small/test/Rank in Logs\n",
      "gdss_community_small-proj-1p0-fixed1-0--998-360p0-none\n",
      "----------------------------------------------------------------------------------------------------\n",
      "[community_small]   init=deg (10)   seed=12   batch_size=128\n",
      "----------------------------------------------------------------------------------------------------\n",
      "lr=0.01 schedule=True ema=0.999 epochs=5000 reduce=False eps=1e-05\n",
      "(ScoreNetworkX)+(ScoreNetworkA=GCN,4)   : depth=3 adim=32 nhid=32 layers=5 linears=2 c=(2 8 4)\n",
      "(x:VP)=(0.10, 1.00) N=1000 (adj:VP)=(0.10, 1.00) N=1000\n",
      "----------------------------------------------------------------------------------------------------\n",
      "(Euler)+(Langevin): eps=0.0001 denoise=True ema=False || snr=0.05 seps=0.7 n_steps=10 \n",
      "----------------------------------------------------------------------------------------------------\n",
      "GEN SEED: 42\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": []
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Constraint Validity before round: 1.0\n",
      "Round 0 : 6.82s\n",
      "Constraint Validity: 1.0\n",
      "Constraint Validity: 1.0\n",
      "\u001b[91mdegree   \u001b[0m : \u001b[94m0.166190\u001b[0m\n",
      "\u001b[91mcluster  \u001b[0m : \u001b[94m0.760776\u001b[0m\n",
      "\u001b[91morbit    \u001b[0m : \u001b[94m0.354999\u001b[0m\n",
      "\u001b[91mspectral \u001b[0m : \u001b[94m0.220052\u001b[0m\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "sampler.sample()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dir = f\"logs_sample/{config.data.data}/test/{constr_config.constraint}/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_file = [f for f in os.listdir (log_dir) if f.endswith(\".log\")][0]\n",
    "\n",
    "results_df = {\"method_op\": [], \"method_gamma\": [], \"add_diff_steps\": [], \"params\": [],\n",
    "              \"rounding\": [], \"constraint_val\": [], \"constraint_val_preround\": [], #time: [] \n",
    "              \"validity_wo_corr\": [], \"valid\": [], \"unique\": [], \"fcd_test\": [], \"novelty\": [], \n",
    "              \"nspdk_mmd\": [], \"burnin\": [], \"schedule\": [], \"schedule_params\": []}\n",
    "log_df = {res_key: None for res_key in results_df}\n",
    "try:\n",
    "    attrs = log_file[:-4].split (\"-\")\n",
    "    # _, dataset = attrs[0].split (\"_\", 1)\n",
    "    method_op, method_gamma, schedule, burnin, add_diff_steps, params, rounding = attrs[1:]\n",
    "    if method_op == \"None\":\n",
    "        log_df[\"method_op\"] = \"proj\"\n",
    "        log_df[\"method_gamma\"] = 0.0\n",
    "    else:\n",
    "        log_df[\"method_op\"] = method_op\n",
    "        log_df[\"method_gamma\"] = float(method_gamma.replace(\"p\", \".\"))\n",
    "    log_df[\"add_diff_steps\"] = int(add_diff_steps)\n",
    "    log_df[\"params\"] = params.replace(\"p\", \".\") #.replace(\",\", \"|\")\n",
    "    log_df[\"rounding\"] = rounding\n",
    "    log_df[\"burnin\"] = int(burnin)\n",
    "    split_id = re.search(r'\\d', schedule).start()\n",
    "    sch_name, sch_params = schedule[:split_id], schedule[split_id:]\n",
    "    log_df[\"schedule\"] = sch_name\n",
    "    log_df[\"schedule_params\"] = sch_params.replace(\"p\", \".\").replace(\",\", \"|\")\n",
    "except:\n",
    "    print (log_dir, log_file, \"has errors.\")\n",
    "    pass\n",
    "with open (f\"{log_dir}/{log_file}\", 'r') as log_f:\n",
    "    for line in log_f:\n",
    "        line = line[:-1]\n",
    "        if line.startswith(\"Constraint Validity:\"):\n",
    "            log_df[\"constraint_val\"] = float(line.split(\": \")[1])\n",
    "        elif line.startswith(\"Constraint Validity before round:\"):\n",
    "            log_df[\"constraint_val_preround\"] = float(line.split(\": \")[1])\n",
    "        elif line.startswith(\"validity w/o correction:\"):\n",
    "            log_df[\"validity_wo_corr\"] = float(line.split(\": \")[1])\n",
    "        elif line.startswith(\"valid:\"):\n",
    "            log_df[\"valid\"] = float(line.split(\": \")[1])\n",
    "        elif line.startswith(\"unique@\"):\n",
    "            log_df[\"unique\"] = float(line.split(\": \")[1])\n",
    "        elif line.startswith(\"FCD/Test:\"):\n",
    "            log_df[\"fcd_test\"] = float(line.split(\": \")[1])\n",
    "        elif line.startswith(\"Novelty:\"):\n",
    "            log_df[\"novelty\"] = float(line.split(\": \")[1])\n",
    "        elif line.startswith(\"NSPDK MMD:\"):\n",
    "            log_df[\"nspdk_mmd\"] = float(line.split(\": \")[1])\n",
    "\n",
    "# print (log_df)\n",
    "\n",
    "if len([1 for res_log in log_df.values() if res_log is None]) == 0:\n",
    "    if method_op == \"None\":\n",
    "        for rounding in [\"none\", \"randomized\", \"repeated\"]:\n",
    "            log_df[\"rounding\"] = rounding\n",
    "            for res_key in results_df:\n",
    "                results_df[res_key].append(log_df[res_key])\n",
    "    else:\n",
    "        for res_key in results_df:\n",
    "            results_df[res_key].append(log_df[res_key])\n",
    "else:\n",
    "    print (log_dir, log_file, \"is not complete.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'12.0106v14.006855v15.9994v18.998403163'"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'v'.join(map(str, ast.literal_eval(param.split(\"],\")[0] + ']')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.8095770152180073"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(results_df)[\"fcd_test\"].max()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Rough2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle as pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_mols = pkl.load(open('gen_mols.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples_temp = torch.load('samples_temp.pt')\n",
    "x = torch.load('x_temp.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_mols[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.all(xs.sum(dim=1) <= atomCounts[None, :], dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.load('x_temp.pt')\n",
    "samples = torch.load('samples_temp.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "all_params = [[1,2,3], [4,5,6]]\n",
    "for param in itertools.product(*all_params):\n",
    "    print (param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import json\n",
    "import logging\n",
    "import os\n",
    "import pickle\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "import argparse\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.rand(2, 4, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigs, Qs = torch.linalg.eigh(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigs.shape, Qs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigs, torch.clamp(eigs[:], min=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cat((torch.clamp(eigs[:, :2], max=0), torch.clamp(eigs[:, 2:], min=1)), dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigs = torch.rand (10, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigs, _ = eigs.sort(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.clamp (torch.tensor([3, 0]), min=1, max=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigs[torch.arange(eigs.shape[0]), indices.squeeze()] = torch.clamp(eigs[torch.arange(eigs.shape[0]), indices.squeeze()], min=0.5, max=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigs[nnz_inds]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigs[:, -1] <= 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigs[nnz_inds] = torch.rand(nnz_inds.sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = torch.arange(eigs.shape[1], 0, -1)\n",
    "tmp2= (eigs > 0.5) * idx\n",
    "indices = torch.argmax(tmp2, 1, keepdim=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigs[, indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Qs @ torch.stack([torch.diag(x) for x in eigs]) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def n_community(num_communities, max_nodes, p_inter=0.05):\n",
    "    # -------- From Niu et al. (2020) --------\n",
    "    assert num_communities > 1\n",
    "    \n",
    "    one_community_size = max_nodes // num_communities\n",
    "    c_sizes = [one_community_size] * num_communities\n",
    "    total_nodes = one_community_size * num_communities\n",
    "    p_make_a_bridge = p_inter * 2 / ((num_communities - 1) * one_community_size)\n",
    "    \n",
    "    print(num_communities, total_nodes, end=' ')\n",
    "    graphs = [nx.gnp_random_graph(c_sizes[i], 0.7, seed=i) for i in range(len(c_sizes))]\n",
    "\n",
    "    G = nx.disjoint_union_all(graphs)\n",
    "    communities = list(G.subgraph(c) for c in nx.connected_components(G))\n",
    "    add_edge = 0\n",
    "    for i in range(len(communities)):\n",
    "        subG1 = communities[i]\n",
    "        nodes1 = list(subG1.nodes())\n",
    "        for j in range(i + 1, len(communities)):  # loop for C_M^2 times\n",
    "            subG2 = communities[j]\n",
    "            nodes2 = list(subG2.nodes())\n",
    "            has_inter_edge = False\n",
    "            for n1 in nodes1:  # loop for N times\n",
    "                for n2 in nodes2:  # loop for N times\n",
    "                    if np.random.rand() < p_make_a_bridge:\n",
    "                        G.add_edge(n1, n2)\n",
    "                        has_inter_edge = True\n",
    "                        add_edge += 1\n",
    "            if not has_inter_edge:\n",
    "                G.add_edge(nodes1[0], nodes2[0])\n",
    "                add_edge += 1\n",
    "    print('connected comp: ', len( list(G.subgraph(c) for c in nx.connected_components(G)) ), \n",
    "            'add edges: ', add_edge)\n",
    "    print(G.number_of_edges())\n",
    "    return G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = n_community (2, 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = np.array(nx.adjacency_matrix(g).todense())\n",
    "L = np.diag (np.sum(A, axis=1)) - A\n",
    "eigs = np.linalg.eigvalsh (L)\n",
    "eigs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = np.array([[0, 1, 1, 0, 1],\n",
    "              [1, 0, 0, 1, 1],\n",
    "              [1, 0, 0, 0, 0],\n",
    "              [0, 1, 0, 0, 1],\n",
    "              [1, 1, 0, 1, 0]])\n",
    "L = np.diag (np.sum(A, axis=1)) - A\n",
    "eigs = np.linalg.eigvalsh (L)\n",
    "eigs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "a = torch.randn(3, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a[0, torch.triu_indices(3, 3)[0], torch.triu_indices(3, 3)[1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "row_inds, col_inds = torch.triu_indices(3, 3, offset=1)\n",
    "b = torch.zeros_like(a)\n",
    "b[row_inds, col_inds] = a[row_inds, col_inds]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = b + b.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj = torch.rand(2, 5, 5)\n",
    "adj0 = torch.zeros_like(adj)\n",
    "budget = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from projop import lp_balls\n",
    "adj_vec = (adj - adj0)[:, row_inds, col_inds]\n",
    "adj_proj = lp_balls.l2_ball_vecs(adj_vec, budget)\n",
    "proj_adj = torch.zeros_like (adj)\n",
    "proj_adj[:, row_inds, col_inds] = adj0[:, row_inds, col_inds] + adj_proj\n",
    "proj_adj = proj_adj + proj_adj.T\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A[7,12] = 0\n",
    "A[12,7] = 0\n",
    "L = np.diag (np.sum(A, axis=1)) - A\n",
    "eigs = np.linalg.eigvalsh (L)\n",
    "eigs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def valency_projection_multiple (As: torch.tensor, Xs: torch.tensor, valencies: torch.tensor, hidden_Hs=True):\n",
    "    # For Molecular data only with given atom valencies\n",
    "    assert ((As.shape[0] == Xs.shape[0]) and (As.shape[1] == As.shape[2]) and (Xs.shape[1] == As.shape[0]) and \n",
    "            (Xs.shape[2] == valencies.shape[0]))\n",
    "    Xsnorm = Xs / Xs.sum(keepdim=True, dim=2)\n",
    "    Xsnorm = torch.nan_to_num(Xsnorm, nan=0)\n",
    "    # Xsmpld = torch.zeros_like(Xs).scatter(dim=2, index=Xsnorm.argmax(keepdim=True, dim=2), src=1)\n",
    "    wtd_vals = torch.matmul (Xs, valencies[:, None])\n",
    "    N, n = As.shape[0], As.shape[1]\n",
    "    a = As.reshape(N, -1, 1)\n",
    "    M_val = torch.zeros(N, n, n**2, dtype=a.dtype, device=a.device)\n",
    "    for i in range(M_val.shape[1]):\n",
    "        M_val[:, i, i*n:(i+1)*n] = 1\n",
    "    if hidden_Hs:\n",
    "        atoms_exist = torch.any(Xsnorm > 0.5, dim=2, keepdim=True).squeeze()\n",
    "        h_atoms = ~atoms_exist\n",
    "        # X_atoms_exist = atoms_exist.repeat(1, 1, Xs.shape[2])\n",
    "        a_atoms_exist = torch.einsum ('ij,ik->ijk', atoms_exist, atoms_exist).reshape(N, -1)\n",
    "        a_atoms0, a_atoms2 = torch.where(~a_atoms_exist)\n",
    "        M_val[h_atoms] = 0\n",
    "        M_val[a_atoms0, :, a_atoms2] = 0\n",
    "        # since Minv = I/n\n",
    "        as_proj = a - torch.transpose(M_val, 1, 2) @ plus_fn (M_val @ a - wtd_vals) / atoms_exist.sum(dim=1)[:, None, None]\n",
    "    else:\n",
    "        as_proj = a - M_val.T @ (M_val @ a - wtd_vals)/n # since Minv = I/n\n",
    "    As_proj = as_proj.reshape(N, n, n)\n",
    "    return As_proj, Xsnorm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plus_fn (x):\n",
    "    return torch.where(x > 0, x, torch.zeros_like(x))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "As, Xs = torch.rand(10, 10, 10), torch.rand(10, 10, 4)\n",
    "valencies = torch.tensor([4., 3., 2., 1.])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.randint(2, (10, 100, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = torch.randint(1, 5, (10, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valency_projection_multiple (As, Xs, valencies, hidden_Hs=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Rough"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from utils.mol_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjs = torch.load(\"temp_samples.pt\")\n",
    "xs = torch.load(\"temp_x.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj_mols = torch.nn.functional.one_hot(torch.tensor(adjs), num_classes=4).permute(0, 3, 1, 2).detach().cpu().numpy()\n",
    "\n",
    "xs_dis = torch.where(xs > 0.5, 1., 0.).to(xs.device)\n",
    "xs_dish = torch.cat([xs_dis, 1 - xs_dis.sum(dim=-1, keepdim=True)], dim=-1).detach().cpu().numpy()\n",
    "\n",
    "atomic_num_list = torch.tensor([6, 7, 8, 9, 0], device=xs.device, dtype=xs.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mols = [construct_mol(x, adj, atomic_num_list) for x, adj in zip(xs_dish, adj_mols)]\n",
    "val_corr = [check_valency(mol) for mol in mols]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mol_graphs = mols_to_nx(mols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjs[adjs == 3] = -1\n",
    "adjs = adjs + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valencies = torch.tensor([4, 3, 2, 1], device=xs.device, dtype=xs.dtype)\n",
    "\n",
    "def constr_sat (x, adj):\n",
    "    atoms_exist = torch.cat([torch.any(x[i] > 0.5).ravel() for i in range(x.shape[0])])\n",
    "    adj_n = adj[atoms_exist][:, atoms_exist]\n",
    "    x_n = x[atoms_exist]\n",
    "    return torch.all (adj_n.sum(dim=1) <= x_n @ valencies)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "constr_sats = [constr_sat(x, adj) for x, adj in zip(xs_dis, adjs)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(constr_sats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(map (lambda x: x[0], val_corr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "not_same = [(i, x.item(), y[0]) for i, (x, y) in enumerate(zip (constr_sats, val_corr)) if (x != y[0])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "not_same"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mols[not_same[0][0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mols[not_same[100][0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[i for i, x in enumerate(val_corr) if not x[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mols[3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Molecular valency check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RDKit WARNING: [16:39:16] Enabling RDKit 2019.09.3 jupyter extensions\n",
      "[16:39:16] Enabling RDKit 2019.09.3 jupyter extensions\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import argparse\n",
    "import time\n",
    "from parsers.parser import Parser\n",
    "from parsers.config import get_config, get_constraint_config\n",
    "from trainer import Trainer\n",
    "from sampler import Sampler, Sampler_mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \"nedges\" \"nedgesl2\" \"nconn\" \"specradius\" \"cheeger\")\n",
    "config = 'sample_community_small'\n",
    "constr_config = 'cheeger'\n",
    "device = 'cuda:1'\n",
    "seed = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = get_config(config, seed)\n",
    "constraint_config = get_constraint_config(constr_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint_config.add_diff_step = -998\n",
    "constraint_config.schedule.gamma = 'fixed'\n",
    "constraint_config.schedule.params = [1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./checkpoints/community_small/gdss_community_small.pth loaded\n"
     ]
    }
   ],
   "source": [
    "if 'qm9' in constr_config or 'zinc250k' in constr_config:\n",
    "    sampler = Sampler_mol(config, constraint_config, device=device)\n",
    "else:\n",
    "    sampler = Sampler(config, constraint_config, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.005, 0.1)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "constraint_config.params[0]*2/(2 *constraint_config.params[1]), 2 * constraint_config.params[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------------------------------------------\n",
      "Make Directory community_small/test/Cheeger-bound in Logs\n",
      "(Euler)+(Langevin): eps=0.0001 denoise=True ema=False || snr=0.05 seps=0.7 n_steps=10 \n",
      "----------------------------------------------------------------------------------------------------\n",
      "GEN SEED: 42\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": []
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(128, device='cuda:1')\n",
      "tensor(128, device='cuda:1')\n",
      "tensor(128, device='cuda:1')\n",
      "tensor(128, device='cuda:1')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": []
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(128, device='cuda:1')\n",
      "tensor(128, device='cuda:1')\n",
      "tensor(128, device='cuda:1')\n",
      "tensor(128, device='cuda:1')\n",
      "Constraint Validity before round: 0.3046875\n"
     ]
    }
   ],
   "source": [
    "x, adj = sampler.sample()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "Ls = torch.load(\"laplacians.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 20])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "U, S, Vh = torch.linalg.svd(Ls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = torch.tensor([[0., 1., 1.], [1, 0, 1], [1, 1, 0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "L = torch.diag(A.sum(dim=1)) - A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "Ls = torch.stack ((L, L, L))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 2.0000, -1.0000, -1.0000],\n",
       "        [-1.0000,  2.0000, -1.0000],\n",
       "        [-1.0000, -1.0000,  2.0000]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "U, S, Vh = torch.linalg.svd(L)\n",
    "eigs, Qs = torch.linalg.eigh(L)\n",
    "U @ torch.diag(S) @ Vh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([-1.0000, -1.0000,  2.0000]), tensor([2.0000, 1.0000, 1.0000]))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "U, S, Vh = torch.linalg.svd(A)\n",
    "eigs, Qs = torch.linalg.eigh(A)\n",
    "eigs, S"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1.0000, 1.0000, 2.0000])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_tol = 1e-6\n",
    "Ls = torch.stack([torch.diag (s) for s in torch.sum(adj, dim=1)]) - adj\n",
    "eigs = torch.linalg.eigvalsh (Ls)\n",
    "eigs[torch.abs(eigs) <= zero_tol] = 0\n",
    "cheeger_chi = constraint_config.params[0]\n",
    "lbound, ubound = cheeger_chi**2/2, 2*cheeger_chi / adj.sum(dim=2).max(dim=1)[0]\n",
    "# find the first non-trivial index\n",
    "nnz_inds = eigs > 0\n",
    "nnz_srtd = nnz_inds * torch.arange(start=eigs.shape[1], end=0, step=-1, \n",
    "                                    device=adj.device, dtype=adj.dtype)\n",
    "first_nnz_inds = torch.argmax(nnz_srtd, 1, keepdim=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "later_nnz_inds = nnz_inds.scatter(1, first_nnz_inds, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "valencies = torch.tensor([4, 3, 2, 1], device=x.device, dtype=x.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, A = x[0].clone(), adj[0].clone()\n",
    "atoms_Hs = ~torch.cat([torch.any(X[i] > 0.5).ravel() for i in range(X.shape[0])])\n",
    "atoms_others = ~atoms_Hs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X[atoms_Hs] = 0\n",
    "smpld_x = torch.zeros_like(X)\n",
    "smpld_x[atoms_others, X.argmax(dim=1)[atoms_others]] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "smpld_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = torch.tensor([[-0.1300, -0.1300, -0.1300,  1.2660, -0.1300, -0.1300, -0.1300, -0.1300,\n",
    "          0.6437],\n",
    "        [-0.2377, -0.2377, -0.2377,  1.1238,  0.3107,  0.3053, -0.2377,  0.5324,\n",
    "          0.1868],\n",
    "        [-0.2104, -0.2104, -0.2104,  1.8970, -0.2104,  0.4229,  1.1051, -0.2104,\n",
    "         -0.2104],\n",
    "        [ 0.8530,  0.8186,  1.5644, -0.5429,  0.0145, -0.5226,  1.2761, -0.5429,\n",
    "         -0.5429],\n",
    "        [-0.1232,  0.4252, -0.1232,  0.4342, -0.1232, -0.1232, -0.1232,  0.4771,\n",
    "          2.4699],\n",
    "        [ 0.0000,  0.5430,  0.6333,  0.0203,  0.0000,  0.0000,  0.4612,  0.0000,\n",
    "          0.1132],\n",
    "        [-0.2882, -0.2882,  1.0272,  1.5308, -0.2882,  0.1730, -0.2882,  0.4321,\n",
    "          0.0529],\n",
    "        [ 0.0000,  0.7701,  0.0000,  0.0000,  0.6003,  0.0000,  0.7203,  0.0000,\n",
    "          0.0000],\n",
    "        [ 0.4130,  0.0639, -0.3606, -0.3606,  2.2324, -0.2474, -0.0195, -0.3606,\n",
    "         -0.3606]], device='cuda:0')\n",
    "\n",
    "smpld_x = torch.tensor([[0., 0., 0., 1.],\n",
    "        [0., 0., 1., 0.],\n",
    "        [0., 1., 0., 0.],\n",
    "        [0., 0., 1., 0.],\n",
    "        [1., 0., 0., 0.],\n",
    "        [1., 0., 0., 0.],\n",
    "        [0., 1., 0., 0.],\n",
    "        [0., 0., 0., 1.],\n",
    "        [0., 0., 0., 1.]], device='cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "atom_inds = torch.arange(A.shape[0])\n",
    "atom_inds = atom_inds[torch.randperm(A.shape[0])]\n",
    "satisfied = False\n",
    "valencies = torch.tensor([4., 3., 2., 1.], device='cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "while not satisfied:\n",
    "    satisfied = True\n",
    "    smpld_a = torch.zeros_like(A)\n",
    "    atom_inds = atom_inds[torch.randperm(len(atom_inds))]\n",
    "    for j in atom_inds:\n",
    "        # how to maintain undirected ???? and still remove atoms_Hs.\n",
    "        valency_j = smpld_x[j].dot(valencies)\n",
    "        # knapsack problem to maximize sum_k a_probs[k] such that sum_k a_vals[k] <= val_j.\n",
    "        a_vals, a_probs = torch.floor(A[j]), torch.frac(A[j])\n",
    "        # greedy \n",
    "        _, psorted_nodes = torch.sort (a_probs, descending=True)\n",
    "        psorted_nodes = psorted_nodes\n",
    "        selected_nodes, current_valency = [], smpld_a[j].sum()\n",
    "        if smpld_a[j].sum() > valency_j:\n",
    "            satisfied = False\n",
    "            break\n",
    "        for node in psorted_nodes:\n",
    "            if (current_valency + a_vals[node] <= valency_j) and (a_vals[node] > 0):\n",
    "                selected_nodes.append(node.item())\n",
    "                current_valency += a_vals[node]\n",
    "            if (current_valency == valency_j) or (a_probs[node] == 0):\n",
    "                break\n",
    "        assert (a_vals[selected_nodes].sum() <= current_valency)\n",
    "        print (j, current_valency, valency_j)\n",
    "        smpld_a[j, selected_nodes] = a_vals[selected_nodes]\n",
    "        smpld_a[selected_nodes, j] = a_vals[selected_nodes] # undirected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "smpld_a.sum(dim=1), smpld_x @ valencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A[atoms_Hs, :] = 0\n",
    "A[:, atoms_Hs] = 0\n",
    "smpld_a = torch.zeros_like(A)\n",
    "j = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valency_j = smpld_x[j].dot(valencies)\n",
    "valency_j\n",
    "a_vals, a_probs = torch.floor(A[j]), torch.frac(A[j])\n",
    "a_vals, a_probs\n",
    "_, psorted_nodes = torch.sort (a_probs, descending=True)\n",
    "psorted_nodes = psorted_nodes[smpld_a[j] == 0]\n",
    "selected_nodes, current_valency = [], smpld_a[j].sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for node in psorted_nodes:\n",
    "    if current_valency + a_vals[node] <= valency_j:\n",
    "        selected_nodes.append(node.item())\n",
    "        current_valency += a_vals[node]\n",
    "    if (current_valency == valency_j) or (a_probs[node] == 0):\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "smpld_a[j, selected_nodes] = a_vals[selected_nodes]\n",
    "smpld_a[selected_nodes, j] = a_vals[selected_nodes]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "atom_inds = torch.where(atoms_others)[0]\n",
    "atom_inds[torch.randperm(len(atom_inds))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.rand(5, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x.reshape(-1).reshape(5,5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_flat = x.reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, x_flat.reshape(5, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = torch.rand(10, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y[:5][:, :5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y[:5][:, :5] = x_flat.reshape(5, 5)\n",
    "y[:5][:, :5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = torch.rand(10, 10)\n",
    "y = z.clone()\n",
    "y[[1,4,3,2,5]][:, [1,4,3,2,5]] = x_flat.reshape(5, 5)\n",
    "y[[1,4,3,2,5]][:, [1,4,3,2,5]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.zeros(5, 5)\n",
    "mask = torch.tensor([[1,2,3,4], [1,2,3,4]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = torch.ones(4, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.scatter(x, 0, mask, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.index_put(x, (mask[0], mask[1]), torch.ones(4, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xnum = x.numpy()\n",
    "xnum[mask[0]][:, mask[1]] = 1\n",
    "xnum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = torch.tensor([0, 1, 1, 1, 1], dtype=bool)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.tensor([[ 0.0000,  0.0000,  0.0000,  1.1546,  0.2104,  0.0000],\n",
    "        [-0.1378, -0.1378,  2.2667, -0.1378, -0.1378, -0.1378],\n",
    "        [-0.3719,  2.0326, -0.3719,  1.1780,  1.1446, -0.3719],\n",
    "        [ 0.9730, -0.1816,  1.3684, -0.1816, -0.1816,  0.4872],\n",
    "        [ 0.2104,  0.0000,  1.5165,  0.0000,  0.0000,  0.9480],\n",
    "        [ 0.0000,  0.0000,  0.0000,  0.6688,  0.9480,  0.0000]],\n",
    "       device='cuda:2')\n",
    "\n",
    "n = a.shape[0]       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "M_val = torch.zeros(n, n**2, dtype=a.dtype, device=a.device)\n",
    "for i in range(M_val.shape[0]):\n",
    "    M_val[i, i*n:(i+1)*n] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "M_val @ a.reshape(-1), a.sum(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "44dc670dcdd6ffb1ba23034ae072504999a2c20bd6cc686fd82920ca8c3f3b47"
  },
  "kernelspec": {
   "display_name": "Python 3.7.15 64-bit ('moltemp': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.15"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
