{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os.path as osp\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "from joblib.parallel import Parallel, delayed\n",
    "from matplotlib import pyplot as plt\n",
    "from copy import deepcopy\n",
    "\n",
    "from molecules.parse_sdf import get_num_atoms_from_smiles\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 42\n",
    "ROOT = \".\"\n",
    "DATASET = osp.join(ROOT, \"pcqm4m-v2-geometry-normed.csv.gz\")\n",
    "VAL_RATIO = 0.1\n",
    "TEST_RATIO = 0.1\n",
    "MAX_ATOMS_IN_TRAIN = 15\n",
    "NUMATOMS_500k_DIR = osp.join(ROOT, \"pcqm500k_num_atoms_split_dict.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0          O=C1[N]c2ccncc2[CH][C@@H]1c1ccc(cc1)C\n",
       "1                  COc1cc(OC)ccc1/C=C/N(C(=O)C)C\n",
       "2                    C=CCN(C(=O)C)/C=C/c1ccccc1C\n",
       "3                    C=CCN(C(=O)C)/C=C/c1ccccc1F\n",
       "4                   C=CCN(C(=O)C)/C=C/c1ccccc1Cl\n",
       "                           ...                  \n",
       "3378521    Cc1ccc(c(c1)C)N[C@H](/C(=N\\C1CC1)/O)C\n",
       "3378522     C[C@@H](/C(=N\\C1CC1)/O)Nc1cccc(c1C)C\n",
       "3378523     C[C@H](/C(=N\\C(=N)O)/O)Nc1cccc(c1C)C\n",
       "3378524    C[C@@H](/C(=N\\C(=N)O)/O)Nc1cccc(c1C)C\n",
       "3378525              CCOc1ccccc1NC/C(=N\\C1CC1)/O\n",
       "Name: smiles, Length: 3378526, dtype: object"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smiles = pd.read_csv(DATASET, nrows=None, usecols=[\"smiles\"])[\"smiles\"]\n",
    "smiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3378526/3378526 [06:08<00:00, 9156.81it/s] \n"
     ]
    }
   ],
   "source": [
    "num_atoms = Parallel(n_jobs=-1)(delayed(get_num_atoms_from_smiles)(s) for s in tqdm(smiles))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_numatoms_split_500k(num_atoms_list, val_ratio, test_ratio, max_atoms_in_train):\n",
    "    \"\"\" Create split by the size of molecules, testing on the largest ones.\n",
    "    Args:\n",
    "        num_atoms_list: List with molecule size per each graph.\n",
    "    \"\"\"\n",
    "    N = int(5e5)\n",
    "    N_train = int(N * (1 - val_ratio - test_ratio))\n",
    "    N_val = int(N * val_ratio)\n",
    "    N_test = N - N_train - N_val\n",
    "    rng = np.random.default_rng(seed=SEED)\n",
    "    train_ratio = 1 - val_ratio - test_ratio\n",
    "\n",
    "    num_atoms_uniq, num_atoms_inv, num_atoms_count = np.unique(\n",
    "                num_atoms_list, return_inverse=True, return_counts=True)\n",
    "    weights = num_atoms_count.astype(float)[num_atoms_inv] ** -1\n",
    "    weights_train = deepcopy(weights)\n",
    "    weights_train[(num_atoms_uniq >= max_atoms_in_train)[num_atoms_inv]] = 0\n",
    "    weights_train /= np.sum(weights_train)\n",
    "\n",
    "    weights_valtest = deepcopy(weights)\n",
    "    weights_valtest[(num_atoms_uniq < max_atoms_in_train)[num_atoms_inv]] = 0\n",
    "    weights_valtest /= np.sum(weights_valtest)\n",
    "\n",
    "\n",
    "    # Split based on mol size into 90/5/5, but shuffle the top 10% randomly\n",
    "    # before splitting to validation and test set.\n",
    "\n",
    "    train_ind = rng.choice(np.arange(len(num_atoms_list)), size=N_train, replace=False, p=weights_train)\n",
    "    val_test_ind = rng.choice(np.arange(len(num_atoms_list)), size=N_val+N_test, replace=False, p=weights_valtest)\n",
    "    val_ind = val_test_ind[:N_val]\n",
    "    test_ind = val_test_ind[N_val:]\n",
    "    assert len(train_ind) + len(val_ind) + len(test_ind) == N\n",
    "    assert check_splits(N, [train_ind, val_ind, test_ind], [train_ratio, val_ratio, test_ratio])\n",
    "\n",
    "    size_split = {'train': train_ind, 'val': val_ind, 'test': test_ind}\n",
    "    torch.save(size_split, osp.join(ROOT, NUMATOMS_500k_DIR))\n",
    "\n",
    "def check_splits(N, splits, ratios):\n",
    "    \"\"\" Check whether splits intersect and raise error if so.\n",
    "    \"\"\"\n",
    "    assert sum([len(split) for split in splits]) == N\n",
    "    for ii, split in enumerate(splits):\n",
    "        true_ratio = len(split) / N\n",
    "        assert abs(true_ratio - ratios[ii]) < 3/N\n",
    "    for i in range(len(splits) - 1):\n",
    "        for j in range(i + 1, len(splits)):\n",
    "            n_intersect = len(set(splits[i]) & set(splits[j]))\n",
    "            if n_intersect != 0:\n",
    "                raise ValueError(\n",
    "                    f\"Splits must not have intersecting indices: \"\n",
    "                    f\"split #{i} (n = {len(splits[i])}) and \"\n",
    "                    f\"split #{j} (n = {len(splits[j])}) have \"\n",
    "                    f\"{n_intersect} intersecting indices\"\n",
    "                )\n",
    "    return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "create_numatoms_split_500k(num_atoms, VAL_RATIO, TEST_RATIO, MAX_ATOMS_IN_TRAIN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_atoms_split = torch.load(NUMATOMS_500k_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x277de21a880>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAYhElEQVR4nO3df5BV5Z3n8fdXYED8QZQfhtAmTVYmUSDB0Lqk3JrScaMYJ8FU1PQYVypDSSpxEpOKWWG2kjFboQr/WE2crKZIdEUjUYqMIzXKRkUI2Sqi2zjsgKIFRqItBFoEhqTEEfLdP+4DNm1z6F9wG/v9qrp1zv3e85x+zuXSn3rOc+7pyEwkSTqcE+rdAUlS/2ZQSJIqGRSSpEoGhSSpkkEhSao0uN4d6KlRo0ZlY2NjvbshSceVNWvWvJ6Zo7vT5rgNisbGRlpaWurdDUk6rkTE77rbxlNPkqRKBoUkqZJBIUmqdNzOUUhST7z99tu0trayd+/eenflqBo2bBgNDQ0MGTKk1/vqUlBExPuAnwKTgAT+BngReAhoBDYDV2fmzrL9XGAWsB/4emb+stSnAvcCJwKPATdmZkbEUOA+YCqwA/hCZm7u9dFJUgetra2ccsopNDY2EhH17s5RkZns2LGD1tZWxo8f3+v9dfXU0w+B/52ZHwU+DmwA5gDLM3MCsLw8JyLOAZqBicB04M6IGFT2cxcwG5hQHtNLfRawMzPPAm4Hbu3lcUlSp/bu3cvIkSPfsyEBEBGMHDmyz0ZNRwyKiDgV+AvgboDM/PfM3AXMABaWzRYCV5T1GcCDmflWZr4MbALOj4ixwKmZuTprt6y9r0ObA/taAlwc7+V/RUl1NRB+vfTlMXZlRPFhoA34XxHxLxHx04g4CTgjM7cClOWYsv044NV27VtLbVxZ71g/pE1m7gN2AyM7diQiZkdES0S0tLW1dfEQJUm90ZU5isHAJ4CvZebTEfFDymmmw+gsxrKiXtXm0ELmAmABQFNTk39IQ1KvNc55tE/3t3n+5ZWv79q1i0WLFvHVr361W/v99Kc/zaJFi3jf+97Xi971TFeCohVozcyny/Ml1IJiW0SMzcyt5bTS9nbbn9mufQOwpdQbOqm3b9MaEYOBEcAbPTgeSf1cb38xH+kXcX+3a9cu7rzzzncFxf79+xk0aNBhWsFjjz12tLt2WEc89ZSZvwdejYiPlNLFwPPAUmBmqc0EHinrS4HmiBgaEeOpTVo/U05P7YmIaWX+4boObQ7s60rgqfRP70l6D5ozZw4vvfQSU6ZM4bzzzuOiiy7immuuYfLkyQBcccUVTJ06lYkTJ7JgwYKD7RobG3n99dfZvHkzZ599Ntdffz0TJ07kkksu4c033zyqfe7q9yi+BjwQEX8G/Bb4ErWQWRwRs4BXgKsAMvO5iFhMLUz2ATdk5v6yn6/wzuWxy8oDahPl90fEJmojieZeHpck9Uvz589n/fr1rF27lpUrV3L55Zezfv36g5ex3nPPPZx++um8+eabnHfeeXz+859n5MhDp2w3btzIz3/+c37yk59w9dVX84tf/IJrr732qPW5S0GRmWuBpk5euvgw288D5nVSb6H2XYyO9b2UoJGkgeT8888/5LsOd9xxBw8//DAAr776Khs3bnxXUIwfP54pU6YAMHXqVDZv3nxU++g3syV1S19P/g50J5100sH1lStX8uSTT7J69WqGDx/OhRde2Ol3IYYOHXpwfdCgQUf91JP3epKkY+iUU05hz549nb62e/duTjvtNIYPH84LL7zAb37zm2Pcu845opA0oB3rq6hGjhzJBRdcwKRJkzjxxBM544wzDr42ffp0fvzjH/Oxj32Mj3zkI0ybNu2Y9u1wDApJOsYWLVrUaX3o0KEsW7as09cOzEOMGjWK9evXH6zfdNNNfd6/jjz1JEmqZFBIkioZFJKkSs5RSAOMl7equxxRSJIqGRSSpEqeepI0sN0yoo/3t7tPd3fyySfzhz/8oU/32V2OKCRJlRxRSNIxdPPNN/OhD33o4N+juOWWW4gIVq1axc6dO3n77bf5/ve/z4wZM+rc03c4opCkY6i5uZmHHnro4PPFixfzpS99iYcffphnn32WFStW8K1vfYv+9Cd5HFFI0jF07rnnsn37drZs2UJbWxunnXYaY8eO5Zvf/CarVq3ihBNO4LXXXmPbtm28//3vr3d3AYNCOu74PYjj35VXXsmSJUv4/e9/T3NzMw888ABtbW2sWbOGIUOG0NjY2OntxevFoJCkY6y5uZnrr7+e119/nV/96lcsXryYMWPGMGTIEFasWMHvfve7enfxEAaFpIGtjy9n7YqJEyeyZ88exo0bx9ixY/niF7/IZz7zGZqampgyZQof/ehHj3mfqhgUklQH69atO7g+atQoVq9e3el29f4OBXjVkyTpCAwKSVIlg0KSVMmgkCRVMigkSZUMCklSJS+PlTSgTV44uU/3t27musrXd+3axaJFiw7eFLA7fvCDHzB79myGDx/e0+71SJeCIiI2A3uA/cC+zGyKiNOBh4BGYDNwdWbuLNvPBWaV7b+emb8s9anAvcCJwGPAjZmZETEUuA+YCuwAvpCZm/vkCKV+xltwDGy7du3izjvv7HFQXHvttf0zKIqLMvP1ds/nAMszc35EzCnPb46Ic4BmYCLwAeDJiPjzzNwP3AXMBn5DLSimA8uohcrOzDwrIpqBW4Ev9PLYJKnfmTNnDi+99BJTpkzhU5/6FGPGjGHx4sW89dZbfO5zn+N73/sef/zjH7n66qtpbW1l//79fOc732Hbtm1s2bKFiy66iFGjRrFixYpj1ufenHqaAVxY1hcCK4GbS/3BzHwLeDkiNgHnl1HJqZm5GiAi7gOuoBYUM4Bbyr6WAD+KiMj+dJ9dSeoD8+fPZ/369axdu5bHH3+cJUuW8Mwzz5CZfPazn2XVqlW0tbXxgQ98gEcfrY0+d+/ezYgRI7jttttYsWIFo0aNOqZ97upkdgKPR8SaiJhdamdk5laAshxT6uOAV9u1bS21cWW9Y/2QNpm5D9gNjOzYiYiYHREtEdHS1tbWxa5LUv/0+OOP8/jjj3PuuefyiU98ghdeeIGNGzcyefJknnzySW6++WZ+/etfM2JEH/+51m7q6ojigszcEhFjgCci4oWKbaOTWlbUq9ocWshcACwAaGpqcrQh6biWmcydO5cvf/nL73ptzZo1PPbYY8ydO5dLLrmE7373u3XoYU2XRhSZuaUstwMPA+cD2yJiLEBZbi+btwJntmveAGwp9YZO6oe0iYjBwAjgje4fjiT1b6eccgp79uwB4NJLL+Wee+45eOO/11577eAfNRo+fDjXXnstN910E88+++y72h5LRxxRRMRJwAmZuaesXwL8d2ApMBOYX5aPlCZLgUURcRu1yewJwDOZuT8i9kTENOBp4DrgH9q1mQmsBq4EnnJ+QtKxcKTLWfvayJEjueCCC5g0aRKXXXYZ11xzDZ/85CcBOPnkk/nZz37Gpk2b+Pa3v80JJ5zAkCFDuOuuuwCYPXs2l112GWPHjj2mk9lxpN/HEfFhaqMIqAXLosycFxEjgcXAB4FXgKsy843S5r8BfwPsA76RmctKvYl3Lo9dBnytXB47DLgfOJfaSKI5M39b1a+mpqZsaWnp/hFLdeblsb2zef7lvWq/YcMGzj777D7qTf/W2bFGxJrMbOrOfo44oii/sD/eSX0HcPFh2swD5nVSbwEmdVLfC1zVhf5Kko4xb+EhSapkUEgacAbCFGhfHqNBIWlAGTZsGDt27HhPh0VmsmPHDoYNG9Yn+/OmgJIGlIaGBlpbW3mvf2l32LBhNDQ0HHnDLjAoJA0oQ4YMYfz48fXuxnHFU0+SpEoGhSSpkkEhSapkUEiSKhkUkqRKBoUkqZKXx0rd5E39NNA4opAkVTIoJEmVDApJUiWDQpJUyaCQJFUyKCRJlQwKSVIlg0KSVMmgkCRVMigkSZUMCklSJYNCklTJoJAkVepyUETEoIj4l4j45/L89Ih4IiI2luVp7badGxGbIuLFiLi0XX1qRKwrr90REVHqQyPioVJ/OiIa+/AYJUm90J0RxY3AhnbP5wDLM3MCsLw8JyLOAZqBicB04M6IGFTa3AXMBiaUx/RSnwXszMyzgNuBW3t0NJKkPteloIiIBuBy4KftyjOAhWV9IXBFu/qDmflWZr4MbALOj4ixwKmZuTozE7ivQ5sD+1oCXHxgtCFJqq+ujih+APxX4E/tamdk5laAshxT6uOAV9tt11pq48p6x/ohbTJzH7AbGNnVg5AkHT1HDIqI+Ctge2au6eI+OxsJZEW9qk3HvsyOiJaIaGlra+tidyRJvdGVEcUFwGcjYjPwIPCXEfEzYFs5nURZbi/btwJntmvfAGwp9YZO6oe0iYjBwAjgjY4dycwFmdmUmU2jR4/u0gFKknrniEGRmXMzsyEzG6lNUj+VmdcCS4GZZbOZwCNlfSnQXK5kGk9t0vqZcnpqT0RMK/MP13Voc2BfV5af8a4RhSTp2Bvci7bzgcURMQt4BbgKIDOfi4jFwPPAPuCGzNxf2nwFuBc4EVhWHgB3A/dHxCZqI4nmXvRLktSHuhUUmbkSWFnWdwAXH2a7ecC8TuotwKRO6nspQSNJ6l/8ZrYkqZJBIUmqZFBIkioZFJKkSgaFJKmSQSFJqmRQSJIqGRSSpEoGhSSpkkEhSarUm3s9ScelxjmP1rsL0nHFEYUkqZJBIUmqZFBIkioZFJKkSgaFJKmSQSFJqmRQSJIqGRSSpEoGhSSpkkEhSapkUEiSKhkUkqRKBoUkqZJBIUmqZFBIkiodMSgiYlhEPBMR/y8inouI75X66RHxRERsLMvT2rWZGxGbIuLFiLi0XX1qRKwrr90REVHqQyPioVJ/OiIaj8KxSpJ6oCsjireAv8zMjwNTgOkRMQ2YAyzPzAnA8vKciDgHaAYmAtOBOyNiUNnXXcBsYEJ5TC/1WcDOzDwLuB24tfeHJknqC0cMiqz5Q3k6pDwSmAEsLPWFwBVlfQbwYGa+lZkvA5uA8yNiLHBqZq7OzATu69DmwL6WABcfGG1IkuqrS3MUETEoItYC24EnMvNp4IzM3ApQlmPK5uOAV9s1by21cWW9Y/2QNpm5D9gNjOykH7MjoiUiWtra2rp0gJKk3ulSUGTm/sycAjRQGx1Mqti8s5FAVtSr2nTsx4LMbMrMptGjRx+h15KkvtCtq54ycxewktrcwrZyOomy3F42awXObNesAdhS6g2d1A9pExGDgRHAG93pmyTp6OjKVU+jI+J9Zf1E4D8DLwBLgZlls5nAI2V9KdBcrmQaT23S+plyempPREwr8w/XdWhzYF9XAk+VeQxJUp0N7sI2Y4GF5cqlE4DFmfnPEbEaWBwRs4BXgKsAMvO5iFgMPA/sA27IzP1lX18B7gVOBJaVB8DdwP0RsYnaSKK5Lw5OktR7RwyKzPxX4NxO6juAiw/TZh4wr5N6C/Cu+Y3M3EsJGklS/+I3syVJlQwKSVIlg0KSVMmgkCRVMigkSZUMCklSJYNCklTJoJAkVTIoJEmVDApJUiWDQpJUyaCQJFUyKCRJlbpym3FJes+YvHByXX/+upnr6vrze8KgkHRcqfcv+oHIU0+SpEoGhSSpkkEhSapkUEiSKhkUkqRKBoUkqZKXx+q40zjn0Xp3QRpQHFFIkioZFJKkSgaFJKmSQSFJqnTEoIiIMyNiRURsiIjnIuLGUj89Ip6IiI1leVq7NnMjYlNEvBgRl7arT42IdeW1OyIiSn1oRDxU6k9HRONROFZJUg90ZUSxD/hWZp4NTANuiIhzgDnA8sycACwvzymvNQMTgenAnRExqOzrLmA2MKE8ppf6LGBnZp4F3A7c2gfHJknqA0e8PDYztwJby/qeiNgAjANmABeWzRYCK4GbS/3BzHwLeDkiNgHnR8Rm4NTMXA0QEfcBVwDLSptbyr6WAD+KiMjM7PURSlI/0tu739bjNuXdmqMop4TOBZ4GzighciBMxpTNxgGvtmvWWmrjynrH+iFtMnMfsBsY2cnPnx0RLRHR0tbW1p2uS5J6qMtBEREnA78AvpGZ/1a1aSe1rKhXtTm0kLkgM5sys2n06NFH6rIkqQ90KSgiYgi1kHggM/+xlLdFxNjy+lhge6m3Ame2a94AbCn1hk7qh7SJiMHACOCN7h6MJKnvdeWqpwDuBjZk5m3tXloKzCzrM4FH2tWby5VM46lNWj9TTk/tiYhpZZ/XdWhzYF9XAk85PyFJ/UNX7vV0AfBfgHURsbbU/g6YDyyOiFnAK8BVAJn5XEQsBp6ndsXUDZm5v7T7CnAvcCK1SexlpX43cH+Z+H6D2lVTkqR+oCtXPf0fOp9DALj4MG3mAfM6qbcAkzqp76UEjSSpf/Gb2ZKkSgaFJKmSQSFJqmRQSJIqGRSSpEoGhSSpkkEhSapkUEiSKhkUkqRKBoUkqZJBIUmqZFBIkioZFJKkSgaFJKmSQSFJqmRQSJIqGRSSpEoGhSSpkkEhSapkUEiSKhkUkqRKBoUkqZJBIUmqZFBIkioNrncHJKk71r38Sr270CuTx3+w3l3oNkcUkqRKRxxRRMQ9wF8B2zNzUqmdDjwENAKbgaszc2d5bS4wC9gPfD0zf1nqU4F7gROBx4AbMzMjYihwHzAV2AF8ITM399kRqt9pnPNovbsgqRu6MqK4F5jeoTYHWJ6ZE4Dl5TkRcQ7QDEwsbe6MiEGlzV3AbGBCeRzY5yxgZ2aeBdwO3NrTg5Ek9b0jjigyc1VENHYozwAuLOsLgZXAzaX+YGa+BbwcEZuA8yNiM3BqZq4GiIj7gCuAZaXNLWVfS4AfRURkZvb0oCSpvzoe51h6OkdxRmZuBSjLMaU+Dni13XatpTaurHesH9ImM/cBu4GRnf3QiJgdES0R0dLW1tbDrkuSuqOvJ7Ojk1pW1KvavLuYuSAzmzKzafTo0T3soiSpO3oaFNsiYixAWW4v9VbgzHbbNQBbSr2hk/ohbSJiMDACeKOH/ZIk9bGeBsVSYGZZnwk80q7eHBFDI2I8tUnrZ8rpqT0RMS0iAriuQ5sD+7oSeMr5CUnqP7pyeezPqU1cj4qIVuDvgfnA4oiYBbwCXAWQmc9FxGLgeWAfcENm7i+7+grvXB67rDwA7gbuLxPfb1C7akqS1E905aqnvz7MSxcfZvt5wLxO6i3ApE7qeylBI0nqf/xmtiSpkkEhSapkUEiSKhkUkqRKBoUkqZJBIUmqZFBIkioZFJKkSgaFJKmSQSFJqmRQSJIqGRSSpEoGhSSpkkEhSap0xNuMSx01znm03l2QdAw5opAkVTIoJEmVDApJUiWDQpJUyaCQJFUyKCRJlQwKSVIlv0cxAPk9CEnd4YhCklTJoJAkVTIoJEmV+s0cRURMB34IDAJ+mpnz69ylfss5BknHUr8YUUTEIOB/ApcB5wB/HRHn1LdXkiToPyOK84FNmflbgIh4EJgBPF/XXh0ljggkHU/6S1CMA15t97wV+I8dN4qI2cDs8vQPEfFiD3/eKOD1HraV719v+f71Qgz09+970ds9fKS7DfpLUHR25PmuQuYCYEGvf1hES2Y29XY/A5XvX+/4/vWO71/vRERLd9v0izkKaiOIM9s9bwC21KkvkqR2+ktQ/F9gQkSMj4g/A5qBpXXukySJfnLqKTP3RcTfAr+kdnnsPZn53FH8kb0+fTXA+f71ju9f7/j+9U6337/IfNdUgCRJB/WXU0+SpH7KoJAkVRpwQRER0yPixYjYFBFz6t2f401EbI6IdRGxtieX2Q00EXFPRGyPiPXtaqdHxBMRsbEsT6tnH/uzw7x/t0TEa+UzuDYiPl3PPvZXEXFmRKyIiA0R8VxE3Fjq3f78Daig8FYhfeaizJzitexdci8wvUNtDrA8MycAy8tzde5e3v3+AdxePoNTMvOxY9yn48U+4FuZeTYwDbih/L7r9udvQAUF7W4Vkpn/Dhy4VYh0VGTmKuCNDuUZwMKyvhC44lj26XhymPdPXZCZWzPz2bK+B9hA7S4Y3f78DbSg6OxWIePq1JfjVQKPR8SacksVdd8ZmbkVav+ZgTF17s/x6G8j4l/LqSlP3R1BRDQC5wJP04PP30ALii7dKkSVLsjMT1A7fXdDRPxFvTukAecu4D8AU4CtwP+oa2/6uYg4GfgF8I3M/Lee7GOgBYW3CumlzNxSltuBh6mdzlP3bIuIsQBlub3O/TmuZOa2zNyfmX8CfoKfwcOKiCHUQuKBzPzHUu7252+gBYW3CumFiDgpIk45sA5cAqyvbqVOLAVmlvWZwCN17Mtx58AvueJz+BnsVEQEcDewITNva/dStz9/A+6b2eVSuh/wzq1C5tW3R8ePiPgwtVEE1G7/ssj3r1pE/By4kNqtsbcBfw/8E7AY+CDwCnBVZjph24nDvH8XUjvtlMBm4MsHzrnrHRHxn4BfA+uAP5Xy31Gbp+jW52/ABYUkqXsG2qknSVI3GRSSpEoGhSSpkkEhSapkUEiSKhkUkqRKBoUkqdL/B0ZsgqdBANAPAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "max_len = max([len(idx) for idx in num_atoms_split.values()])\n",
    "arr = []\n",
    "for idx in num_atoms_split.values():\n",
    "    this_arr = np.nan + np.zeros(max_len)\n",
    "    this_arr[:len(idx)] = np.array(num_atoms)[idx]\n",
    "    arr.append(this_arr)\n",
    "arr = np.stack(arr, axis=1)\n",
    "\n",
    "plt.hist(arr, stacked=True, bins=np.arange(0, np.nanmax(arr)), log=False)\n",
    "plt.xticks(np.arange(0, 25, 5))\n",
    "plt.legend(num_atoms_split.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "96a7dde498302288fd342a0dd7e7c4e017c4083a8e5fac7570734c90ace78c7f"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 ('goli')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
