{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import dill\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from tabulate import tabulate\n",
    "from torch.utils.data import DataLoader\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from datasets.MD17.MD17Dataset import MD17SingleDataset\n",
    "from datasets.RMD17.RMD17Dataset import RMD17SingleDataset\n",
    "from datasets.SMD17.SMD17Dataset import SMD17SingleDataset\n",
    "from scripts.Chemistry.losses import EnergyLoss, PosForceLoss\n",
    "from scripts.commom_util import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_configs(model_path):\n",
    "    configs = dict()\n",
    "    with open(f\"{model_path}/info.txt\") as fp:\n",
    "        for line in fp:\n",
    "            line = line.strip().split(\" \")\n",
    "            if len(line)>1:\n",
    "                configs[line[0]] = line[1]\n",
    "\n",
    "    return configs\n",
    "\n",
    "def predict(model_path, epoch, configs, device, root):\n",
    "    # initialize\n",
    "    if configs[\"dataset\"]==\"MD17SingleDataset\":\n",
    "        dataset = MD17SingleDataset(configs[\"style\"], configs[\"molecule\"], \"test\", configs[\"split\"], root)\n",
    "    elif configs[\"dataset\"]==\"SMD17SingleDataset\":\n",
    "        dataset = SMD17SingleDataset(configs[\"style\"], configs[\"molecule\"], \"test\", configs[\"split\"], root)\n",
    "    else:\n",
    "        dataset = RMD17SingleDataset(configs[\"style\"], configs[\"molecule\"], \"test\", configs[\"split\"], root)\n",
    "    identifier = dataset.identifier\n",
    "    test_dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=dataset.collate)\n",
    "    \n",
    "    path = glob.glob(f\"{model_path}/{epoch:03d}_*.pth\")\n",
    "    if not path: raise ValueError(\"model not found\")\n",
    "    model = torch.load(path[0], map_location=torch.device(torch.cuda.current_device()), pickle_module=dill)\n",
    "    print(f\"Using model {path[0]}\")\n",
    "    model.eval()\n",
    "    \n",
    "    # test\n",
    "    preds = []\n",
    "    losses = []\n",
    "    tq = tqdm(test_dataloader)\n",
    "    for data, label in tq:\n",
    "        data = {i:v.to(device) for i, v in data.items()}\n",
    "        label = {i:v.to(device) for i, v in label.items()}\n",
    "        pred = model(data)\n",
    "        preds.append( ((torch.cat((pred[\"E\"], pred[\"F\"].reshape(1, -1)), axis=1)).squeeze()).tolist() )\n",
    "     \n",
    "        loss_E = (EnergyLoss(pred, label).to(\"cpu\").item())\n",
    "        loss_F = PosForceLoss(pred, label)\n",
    "        loss_F = [ (l.to(\"cpu\").item()) for l in loss_F]\n",
    "        losses.append([loss_E]+loss_F)\n",
    "    \n",
    "    # save result\n",
    "    save_reult(model_path, identifier, preds, losses)\n",
    "\n",
    "def metric(model_path, configs, if_print=False):\n",
    "    # Every dataset has different unit, unify them to Energy:eV and Force: eV/A\n",
    "    toEv = {\"RMD17SingleDataset\":0.0433634, \"MD17SingleDataset\":0.0433634, \"SMD17SingleDataset\":0.0433634}\n",
    "    ds = configs[\"dataset\"]\n",
    "    scale = toEv[configs[\"dataset\"]]\n",
    "\n",
    "    # load and calculate\n",
    "    style, molecule, split = configs[\"style\"], configs[\"molecule\"], configs[\"split\"]\n",
    "    with open(f\"{model_path}/loss_{ds}_{style}_{molecule}{split}test.csv\", newline='') as fp:\n",
    "        cdata = list(csv.reader(fp, quoting=csv.QUOTE_NONNUMERIC))\n",
    "        cdata = [[c for c in row] for row in cdata]\n",
    "        loss_Emole , loss_Fmole= [], []\n",
    "        hit, total = 0, len(cdata)\n",
    "        failE, failF = 0, 0\n",
    "\n",
    "        #for row in tqdm(cdata):\n",
    "        for row in cdata:\n",
    "            loss_Emole.append(row[0]) \n",
    "            loss_Fmole.append(row[1:])\n",
    "            m = max(row[1:])\n",
    "            if row[0]<=0.02/scale and m<=0.03/scale: hit += 1\n",
    "            else:\n",
    "                if row[0]>0.02/scale: failE += 1\n",
    "                if m>0.03/scale: failF += 1\n",
    "        loss_Emole , loss_Fmole= np.array(loss_Emole), np.array(loss_Fmole)\n",
    "\n",
    "        EMAE = np.mean(loss_Emole)\n",
    "        FMAE = np.mean(loss_Fmole)\n",
    "        EFWT = hit/total\n",
    "        if if_print:\n",
    "            print(f\"Energy MAE: {EMAE:.3f},\\tForce MAE: {FMAE:.3f},\\tEFwT: {EFWT:.3f} (failE:{(failE/total):.3f}, failF:{(failF/total):.3f})\\n\")\n",
    "        return EMAE, FMAE, EFWT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Energy Loss\n",
      "+--------+------------+----------+----------+----------+-----------+-----------+----------+-------------+\n",
      "|        | a          |        b |        e |        m |         n |         s |        t | u           |\n",
      "+========+============+==========+==========+==========+===========+===========+==========+=============+\n",
      "|  50000 | 240.38868  | 0.633092 | 8.2273   | 7.1516   | 14.9496   | 14.9326   | 0.502077 | 120.8090325 |\n",
      "+--------+------------+----------+----------+----------+-----------+-----------+----------+-------------+\n",
      "| 100000 | 15.5846025 | 0.113403 | 0.566786 | 3.97339  |  0.517235 | 21.3781   | 0.74044  | 9.7107275   |\n",
      "+--------+------------+----------+----------+----------+-----------+-----------+----------+-------------+\n",
      "| 150000 | 0.60048    | 0.103456 | 0.774109 | 0.376115 |  0.464251 |  0.272862 | 0.512492 | N/A         |\n",
      "+--------+------------+----------+----------+----------+-----------+-----------+----------+-------------+\n",
      "| 200000 | N/A        | 0.16004  | 0.37205  | 0.359689 |  0.613781 |  0.319805 | 0.233156 | N/A         |\n",
      "+--------+------------+----------+----------+----------+-----------+-----------+----------+-------------+\n",
      "\n",
      "Force Loss\n",
      "+--------+--------------------+----------+---------+---------+----------+----------+----------+--------------------+\n",
      "|        | a                  |        b |       e |       m |        n |        s |        t | u                  |\n",
      "+========+====================+==========+=========+=========+==========+==========+==========+====================+\n",
      "|  50000 | 1.7122653425161836 | 0.366256 | 1.54174 | 1.66467 | 1.01126  | 1.13166  | 1.02488  | 1.2269875143703364 |\n",
      "+--------+--------------------+----------+---------+---------+----------+----------+----------+--------------------+\n",
      "| 100000 | 1.3056154041689445 | 0.31895  | 1.37861 | 1.4175  | 0.94432  | 0.865706 | 0.873079 | 0.9570960620620671 |\n",
      "+--------+--------------------+----------+---------+---------+----------+----------+----------+--------------------+\n",
      "| 150000 | 1.1591457283788158 | 0.313517 | 1.32536 | 1.33295 | 0.83129  | 0.771269 | 0.79715  | N/A                |\n",
      "+--------+--------------------+----------+---------+---------+----------+----------+----------+--------------------+\n",
      "| 200000 | N/A                | 0.301985 | 1.30287 | 1.30206 | 0.723337 | 0.727618 | 0.749723 | N/A                |\n",
      "+--------+--------------------+----------+---------+---------+----------+----------+----------+--------------------+\n"
     ]
    }
   ],
   "source": [
    "# Scalability test of our proposed methods \n",
    "# Model: SchNet\n",
    "# Dataset: MD17\n",
    "\n",
    "need_predict = False\n",
    "EMAE_table = [[\"a\", \"b\", \"e\", \"m\", \"n\", \"s\", \"t\", \"u\"]]\n",
    "FMAE_table = [[\"a\", \"b\", \"e\", \"m\", \"n\", \"s\", \"t\", \"u\"]]\n",
    "EFWT_table = [[\"a\", \"b\", \"e\", \"m\", \"n\", \"s\", \"t\", \"u\"]]\n",
    "\n",
    "for num_sample in [\"50000\", \"100000\", \"150000\", \"200000\"]:\n",
    "    EMAE_row = [num_sample]\n",
    "    FMAE_row = [num_sample]\n",
    "    EFWT_row = [num_sample]\n",
    "    \n",
    "    for mole in [\"a\", \"b\", \"e\", \"m\", \"n\", \"s\", \"t\", \"u\"]:\n",
    "        #print(f\"act:{act}, molecule:{mole}\")\n",
    "\n",
    "        if (num_sample==\"150000\" and mole==\"u\") or (num_sample==\"200000\" and mole==\"a\") or (num_sample==\"200000\" and mole==\"u\"):\n",
    "            EMAE_row.append(\"N/A\")\n",
    "            FMAE_row.append(\"N/A\")\n",
    "            EFWT_row.append(\"N/A\")\n",
    "            continue\n",
    " \n",
    "        model_path = f\"./checkpoints/schnet_SMD{num_sample}_IR_{mole}\"\n",
    "        cfg = get_configs(model_path)\n",
    "\n",
    "        if need_predict:\n",
    "            root = \"../../datasets/SMD17/datas\"\n",
    "            predict(model_path, 49, cfg, \"cuda\", root)\n",
    "        EMAE, FMAE, EFWT = metric(model_path, cfg)\n",
    "        EMAE_row.append(EMAE)\n",
    "        FMAE_row.append(FMAE)\n",
    "        EFWT_row.append(EFWT)\n",
    "    \n",
    "    EMAE_table.append(EMAE_row)\n",
    "    FMAE_table.append(FMAE_row)\n",
    "    EFWT_table.append(EFWT_row)\n",
    "\n",
    "tableE = tabulate(EMAE_table, headers='firstrow', tablefmt='grid')\n",
    "tableF = tabulate(FMAE_table, headers='firstrow', tablefmt='grid')\n",
    "\n",
    "print(\"Energy Loss\")\n",
    "print(tableE)\n",
    "print(\"\")\n",
    "print(\"Force Loss\")\n",
    "print(tableF)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ocp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
