{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "from collections import defaultdict\n",
    "from pathlib import Path\n",
    "\n",
    "import os\n",
    "import json\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "from ase.neighborlist import natural_cutoffs, NeighborList\n",
    "from ase.io import read, Trajectory\n",
    "\n",
    "# optional. nglview for visualization\n",
    "import nglview as nv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "functions for loading simulated trajectories and computing observables.\n",
    "\"\"\"\n",
    "\n",
    "def get_thermo(filename):\n",
    "    \"\"\"\n",
    "    read thermo logs.\n",
    "    \"\"\"\n",
    "    with open(filename, 'r') as f:\n",
    "        thermo = f.read().splitlines()\n",
    "        sim_time, Et, Ep, Ek, T = [], [], [], [], []\n",
    "        for i in range(1, len(thermo)):\n",
    "            try:\n",
    "                t, Etot, Epot, Ekin, Temp = [float(x) for x in thermo[i].split(' ') if x]\n",
    "                sim_time.append(t)\n",
    "                Et.append(Etot)\n",
    "                Ep.append(Epot)\n",
    "                Ek.append(Ekin)\n",
    "                T.append(Temp)\n",
    "            except:\n",
    "                sim_time, Et, Ep, Ek, T = [], [], [], [], []\n",
    "    thermo = {\n",
    "        'time': sim_time,\n",
    "        'Et': Et,\n",
    "        'Ep': Ep,\n",
    "        'Ek': Ek,\n",
    "        'T': T\n",
    "    }\n",
    "    return thermo\n",
    "\n",
    "def get_test_metrics(md_dir):\n",
    "    \"\"\"\n",
    "    read test metrics such as force error.\n",
    "    \"\"\"\n",
    "    run_metrics = {}\n",
    "    with open(md_dir / 'test_metric.json', 'r') as f:\n",
    "        test_metric = json.load(f)\n",
    "        if 'mae_f' in test_metric:\n",
    "            fmae = test_metric['mae_f']\n",
    "            emae = test_metric['mae_e']\n",
    "            run_metrics['fmae'] = fmae\n",
    "            run_metrics['emae'] = emae\n",
    "        elif 'f_mae' in test_metric:\n",
    "            fmae = test_metric['f_mae']\n",
    "            emae = test_metric['e_mae']\n",
    "            run_metrics['fmae'] = fmae\n",
    "            run_metrics['emae'] = emae\n",
    "        elif 'forces_mae' in test_metric:\n",
    "            fmae = test_metric['forces_mae']['metric']\n",
    "            emae = test_metric['energy_mae']['metric']\n",
    "            run_metrics['fmae'] = fmae\n",
    "            run_metrics['emae'] = emae\n",
    "        if 'num_params' in test_metric:\n",
    "            run_metrics['n_params'] = test_metric['num_params']\n",
    "        if 'running_time' in test_metric:\n",
    "            run_metrics['running_time'] = test_metric['running_time']\n",
    "    return run_metrics\n",
    "\n",
    "def mae(x, y, factor):\n",
    "    return np.abs(x-y).mean() * factor\n",
    "\n",
    "def distance_pbc(x0, x1, lattices):\n",
    "    delta = torch.abs(x0 - x1)\n",
    "    lattices = lattices.view(-1,1,3)\n",
    "    delta = torch.where(delta > 0.5 * lattices, delta - lattices, delta)\n",
    "    return torch.sqrt((delta ** 2).sum(dim=-1))\n",
    "\n",
    "def get_diffusivity_traj(pos_seq, dilation=1):\n",
    "    \"\"\"\n",
    "    Input: B x N x T x 3\n",
    "    Output: B x T\n",
    "    \"\"\"\n",
    "    # substract CoM\n",
    "    bsize, time_steps = pos_seq.shape[0], pos_seq.shape[2]\n",
    "    pos_seq = pos_seq - pos_seq.mean(1, keepdims=True)\n",
    "    msd = (pos_seq[:, :, 1:] - pos_seq[:, :, 0].unsqueeze(2)).pow(2).sum(dim=-1).mean(dim=1)\n",
    "    diff = msd / (torch.arange(1, time_steps)*dilation) / 6\n",
    "    return diff.view(bsize, time_steps-1)\n",
    "\n",
    "def get_smoothed_diff(xyz):\n",
    "    seq_len = xyz.shape[0] - 1\n",
    "    diff = torch.zeros(seq_len)\n",
    "    for i in range(seq_len):\n",
    "        diff[:seq_len-i] += get_diffusivity_traj(xyz[i:].transpose(0, 1).unsqueeze(0)).flatten()\n",
    "    diff = diff / torch.flip(torch.arange(seq_len),dims=[0])\n",
    "    return diff"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MD17"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hr(traj, bins):\n",
    "    \"\"\"\n",
    "    compute h(r) for MD17 simulations.\n",
    "    traj: T x N_atoms x 3\n",
    "    \"\"\"\n",
    "    pdist = torch.cdist(traj, traj).flatten()\n",
    "    hist, _ = np.histogram(pdist[:].flatten().numpy(), bins, density=True)\n",
    "    return hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_run(md_dir, xlim, bins, stability_threshold, gt_traj, hist_gt):\n",
    "    \"\"\"\n",
    "    md_dir: directory to the finished MD simulation.\n",
    "    \"\"\"\n",
    "    if not isinstance(md_dir, Path):\n",
    "        md_dir = Path(md_dir)\n",
    "    \n",
    "    model_name = md_dir.parts[-2]\n",
    "    seed = md_dir.parts[-1][-1]\n",
    "    run = {'name': (model_name + f'_seed_{seed}'),}\n",
    "    \n",
    "    # get bonds\n",
    "    traj = Trajectory(md_dir / 'atoms.traj')\n",
    "    atoms = traj[0]\n",
    "    NL = NeighborList(natural_cutoffs(atoms), self_interaction=False)\n",
    "    NL.update(atoms)\n",
    "    bonds = NL.get_connectivity_matrix().todense().nonzero()\n",
    "    bonds = torch.tensor(bonds)\n",
    "    \n",
    "    # process trajectory\n",
    "    traj = [x.positions for x in traj]\n",
    "    run['traj'] = torch.from_numpy(np.stack(traj))\n",
    "    run['traj'] = torch.unique(run['traj'], dim=0) # remove repeated frames from restarting.\n",
    "\n",
    "    # load thermo log\n",
    "    run['thermo'] = get_thermo(md_dir / 'thermo.log')\n",
    "    T = np.array(run['thermo']['T']) \n",
    "    collapse_pt = len(T)\n",
    "    md_time = np.array(run['thermo']['time'])\n",
    "    \n",
    "    # track stability\n",
    "    bond_lens = distance_pbc(\n",
    "        gt_traj[:, bonds[0]], gt_traj[:, bonds[1]], torch.FloatTensor([30., 30., 30.]))\n",
    "    mean_bond_lens = bond_lens.mean(dim=0)\n",
    "    \n",
    "    for i in range(1, len(T)):\n",
    "        bond_lens = distance_pbc(\n",
    "            run['traj'][(i-1):i, bonds[0]], run['traj'][(i-1):i, bonds[1]], torch.FloatTensor([30., 30., 30.]))\n",
    "        max_dev = (bond_lens[0] - mean_bond_lens).abs().max()\n",
    "        if  max_dev > stability_threshold:\n",
    "            collapse_pt = i\n",
    "            break\n",
    "    run['collapse_pt'] = collapse_pt\n",
    "    \n",
    "    # compute h(r)\n",
    "    hist_pred = get_hr(run['traj'][0:collapse_pt], bins)\n",
    "    hr_mae = mae(hist_pred, hist_gt, xlim)\n",
    "    run['hr'] = hist_pred\n",
    "    run['hr_error'] = hr_mae\n",
    "    \n",
    "    # load test metrics\n",
    "    if (md_dir / 'test_metric.json').exists():\n",
    "        test_metrics = get_test_metrics(md_dir)\n",
    "        run.update(test_metrics)\n",
    "    \n",
    "    return run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# eval parameters\n",
    "stability_threshold = 0.5\n",
    "xlim = 10\n",
    "n_bins = 500\n",
    "bins = np.linspace(1e-6, xlim, n_bins + 1) # for computing h(r)\n",
    "\n",
    "# select molecule and get ground truth data\n",
    "molecule = 'aspirin'\n",
    "DATAPATH = f'DATAPATH/md17/{molecule}_dft.npz'\n",
    "gt_data = np.load(DATAPATH)\n",
    "gt_traj = torch.FloatTensor(gt_data.f.R)\n",
    "hist_gt = get_hr(gt_traj, bins)\n",
    "\n",
    "\n",
    "# load run and plot h(r)\n",
    "md_dir = Path('./example_sim/aspirin_dimenet')\n",
    "run = load_run(md_dir, xlim, bins, stability_threshold, gt_traj, hist_gt)\n",
    "plt.plot(bins[1:], hist_gt, label='Reference', linewidth=2, linestyle='--')\n",
    "plt.plot(bins[1:], run['hr'], label='Prediction', linewidth=2, linestyle='--')\n",
    "plt.xlabel('r')\n",
    "plt.ylabel('h(r)')\n",
    "plt.legend()\n",
    "\n",
    "# metrics\n",
    "force_mae = run['fmae'] * 1000\n",
    "collapse_ps = (run['collapse_pt']-1) / 20\n",
    "hr_mae = run['hr_error']\n",
    "print(f'force mae: {force_mae:.1f} meV/A \\nstability: {collapse_ps:.1f} ps \\nh(r) mae: {hr_mae:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualization\n",
    "nv.show_asetraj(Trajectory(md_dir / 'atoms.traj'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Water"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xlim = 6\n",
    "nbins = 1000\n",
    "bins = np.linspace(1e-6, xlim, nbins+1)\n",
    "stability_threshold = 3.0\n",
    "diffusivity_cutoff = 3000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def distance_pbc_select(x, lattices, indices0, indices1):\n",
    "    x0 = x[:, indices0]\n",
    "    x1 = x[:, indices1]\n",
    "    x0_size = x0.shape[1]\n",
    "    x1_size = x1.shape[1]\n",
    "    x0 = x0.repeat([1, x1_size, 1])\n",
    "    x1 = x1.repeat_interleave(x0_size, dim=1)\n",
    "    delta = torch.abs(x0 - x1)\n",
    "    delta = torch.where(delta > 0.5 * lattices, delta - lattices, delta)\n",
    "    return torch.sqrt((delta ** 2).sum(axis=-1))\n",
    "\n",
    "def get_water_rdfs(data_seq, ptypes, lattices, bins, device='cpu'):\n",
    "    \"\"\"\n",
    "    get atom-type conditioned water RDF curves.\n",
    "    \"\"\"\n",
    "    data_seq = data_seq.to(device).float()\n",
    "    lattices = lattices.to(device).float()\n",
    "    \n",
    "    type2indices = {\n",
    "        'H': ptypes == 1,\n",
    "        'O': ptypes == 8\n",
    "    }\n",
    "    pairs = [('O', 'O'), ('H', 'H'), ('H', 'O')]\n",
    "    \n",
    "    data_seq = ((data_seq / lattices) % 1) * lattices\n",
    "    all_rdfs = {}\n",
    "    n_rdfs = 3\n",
    "    for idx in range(n_rdfs):\n",
    "        type1, type2 = pairs[idx]    \n",
    "        indices0 = type2indices[type1].to(device)\n",
    "        indices1 = type2indices[type2].to(device)\n",
    "        data_pdist = distance_pbc_select(data_seq, lattices, indices0, indices1)\n",
    "        \n",
    "        data_pdist = data_pdist.flatten().cpu().numpy()\n",
    "        data_shape = data_pdist.shape[0]\n",
    "            \n",
    "        data_pdist = data_pdist[data_pdist != 0]\n",
    "        data_hist, _ = np.histogram(data_pdist, bins)\n",
    "        rho_data = data_shape / torch.prod(lattices).cpu().numpy() \n",
    "        Z_data = rho_data * 4 / 3 * np.pi * (bins[1:] ** 3 - bins[:-1] ** 3)\n",
    "        data_rdf = data_hist / Z_data\n",
    "        all_rdfs[type1 + type2] = data_rdf\n",
    "        \n",
    "    return all_rdfs\n",
    "\n",
    "def load_run(md_dir, atom_types, xlim, bins, stability_threshold, gt_rdfs, gt_diff):\n",
    "    if not isinstance(md_dir, Path):\n",
    "        md_dir = Path(md_dir)\n",
    "        \n",
    "    model_name = md_dir.parts[-2]\n",
    "    seed = md_dir.parts[-1][-1]\n",
    "    run = {'name': (model_name + f'_seed_{seed}'),}\n",
    "    \n",
    "    # process trajectory\n",
    "    traj = [x.positions for x in Trajectory(md_dir / 'atoms.traj')]\n",
    "    run['traj'] = torch.from_numpy(np.stack(traj))\n",
    "    run['traj'] = torch.unique(run['traj'], dim=0) # remove repeated frames from restarting.\n",
    "    \n",
    "    # load thermo log\n",
    "    run['thermo'] = get_thermo(md_dir / 'thermo.log')\n",
    "    md_time = np.array(run['thermo']['time'])\n",
    "    T = np.array(run['thermo']['T']) \n",
    "    collapse_pt = len(T)\n",
    "    for i in (range(1, len(T)-rdf_check_interval)):\n",
    "        timerange = torch.arange(i, i + rdf_check_interval)\n",
    "        current_rdf = get_water_rdfs(run['traj'][timerange], atom_types, lattices, bins)\n",
    "        rdf_mae_oo = mae(current_rdf['OO'], gt_rdfs['OO'], xlim)\n",
    "        rdf_mae_ho = mae(current_rdf['HO'], gt_rdfs['HO'], xlim)\n",
    "        rdf_mae_hh = mae(current_rdf['HH'], gt_rdfs['HH'], xlim)\n",
    "        if max([rdf_mae_oo, rdf_mae_ho, rdf_mae_hh]) > stability_threshold:\n",
    "            collapse_pt = i\n",
    "            break\n",
    "\n",
    "    run['collapse_pt'] = collapse_pt        \n",
    "\n",
    "    # at least 100 ps for computing diffusivity.\n",
    "    if collapse_pt >= 1000:\n",
    "        run['diffusivity'] = get_smoothed_diff(\n",
    "            run['traj'][:collapse_pt:10, atom_types == 8])[:100]\n",
    "        run['diff_error'] = float((run['diffusivity'][-1] - gt_diff[-1]).abs())\n",
    "        run['end_diff'] = float(run['diffusivity'][-1])\n",
    "    else:\n",
    "        run['diffusivity'] = None\n",
    "        run['diff_error'] = np.inf\n",
    "        run['end_diff'] = np.inf\n",
    "        \n",
    "    # at least 1 ps for computing RDFs.\n",
    "    if collapse_pt >= 10:\n",
    "        run['rdf'] = get_water_rdfs(run['traj'][:collapse_pt], atom_types, lattices, bins)\n",
    "        run['rdf_error'] = [mae(run['rdf'][k], gt_rdfs[k], xlim) for k in ['OO', 'HH', 'HO']]\n",
    "    else:\n",
    "        run['rdf'] = None\n",
    "        run['rdf_error'] = [np.inf] * 3\n",
    "        \n",
    "    # load test metrics\n",
    "    if (md_dir / 'test_metric.json').exists():\n",
    "        test_metrics = get_test_metrics(md_dir)\n",
    "        run.update(test_metrics)\n",
    "        \n",
    "    return run\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# eval parameters\n",
    "stability_threshold = 3.0\n",
    "rdf_check_interval = 10 # 1 ps\n",
    "xlim = 6\n",
    "n_bins = 500\n",
    "bins = np.linspace(1e-6, xlim, n_bins + 1) # for computing RDF\n",
    "\n",
    "# get ground truth data\n",
    "DATAPATH = 'DATAPATH/mdbench_data/flexwater/flexwater.npy'\n",
    "gt_data = np.load(DATAPATH, allow_pickle=True).item()\n",
    "atom_types = torch.tensor(gt_data['atom_types'])\n",
    "lattices = torch.tensor(gt_data['lengths'][0]).float()\n",
    "gt_traj = torch.tensor(gt_data['unwrapped_coords'])\n",
    "gt_diff = get_smoothed_diff(gt_traj[0::100, atom_types==8])[:100] # track diffusivity of oxygen atoms, unit is A^2/ps\n",
    "gt_rdfs = get_water_rdfs(gt_traj[::10], atom_types, lattices, bins) # match the recording frequency of 0.1 ps\n",
    "\n",
    "# load run and plot RDFs\n",
    "md_dir = Path('./example_sim/water-1k_schnet')\n",
    "run = load_run(md_dir, atom_types, xlim, bins, stability_threshold, gt_rdfs, gt_diff)\n",
    "\n",
    "plt.subplots_adjust()\n",
    "plt.rc('xtick', labelsize=16)\n",
    "plt.rc('ytick', labelsize=16) \n",
    "plt.rc('legend', fontsize=24)\n",
    "plt.rc('figure', titlesize=24)\n",
    "plt.rc('axes', titlesize=24)\n",
    "plt.rc('axes', labelsize=24)\n",
    "fig, axs = plt.subplots(1, 3)\n",
    "fig.set_size_inches(18, 4)\n",
    "fig.tight_layout(h_pad=3, w_pad=1)\n",
    "\n",
    "for i, elem in enumerate(['OO', 'HH', 'HO']):\n",
    "    axs[i].plot(bins[:-1], gt_rdfs[elem], label='Reference', linewidth=3, linestyle='--')\n",
    "    axs[i].plot(bins[:-1], run['rdf'][elem], label='Prediction', linewidth=3, linestyle='--')\n",
    "    axs[i].set(title=f'RDF {elem}', xlabel='r')\n",
    "    axs[i].legend()\n",
    "axs[0].set_ylabel('RDF(r)')\n",
    "\n",
    "# metrics\n",
    "force_mae = run['fmae'] * 1000\n",
    "collapse_ps = (run['collapse_pt']-1) / 10\n",
    "rdf_oo_mae = run['rdf_error'][0]\n",
    "rdf_hh_mae = run['rdf_error'][1]\n",
    "rdf_ho_mae = run['rdf_error'][2]\n",
    "diff_mae = run['diff_error'] * 10 # A^2/ps -> 10^-9 m^2/s \n",
    "print(f'force mae: {force_mae:.1f} meV/A \\nstability: {collapse_ps:.1f} ps \\nRDF (O,O) mae: {rdf_oo_mae:.2f}' + \n",
    "     f'\\nRDF (H,H) mae: {rdf_hh_mae:.2f} \\nRDF (H,O) mae: {rdf_ho_mae:.2f} \\nDiffusivity mae: {diff_mae:.2f} x 10^-9 m^2/s ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualization (PBC is not incorporated)\n",
    "nv.show_asetraj(Trajectory(md_dir / 'atoms.traj'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Alanine dipeptide"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kBT=2.49 # constant at 300 K\n",
    "\n",
    "def os_obtain_fes(md_dir):\n",
    "    \"\"\"\n",
    "    run a plumed command to get the FES from simulation outputs.\n",
    "    \"\"\"\n",
    "    cmd = os.getcwd()\n",
    "    os.chdir(f'{str(md_dir)}')\n",
    "    os.system('plumed sum_hills --hills HILLS --mintozero --bin 60,60')\n",
    "    os.chdir(cmd)\n",
    "\n",
    "def load_fes(filename):\n",
    "    \"\"\"\n",
    "    load FES generated from plumed.\n",
    "    \"\"\"\n",
    "    X = np.loadtxt(filename, skiprows=9)\n",
    "    xg = X[:,0].reshape((60, 60))\n",
    "    yg = X[:,1].reshape((60, 60))\n",
    "    energies = zg = X[:,2].reshape((60, 60))\n",
    "    return xg, yg, zg\n",
    "\n",
    "def plot_fes(xg, yg, zg):\n",
    "    energies = zg\n",
    "    v = np.arange(-9, +0.5, 0.5) # Contours to plot.\n",
    "    fig = plt.figure(figsize=(10, 8.2))\n",
    "    energies = (energies - np.min(energies))\n",
    "    cmap=plt.cm.magma\n",
    "    cmap.set_bad(color='white')\n",
    "    img = plt.imshow(energies/kBT, interpolation='nearest', cmap = cmap)\n",
    "    plt.gca().invert_yaxis()\n",
    "    cb = plt.colorbar()\n",
    "    cb.set_label(\"G [kBT]\")\n",
    "    plt.xlabel(\"$\\phi$\")\n",
    "    plt.ylabel(\"$\\psi$\")\n",
    "    plt.xlim((-np.pi, np.pi))\n",
    "    plt.ylim((-np.pi, np.pi))\n",
    "    plt.axis(\"equal\")\n",
    "    cb.ax.set_title(r'$\\tilde{F}/k_BT$')\n",
    "    return xg, yg, zg\n",
    "\n",
    "def get_pmf(zg):\n",
    "    pmf = np.exp(-zg/kBT)\n",
    "    pmf = pmf/ pmf.sum()\n",
    "    X = pmf.sum(axis=0)\n",
    "    Y = pmf.sum(axis=1)\n",
    "    pmf_x = -kBT * np.log(X)\n",
    "    pmf_y = -kBT * np.log(Y)\n",
    "    pmf_x -= pmf_x.min()\n",
    "    pmf_y -= pmf_y.min()\n",
    "    return pmf_x, pmf_y\n",
    "\n",
    "def load_run(md_dir, gt_traj, stability_threshold):\n",
    "    if not isinstance(md_dir, Path):\n",
    "        md_dir = Path(md_dir)\n",
    "        \n",
    "    model_name = md_dir.parts[-2]\n",
    "    seed = md_dir.parts[-1][-1]\n",
    "    run = {'name': (model_name + f'_seed_{seed}'),}\n",
    "    \n",
    "    # get bonds\n",
    "    traj = Trajectory(md_dir / 'atoms.traj')\n",
    "    atoms = traj[0]\n",
    "    lattices = torch.from_numpy(atoms.cell.diagonal()).float()\n",
    "    NL = NeighborList(natural_cutoffs(atoms), self_interaction=False)\n",
    "    NL.update(atoms)\n",
    "    bonds = NL.get_connectivity_matrix().todense().nonzero()\n",
    "    bonds = torch.tensor(bonds)\n",
    "    \n",
    "    # process trajectory\n",
    "    traj = [x.positions for x in traj]\n",
    "    run['traj'] = torch.from_numpy(np.stack(traj))\n",
    "    run['traj'] = torch.unique(run['traj'], dim=0) # remove repeated frames from restarting.\n",
    "\n",
    "    # track stability\n",
    "    bond_lens = distance_pbc(gt_traj[:, bonds[0]], gt_traj[:, bonds[1]], lattices)\n",
    "    mean_bond_lens = bond_lens.mean(dim=0)\n",
    "    run['thermo'] = get_thermo(md_dir / 'thermo.log')\n",
    "    T = np.array(run['thermo']['T']) \n",
    "    collapse_pt = len(T)\n",
    "    md_time = np.array(run['thermo']['time'])\n",
    "    for i in range(1, len(T)):\n",
    "        bond_lens = distance_pbc(\n",
    "            run['traj'][(i-1):i, bonds[0]], run['traj'][(i-1):i, bonds[1]], lattices)\n",
    "        max_dev = (bond_lens[0] - mean_bond_lens).abs().max()\n",
    "        if  max_dev > stability_threshold:\n",
    "            collapse_pt = i\n",
    "            break\n",
    "\n",
    "    run['collapse_pt'] = collapse_pt\n",
    "    if (md_dir / 'fes.dat').exists():\n",
    "        _, _, zg = load_fes(md_dir / 'fes.dat')\n",
    "        run['fes'] = zg\n",
    "    else:\n",
    "        run['fes'] = None\n",
    "\n",
    "    # load test metrics\n",
    "    if (md_dir / 'test_metric.json').exists():\n",
    "        test_metrics = get_test_metrics(md_dir)\n",
    "        run.update(test_metrics)\n",
    "    \n",
    "    return run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stability_threshold = 0.5\n",
    "\n",
    "DATAPATH = 'DATAPATH/mdbench_data/ala_src/ala.npy'\n",
    "gt_data = np.load(DATAPATH, allow_pickle=True).item()\n",
    "gt_traj = torch.from_numpy(gt_data['pos'])\n",
    "_, _, gt_fes = load_fes('./alanine_dipeptide_files/fes.dat')\n",
    "pmf_phi, pmf_psi = get_pmf(gt_fes)\n",
    "\n",
    "md_dir = Path('./example_sim/ala_nequip')\n",
    "os_obtain_fes(md_dir)\n",
    "run = load_run(md_dir, gt_traj, stability_threshold)\n",
    "pmf_phi_pred, pmf_psi_pred = get_pmf(run['fes'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# metrics\n",
    "force_mae = run['fmae'] * 1000\n",
    "phi_mae = mae(pmf_phi, pmf_phi_pred, 6.28)\n",
    "psi_mae = mae(pmf_psi, pmf_psi_pred, 6.28)\n",
    "print(f'force mae: {force_mae:.1f} meV/A \\nphi mae: {phi_mae:.1f} \\npsi mae: {psi_mae:.1f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot PMF (one-dim FES)\n",
    "plt.rc('xtick', labelsize=22)\n",
    "plt.rc('ytick', labelsize=22) \n",
    "plt.rc('legend', fontsize=22)\n",
    "plt.rc('figure', titlesize=30)\n",
    "plt.rc('axes', titlesize=24)\n",
    "plt.rc('axes', labelsize=24)\n",
    "\n",
    "x_axis = np.linspace(-180, 180, 60)\n",
    "\n",
    "plt.subplots_adjust()\n",
    "fig, axs = plt.subplots(1, 2)\n",
    "fig.set_size_inches(11, 4.5)\n",
    "fig.tight_layout(pad=0)\n",
    "\n",
    "ax = axs[0]\n",
    "ax.plot(x_axis, pmf_phi, label='Reference', linewidth=4)\n",
    "ax.set(xlabel=r'$\\phi$ [degree]', ylabel='F [kJ/mol]', xticks=[-180, -90, 0, 90, 180])\n",
    "ax.plot(x_axis, pmf_phi_pred, label='Prediction',linewidth=4)\n",
    "lines, labels = ax.get_legend_handles_labels()\n",
    "ax.legend()\n",
    "\n",
    "ax = axs[1]\n",
    "ax.plot(x_axis, pmf_psi, label='Reference',linewidth=4)\n",
    "ax.set(xlabel=r'$\\psi$ [degree]', ylabel='', xticks=[-180, -90, 0, 90, 180])\n",
    "ax.plot(x_axis, pmf_psi_pred, label='Prediction',linewidth=4)\n",
    "lines, labels = ax.get_legend_handles_labels()\n",
    "ax.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot FES\n",
    "plt.subplots_adjust()\n",
    "fig, axs = plt.subplots(1, 2)\n",
    "fig.set_size_inches(12.5, 6)\n",
    "fig.tight_layout(pad=0)\n",
    "\n",
    "ax = axs[0]\n",
    "energies = gt_fes\n",
    "energies = (energies - np.min(energies))\n",
    "img = ax.imshow(energies, interpolation='nearest', cmap = plt.cm.magma)\n",
    "ax.invert_yaxis()\n",
    "divider = make_axes_locatable(ax)\n",
    "cax = divider.append_axes('right', size='5%', pad=0.05)\n",
    "cb = fig.colorbar(img, cax=cax, orientation='vertical')\n",
    "\n",
    "ax.set(xlabel=\"$\\phi$ [degree]\", ylabel=\"$\\psi$ [degree]\", title='Reference')\n",
    "ax.axis(\"equal\")\n",
    "ax.set(xticks=[0,15,30,45,59], xticklabels=[-180, -90, 0, 90, 180],\n",
    "       yticks=[0,15,30,45,59], yticklabels=[-180, -90, 0, 90, 180])\n",
    "cb.ax.set_title('F [kJ/mol]')\n",
    "\n",
    "\n",
    "ax = axs[1]\n",
    "energies = run['fes']\n",
    "energies = (energies - np.min(energies))\n",
    "img = ax.imshow(energies, interpolation='nearest', cmap = plt.cm.magma)\n",
    "ax.invert_yaxis()\n",
    "divider = make_axes_locatable(ax)\n",
    "ax.set(xticks=[0,15,30,45,59], xticklabels=[-180, -90, 0, 90, 180],\n",
    "       yticks=[],yticklabels=[])\n",
    "cax = divider.append_axes('right', size='5%', pad=0.05)\n",
    "cb = fig.colorbar(img, cax=cax, orientation='vertical')\n",
    "\n",
    "ax.set(xlabel=\"$\\phi$ [degree]\", title='Prediction')\n",
    "ax.axis(\"equal\")\n",
    "cb.ax.set_title('F [kJ/mol]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# adjust positions to show under pbc properly\n",
    "traj = Trajectory(md_dir / 'atoms.traj')\n",
    "lattices = (traj[0].cell.diagonal())[0]\n",
    "new_traj = []\n",
    "for i in range(len(traj)):\n",
    "    atoms = traj[i]\n",
    "    atoms.positions = (atoms.positions + np.array([0,0,10])) % lattices\n",
    "    new_traj.append(atoms)\n",
    "\n",
    "# visualization\n",
    "nv.show_asetraj(new_traj)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LiPS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_image_flag(cell, fcoord1, fcoord2):\n",
    "    supercells = torch.FloatTensor(list(itertools.product((-1, 0, 1), repeat=3))).to(cell.device)\n",
    "    fcoords = fcoord2[:, None] + supercells\n",
    "    coords = fcoords @ cell\n",
    "    coord1 = fcoord1 @ cell\n",
    "    dists = torch.cdist(coord1[:, None], coords).squeeze()\n",
    "    image = dists.argmin(dim=-1)\n",
    "    return supercells[image].long()\n",
    "\n",
    "def frac2cart(fcoord, cell):\n",
    "    return fcoord @ cell\n",
    "\n",
    "def cart2frac(coord, cell):\n",
    "    invcell = torch.linalg.inv(cell)\n",
    "    return coord @ invcell\n",
    "\n",
    "# the source data is in wrapped coordinates. need to unwrap it for computing diffusivity.\n",
    "def unwrap(pos0, pos1, cell):\n",
    "    fcoords1 = cart2frac(pos0, cell)\n",
    "    fcoords2 = cart2frac(pos1, cell)\n",
    "    flags = compute_image_flag(cell, fcoords1, fcoords2)\n",
    "    remapped_frac_coords = cart2frac(pos1, cell) + flags\n",
    "    return frac2cart(remapped_frac_coords, cell)\n",
    "\n",
    "# different from previous functions, now needs to deal with non-cubic cells. \n",
    "def compute_distance_matrix_batch(cell, cart_coords, num_cells=1):\n",
    "    pos = torch.arange(-num_cells, num_cells+1, 1).to(cell.device)\n",
    "    combos = torch.stack(\n",
    "        torch.meshgrid(pos, pos, pos, indexing='xy')\n",
    "            ).permute(3, 2, 1, 0).reshape(-1, 3).to(cell.device)\n",
    "    shifts = torch.sum(cell.unsqueeze(0) * combos.unsqueeze(-1), dim=1)\n",
    "    # NxNxCells distance array\n",
    "    shifted = cart_coords.unsqueeze(2) + shifts.unsqueeze(0).unsqueeze(0)\n",
    "    dist = cart_coords.unsqueeze(2).unsqueeze(2) - shifted.unsqueeze(1)\n",
    "    dist = dist.pow(2).sum(dim=-1).sqrt()\n",
    "    # But we want only min\n",
    "    distance_matrix = dist.min(dim=-1)[0]\n",
    "    return distance_matrix\n",
    "\n",
    "def get_lips_rdf(data_seq, lattices, bins, device='cpu'):\n",
    "    data_seq = data_seq.to(device).float()\n",
    "    lattices = lattices.to(device).float()\n",
    "    \n",
    "    lattice_np = lattices.cpu().numpy()\n",
    "    volume = float(abs(np.dot(np.cross(lattice_np[0], lattice_np[1]), lattice_np[2])))\n",
    "    data_pdist = compute_distance_matrix_batch(lattices, data_seq)\n",
    "\n",
    "    data_pdist = data_pdist.flatten().cpu().numpy()\n",
    "    data_shape = data_pdist.shape[0]\n",
    "\n",
    "    data_pdist = data_pdist[data_pdist != 0]\n",
    "    data_hist, _ = np.histogram(data_pdist, bins)\n",
    "\n",
    "    rho_data = data_shape / volume\n",
    "    Z_data = rho_data * 4 / 3 * np.pi * (bins[1:] ** 3 - bins[:-1] ** 3)\n",
    "    rdf = data_hist / Z_data\n",
    "        \n",
    "    return rdf\n",
    "\n",
    "def load_run(md_dir, atomic_numbers, cell, xlim, bins, stability_threshold, gt_rdf, gt_diff):\n",
    "    if not isinstance(md_dir, Path):\n",
    "        md_dir = Path(md_dir)\n",
    "        \n",
    "    model_name = md_dir.parts[-2]\n",
    "    seed = md_dir.parts[-1][-1]\n",
    "    run = {'name': (model_name + f'_seed_{seed}')}\n",
    "\n",
    "    run['traj'] = Trajectory(md_dir / 'atoms.traj')\n",
    "    run['traj'] = torch.from_numpy(np.stack([run['traj'][i].positions \n",
    "                                                  for i in range(len(run['traj']))]))\n",
    "    run['thermo'] = get_thermo( md_dir / 'thermo.log')\n",
    "\n",
    "    md_time = np.array(run['thermo']['time'])\n",
    "    T = np.array(run['thermo']['T']) \n",
    "    collapse_pt = len(T)\n",
    "    for i in (range(1, len(T)-rdf_check_interval)):\n",
    "        timerange = torch.arange(i, i + rdf_check_interval)\n",
    "        current_rdf = get_lips_rdf(run['traj'][timerange], cell, bins)\n",
    "        rdf_mae = mae(current_rdf, gt_rdf, xlim)\n",
    "        if rdf_mae > stability_threshold:\n",
    "            collapse_pt = i\n",
    "            break\n",
    "\n",
    "    run['collapse_pt'] = collapse_pt \n",
    "\n",
    "    run['rdf'] = get_lips_rdf(run['traj'][:collapse_pt], cell, bins)\n",
    "    run['rdf_error'] = mae(run['rdf'], gt_rdf, xlim)\n",
    "\n",
    "    if collapse_pt > 3200:\n",
    "        # removing the first 5 ps for equilibrium. use the diffusivity at 40 ps as a convergence value.\n",
    "        # some random error is unavoidable with 50-ps reference simulations.\n",
    "        diff = get_smoothed_diff(run['traj'][400:collapse_pt:4, atomic_numbers == 3])\n",
    "        run['diffusivity'] = diff[700] * 20 * 1e-8\n",
    "        run['end_diff'] = float(run['diffusivity'])\n",
    "        run['diff_error'] = np.abs(float(run['diffusivity']) - float(gt_diff[700]))\n",
    "    else:\n",
    "        run['diffusivity'] = None\n",
    "        run['end_diff'] = np.inf\n",
    "        run['diff_error'] = np.inf\n",
    "\n",
    "    # load test metrics\n",
    "    if (md_dir / 'test_metric.json').exists():\n",
    "        test_metrics = get_test_metrics(md_dir)\n",
    "        run.update(test_metrics)\n",
    "\n",
    "    return run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stability_threshold = 1.0\n",
    "rdf_check_interval = 80 # 1 ps. recording freq is 0.0125 ps. \n",
    "\n",
    "xlim = 6\n",
    "nbins = 500\n",
    "bins = np.linspace(1e-6, xlim, nbins + 1)\n",
    "\n",
    "atoms = read('DATAPATH/lips/lips.xyz', index=':', format='extxyz')\n",
    "n_points = len(atoms)\n",
    "positions, cell, atomic_numbers = [], [], []\n",
    "for i in range(n_points):\n",
    "    positions.append(atoms[i].get_positions())\n",
    "    cell.append(atoms[i].get_cell())\n",
    "    atomic_numbers.append(atoms[i].get_atomic_numbers())\n",
    "positions = torch.from_numpy(np.array(positions))\n",
    "cell = torch.from_numpy(np.array(cell)[0])\n",
    "atomic_numbers = torch.from_numpy(np.array(atomic_numbers)[0])\n",
    "\n",
    "# unwrap positions\n",
    "all_displacements = []\n",
    "for i in (range(1, len(positions))):\n",
    "    next_pos = unwrap(positions[i-1], positions[i], cell)\n",
    "    displacements = next_pos - positions[i-1]\n",
    "    all_displacements.append(displacements)\n",
    "displacements = torch.stack(all_displacements)\n",
    "accum_displacements = torch.cumsum(displacements, dim=0)\n",
    "positions = torch.cat([positions[0].unsqueeze(0), positions[0] + accum_displacements], dim=0)\n",
    "\n",
    "\n",
    "gt_rdf = get_lips_rdf(positions[::], cell, bins, device='cpu')\n",
    "# Li diffusivity unit in m^2/s. remove the first 5 ps as equilibrium.\n",
    "# Desirably, we want longer trajectories for computing diffusivity.\n",
    "gt_diff = get_smoothed_diff((positions[2500:None:25, atomic_numbers == 3])) * 20 * 1e-8\n",
    "\n",
    "# load run and plot RDFs\n",
    "md_dir = Path('./example_sim/lips_gemnet-t')\n",
    "run = load_run(md_dir, atomic_numbers, cell, xlim, bins, stability_threshold, gt_rdf, gt_diff)\n",
    "\n",
    "xaxis = np.linspace(1e-6, xlim, nbins)\n",
    "plt.plot(xaxis, gt_rdf, label='Reference', linewidth=2, linestyle='--')\n",
    "plt.plot(xaxis, run['rdf'], label='Prediction', linewidth=2, linestyle='--')\n",
    "plt.legend()\n",
    "\n",
    "force_mae = run['fmae'] * 1000\n",
    "collapse_ps = (run['collapse_pt']-1) / 80\n",
    "rdf_mae = run['rdf_error']\n",
    "diff_mae = run['diff_error'] * 1e9\n",
    "print(f'force mae: {force_mae:.1f} meV/A \\nstability: {collapse_ps:.1f} ps \\nRDF mae: {rdf_mae:.2f}' +\n",
    "      f'\\nDiffusivity mae: {diff_mae:.2f} x 10^-9 m^2/s')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualization\n",
    "traj = Trajectory(md_dir / 'atoms.traj')\n",
    "nv.show_asetraj(traj)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mdbench",
   "language": "python",
   "name": "mdbench"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
