{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets.OC22.OC22Dataset import OC22LmdbDataset\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/10000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "200000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [00:16<00:00, 590.55it/s]\n"
     ]
    }
   ],
   "source": [
    "ds = OC22LmdbDataset(task=\"train\", split=\"200000\", root=\"/usr/data1/OC22/s2ef_total_train_val_test_lmdbs/data/oc22/s2ef-total\")\n",
    "print(len(ds))\n",
    "energy = []\n",
    "force = []\n",
    "\n",
    "dataloader = DataLoader(ds, batch_size=20, shuffle=True, collate_fn=ds.collate, num_workers=16)\n",
    "for data, label in tqdm(dataloader):\n",
    "    #print(label[\"E\"].shape)\n",
    "    #print(label[\"F\"].shape)\n",
    "    energy.append(label[\"E\"])\n",
    "    force.append(label[\"F\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Energy\n",
      "\tmin: -1932.1065673828125\n",
      "\tmax: 717.4010620117188\n",
      "\tmean: -495.7373962402344\n",
      "\tvar: 45306.06640625\n",
      "Force\n",
      "\tmin: -49.94717025756836\n",
      "\tmax: 49.96593475341797\n",
      "\tmean: 2.626641308969835e-12\n",
      "\tvar: 0.06426408141851425\n"
     ]
    }
   ],
   "source": [
    "energy = torch.stack((energy,), dim=0)\n",
    "force = torch.cat(force[:-1], dim=0)\n",
    "print(\"Energy\")\n",
    "print(f\"\\tmin: {torch.min(energy)}\")\n",
    "print(f\"\\tmax: {torch.max(energy)}\")\n",
    "print(f\"\\tmean: {torch.mean(energy)}\")\n",
    "print(f\"\\tvar: {torch.var(energy)}\")\n",
    "print(\"Force\")\n",
    "print(f\"\\tmin: {torch.min(force)}\")\n",
    "print(f\"\\tmax: {torch.max(force)}\")\n",
    "print(f\"\\tmean: {torch.mean(force)}\")\n",
    "print(f\"\\tvar: {torch.var(force)}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ocp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
