{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data2/haiyang/anaconda3/envs/QHBench/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.insert(0, os.path.dirname(os.getcwd()))\n",
    "\n",
    "import torch\n",
    "from datasets import QH9Stable, QH9Dynamic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Here is the statistics of the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hamiltonian_size(molecule_atoms):\n",
    "    atom_mask_periodic_row1 = molecule_atoms <= 2\n",
    "    atom_mask_periodic_row2 = molecule_atoms > 2\n",
    "    num_orbitals = atom_mask_periodic_row2.sum() * 14 + (atom_mask_periodic_row1.sum()) * 2\n",
    "    return num_orbitals\n",
    "\n",
    "\n",
    "def get_dataset_statistic(ori_dataset):\n",
    "    statistic_info = {}\n",
    "    dataset_split_name = ['train', 'val', 'test']\n",
    "    for split_name in dataset_split_name:\n",
    "        statistic_info[split_name] = {}\n",
    "        dataset = ori_dataset[getattr(ori_dataset, f'{split_name}_mask')]\n",
    "\n",
    "        all_num_nodes = [data.num_nodes for data in dataset]\n",
    "        all_num_nodes = torch.tensor(all_num_nodes).float()\n",
    "        num_node_mean, num_node_min, num_node_max, num_node_median = \\\n",
    "            all_num_nodes.mean(), all_num_nodes.min(), all_num_nodes.max(), all_num_nodes.median()\n",
    "\n",
    "        all_electronics = torch.tensor([data.atoms.sum() for data in dataset]).float()\n",
    "        num_electronics_mean, num_electronics_min, num_electronics_max, num_electronics_median = \\\n",
    "            all_electronics.mean(), all_electronics.min(), all_electronics.max(), all_electronics.median()\n",
    "\n",
    "        all_hamiltonian_matrix_size = [get_hamiltonian_size(data.atoms) for data in dataset]\n",
    "        all_hamiltonian_matrix_size = torch.tensor(all_hamiltonian_matrix_size).float()\n",
    "        hamiltonian_size_mean, hamiltonian_size_min, hamiltonian_size_max, hamiltonian_size_median = \\\n",
    "            all_hamiltonian_matrix_size.mean(), all_hamiltonian_matrix_size.min(), \\\n",
    "            all_hamiltonian_matrix_size.max(), all_hamiltonian_matrix_size.median()\n",
    "\n",
    "        statistic_info[split_name]['num_node_mean'], statistic_info[split_name]['num_node_min'], \\\n",
    "            statistic_info[split_name]['num_node_max'], statistic_info[split_name]['num_node_median'] = \\\n",
    "            num_node_mean.item(), num_node_min.item(), num_node_max.item(), num_node_median.item()\n",
    "\n",
    "        statistic_info[split_name]['num_electronics_mean'], statistic_info[split_name]['num_electronics_min'], \\\n",
    "            statistic_info[split_name]['num_electronics_max'], statistic_info[split_name]['num_electronics_median'] = \\\n",
    "            num_electronics_mean.item(), num_electronics_min.item(), num_electronics_max.item(), num_electronics_median.item()\n",
    "\n",
    "        statistic_info[split_name]['hamiltonian_size_mean'], statistic_info[split_name]['hamiltonian_size_min'], \\\n",
    "            statistic_info[split_name]['hamiltonian_size_max'], statistic_info[split_name]['hamiltonian_size_median'], \\\n",
    "            = hamiltonian_size_mean.item(), hamiltonian_size_min.item(), hamiltonian_size_max.item(), hamiltonian_size_median.item()\n",
    "\n",
    "    return statistic_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_stable_random = QH9Stable(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), split='random')\n",
    "dataset_stable_random_statistic= get_dataset_statistic(dataset_stable_random)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing...\n",
      "  0%|          | 113M/30.5G [00:19<22:44, 22.3MB/s]"
     ]
    }
   ],
   "source": [
    "dataset_stable_ood = QH9Stable(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), split='size_ood')\n",
    "dataset_stable_ood_statistic = get_dataset_statistic(dataset_stable_ood)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_dynamic_geo_100k = QH9Dynamic(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), version='100k', split='geometry')\n",
    "dataset_dynamic_geo_100k_statistic = get_dataset_statistic(dataset_dynamic_geo_100k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_dynamic_mol_100k = QH9Dynamic(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), version='100k', split='mol')\n",
    "dataset_dynamic_mol_100k_statistic = get_dataset_statistic(dataset_dynamic_mol_100k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train': {'num_node_mean': 18.03936004638672,\n",
       "  'num_node_min': 7.0,\n",
       "  'num_node_max': 27.0,\n",
       "  'num_node_median': 18.0,\n",
       "  'num_electronics_mean': 65.87992095947266,\n",
       "  'num_electronics_min': 24.0,\n",
       "  'num_electronics_max': 74.0,\n",
       "  'num_electronics_median': 66.0,\n",
       "  'hamiltonian_size_mean': 141.5890655517578,\n",
       "  'hamiltonian_size_min': 54.0,\n",
       "  'hamiltonian_size_max': 162.0,\n",
       "  'hamiltonian_size_median': 144.0},\n",
       " 'val': {'num_node_mean': 18.03936004638672,\n",
       "  'num_node_min': 7.0,\n",
       "  'num_node_max': 27.0,\n",
       "  'num_node_median': 18.0,\n",
       "  'num_electronics_mean': 65.87992095947266,\n",
       "  'num_electronics_min': 24.0,\n",
       "  'num_electronics_max': 74.0,\n",
       "  'num_electronics_median': 66.0,\n",
       "  'hamiltonian_size_mean': 141.5890655517578,\n",
       "  'hamiltonian_size_min': 54.0,\n",
       "  'hamiltonian_size_max': 162.0,\n",
       "  'hamiltonian_size_median': 144.0},\n",
       " 'test': {'num_node_mean': 18.03936004638672,\n",
       "  'num_node_min': 7.0,\n",
       "  'num_node_max': 27.0,\n",
       "  'num_node_median': 18.0,\n",
       "  'num_electronics_mean': 65.87992095947266,\n",
       "  'num_electronics_min': 24.0,\n",
       "  'num_electronics_max': 74.0,\n",
       "  'num_electronics_median': 66.0,\n",
       "  'hamiltonian_size_mean': 141.5890655517578,\n",
       "  'hamiltonian_size_min': 54.0,\n",
       "  'hamiltonian_size_max': 162.0,\n",
       "  'hamiltonian_size_median': 144.0}}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_dynamic_geo_300k = QH9Dynamic(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), version='300k', split='geometry')\n",
    "dataset_dynamic_geo_300k_statistic = get_dataset_statistic(dataset_dynamic_geo_300k)\n",
    "dataset_dynamic_geo_300k_statistic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train': {'num_node_mean': 18.015846252441406,\n",
       "  'num_node_min': 7.0,\n",
       "  'num_node_max': 27.0,\n",
       "  'num_node_median': 18.0,\n",
       "  'num_electronics_mean': 65.91242980957031,\n",
       "  'num_electronics_min': 24.0,\n",
       "  'num_electronics_max': 74.0,\n",
       "  'num_electronics_median': 66.0,\n",
       "  'hamiltonian_size_mean': 141.58465576171875,\n",
       "  'hamiltonian_size_min': 54.0,\n",
       "  'hamiltonian_size_max': 162.0,\n",
       "  'hamiltonian_size_median': 144.0},\n",
       " 'val': {'num_node_mean': 18.153846740722656,\n",
       "  'num_node_min': 10.0,\n",
       "  'num_node_max': 25.0,\n",
       "  'num_node_median': 18.0,\n",
       "  'num_electronics_mean': 65.71237182617188,\n",
       "  'num_electronics_min': 34.0,\n",
       "  'num_electronics_max': 74.0,\n",
       "  'num_electronics_median': 66.0,\n",
       "  'hamiltonian_size_mean': 141.17726135253906,\n",
       "  'hamiltonian_size_min': 72.0,\n",
       "  'hamiltonian_size_max': 158.0,\n",
       "  'hamiltonian_size_median': 144.0},\n",
       " 'test': {'num_node_mean': 18.112957000732422,\n",
       "  'num_node_min': 9.0,\n",
       "  'num_node_max': 25.0,\n",
       "  'num_node_median': 18.0,\n",
       "  'num_electronics_mean': 65.7873764038086,\n",
       "  'num_electronics_min': 50.0,\n",
       "  'num_electronics_max': 72.0,\n",
       "  'num_electronics_median': 66.0,\n",
       "  'hamiltonian_size_mean': 142.03321838378906,\n",
       "  'hamiltonian_size_min': 102.0,\n",
       "  'hamiltonian_size_max': 158.0,\n",
       "  'hamiltonian_size_median': 144.0}}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_dynamic_mol_300k = QH9Dynamic(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), version='300k', split='mol')\n",
    "dataset_dynamic_mol_300k_statistic = get_dataset_statistic(dataset_dynamic_mol_300k)\n",
    "dataset_dynamic_mol_300k_statistic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train': {'num_node_mean': 18.023332595825195,\n",
       "  'num_node_min': 3.0,\n",
       "  'num_node_max': 29.0,\n",
       "  'num_node_median': 18.0,\n",
       "  'num_electronics_mean': 65.89612579345703,\n",
       "  'num_electronics_min': 10.0,\n",
       "  'num_electronics_max': 74.0,\n",
       "  'num_electronics_median': 66.0,\n",
       "  'hamiltonian_size_mean': 141.5970001220703,\n",
       "  'hamiltonian_size_min': 18.0,\n",
       "  'hamiltonian_size_max': 166.0,\n",
       "  'hamiltonian_size_median': 144.0},\n",
       " 'val': {'num_node_mean': 18.026752471923828,\n",
       "  'num_node_min': 6.0,\n",
       "  'num_node_max': 29.0,\n",
       "  'num_node_median': 18.0,\n",
       "  'num_electronics_mean': 65.90185546875,\n",
       "  'num_electronics_min': 18.0,\n",
       "  'num_electronics_max': 74.0,\n",
       "  'num_electronics_median': 66.0,\n",
       "  'hamiltonian_size_mean': 141.6219482421875,\n",
       "  'hamiltonian_size_min': 36.0,\n",
       "  'hamiltonian_size_max': 166.0,\n",
       "  'hamiltonian_size_median': 144.0},\n",
       " 'test': {'num_node_mean': 18.035158157348633,\n",
       "  'num_node_min': 4.0,\n",
       "  'num_node_max': 29.0,\n",
       "  'num_node_median': 18.0,\n",
       "  'num_electronics_mean': 65.8647232055664,\n",
       "  'num_electronics_min': 24.0,\n",
       "  'num_electronics_max': 74.0,\n",
       "  'num_electronics_median': 66.0,\n",
       "  'hamiltonian_size_mean': 141.55824279785156,\n",
       "  'hamiltonian_size_min': 48.0,\n",
       "  'hamiltonian_size_max': 166.0,\n",
       "  'hamiltonian_size_median': 144.0}}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_stable_random_statistic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_stable_ood = QH9Stable(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), split='size_ood')\n",
    "dataset_stable_ood_statistic = get_dataset_statistic(dataset_stable_ood)\n",
    "dataset_stable_ood_statistic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([], dtype=int64)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_stable_ood.train_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "QHBench",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
