{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3200850a-b8fb-4f50-9815-16ae8da0f942",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy import stats\n",
    "from scipy.interpolate import UnivariateSpline\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from ase import Atom, Atoms\n",
    "from ase.data import chemical_symbols, covalent_radii, vdw_alvarez\n",
    "from ase.io import read, write\n",
    "from mlip_arena.models import REGISTRY, MLIPEnum\n",
    "from pymatgen.core import Element"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02ff9cf9-49a2-4cec-80d3-56c6661a513b",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Compute MLIP homonuclear diatomics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90887faa-1601-4c4c-9c44-d16731471d7f",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "for model in MLIPEnum:\n",
    "    model_name = model.name\n",
    "\n",
    "    if model_name != \"MACE-MPA\":\n",
    "        continue\n",
    "\n",
    "    print(f\"========== {model_name} ==========\")\n",
    "\n",
    "    calc = MLIPEnum[model_name].value()\n",
    "\n",
    "    for symbol in tqdm(chemical_symbols[1:]):\n",
    "        s = set([symbol])\n",
    "\n",
    "        if \"X\" in s:\n",
    "            continue\n",
    "\n",
    "        try:\n",
    "            atom = Atom(symbol)\n",
    "            rmin = 0.9 * covalent_radii[atom.number]\n",
    "            rvdw = (\n",
    "                vdw_alvarez.vdw_radii[atom.number]\n",
    "                if atom.number < len(vdw_alvarez.vdw_radii)\n",
    "                else np.nan\n",
    "            )\n",
    "            rmax = 3.1 * rvdw if not np.isnan(rvdw) else 6\n",
    "            rstep = 0.01\n",
    "            npts = int((rmax - rmin) / rstep)\n",
    "\n",
    "            rs = np.linspace(rmin, rmax, npts)\n",
    "            es = np.zeros_like(rs)\n",
    "\n",
    "            da = symbol + symbol\n",
    "\n",
    "            out_dir = Path(REGISTRY[model_name][\"family\"]) / str(da)\n",
    "            os.makedirs(out_dir, exist_ok=True)\n",
    "\n",
    "            skip = 0\n",
    "\n",
    "            element = Element(symbol)\n",
    "\n",
    "            try:\n",
    "                m = element.valence[1]\n",
    "                if element.valence == (0, 2):\n",
    "                    m = 0\n",
    "            except Exception:\n",
    "                m = 0\n",
    "\n",
    "            a = 2 * rmax\n",
    "            r = rs[0]\n",
    "\n",
    "            positions = [\n",
    "                [a / 2 - r / 2, a / 2, a / 2],\n",
    "                [a / 2 + r / 2, a / 2, a / 2],\n",
    "            ]\n",
    "\n",
    "            traj_fpath = out_dir / f\"{model_name}.extxyz\"\n",
    "\n",
    "            if traj_fpath.exists():\n",
    "                traj = read(traj_fpath, index=\":\")\n",
    "                skip = len(traj)\n",
    "                atoms = traj[-1]\n",
    "            else:\n",
    "                # Create the unit cell with two atoms\n",
    "                atoms = Atoms(\n",
    "                    da,\n",
    "                    positions=positions,\n",
    "                    # magmoms=magmoms,\n",
    "                    cell=[a, a + 0.001, a + 0.002],\n",
    "                    pbc=True,\n",
    "                )\n",
    "\n",
    "            print(atoms)\n",
    "\n",
    "            atoms.calc = calc\n",
    "\n",
    "            for i, r in enumerate(tqdm(rs)):\n",
    "                if i < skip:\n",
    "                    continue\n",
    "\n",
    "                positions = [\n",
    "                    [a / 2 - r / 2, a / 2, a / 2],\n",
    "                    [a / 2 + r / 2, a / 2, a / 2],\n",
    "                ]\n",
    "\n",
    "                # atoms.set_initial_magnetic_moments(magmoms)\n",
    "\n",
    "                atoms.set_positions(positions)\n",
    "\n",
    "                es[i] = atoms.get_potential_energy()\n",
    "\n",
    "                write(traj_fpath, atoms, append=\"a\")\n",
    "        except Exception as e:\n",
    "            print(e)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1bbfae1-790d-4586-9d7d-79c1ba658dcb",
   "metadata": {},
   "source": [
    "# Analysis and output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0ac2c09-370b-4fdd-bf74-ea5c4ade0215",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "for model in MLIPEnum:\n",
    "    model_name = model.name\n",
    "\n",
    "    # if model_name != \"MatterSim\":\n",
    "    #     continue\n",
    "\n",
    "    print(f\"========== {model_name} ==========\")\n",
    "\n",
    "    df = pd.DataFrame(\n",
    "        columns=[\n",
    "            \"name\",\n",
    "            \"method\",\n",
    "            \"R\",\n",
    "            \"E\",\n",
    "            \"F\",\n",
    "            \"S^2\",\n",
    "            \"force-flip-times\",\n",
    "            \"force-total-variation\",\n",
    "            \"force-jump\",\n",
    "            \"energy-diff-flip-times\",\n",
    "            \"energy-grad-norm-max\",\n",
    "            \"energy-jump\",\n",
    "            \"energy-total-variation\",\n",
    "            \"tortuosity\",\n",
    "            \"conservation-deviation\",\n",
    "            \"spearman-descending-force\",\n",
    "            \"spearman-ascending-force\",\n",
    "            \"spearman-repulsion-energy\",\n",
    "            \"spearman-attraction-energy\",\n",
    "            \"pbe-energy-mae\",\n",
    "            \"pbe-force-mae\",\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    for symbol in tqdm(chemical_symbols[1:]):\n",
    "        da = symbol + symbol\n",
    "\n",
    "        out_dir = Path(REGISTRY[model_name][\"family\"]) / da\n",
    "\n",
    "        traj_fpath = out_dir / f\"{model_name}.extxyz\"\n",
    "\n",
    "        if traj_fpath.exists():\n",
    "            traj = read(traj_fpath, index=\":\")\n",
    "        else:\n",
    "            continue\n",
    "\n",
    "        Rs, Es, Fs, S2s = [], [], [], []\n",
    "        for atoms in traj:\n",
    "            vec = atoms.positions[1] - atoms.positions[0]\n",
    "            r = np.linalg.norm(vec)\n",
    "            e = atoms.get_potential_energy()\n",
    "            f = np.inner(vec / r, atoms.get_forces()[1])\n",
    "            # s2 = np.mean(np.power(atoms.get_magnetic_moments(), 2))\n",
    "\n",
    "            Rs.append(r)\n",
    "            Es.append(e)\n",
    "            Fs.append(f)\n",
    "            # S2s.append(s2)\n",
    "\n",
    "        rs = np.array(Rs)\n",
    "        es = np.array(Es)\n",
    "        fs = np.array(Fs)\n",
    "\n",
    "        # sort interatomic distances and align to zero at far field\n",
    "        indices = np.argsort(rs)[::-1]\n",
    "        rs = rs[indices]\n",
    "        es = es[indices]\n",
    "        eshift = es[0]\n",
    "        es -= eshift\n",
    "        fs = fs[indices]\n",
    "\n",
    "        iminf = np.argmin(fs)\n",
    "        imine = np.argmin(es)\n",
    "\n",
    "        de_dr = np.gradient(es, rs)\n",
    "        d2e_dr2 = np.gradient(de_dr, rs)\n",
    "\n",
    "        # avoid numerical sensitity close to zero\n",
    "        rounded_fs = np.copy(fs)\n",
    "        rounded_fs[np.abs(rounded_fs) < 1e-2] = 0  # 10meV/A\n",
    "        fs_sign = np.sign(rounded_fs)\n",
    "        mask = fs_sign != 0\n",
    "        rounded_fs = rounded_fs[mask]\n",
    "        fs_sign = fs_sign[mask]\n",
    "        f_flip = np.diff(fs_sign) != 0\n",
    "\n",
    "        fdiff = np.diff(fs)\n",
    "        fdiff_sign = np.sign(fdiff)\n",
    "        mask = fdiff_sign != 0\n",
    "        fdiff = fdiff[mask]\n",
    "        fdiff_sign = fdiff_sign[mask]\n",
    "        fdiff_flip = np.diff(fdiff_sign) != 0\n",
    "        fjump = (\n",
    "            np.abs(fdiff[:-1][fdiff_flip]).sum() + np.abs(fdiff[1:][fdiff_flip]).sum()\n",
    "        )\n",
    "\n",
    "        ediff = np.diff(es)\n",
    "        ediff[np.abs(ediff) < 1e-3] = 0  # 1meV\n",
    "        ediff_sign = np.sign(ediff)\n",
    "        mask = ediff_sign != 0\n",
    "        ediff = ediff[mask]\n",
    "        ediff_sign = ediff_sign[mask]\n",
    "        ediff_flip = np.diff(ediff_sign) != 0\n",
    "        ejump = (\n",
    "            np.abs(ediff[:-1][ediff_flip]).sum() + np.abs(ediff[1:][ediff_flip]).sum()\n",
    "        )\n",
    "\n",
    "        try:\n",
    "            pbe_traj = read(f\"./vasp/{da}/PBE.extxyz\", index=\":\")\n",
    "\n",
    "            pbe_rs, pbe_es, pbe_fs = [], [], []\n",
    "\n",
    "            for atoms in pbe_traj:\n",
    "                vec = atoms.positions[1] - atoms.positions[0]\n",
    "                r = np.linalg.norm(vec)\n",
    "                pbe_rs.append(r)\n",
    "                pbe_es.append(atoms.get_potential_energy())\n",
    "                pbe_fs.append(np.inner(vec / r, atoms.get_forces()[1]))\n",
    "\n",
    "            pbe_rs = np.array(pbe_rs)\n",
    "            pbe_es = np.array(pbe_es)\n",
    "            pbe_fs = np.array(pbe_fs)\n",
    "\n",
    "            indices = np.argsort(pbe_rs)\n",
    "            pbe_rs = pbe_rs[indices]\n",
    "            pbe_es = pbe_es[indices]\n",
    "            pbe_fs = pbe_fs[indices]\n",
    "\n",
    "            pbe_es -= pbe_es[-1]\n",
    "\n",
    "            xs = np.linspace(pbe_rs.min(), pbe_rs.max(), int(1e3))\n",
    "\n",
    "            cs = UnivariateSpline(pbe_rs, pbe_es, s=0)\n",
    "            pbe_energy_mae = np.mean(np.abs(es - cs(rs)))\n",
    "\n",
    "            cs = UnivariateSpline(pbe_rs, pbe_fs, s=0)\n",
    "            pbe_force_mae = np.mean(np.abs(fs - cs(rs)))\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            pbe_energy_mae = None\n",
    "            pbe_force_mae = None\n",
    "\n",
    "        conservation_deviation = np.mean(np.abs(fs + de_dr))\n",
    "\n",
    "        etv = np.sum(np.abs(np.diff(es)))\n",
    "\n",
    "        data = {\n",
    "            \"name\": da,\n",
    "            \"method\": model_name,\n",
    "            \"R\": rs,\n",
    "            \"E\": es + eshift,\n",
    "            \"F\": fs,\n",
    "            \"S^2\": S2s,\n",
    "            \"force-flip-times\": np.sum(f_flip),\n",
    "            \"force-total-variation\": np.sum(np.abs(np.diff(fs))),\n",
    "            \"force-jump\": fjump,\n",
    "            \"energy-diff-flip-times\": np.sum(ediff_flip),\n",
    "            \"energy-grad-norm-max\": np.max(np.abs(de_dr)),\n",
    "            \"energy-jump\": ejump,\n",
    "            # \"energy-grad-norm-mean\": np.mean(de_dr_abs),\n",
    "            \"energy-total-variation\": etv,\n",
    "            \"tortuosity\": etv / (abs(es[0] - es.min()) + (es[-1] - es.min())),\n",
    "            \"conservation-deviation\": conservation_deviation,\n",
    "            \"spearman-descending-force\": stats.spearmanr(\n",
    "                rs[iminf:], fs[iminf:]\n",
    "            ).statistic,\n",
    "            \"spearman-ascending-force\": stats.spearmanr(\n",
    "                rs[:iminf], fs[:iminf]\n",
    "            ).statistic,\n",
    "            \"spearman-repulsion-energy\": stats.spearmanr(\n",
    "                rs[imine:], es[imine:]\n",
    "            ).statistic,\n",
    "            \"spearman-attraction-energy\": stats.spearmanr(\n",
    "                rs[:imine], es[:imine]\n",
    "            ).statistic,\n",
    "            \"pbe-energy-mae\": pbe_energy_mae,\n",
    "            \"pbe-force-mae\": pbe_force_mae,\n",
    "        }\n",
    "\n",
    "        df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)\n",
    "\n",
    "    json_fpath = Path(REGISTRY[model_name][\"family\"]) / \"homonuclear-diatomics.json\"\n",
    "\n",
    "    if json_fpath.exists():\n",
    "        df0 = pd.read_json(json_fpath)\n",
    "        df = pd.concat([df0, df], ignore_index=True)\n",
    "        df.drop_duplicates(inplace=True, subset=[\"name\", \"method\"], keep=\"last\")\n",
    "\n",
    "    df.to_json(json_fpath, orient=\"records\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
