{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30f9be6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af9e752a",
   "metadata": {},
   "outputs": [],
   "source": [
    "thermostat = \"langevin\"\n",
    "\n",
    "petmad_goo = np.genfromtxt(f\"../water/nvt-{thermostat}/petmad/md.grOO\")\n",
    "arrays_goo = {}\n",
    "for u in [\"universal\", \"water\"]:\n",
    "    for n in [1, 4, 16]:\n",
    "        for t in [True]:\n",
    "            try:\n",
    "                arrays_goo[f\"{u}-{n}-{t}\"] = np.genfromtxt(f\"../water/nvt-{thermostat}/{u}-{n}-{t}/md.grOO\")\n",
    "            except Exception as e:\n",
    "                print(e)\n",
    "                pass\n",
    "\n",
    "petmad_ghh = np.genfromtxt(f\"../water/nvt-{thermostat}/petmad/md.grHH\")\n",
    "arrays_ghh = {}\n",
    "for u in [\"universal\", \"water\"]:\n",
    "    for n in [1, 4, 16]:\n",
    "        for t in [True]:\n",
    "            try:\n",
    "                arrays_ghh[f\"{u}-{n}-{t}\"] = np.genfromtxt(f\"../water/nvt-{thermostat}/{u}-{n}-{t}/md.grHH\")\n",
    "            except Exception as e:\n",
    "                print(e)\n",
    "                pass\n",
    "\n",
    "# thermostat = \"svr-2\"\n",
    "# petmad_cuu = np.genfromtxt(f\"../water/nvt-{thermostat}/petmad/cft.dat\")\n",
    "# arrays_cuu = {}\n",
    "# for u in [\"universal\", \"water\"]:\n",
    "#     for n in [1, 4, 16]:\n",
    "#         for t in [True]:\n",
    "#             try:\n",
    "#                 arrays_cuu[f\"{u}-{n}-{t}\"] = np.genfromtxt(f\"../water/nvt-{thermostat}/{u}-{n}-{t}/cft.dat\")\n",
    "#             except Exception as e:\n",
    "#                 print(e)\n",
    "#                 pass\n",
    "\n",
    "pet_mad_mean = 961.20658\n",
    "pet_mad_error = 2.6445108\n",
    "volumes = {}\n",
    "volumes[\"universal-1\"] = (957.85919, 2.9179511)\n",
    "volumes[\"universal-4\"] = (1013.7755, 4.9451266)\n",
    "volumes[\"universal-16\"] = (913.54652, 8.9114886)\n",
    "volumes[\"water-1\"] = (940.55517, 3.8402561)\n",
    "volumes[\"water-4\"] = (968.66631, 3.4192689)\n",
    "volumes[\"water-16\"] = (948.23159, 3.9815192)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fabd6bd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "\n",
    "# Create a figure with specified size (full-page length)\n",
    "fig = plt.figure(figsize=(11, 2.5), dpi=300)\n",
    "\n",
    "# Create a gridspec layout (1 row, 2 columns, with the second column spanning more space)\n",
    "gs = gridspec.GridSpec(nrows=1, ncols=5, width_ratios=[1, 0.25, 1, 0.4, 1])\n",
    "fig.subplots_adjust(wspace=0, hspace=0.6) \n",
    "\n",
    "# Create subplots with specific axes based on the GridSpec layout\n",
    "ax1 = fig.add_subplot(gs[0])  # First plot on the left\n",
    "ax2 = fig.add_subplot(gs[2])  # Second plot on the left\n",
    "ax3 = fig.add_subplot(gs[4])  # Larger plot on the right (spans more space)\n",
    "\n",
    "# Example plotting data for the graphs\n",
    "# Left plot 1\n",
    "\n",
    "ax1.plot(petmad_goo[:, 0], petmad_goo[:, 1], \"--\", label=\"MD (0.25 fs)\", color=\"black\", linewidth=1.5, zorder=100)\n",
    "\n",
    "arr = arrays_goo['water-1-True']\n",
    "ax1.plot(arr[:, 0], arr[:, 1], label=\"water (1 fs)\")\n",
    "arr = arrays_goo['water-4-True']\n",
    "ax1.plot(arr[:, 0], arr[:, 1], label=\"water (4 fs)\")\n",
    "arr = arrays_goo['water-16-True']\n",
    "ax1.plot(arr[:, 0], arr[:, 1], label=\"water (16 fs)\")\n",
    "\n",
    "arr = arrays_goo['universal-1-True']\n",
    "ax1.plot(arr[:, 0], arr[:, 1], label=\"universal (1 fs)\")\n",
    "arr = arrays_goo['universal-4-True']\n",
    "ax1.plot(arr[:, 0], arr[:, 1], label=\"universal (4 fs)\")\n",
    "arr = arrays_goo['universal-16-True']\n",
    "ax1.plot(arr[:, 0], arr[:, 1], label=\"universal (16 fs)\")\n",
    "\n",
    "ax1.set_xlabel(r\"$r$ [Å]\", fontsize=14)\n",
    "ax1.set_ylabel(r\"$g_{OO}(r)$\", fontsize=14)\n",
    "ax1.legend(fontsize=7, ncols=1)\n",
    "ax1.set_xlim(1.8, 6)\n",
    "ax1.set_ylim(-0.2, 4.2)\n",
    "# ax1.text(0.95, 0.05, 'O$-$O', ha='right', va='bottom', transform=ax1.transAxes, fontsize=14, color=\"green\")\n",
    "\n",
    "\n",
    "ax2.plot(petmad_ghh[:, 0], petmad_ghh[:, 1], \"--\", label=\"MD (0.25 fs)\", color=\"black\", linewidth=1.5, zorder=100)\n",
    "\n",
    "arr = arrays_ghh['water-1-True']\n",
    "ax2.plot(arr[:, 0], arr[:, 1], label=\"water (1 fs)\")\n",
    "arr = arrays_ghh['water-4-True']\n",
    "ax2.plot(arr[:, 0], arr[:, 1], label=\"water (4 fs)\")\n",
    "arr = arrays_ghh['water-16-True']\n",
    "ax2.plot(arr[:, 0], arr[:, 1], label=\"water (16 fs)\")\n",
    "\n",
    "arr = arrays_ghh['universal-1-True']\n",
    "ax2.plot(arr[:, 0], arr[:, 1], label=\"universal (1 fs)\")\n",
    "arr = arrays_ghh['universal-4-True']\n",
    "ax2.plot(arr[:, 0], arr[:, 1], label=\"universal (4 fs)\")\n",
    "arr = arrays_ghh['universal-16-True']\n",
    "ax2.plot(arr[:, 0], arr[:, 1], label=\"universal (16 fs)\")\n",
    "\n",
    "\n",
    "ax2.set_ylabel(r\"$g_{HH}(r)$\", fontsize=14)\n",
    "\n",
    "ax2.set_xlabel(r\"$r$ [Å]\", fontsize=14)\n",
    "ax2.set_xlim(0.8, 6)\n",
    "ax2.set_xticks([1,2,3,4,5,6])\n",
    "ax2.set_ylim(-0.1555, 3)\n",
    "ax2.set_yticks([0, 1, 2])\n",
    "# ax2.text(0.95, 0.05, 'H$-$H', ha='right', va='bottom', transform=ax2.transAxes, fontsize=14, color=\"blue\")\n",
    "\n",
    "\n",
    "# Right plot\n",
    "v= volumes['water-1']\n",
    "ax3.errorbar(1, v[0], v[1], fmt=\"o\", markersize=0, label=\"water\", linewidth=2, color=\"dodgerblue\")\n",
    "v= volumes['water-4']\n",
    "ax3.errorbar(4, v[0], v[1], fmt=\"o\", markersize=0, linewidth=2, color=\"dodgerblue\")\n",
    "v= volumes['water-16']\n",
    "ax3.errorbar(16, v[0], v[1], fmt=\"o\", markersize=0, linewidth=2, color=\"dodgerblue\")\n",
    "\n",
    "v= volumes['universal-1']\n",
    "ax3.errorbar(1, v[0], v[1], fmt=\"o\", markersize=0, label=\"universal\", linewidth=2, color=\"green\")\n",
    "v= volumes['universal-4']\n",
    "ax3.errorbar(4, v[0], v[1], fmt=\"o\", markersize=0, linewidth=2, color=\"green\")\n",
    "v= volumes['universal-16']\n",
    "ax3.errorbar(16, v[0], v[1], fmt=\"o\", markersize=0, linewidth=2, color=\"green\")\n",
    "\n",
    "ax3.axhline(pet_mad_mean, color=\"black\", linestyle=\"--\", label=\"MD (0.25 fs)\", linewidth=1)\n",
    "x_min = 0\n",
    "x_max = 100\n",
    "\n",
    "ax3.fill_between(\n",
    "    [x_min, x_max],\n",
    "    [pet_mad_mean - pet_mad_error, pet_mad_mean - pet_mad_error],\n",
    "    [pet_mad_mean + pet_mad_error, pet_mad_mean + pet_mad_error],\n",
    "    color=\"black\",\n",
    "    alpha=0.3,\n",
    "    lw=0,\n",
    ")\n",
    "ax3.set_xscale(\"log\")\n",
    "ax3.set_xlim(0.5, 20)\n",
    "ax3.set_xticks([1, 4, 16], [1, 4, 16])\n",
    "ax3.set_xlabel(\"Time step [fs]\", fontsize=12)\n",
    "ax3.set_ylabel(\"Density [kg/m³]\", fontsize=12)\n",
    "ax3.legend(fontsize=7, loc='lower center')\n",
    "\n",
    "# Adjust the layout to make sure everything fits nicely\n",
    "# plt.tight_layout()\n",
    "\n",
    "# Show the plot\n",
    "plt.savefig(\"figure2_water.pdf\", dpi=300, bbox_inches=\"tight\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bf46bad",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(11, 3), dpi=300)\n",
    "fig.subplots_adjust(wspace=0.15, hspace=0) \n",
    "\n",
    "for rescale in [False, True]:\n",
    "    print(int(rescale))\n",
    "    total_energies = np.load(f\"../ablation/energy_enforcement/total_energies_{rescale}.npy\")\n",
    "    potential_energies = np.load(f\"../ablation/energy_enforcement/potential_energies_{rescale}.npy\")\n",
    "    kinetic_energies = np.load(f\"../ablation/energy_enforcement/kinetic_energies_{rescale}.npy\")\n",
    "    \n",
    "    axes[int(rescale)].plot(np.arange(len(kinetic_energies)) * 0.004, kinetic_energies-kinetic_energies.mean(), label=\"kinetic\", color=\"skyblue\")\n",
    "    axes[int(rescale)].plot(np.arange(len(potential_energies)) * 0.004, potential_energies-potential_energies.mean(), label=\"potential\", color=\"lightgreen\")\n",
    "    axes[int(rescale)].plot(np.arange(len(total_energies)) * 0.004, total_energies-total_energies.mean(), label=\"total\", color=\"black\")\n",
    "    \n",
    "    axes[int(rescale)].set_xlabel(\"Time [ps]\", fontsize=12)\n",
    "    \n",
    "    rescale_string = \"With E conservation filter\" if rescale else \"Without E conservation filter\"   \n",
    "    axes[int(rescale)].text(0.03, 0.95, rescale_string, ha='left', va='top', transform=axes[int(rescale)].transAxes, fontsize=12)\n",
    "    axes[int(rescale)].set_xlim(-0.2, 4.2)\n",
    "    axes[int(rescale)].set_ylim(-2.8, 2.8)\n",
    "    \n",
    "axes[0].set_ylabel(\"Energy [eV]\",fontsize=12)\n",
    "axes[1].legend(loc=\"lower right\")\n",
    "\n",
    "    # plt.plot(temperatures, label=\"temperature\")\n",
    "    # plt.xlabel(\"Step number\")\n",
    "    # plt.ylabel(\"Temperature (K)\")\n",
    "    # plt.title(f\"Water (4 fs), rescale: {rescale}\")\n",
    "    # plt.legend()\n",
    "    # plt.show()\n",
    "plt.savefig(\"figure_energy_conservation.pdf\", dpi=300, bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9594273d-0239-47e1-a767-4153aa2edbd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "water_errors_q = {\n",
    "    1: 6.2e-5,\n",
    "    4: 5.4e-4,\n",
    "    16: 6.2e-3,\n",
    "    64: 0.25,\n",
    "    256: 1.3,\n",
    "}\n",
    "water_errors_p = {\n",
    "    1: 1.2e-3,\n",
    "    4: 2.5e-3,\n",
    "    16: 1.1e-2,\n",
    "    64: 0.18,\n",
    "    256: 0.22,\n",
    "}\n",
    "\n",
    "universal_errors_q = {\n",
    "    1: 1.3e-4,\n",
    "    4: 1.3e-3,\n",
    "    16: 0.015,\n",
    "    64: 0.15,\n",
    "    256: 1.0,\n",
    "}\n",
    "universal_errors_p = {\n",
    "    1: 2.3e-3,\n",
    "    4: 6e-3,\n",
    "    16: 0.021,\n",
    "    64: 0.098,\n",
    "    256: 0.23,\n",
    "}\n",
    "\n",
    "water_errors_energy = {\n",
    "    1: 0.03,\n",
    "    4: 0.10,\n",
    "    16: 0.57,\n",
    "    64: 300,\n",
    "    256: 10000,\n",
    "}\n",
    "universal_errors_energy = {\n",
    "    1: 1.7,\n",
    "    4: 5.3,\n",
    "    16: 22,\n",
    "    64: 320,\n",
    "    256: 2300,\n",
    "}\n",
    "pet_mad_error = 30\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1003450a-388d-40f7-8706-9bd9844a843e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "fig, ax = plt.subplots(1, 4, figsize=(11, 2.5), dpi=300)\n",
    "fig.subplots_adjust(wspace=0.5, hspace=0) \n",
    "\n",
    "ax[0].plot(water_errors_q.keys(), water_errors_q.values(), \"o\", label=\"water\", markersize=5)\n",
    "ax[0].plot(universal_errors_q.keys(), universal_errors_q.values(), \"o\", label=\"universal\", markersize=5)\n",
    "ax[0].set_xscale(\"log\")\n",
    "ax[0].set_yscale(\"log\")\n",
    "ax[0].set_xticks(list(universal_errors_q.keys()), list(universal_errors_q.keys())) \n",
    "# ax[0].set_yticks([0.1, 1, 10, 100, 1000, 10000], [\"0.1\", \"1\", \"10\", \"100\", \"1000\", \"10000\"])\n",
    "ax[0].xaxis.set_minor_locator(plt.NullLocator())\n",
    "ax[0].set_xlabel(\"Time step [fs]\")\n",
    "ax[0].set_ylabel(r\"RMSE [$Å \\sqrt{u}$]\")\n",
    "# ax[0].set_title(\"Error in mass-scaled positions\")\n",
    "# ax[0].legend()\n",
    "\n",
    "ax[1].plot(water_errors_p.keys(), water_errors_p.values(), \"o\", label=\"water\", markersize=5)\n",
    "ax[1].plot(universal_errors_p.keys(), universal_errors_p.values(), \"o\", label=\"universal\", markersize=5)\n",
    "ax[1].set_xscale(\"log\")\n",
    "ax[1].set_yscale(\"log\")\n",
    "ax[1].set_xticks(list(universal_errors_p.keys()), list(universal_errors_p.keys())) \n",
    "# ax[1].set_yticks([0.1, 1, 10, 100, 1000, 10000], [\"0.1\", \"1\", \"10\", \"100\", \"1000\", \"10000\"])\n",
    "ax[1].xaxis.set_minor_locator(plt.NullLocator())\n",
    "ax[1].set_xlabel(\"Time step [fs]\")\n",
    "ax[1].set_ylabel(r\"RMSE [$\\sqrt{eV}$]\")\n",
    "# ax[1].set_title(\"Error in mass-scaled momenta\")\n",
    "# ax[1].legend()\n",
    "\n",
    "# horizontal line for petmad\n",
    "ax[2].axhline(pet_mad_error, color=\"black\", linestyle=\"--\", label=\"PET-MAD\")\n",
    "ax[2].plot(water_errors_energy.keys(), water_errors_energy.values(), \"o\", label=\"water\", markersize=5, )\n",
    "ax[2].plot(universal_errors_energy.keys(), universal_errors_energy.values(), \"o\", label=\"universal\", markersize=5, )\n",
    "ax[2].set_xscale(\"log\")\n",
    "ax[2].set_yscale(\"log\")\n",
    "ax[2].set_xticks(list(universal_errors_energy.keys()), list(universal_errors_energy.keys())) \n",
    "# ax[2].set_yticks([0.1, 1, 10, 100, 1000, 10000], [\"0.1\", \"1\", \"10\", \"100\", \"1000\", \"10000\"])\n",
    "ax[2].xaxis.set_minor_locator(plt.NullLocator())\n",
    "ax[2].set_xlabel(\"Time step [fs]\")\n",
    "ax[2].set_ylabel(\"RMSE [meV/at.]\")\n",
    "# ax[2].set_title(\"Error in energy conservation\")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "symplecticity = {'water-1': 0.9988048007458071, 'water-4': 0.8154773786656506, 'water-16': 0.3844331911748218, 'universal-1': 0.998672098724109, 'universal-4': 0.8190309905799852, 'universal-16': 0.2723277747104351}\n",
    "\n",
    "for model_type in [\"water\", \"universal\"]:\n",
    "    ax[3].plot([int(k.split(\"-\")[1]) for k in symplecticity.keys() if model_type in k], [v for k, v in symplecticity.items() if model_type in k], \"o\", label=model_type, markersize=5)\n",
    "ax[3].axhline(1.0, color=\"black\", linestyle=\"--\")\n",
    "ax[3].set_ylim(-0.1, 1.3)\n",
    "ax[3].set_xscale(\"log\")\n",
    "ax[3].set_xticks([1, 4, 16], [1, 4, 16])\n",
    "ax[3].set_xlabel(\"Time step [fs]\")\n",
    "ax[3].set_ylabel(\"Symplecticity\")\n",
    "\n",
    "ax[3].legend()\n",
    "\n",
    "ax[0].text(0.05, 0.95, \"a\", ha='left', va='top', transform=ax[0].transAxes, fontsize=12, weight='bold')\n",
    "ax[1].text(0.05, 0.95, \"b\", ha='left', va='top', transform=ax[1].transAxes, fontsize=12, weight='bold')\n",
    "ax[2].text(0.05, 0.95, \"c\", ha='left', va='top', transform=ax[2].transAxes, fontsize=12, weight='bold')\n",
    "ax[3].text(0.05, 0.95, \"d\", ha='left', va='top', transform=ax[3].transAxes, fontsize=12, weight='bold')\n",
    "\n",
    "\n",
    "plt.savefig(\"figure_errors.pdf\", dpi=300, bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "047c46e3-805b-40d7-937f-e8bc77c68fb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=(11, 3), dpi=300)\n",
    "fig.subplots_adjust(wspace=0.3, hspace=0) \n",
    "\n",
    "for i_model, model in enumerate([\"broken\", \"final\"]):\n",
    "\n",
    "    energies = np.load(f\"../ablation/energy_misalignment/energy_errors_{model}.npy\")\n",
    "    energies_md = energies[0]\n",
    "    energies_skipmd = energies[1]\n",
    "    rmse = np.sqrt(np.mean((energies_md - energies_skipmd) ** 2))\n",
    "\n",
    "    axes[i_model].plot(energies_md, energies_skipmd, \".\", markersize=5, alpha=0.5, mew=0, color=\"#00B4D8\")\n",
    "    axes[i_model].set_xlabel(\"PET-MAD energy [eV]\", fontsize=12)\n",
    "    axes[i_model].set_ylabel(\"FlashMD energy [eV]\", fontsize=12)\n",
    "    axes[i_model].plot([min(energies_md), max(energies_md)], [min(energies_md), max(energies_md)], '--', color='black')\n",
    "    \n",
    "\n",
    "axes[0].text(0.03, 0.95, \"Checkpoint A\", ha='left', va='top', transform=axes[0].transAxes, fontsize=12)\n",
    "axes[1].text(0.03, 0.95, \"Checkpoint B\", ha='left', va='top', transform=axes[1].transAxes, fontsize=12)\n",
    "\n",
    "\n",
    "plt.savefig(\"figure_energy_misalignment.pdf\", dpi=300, bbox_inches=\"tight\")\n",
    "plt.show()\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "808e5c02-35de-4493-9f1d-829f63c86142",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
