{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f268900a-9f9a-474f-9ef4-dd7a6eaf20cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "###########################################################################################\n",
    "# Script for evaluating configurations contained in an xyz file with a trained model\n",
    "# Authors: Ilyes Batatia, Gregor Simm\n",
    "# This program is distributed under the MIT License (see MIT.md)\n",
    "###########################################################################################\n",
    "\n",
    "import argparse\n",
    "\n",
    "import ase.data\n",
    "import ase.io\n",
    "import numpy as np\n",
    "import torch\n",
    "# from e3nn import o3\n",
    "\n",
    "from mace import data\n",
    "from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq\n",
    "from mace.modules.utils import extract_invariant\n",
    "from mace.tools import torch_geometric, torch_tools, utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b71deb4b-3011-4b3d-a40a-d585ab7dbf57",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch_tools.set_default_dtype('float64')\n",
    "model = './checkpoints/mace_3BPA_2_run-123.model'\n",
    "device = 'cuda:4'\n",
    "model = torch.load(f=model, map_location=device)\n",
    "configs = \"datasets/3BPA/test_600K.xyz\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1aadcdd-66e0-4f71-86f7-03900a432272",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.num_interactions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "464ea6c9-76e1-4e69-a8ac-c8496c9e3eae",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    heads = model.heads\n",
    "except AttributeError:\n",
    "    heads = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "267da43d-8f8f-40a5-8cee-9187ea52a9d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "for atoms, config in zip(atoms_list, configs):\n",
    "    config.properties['energy']=torch.tensor(atoms.get_potential_energy(), dtype=torch.get_default_dtype())\n",
    "    config.properties['forces']=torch.tensor(atoms.get_forces(), dtype=torch.get_default_dtype())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d30dcd14-1f53-4258-8294-0b727dbc20e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 2\n",
    "data_loader = torch_geometric.dataloader.DataLoader(\n",
    "    dataset=[\n",
    "        data.AtomicData.from_config(\n",
    "            config, z_table=z_table, cutoff=float(model.r_max), heads=heads\n",
    "        )\n",
    "        for config in configs\n",
    "    ],\n",
    "    batch_size=batch_size,\n",
    "    shuffle=False,\n",
    "    drop_last=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3b1f2fc-5a5d-45ab-95e6-549c9cdf5efa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn.functional import l1_loss, mse_loss\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Collect data\n",
    "energies_mae_list = []\n",
    "forces_mae_list = []\n",
    "contributions_list = []\n",
    "descriptors_list = []\n",
    "node_energies_list = []\n",
    "stresses_list = []\n",
    "forces_collection = []\n",
    "\n",
    "for batch in tqdm(data_loader):\n",
    "    batch = batch.to(device)\n",
    "    output = model(batch.to_dict(), compute_stress=False)\n",
    "    forces = np.split(\n",
    "        torch_tools.to_numpy(output[\"forces\"]),\n",
    "        indices_or_sections=batch.ptr[1:],\n",
    "        axis=0,\n",
    "    )\n",
    "    forces_collection.append(forces[:-1])  # drop last as its empty\n",
    "    energy_loss = l1_loss(batch.energy, output['energy'])\n",
    "    forces_loss = l1_loss(batch.forces, output['forces'])\n",
    "    print(energy_loss, forces_loss)\n",
    "    energies_list.append(energy_loss)\n",
    "    forces_list.append(forces_loss)\n",
    "    break\n",
    "\n",
    "print(sum(energies_mae_list)/len(energies_mae_list))\n",
    "print(sum(forces_mae_list)/len(forces_mae_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2307016-b235-4779-b3b4-7be2cb860031",
   "metadata": {},
   "outputs": [],
   "source": [
    "a.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "202ea376-aa55-4728-8463-520c6acbbad6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T13:51:16.187046Z",
     "iopub.status.busy": "2025-09-24T13:51:16.186109Z",
     "iopub.status.idle": "2025-09-24T13:51:17.492957Z",
     "shell.execute_reply": "2025-09-24T13:51:17.491275Z",
     "shell.execute_reply.started": "2025-09-24T13:51:16.186948Z"
    }
   },
   "outputs": [],
   "source": [
    "import argparse\n",
    "import ase.data\n",
    "import ase.io\n",
    "import numpy as np\n",
    "import torch\n",
    "from mace import data\n",
    "from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq\n",
    "from mace.modules.utils import extract_invariant\n",
    "from mace.tools import torch_geometric, torch_tools, utils\n",
    "import torch.nn as nn\n",
    "from torch.nn.functional import l1_loss\n",
    "from tqdm import tqdm\n",
    "\n",
    "torch_tools.set_default_dtype('float64')\n",
    "config_path = \"datasets/semi/SiN/Testset.xyz\"\n",
    "atoms_list = ase.io.read(config_path, index=\":\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "142351be-f1a0-48cd-9015-96015754a083",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T13:51:18.755378Z",
     "iopub.status.busy": "2025-09-24T13:51:18.754622Z",
     "iopub.status.idle": "2025-09-24T13:51:18.864493Z",
     "shell.execute_reply": "2025-09-24T13:51:18.863113Z",
     "shell.execute_reply.started": "2025-09-24T13:51:18.755308Z"
    }
   },
   "outputs": [],
   "source": [
    "configs = [data.config_from_atoms(atoms) for atoms in atoms_list]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2b3f2e95-ea3e-45f6-8209-a6310edd1d57",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T13:51:20.727558Z",
     "iopub.status.busy": "2025-09-24T13:51:20.726879Z",
     "iopub.status.idle": "2025-09-24T13:51:21.950615Z",
     "shell.execute_reply": "2025-09-24T13:51:21.949431Z",
     "shell.execute_reply.started": "2025-09-24T13:51:20.727495Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_664375/1697944886.py:4: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model = torch.load(model_path)\n"
     ]
    }
   ],
   "source": [
    "model_path = f'./checkpoints/DMACE_SiN_run-123_stagetwo.model'\n",
    "\n",
    "# Load model\n",
    "model = torch.load(model_path)\n",
    "z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers])\n",
    "try:\n",
    "    heads = model.heads\n",
    "except AttributeError:\n",
    "    heads = None\n",
    "\n",
    "for atoms, config in zip(atoms_list, configs):\n",
    "    config.properties['energy']=torch.tensor(atoms.get_potential_energy(), dtype=torch.get_default_dtype())\n",
    "    config.properties['forces']=torch.tensor(atoms.get_forces(), dtype=torch.get_default_dtype())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f0e4a3b5-262e-4bf1-ac89-53557084efe5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T13:51:23.344765Z",
     "iopub.status.busy": "2025-09-24T13:51:23.344211Z",
     "iopub.status.idle": "2025-09-24T13:51:59.887262Z",
     "shell.execute_reply": "2025-09-24T13:51:59.885531Z",
     "shell.execute_reply.started": "2025-09-24T13:51:23.344736Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hkhong/Research/MLIP/diffmace/mace/data/atomic_data.py:245: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.tensor(\n",
      "/home/hkhong/Research/MLIP/diffmace/mace/data/atomic_data.py:252: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.tensor(\n"
     ]
    }
   ],
   "source": [
    "batch_size = 2\n",
    "data_loader = torch_geometric.dataloader.DataLoader(\n",
    "    dataset=[\n",
    "        data.AtomicData.from_config(\n",
    "            config, z_table=z_table, cutoff=float(model.r_max), heads=heads\n",
    "        )\n",
    "        for config in configs\n",
    "    ],\n",
    "    batch_size=batch_size,\n",
    "    shuffle=False,\n",
    "    drop_last=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "2b7ecc80-4173-4387-bbe0-3baeccd9aebd",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T13:52:05.336964Z",
     "iopub.status.busy": "2025-09-24T13:52:05.336645Z",
     "iopub.status.idle": "2025-09-24T13:52:07.184514Z",
     "shell.execute_reply": "2025-09-24T13:52:07.183756Z",
     "shell.execute_reply.started": "2025-09-24T13:52:05.336937Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1433 [00:01<?, ?it/s]\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "for batch in tqdm(data_loader):\n",
    "    # batch = batch.to(device)\n",
    "    output = model(batch.to_dict(), compute_stress=False)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "c564c5cd-739f-4848-a787-5e11b65c666a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T14:04:08.567792Z",
     "iopub.status.busy": "2025-09-24T14:04:08.567111Z",
     "iopub.status.idle": "2025-09-24T14:04:08.578115Z",
     "shell.execute_reply": "2025-09-24T14:04:08.577054Z",
     "shell.execute_reply.started": "2025-09-24T14:04:08.567730Z"
    }
   },
   "outputs": [],
   "source": [
    "energy_loss = l1_loss(batch.energy, output['energy'])\n",
    "el = abs(output['energy'] - batch.energy)/(batch.ptr[1:]-batch.ptr[:-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "5fddc21e-4750-41f9-aa7e-0f037526426e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T14:04:11.195124Z",
     "iopub.status.busy": "2025-09-24T14:04:11.194681Z",
     "iopub.status.idle": "2025-09-24T14:04:11.199712Z",
     "shell.execute_reply": "2025-09-24T14:04:11.199249Z",
     "shell.execute_reply.started": "2025-09-24T14:04:11.195100Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0002, 0.0148], grad_fn=<DivBackward0>)"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "el"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "dd4969cd-d061-45c0-a71b-098e9f751061",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T13:54:55.419360Z",
     "iopub.status.busy": "2025-09-24T13:54:55.419072Z",
     "iopub.status.idle": "2025-09-24T13:54:55.427820Z",
     "shell.execute_reply": "2025-09-24T13:54:55.426917Z",
     "shell.execute_reply.started": "2025-09-24T13:54:55.419337Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0088, 0.7121], grad_fn=<AbsBackward0>)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "abs(output['energy'] - batch.energy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "cb9a5708-4162-439f-9dbd-911a373a7056",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T13:56:52.143961Z",
     "iopub.status.busy": "2025-09-24T13:56:52.143278Z",
     "iopub.status.idle": "2025-09-24T13:56:52.153676Z",
     "shell.execute_reply": "2025-09-24T13:56:52.152624Z",
     "shell.execute_reply.started": "2025-09-24T13:56:52.143897Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([96, 3])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output['forces'].size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "208312e7-8536-4823-aec7-077ed7b88b94",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T13:58:55.995497Z",
     "iopub.status.busy": "2025-09-24T13:58:55.994886Z",
     "iopub.status.idle": "2025-09-24T13:58:56.002783Z",
     "shell.execute_reply": "2025-09-24T13:58:56.001472Z",
     "shell.execute_reply.started": "2025-09-24T13:58:55.995441Z"
    }
   },
   "outputs": [],
   "source": [
    "mse_loss_fn = nn.MSELoss()\n",
    "f_loss = mse_loss_fn(batch.forces, output['forces'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "723be374-95f9-494f-ab22-4719348b8493",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T13:59:00.208657Z",
     "iopub.status.busy": "2025-09-24T13:59:00.207966Z",
     "iopub.status.idle": "2025-09-24T13:59:00.221745Z",
     "shell.execute_reply": "2025-09-24T13:59:00.220513Z",
     "shell.execute_reply.started": "2025-09-24T13:59:00.208593Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.6269)"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "329a0702-6c5a-4cb0-a98a-19a3d81b2e60",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T14:01:20.706845Z",
     "iopub.status.busy": "2025-09-24T14:01:20.706162Z",
     "iopub.status.idle": "2025-09-24T14:01:20.715048Z",
     "shell.execute_reply": "2025-09-24T14:01:20.712922Z",
     "shell.execute_reply.started": "2025-09-24T14:01:20.706780Z"
    }
   },
   "outputs": [],
   "source": [
    "bf = batch.forces\n",
    "of = output['forces']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "09e4e96b-39f7-4955-8528-ac5107d5635e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T14:07:10.484055Z",
     "iopub.status.busy": "2025-09-24T14:07:10.483031Z",
     "iopub.status.idle": "2025-09-24T14:07:10.493937Z",
     "shell.execute_reply": "2025-09-24T14:07:10.492276Z",
     "shell.execute_reply.started": "2025-09-24T14:07:10.484028Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2.3869e-01, 3.1054e-01, 1.0598e+00, 9.4753e-01, 6.8385e-02, 1.1678e-02,\n",
       "        2.2708e-01, 2.1223e-04, 1.3242e-01, 1.3848e-01, 3.1138e-01, 9.3219e-01,\n",
       "        1.7299e-01, 8.5236e-01, 4.8234e-02, 1.7238e-01, 7.3588e-01, 5.4389e-02,\n",
       "        6.1958e-01, 9.5182e-02, 6.4095e-02, 3.3128e-01, 6.7441e-02, 1.0695e+00,\n",
       "        5.9222e-05, 4.1010e-01, 1.9186e+00, 1.3400e+00, 2.8073e-02, 1.7802e+00,\n",
       "        8.4291e-01, 2.4644e-01, 5.9737e-01, 1.5719e+00, 1.9774e+00, 4.4056e-01,\n",
       "        3.8699e-01, 3.7515e-02, 3.9306e-01, 1.2927e+00, 9.3178e-02, 1.6268e+00,\n",
       "        6.7935e-01, 1.2011e-01, 2.7635e-01, 1.8361e+00, 2.5081e-02, 3.4711e+00,\n",
       "        4.8718e-01, 1.7102e-01, 4.5206e-03, 1.7820e+00, 9.4035e-01, 2.3417e-02,\n",
       "        5.3944e-04, 8.1636e-02, 1.4151e+00, 1.5660e-02, 1.4036e+00, 1.3454e+00,\n",
       "        3.2718e-01, 2.0292e-01, 5.8102e-01, 3.8440e-02, 1.3737e-05, 5.3737e-01,\n",
       "        5.8745e-01, 4.3885e-02, 9.1922e-01, 5.1017e-01, 2.0184e+00, 7.0754e-03,\n",
       "        1.5316e-01, 1.5294e-01, 7.0643e-02, 8.2163e-01, 3.7608e-02, 1.6929e-03,\n",
       "        8.7310e-02, 1.1005e+00, 2.5407e-01, 1.1405e+00, 1.4928e-02, 3.0937e-03,\n",
       "        1.1075e+00, 1.3891e-01, 6.4928e+00, 2.7444e-01, 6.1310e-01, 6.9359e-01,\n",
       "        9.1412e-01, 8.8199e-01, 1.1385e+00, 1.0900e+00, 9.2418e-01, 1.6246e-01])"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.square(bf[:, 0]-of[:, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "efb0f838-46c1-40c4-9bf6-c8041cd27905",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T14:12:25.372261Z",
     "iopub.status.busy": "2025-09-24T14:12:25.371790Z",
     "iopub.status.idle": "2025-09-24T14:12:25.380907Z",
     "shell.execute_reply": "2025-09-24T14:12:25.379519Z",
     "shell.execute_reply.started": "2025-09-24T14:12:25.372213Z"
    }
   },
   "outputs": [],
   "source": [
    "energy_mse = mse_loss_fn(batch.energy, output['energy'])\n",
    "forces_mse = mse_loss_fn(batch.forces, output['forces'])\n",
    "energy_loss = torch.sqrt(energy_mse)\n",
    "forces_loss = torch.sqrt(forces_mse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "017b2559-abd9-4de0-9155-e31cd6f3ee24",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T14:12:56.786709Z",
     "iopub.status.busy": "2025-09-24T14:12:56.786422Z",
     "iopub.status.idle": "2025-09-24T14:12:56.794057Z",
     "shell.execute_reply": "2025-09-24T14:12:56.793095Z",
     "shell.execute_reply.started": "2025-09-24T14:12:56.786686Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3605, grad_fn=<MeanBackward0>)"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "energy_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "3f462751-1759-4672-b809-976bd10e9de7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T14:12:57.947781Z",
     "iopub.status.busy": "2025-09-24T14:12:57.947500Z",
     "iopub.status.idle": "2025-09-24T14:12:57.955766Z",
     "shell.execute_reply": "2025-09-24T14:12:57.954735Z",
     "shell.execute_reply.started": "2025-09-24T14:12:57.947759Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.6403)"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "forces_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "bcadfad6-487f-4b99-8ad0-2d43f8e16344",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T14:12:55.092623Z",
     "iopub.status.busy": "2025-09-24T14:12:55.092337Z",
     "iopub.status.idle": "2025-09-24T14:12:55.100446Z",
     "shell.execute_reply": "2025-09-24T14:12:55.099078Z",
     "shell.execute_reply.started": "2025-09-24T14:12:55.092600Z"
    }
   },
   "outputs": [],
   "source": [
    "energy_loss = l1_loss(batch.energy, output['energy'])\n",
    "forces_loss = l1_loss(batch.forces, output['forces'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "26d2dd51-7dee-469b-90e5-a8414b92ed80",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T14:16:11.665030Z",
     "iopub.status.busy": "2025-09-24T14:16:11.664742Z",
     "iopub.status.idle": "2025-09-24T14:16:11.676951Z",
     "shell.execute_reply": "2025-09-24T14:16:11.675584Z",
     "shell.execute_reply.started": "2025-09-24T14:16:11.665001Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0105, grad_fn=<SqrtBackward0>)"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.sqrt(sum(torch.square((batch.energy-output['energy'])/(batch.ptr[1:]-batch.ptr[:-1])))/batch.energy.size(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "a465d751-b701-4f1c-b18a-4634b1ade4ef",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T14:15:37.768237Z",
     "iopub.status.busy": "2025-09-24T14:15:37.767398Z",
     "iopub.status.idle": "2025-09-24T14:15:37.779564Z",
     "shell.execute_reply": "2025-09-24T14:15:37.778158Z",
     "shell.execute_reply.started": "2025-09-24T14:15:37.768187Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.0002, -0.0148], grad_fn=<DivBackward0>)"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.sqrt(sum(torch.square((batch.forces-output['forces'])/(batch.ptr[1:]-batch.ptr[:-1])))/batch.energy.size(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "88ca350f-d445-49b3-a3f1-fe9b0c47ff49",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T14:20:56.902076Z",
     "iopub.status.busy": "2025-09-24T14:20:56.901716Z",
     "iopub.status.idle": "2025-09-24T14:20:56.908977Z",
     "shell.execute_reply": "2025-09-24T14:20:56.908148Z",
     "shell.execute_reply.started": "2025-09-24T14:20:56.902052Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.7918)"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.sqrt(sum(sum(torch.square(batch.forces-output['forces'])))/(batch.ptr[-1]*3))\n",
    "# torch.sqrt(sum(torch.square(/(batch.ptr[1:]-batch.ptr[:-1])))/batch.energy.size(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "3ad92f75-6d72-4ef4-a940-225506603279",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T14:21:34.181134Z",
     "iopub.status.busy": "2025-09-24T14:21:34.180397Z",
     "iopub.status.idle": "2025-09-24T14:21:34.195056Z",
     "shell.execute_reply": "2025-09-24T14:21:34.193334Z",
     "shell.execute_reply.started": "2025-09-24T14:21:34.181068Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3605, grad_fn=<MeanBackward0>)"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "l1_loss(batch.energy, output['energy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "091b867e-4c34-4280-a298-be4dd97935df",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T14:22:34.337242Z",
     "iopub.status.busy": "2025-09-24T14:22:34.336331Z",
     "iopub.status.idle": "2025-09-24T14:22:34.351327Z",
     "shell.execute_reply": "2025-09-24T14:22:34.349339Z",
     "shell.execute_reply.started": "2025-09-24T14:22:34.337175Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0075, grad_fn=<DivBackward0>)"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(abs(batch.energy-output['energy'])/(batch.ptr[1:]-batch.ptr[:-1]))/batch.energy.size(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "aa068680-4a05-466a-9b5d-12b12b263db5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T15:17:01.456989Z",
     "iopub.status.busy": "2025-09-24T15:17:01.456548Z",
     "iopub.status.idle": "2025-09-24T15:17:22.383752Z",
     "shell.execute_reply": "2025-09-24T15:17:22.382626Z",
     "shell.execute_reply.started": "2025-09-24T15:17:01.456950Z"
    }
   },
   "outputs": [],
   "source": [
    "import argparse\n",
    "from ase.io import read\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "\n",
    "# def parse_arguments():\n",
    "#     parser = argparse.ArgumentParser(description='Analyze bond stability in molecular dynamics trajectory')\n",
    "#     parser.add_argument('--traj_file', type=str, required=True, help='Path to trajectory file (.traj)')\n",
    "#     parser.add_argument('--ref_file', type=str, required=True, help='Path to reference file (.xyz)')\n",
    "#     parser.add_argument('--bond_threshold', type=float, default=1.8, help='Bond length threshold in Å (default: 1.8)')\n",
    "#     parser.add_argument('--tolerance', type=float, default=0.5, help='Bond length tolerance (default: 0.5)')\n",
    "#     parser.add_argument('--timestep_fs', type=float, default=1.0, help='Timestep in femtoseconds (default: 1.0)')\n",
    "#     return parser.parse_args()\n",
    "\n",
    "def analyze_bond_lengths(ref_file, bond_threshold):\n",
    "    # Read all molecules from the reference file\n",
    "    molecules = read(ref_file, index=':', format='xyz')\n",
    "    \n",
    "    # Dictionary to store bond lengths by bond type\n",
    "    bond_lengths = defaultdict(list)\n",
    "    # List to store edge indices (atom index pairs for bonds)\n",
    "    edge_indices = []\n",
    "    \n",
    "    bond_index = {}\n",
    "    for mol_idx, mol in enumerate(molecules):\n",
    "        # Get atomic symbols and positions\n",
    "        symbols = mol.get_chemical_symbols()\n",
    "        positions = mol.get_positions()\n",
    "        \n",
    "        # Analyze bonds\n",
    "        for i in range(len(mol)):\n",
    "            for j in range(i + 1, len(mol)):\n",
    "                # Calculate distance between atoms i and j\n",
    "                dist = np.linalg.norm(positions[i] - positions[j])\n",
    "                \n",
    "                # Check if distance is within bonding threshold\n",
    "                if dist < bond_threshold:\n",
    "                    # Store edge index: (molecule_index, atom_i, atom_j)\n",
    "                    edge_indices.append((mol_idx, i, j))\n",
    "                    # Store bond length with bond type\n",
    "                    bond_type = f\"{i}-{j}\"\n",
    "                    bond_lengths[bond_type].append(dist)\n",
    "    \n",
    "    # Calculate average bond lengths for consistent bonds\n",
    "    mol_len = len(molecules)\n",
    "    true_bonds = {}\n",
    "    true_lengths = []\n",
    "    edge_index = []\n",
    "    for bond_type, lengths in bond_lengths.items():\n",
    "        if len(lengths) == mol_len:\n",
    "            a, b = map(int, bond_type.split('-'))\n",
    "            true_bonds[bond_type] = lengths\n",
    "            edge_index.append([a, b])\n",
    "            true_lengths.append(np.mean(lengths))\n",
    "    \n",
    "    return np.array(edge_index), true_lengths\n",
    "\n",
    "def check_edge_lengths(reference_lengths, calculated_lengths, tolerance):\n",
    "    if len(reference_lengths) != len(calculated_lengths):\n",
    "        raise ValueError(\"Reference and calculated lengths lists must have the same length\")\n",
    "    \n",
    "    for ref, calc in zip(reference_lengths, calculated_lengths):\n",
    "        if not (ref - tolerance <= calc <= ref + tolerance):\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "def analyze_trajectory(traj_file, edge_index, true_lengths, tolerance, timestep_fs):\n",
    "    # Read trajectory\n",
    "    traj = read(traj_file, index=':')\n",
    "    \n",
    "    stable_frames = len(traj)\n",
    "    for k, atoms in enumerate(tqdm(traj)):\n",
    "        pos = atoms.get_positions()\n",
    "        p1 = pos[edge_index[:, 0]]  # Start points\n",
    "        p2 = pos[edge_index[:, 1]]  # End points\n",
    "        edges = np.sqrt(np.sum((p2 - p1) ** 2, axis=1))\n",
    "        result = check_edge_lengths(true_lengths, edges, tolerance)\n",
    "        if not result:\n",
    "            stable_frames = k\n",
    "            break\n",
    "    \n",
    "    # Calculate stable time\n",
    "    stable_time_ps = (stable_frames - 1) * timestep_fs / 1000.0 if stable_frames > 0 else 0.0\n",
    "    stable_time_fs = (stable_frames - 1) * timestep_fs if stable_frames > 0 else 0.0\n",
    "    \n",
    "    return stable_time_ps, stable_time_fs\n",
    "\n",
    "\n",
    "# args = parse_arguments()\n",
    "\n",
    "traj_file = 'real_md/sin/md_nh.traj'\n",
    "ref_file = 'datasets/semi/SiN/OOD.xyz'\n",
    "# Analyze reference file for bond lengths\n",
    "edge_index, true_lengths = analyze_bond_lengths(ref_file, 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "9b69c8ab-b281-4c51-b658-88ca916c7151",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T15:21:44.896772Z",
     "iopub.status.busy": "2025-09-24T15:21:44.895138Z",
     "iopub.status.idle": "2025-09-24T15:21:45.646280Z",
     "shell.execute_reply": "2025-09-24T15:21:45.644936Z",
     "shell.execute_reply.started": "2025-09-24T15:21:44.896694Z"
    }
   },
   "outputs": [],
   "source": [
    "molecules = read(ref_file, index=':', format='extxyz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "db93f801-e689-4e09-9ecd-3fae8eff61ec",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T15:21:46.420825Z",
     "iopub.status.busy": "2025-09-24T15:21:46.420532Z",
     "iopub.status.idle": "2025-09-24T15:21:46.427823Z",
     "shell.execute_reply": "2025-09-24T15:21:46.426888Z",
     "shell.execute_reply.started": "2025-09-24T15:21:46.420801Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Atoms(symbols='N60Si45', pbc=True, cell=[[10.56110287, 0.0, 0.0], [0.35189969, 10.09798385, 0.0], [0.09579951, 0.81674066, 10.95030383]], calculator=SinglePointCalculator(...))"
      ]
     },
     "execution_count": 96,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "molecules[0]."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "655f9c62-644c-4be4-bbd8-3be658696dcd",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T15:18:28.779737Z",
     "iopub.status.busy": "2025-09-24T15:18:28.778837Z",
     "iopub.status.idle": "2025-09-24T15:18:51.858751Z",
     "shell.execute_reply": "2025-09-24T15:18:51.857144Z",
     "shell.execute_reply.started": "2025-09-24T15:18:28.779672Z"
    }
   },
   "outputs": [],
   "source": [
    "# Dictionary to store bond lengths by bond type\n",
    "bond_lengths = defaultdict(list)\n",
    "# List to store edge indices (atom index pairs for bonds)\n",
    "edge_indices = []\n",
    "bond_threshold = 1.8\n",
    "bond_index = {}\n",
    "for mol_idx, mol in enumerate(molecules):\n",
    "    # Get atomic symbols and positions\n",
    "    symbols = mol.get_chemical_symbols()\n",
    "    positions = mol.get_positions()\n",
    "    \n",
    "    # Analyze bonds\n",
    "    for i in range(len(mol)):\n",
    "        for j in range(i + 1, len(mol)):\n",
    "            # Calculate distance between atoms i and j\n",
    "            dist = np.linalg.norm(positions[i] - positions[j])\n",
    "            \n",
    "            # Check if distance is within bonding threshold\n",
    "            if dist < bond_threshold:\n",
    "                # Store edge index: (molecule_index, atom_i, atom_j)\n",
    "                edge_indices.append((mol_idx, i, j))\n",
    "                # Store bond length with bond type\n",
    "                bond_type = f\"{i}-{j}\"\n",
    "                bond_lengths[bond_type].append(dist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "d703fb3f-eb91-497f-906a-62a438d71949",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T15:42:50.608357Z",
     "iopub.status.busy": "2025-09-24T15:42:50.607862Z",
     "iopub.status.idle": "2025-09-24T15:43:40.894907Z",
     "shell.execute_reply": "2025-09-24T15:43:40.893843Z",
     "shell.execute_reply.started": "2025-09-24T15:42:50.608307Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def compute_stability_criterion(list_of_atoms):\n",
    "    \"\"\"\n",
    "    Calculate the stability criterion (maximum RDF deviation in the reference dataset)\n",
    "    for systems with periodic boundary conditions. This can be used to set the threshold theta\n",
    "    for detecting unstable simulations, e.g., theta = max_deviation (or with a relaxed margin).\n",
    "\n",
    "    Input:\n",
    "        list_of_atoms: list of ASE Atoms objects representing the reference trajectory frames.\n",
    "    \n",
    "    Output:\n",
    "        max_deviation: the maximum integral |RDF_frame - <RDF>| dr over all frames.\n",
    "    \"\"\"\n",
    "    if not list_of_atoms:\n",
    "        raise ValueError(\"Empty list of atoms\")\n",
    "\n",
    "    # Assume all frames have the same cell and PBC\n",
    "    atoms0 = list_of_atoms[0]\n",
    "    if not atoms0.pbc.all():\n",
    "        raise ValueError(\"System must have periodic boundary conditions\")\n",
    "\n",
    "    # Determine r_max as half the minimum cell length\n",
    "    cell_lengths = np.linalg.norm(atoms0.cell, axis=1)\n",
    "    r_max = np.min(cell_lengths) / 2.0\n",
    "\n",
    "    # Binning parameters (bins=200 is arbitrary; adjust based on system if needed)\n",
    "    bins = 200\n",
    "    bin_edges = np.linspace(0, r_max, bins + 1)\n",
    "    dr = bin_edges[1] - bin_edges[0]\n",
    "    r = bin_edges[:-1] + dr / 2.0\n",
    "\n",
    "    rdf_list = []\n",
    "    for atoms in list_of_atoms:\n",
    "        # Get pairwise distances using minimum image convention\n",
    "        dists = atoms.get_all_distances(mic=True)\n",
    "        # Upper triangle to count undirected pairs\n",
    "        indices = np.triu_indices(len(atoms), k=1)\n",
    "        d = dists[indices]\n",
    "\n",
    "        # Histogram\n",
    "        hist, _ = np.histogram(d, bins=bin_edges)\n",
    "\n",
    "        # RDF calculation\n",
    "        N = len(atoms)\n",
    "        V = atoms.get_volume()\n",
    "        rho = N / V\n",
    "        rdf = 2 * hist / (N * rho * 4 * np.pi * r**2 * dr)\n",
    "\n",
    "        rdf_list.append(rdf)\n",
    "\n",
    "    if not rdf_list:\n",
    "        raise ValueError(\"No RDFs computed\")\n",
    "\n",
    "    # Average RDF\n",
    "    rdf_ref = np.mean(rdf_list, axis=0)\n",
    "\n",
    "    # Compute deviations\n",
    "    deviations = []\n",
    "    for rdf in rdf_list:\n",
    "        dev = np.sum(np.abs(rdf - rdf_ref)) * dr\n",
    "        deviations.append(dev)\n",
    "\n",
    "    max_dev = np.max(deviations)\n",
    "\n",
    "    return max_dev\n",
    "\n",
    "max_dev = compute_stability_criterion(molecules)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "4806df88-0549-4ca4-8fb9-0232b46fecce",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-24T15:47:16.086786Z",
     "iopub.status.busy": "2025-09-24T15:47:16.086498Z",
     "iopub.status.idle": "2025-09-24T15:48:23.865883Z",
     "shell.execute_reply": "2025-09-24T15:48:23.865046Z",
     "shell.execute_reply.started": "2025-09-24T15:47:16.086763Z"
    }
   },
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "'return' outside function (1222234520.py, line 53)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  Cell \u001b[0;32mIn[99], line 53\u001b[0;36m\u001b[0m\n\u001b[0;31m    return max_dev\u001b[0m\n\u001b[0m    ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m 'return' outside function\n"
     ]
    }
   ],
   "source": [
    "list_of_atoms = molecules\n",
    "if not list_of_atoms:\n",
    "    raise ValueError(\"Empty list of atoms\")\n",
    "\n",
    "# Assume all frames have the same cell and PBC\n",
    "atoms0 = list_of_atoms[0]\n",
    "if not atoms0.pbc.all():\n",
    "    raise ValueError(\"System must have periodic boundary conditions\")\n",
    "\n",
    "# Determine r_max as half the minimum cell length\n",
    "cell_lengths = np.linalg.norm(atoms0.cell, axis=1)\n",
    "r_max = np.min(cell_lengths) / 2.0\n",
    "\n",
    "# Binning parameters (bins=200 is arbitrary; adjust based on system if needed)\n",
    "bins = 200\n",
    "bin_edges = np.linspace(0, r_max, bins + 1)\n",
    "dr = bin_edges[1] - bin_edges[0]\n",
    "r = bin_edges[:-1] + dr / 2.0\n",
    "\n",
    "rdf_list = []\n",
    "for atoms in list_of_atoms:\n",
    "    # Get pairwise distances using minimum image convention\n",
    "    dists = atoms.get_all_distances(mic=True)\n",
    "    # Upper triangle to count undirected pairs\n",
    "    indices = np.triu_indices(len(atoms), k=1)\n",
    "    d = dists[indices]\n",
    "\n",
    "    # Histogram\n",
    "    hist, _ = np.histogram(d, bins=bin_edges)\n",
    "\n",
    "    # RDF calculation\n",
    "    N = len(atoms)\n",
    "    V = atoms.get_volume()\n",
    "    rho = N / V\n",
    "    rdf = 2 * hist / (N * rho * 4 * np.pi * r**2 * dr)\n",
    "\n",
    "    rdf_list.append(rdf)\n",
    "\n",
    "if not rdf_list:\n",
    "    raise ValueError(\"No RDFs computed\")\n",
    "\n",
    "# Average RDF\n",
    "rdf_ref = np.mean(rdf_list, axis=0)\n",
    "\n",
    "# Compute deviations\n",
    "deviations = []\n",
    "for rdf in rdf_list:\n",
    "    dev = np.sum(np.abs(rdf - rdf_ref)) * dr\n",
    "    deviations.append(dev)\n",
    "\n",
    "max_dev = np.max(deviations)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.22"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
