{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "from dask.distributed import Client\n",
    "from dask_jobqueue import SLURMCluster\n",
    "from mlip_arena.models import REGISTRY, MLIPEnum\n",
    "from mlip_arena.tasks.md import run as MD\n",
    "from mlip_arena.tasks.utils import get_calculator\n",
    "from prefect import flow\n",
    "from prefect_dask import DaskTaskRunner\n",
    "\n",
    "from ase import Atoms, units\n",
    "from ase.build import molecule\n",
    "from ase.io import read, write\n",
    "from pymatgen.core import Molecule\n",
    "from pymatgen.io.packmol import PackmolBoxGen"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "## Intial configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "h2 = molecule(\"H2\")\n",
    "o2 = molecule(\"O2\")\n",
    "h2o = molecule(\"H2O\")\n",
    "\n",
    "write(\"h2.xyz\", h2)\n",
    "write(\"o2.xyz\", o2)\n",
    "write(\"h2o.xyz\", h2o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "h2 = Molecule.from_file(\"h2.xyz\")\n",
    "o2 = Molecule.from_file(\"o2.xyz\")\n",
    "h2o = Molecule.from_file(\"h2o.xyz\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "molecules = []\n",
    "\n",
    "for m, number in zip([h2, o2], [128, 64]):\n",
    "    molecules.append(\n",
    "        {\n",
    "            \"name\": m.composition.to_pretty_string(),\n",
    "            \"number\": number,\n",
    "            \"coords\": m,\n",
    "        }\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tolerance = 2.0\n",
    "input_gen = PackmolBoxGen(\n",
    "    tolerance=tolerance,\n",
    "    seed=1,\n",
    ")\n",
    "margin = 0.5 * tolerance\n",
    "\n",
    "a = 30\n",
    "\n",
    "packmol_set = input_gen.get_input_set(\n",
    "    molecules=molecules,\n",
    "    box=[margin, margin, margin, a - margin, a - margin, a - margin],\n",
    ")\n",
    "packmol_set.write_input(\".\")\n",
    "packmol_set.run(\".\")\n",
    "\n",
    "atoms = read(\"packmol_out.xyz\")\n",
    "atoms.cell = [a, a, a]\n",
    "atoms.pbc = True\n",
    "\n",
    "print(atoms)\n",
    "\n",
    "write(f\"{atoms.get_chemical_formula()}.extxyz\", atoms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run workflow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "atoms = read(\"H256O128.extxyz\")\n",
    "print(atoms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nodes_per_alloc = 1\n",
    "gpus_per_alloc = 4\n",
    "ntasks = 1\n",
    "\n",
    "cluster_kwargs = dict(\n",
    "    cores=1,\n",
    "    memory=\"64 GB\",\n",
    "    shebang=\"#!/bin/bash\",\n",
    "    account=\"m4282\",\n",
    "    walltime=\"00:30:00\",\n",
    "    job_mem=\"0\",\n",
    "    job_script_prologue=[\n",
    "        \"source ~/.bashrc\",\n",
    "        \"module load python\",\n",
    "        \"source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena\",\n",
    "    ],\n",
    "    job_directives_skip=[\"-n\", \"--cpus-per-task\", \"-J\"],\n",
    "    job_extra_directives=[\n",
    "        \"-J combustion-water\",\n",
    "        \"-q debug\",\n",
    "        f\"-N {nodes_per_alloc}\",\n",
    "        \"-C gpu\",\n",
    "        f\"-G {gpus_per_alloc}\",\n",
    "        \"--exclusive\",\n",
    "    ],\n",
    "    death_timeout=86400,\n",
    ")\n",
    "\n",
    "cluster = SLURMCluster(**cluster_kwargs)\n",
    "\n",
    "\n",
    "print(cluster.job_script())\n",
    "cluster.adapt(minimum_jobs=1, maximum_jobs=1)\n",
    "client = Client(cluster)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@flow(task_runner=DaskTaskRunner(address=client.scheduler.address), log_prints=True)\n",
    "def combustion(atoms: Atoms):\n",
    "    futures = []\n",
    "\n",
    "    for model in MLIPEnum:\n",
    "        if model.name != \"MatterSim\":\n",
    "            continue\n",
    "\n",
    "        future = MD.submit(\n",
    "            atoms=atoms,\n",
    "            calculator=get_calculator(\n",
    "                calculator_name=model,\n",
    "                calculator_kwargs=None,\n",
    "            ),\n",
    "            ensemble=\"nvt\",\n",
    "            dynamics=\"nose-hoover\",\n",
    "            time_step=None,\n",
    "            dynamics_kwargs=dict(ttime=25 * units.fs, pfactor=None),\n",
    "            total_time=1000_000,\n",
    "            temperature=[300, 3000, 3000, 300],\n",
    "            pressure=None,\n",
    "            velocity_seed=0,\n",
    "            traj_file=Path(REGISTRY[model.name][\"family\"])\n",
    "            / f\"{model.name}_{atoms.get_chemical_formula()}.traj\",\n",
    "            traj_interval=1000,\n",
    "            restart=True,\n",
    "        )\n",
    "\n",
    "        futures.append(future)\n",
    "\n",
    "    return [future.result() for future in futures]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results = combustion(atoms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def combustion(atoms: Atoms):\n",
    "    futures = []\n",
    "\n",
    "    for model in MLIPEnum:\n",
    "        if model.name != \"MatterSim\":\n",
    "            continue\n",
    "\n",
    "        future = MD(\n",
    "            atoms=atoms,\n",
    "            calculator=get_calculator(\n",
    "                calculator_name=model,\n",
    "                calculator_kwargs=None,\n",
    "            ),\n",
    "            ensemble=\"nvt\",\n",
    "            dynamics=\"nose-hoover\",\n",
    "            time_step=None,\n",
    "            dynamics_kwargs=dict(ttime=25 * units.fs, pfactor=None),\n",
    "            total_time=1000_000,\n",
    "            temperature=[300, 3000, 3000, 300],\n",
    "            pressure=None,\n",
    "            velocity_seed=0,\n",
    "            traj_file=Path(REGISTRY[model.name][\"family\"])\n",
    "            / f\"{model.name}_{atoms.get_chemical_formula()}.traj\",\n",
    "            traj_interval=1000,\n",
    "            restart=True,\n",
    "        )\n",
    "\n",
    "        futures.append(future)\n",
    "\n",
    "    return [future.result() for future in futures]\n",
    "\n",
    "\n",
    "results = combustion(atoms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlip-arena",
   "language": "python",
   "name": "mlip-arena"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
