{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from torch.utils.data import DataLoader\n",
    "from qm9.data.dataset_class import ProcessedDataset\n",
    "from qm9.data.prepare import prepare_dataset\n",
    "from qm9.data.utils import _get_species"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "datafiles = prepare_dataset(datadir=\"/home/gongjingjing/tmp/data\",\n",
    "                            dataset='qm9',\n",
    "                            subset=None,\n",
    "                            splits=None,\n",
    "                            force_download=False)\n",
    "\n",
    "datasets = {}\n",
    "for split, datafile in datafiles.items():\n",
    "    with np.load(datafile) as f:\n",
    "        datasets[split] = {\n",
    "            key: val\n",
    "            for key, val in f.items()\n",
    "        }\n",
    "\n",
    "keys = [list(data.keys()) for data in datasets.values()]\n",
    "assert all([key == keys[0]\n",
    "            for key in keys]), 'Datasets must have same set of keys!'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-1.2698136e-02,  1.0858041e+00,  8.0009960e-03],\n",
       "       [ 2.1504159e-03, -6.0313176e-03,  1.9761203e-03],\n",
       "       [ 1.0117308e+00,  1.4637512e+00,  2.7657481e-04],\n",
       "       [-5.4081506e-01,  1.4475266e+00, -8.7664372e-01],\n",
       "       [-5.2381361e-01,  1.4379326e+00,  9.0639728e-01]], dtype=float32)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets[\"train\"][\"positions\"][0][:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[6 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [7 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]\n",
      "num_atoms (100000,)\n",
      "charges (100000, 29)\n",
      "positions (100000, 29, 3)\n",
      "index (100000,)\n",
      "A (100000,)\n",
      "B (100000,)\n",
      "C (100000,)\n",
      "mu (100000,)\n",
      "alpha (100000,)\n",
      "homo (100000,)\n",
      "lumo (100000,)\n",
      "gap (100000,)\n",
      "r2 (100000,)\n",
      "zpve (100000,)\n",
      "U0 (100000,)\n",
      "U (100000,)\n",
      "H (100000,)\n",
      "G (100000,)\n",
      "Cv (100000,)\n",
      "omega1 (100000,)\n",
      "zpve_thermo (100000,)\n",
      "U0_thermo (100000,)\n",
      "U_thermo (100000,)\n",
      "H_thermo (100000,)\n",
      "G_thermo (100000,)\n",
      "Cv_thermo (100000,)\n"
     ]
    }
   ],
   "source": [
    "print(datasets[\"train\"]['charges'][:2])\n",
    "for k in datasets[\"train\"]:\n",
    "    print(k, datasets[\"train\"][k].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1786773/3508418254.py:7: UserWarning: torch.range is deprecated and will be removed in a future release because its behavior is inconsistent with Python's range builtin. Instead, use torch.arange, which produces values in [start, end).\n",
      "  row = torch.range(0, n_nodes-1, dtype=torch.long).reshape(1, -1, 1).repeat(1, 1, n_nodes)\n",
      "/tmp/ipykernel_1786773/3508418254.py:8: UserWarning: torch.range is deprecated and will be removed in a future release because its behavior is inconsistent with Python's range builtin. Instead, use torch.arange, which produces values in [start, end).\n",
      "  col = torch.range(0, n_nodes-1, dtype=torch.long).reshape(1, 1, -1).repeat(1, n_nodes, 1)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],\n",
       "        [1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 0, 1, 2, 3]])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_ajacency_matrix(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.datasets import QM9\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.data import Batch\n",
    "from torch_geometric.data import DataLoader\n",
    "from torch_geometric.data.separate import separate\n",
    "import torch\n",
    "\n",
    "def make_adjacency_matrix(n_nodes):\n",
    "    row = torch.arange(0, n_nodes, dtype=torch.long).reshape(1, -1, 1).repeat(1, 1, n_nodes)\n",
    "    col = torch.arange(0, n_nodes, dtype=torch.long).reshape(1, 1, -1).repeat(1, n_nodes, 1)\n",
    "    full_adj = torch.concat([row, col], dim=0).reshape(2, -1)\n",
    "    diag_bool = torch.eye(n_nodes, dtype=torch.bool).reshape(-1)\n",
    "    return full_adj[:, ~diag_bool]\n",
    "\n",
    "def transform(data):\n",
    "    data.x = data.x[:, :6]\n",
    "    data.edge_index = make_adjacency_matrix(data.x.shape[0])\n",
    "    data.edge_attr = None\n",
    "    data.zx = torch.randn_like(data.x)\n",
    "    data.zpos = torch.randn_like(data.pos)\n",
    "    data.y = None\n",
    "    data.z = None\n",
    "    return data\n",
    "\n",
    "ds = QM9(root=\"/home/gongjingjing/tmp/data/qm9\", transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataBatch(x=[33, 6], edge_index=[2, 512], pos=[33, 3], idx=[2], name=[2], zx=[33, 6], zpos=[33, 3], batch=[33], ptr=[3])\n",
      "defaultdict(<class 'dict'>, {'x': tensor([ 0, 17, 33]), 'edge_index': tensor([  0, 272, 512]), 'pos': tensor([ 0, 17, 33]), 'idx': tensor([0, 1, 2]), 'name': tensor([0, 1, 2]), 'zx': tensor([ 0, 17, 33]), 'zpos': tensor([ 0, 17, 33])})\n",
      "defaultdict(<class 'dict'>, {'x': tensor([0, 0]), 'edge_index': tensor([ 0, 17]), 'pos': tensor([0, 0]), 'idx': tensor([0, 0]), 'name': None, 'zx': tensor([0, 0]), 'zpos': tensor([0, 0])})\n",
      "tensor([ 0, 17, 33])\n",
      "DataBatch(x=[36, 6], edge_index=[2, 630], pos=[36, 3], idx=[2], name=[2], zx=[36, 6], zpos=[36, 3], batch=[36], ptr=[3])\n",
      "defaultdict(<class 'dict'>, {'x': tensor([ 0, 15, 36]), 'edge_index': tensor([  0, 210, 630]), 'pos': tensor([ 0, 15, 36]), 'idx': tensor([0, 1, 2]), 'name': tensor([0, 1, 2]), 'zx': tensor([ 0, 15, 36]), 'zpos': tensor([ 0, 15, 36])})\n",
      "defaultdict(<class 'dict'>, {'x': tensor([0, 0]), 'edge_index': tensor([ 0, 15]), 'pos': tensor([0, 0]), 'idx': tensor([0, 0]), 'name': None, 'zx': tensor([0, 0]), 'zpos': tensor([0, 0])})\n",
      "tensor([ 0, 15, 36])\n",
      "DataBatch(x=[29, 6], edge_index=[2, 396], pos=[29, 3], idx=[2], name=[2], zx=[29, 6], zpos=[29, 3], batch=[29], ptr=[3])\n",
      "defaultdict(<class 'dict'>, {'x': tensor([ 0, 16, 29]), 'edge_index': tensor([  0, 240, 396]), 'pos': tensor([ 0, 16, 29]), 'idx': tensor([0, 1, 2]), 'name': tensor([0, 1, 2]), 'zx': tensor([ 0, 16, 29]), 'zpos': tensor([ 0, 16, 29])})\n",
      "defaultdict(<class 'dict'>, {'x': tensor([0, 0]), 'edge_index': tensor([ 0, 16]), 'pos': tensor([0, 0]), 'idx': tensor([0, 0]), 'name': None, 'zx': tensor([0, 0]), 'zpos': tensor([0, 0])})\n",
      "tensor([ 0, 16, 29])\n",
      "DataBatch(x=[37, 6], edge_index=[2, 688], pos=[37, 3], idx=[2], name=[2], zx=[37, 6], zpos=[37, 3], batch=[37], ptr=[3])\n",
      "defaultdict(<class 'dict'>, {'x': tensor([ 0, 23, 37]), 'edge_index': tensor([  0, 506, 688]), 'pos': tensor([ 0, 23, 37]), 'idx': tensor([0, 1, 2]), 'name': tensor([0, 1, 2]), 'zx': tensor([ 0, 23, 37]), 'zpos': tensor([ 0, 23, 37])})\n",
      "defaultdict(<class 'dict'>, {'x': tensor([0, 0]), 'edge_index': tensor([ 0, 23]), 'pos': tensor([0, 0]), 'idx': tensor([0, 0]), 'name': None, 'zx': tensor([0, 0]), 'zpos': tensor([0, 0])})\n",
      "tensor([ 0, 23, 37])\n",
      "DataBatch(x=[34, 6], edge_index=[2, 546], pos=[34, 3], idx=[2], name=[2], zx=[34, 6], zpos=[34, 3], batch=[34], ptr=[3])\n",
      "defaultdict(<class 'dict'>, {'x': tensor([ 0, 18, 34]), 'edge_index': tensor([  0, 306, 546]), 'pos': tensor([ 0, 18, 34]), 'idx': tensor([0, 1, 2]), 'name': tensor([0, 1, 2]), 'zx': tensor([ 0, 18, 34]), 'zpos': tensor([ 0, 18, 34])})\n",
      "defaultdict(<class 'dict'>, {'x': tensor([0, 0]), 'edge_index': tensor([ 0, 18]), 'pos': tensor([0, 0]), 'idx': tensor([0, 0]), 'name': None, 'zx': tensor([0, 0]), 'zpos': tensor([0, 0])})\n",
      "tensor([ 0, 18, 34])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/gongjingjing/.local/lib/python3.8/site-packages/torch_geometric/deprecation.py:22: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
      "  warnings.warn(out)\n"
     ]
    }
   ],
   "source": [
    "_ds = DataLoader(ds, batch_size=2, shuffle=True)\n",
    "for _i, _d in enumerate(_ds):\n",
    "    if _i > 4:\n",
    "        break\n",
    "    # separate(_d)\n",
    "    print(_d)\n",
    "    print(_d._slice_dict)\n",
    "    print(_d._inc_dict)\n",
    "    print(_d.ptr)\n",
    "    # print(_d.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 0,  1],\n",
       "         [ 5,  6],\n",
       "         [10, 11],\n",
       "         [15, 16]]),\n",
       " tensor([[ 2,  3,  4],\n",
       "         [ 7,  8,  9],\n",
       "         [12, 13, 14],\n",
       "         [17, 18, 19]]))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "ch = torch.arange(20).reshape(4, 5)\n",
    "torch.split(ch, [2, 3], dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([100000, 29])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets[\"train\"][\"charges\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "974"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "17*16+(44-17)*(44-17-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
