{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from ase import units\n",
    "from dask.distributed import Client\n",
    "from dask_jobqueue import SLURMCluster\n",
    "from dotenv import load_dotenv\n",
    "from prefect import flow, task\n",
    "from prefect_dask import DaskTaskRunner\n",
    "\n",
    "from mlip_arena.models import REGISTRY, MLIPEnum\n",
    "from mlip_arena.tasks.md import run as MD\n",
    "from mlip_arena.tasks.stability.input import get_atoms_from_db\n",
    "\n",
    "load_dotenv()\n",
    "\n",
    "HF_TOKEN = os.environ.get(\"HF_TOKEN\", None)\n",
    "MP_API_KEY = os.environ.get(\"MP_API_KEY\", None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "nodes_per_alloc = 1\n",
    "gpus_per_alloc = 4\n",
    "ntasks = 1\n",
    "\n",
    "cluster_kwargs = dict(\n",
    "    cores=1,\n",
    "    memory=\"64 GB\",\n",
    "    processes=1,\n",
    "    shebang=\"#!/bin/bash\",\n",
    "    account=\"matgen\",\n",
    "    walltime=\"03:00:00\",\n",
    "    # job_cpu=128,\n",
    "    job_mem=\"0\",\n",
    "    job_script_prologue=[\n",
    "        \"source ~/.bashrc\",\n",
    "        \"module load python\",\n",
    "        \"source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena\",\n",
    "    ],\n",
    "    job_directives_skip=[\"-n\", \"--cpus-per-task\", \"-J\"],\n",
    "    job_extra_directives=[\n",
    "        \"-J stability-npt\",\n",
    "        \"-q preempt\",\n",
    "        \"--time-min=00:30:00\",\n",
    "        \"--comment=12:00:00\",\n",
    "        f\"-N {nodes_per_alloc}\",\n",
    "        \"-C gpu\",\n",
    "        f\"-G {gpus_per_alloc}\",\n",
    "    ],\n",
    ")\n",
    "\n",
    "cluster = SLURMCluster(**cluster_kwargs)\n",
    "print(cluster.job_script())\n",
    "cluster.adapt(minimum_jobs=5, maximum_jobs=10)\n",
    "client = Client(cluster)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mlip_arena.tasks.utils import get_calculator\n",
    "\n",
    "selected_models = [\n",
    "    \"MACE-MP(M)\",\n",
    "    \"CHGNet\",\n",
    "    \"M3GNet\",\n",
    "    \"MatterSim\",\n",
    "    \"eqV2(OMat)\",\n",
    "    \"MACE-MPA\",\n",
    "    \"ORBv2\",\n",
    "    \"SevenNet\",\n",
    "    \"ALIGNN\",\n",
    "]\n",
    "\n",
    "\n",
    "@task\n",
    "def run_one(\n",
    "    atoms,\n",
    "    model,\n",
    "):\n",
    "    result = MD.with_options(\n",
    "        timeout_seconds=600,\n",
    "        retries=2,\n",
    "        refresh_cache=True\n",
    "    )(\n",
    "        atoms=atoms,\n",
    "        calculator=get_calculator(\n",
    "            model.name,\n",
    "            calculator_kwargs=None,\n",
    "        ),\n",
    "        ensemble=\"npt\",\n",
    "        dynamics=\"nose-hoover\",\n",
    "        time_step=None,\n",
    "        dynamics_kwargs=dict(\n",
    "            ttime=25 * units.fs, pfactor=((75 * units.fs) ** 2) * 1e2 * units.GPa\n",
    "        ),\n",
    "        total_time=1e4,  # 5e4, # fs\n",
    "        temperature=[300, 3000],\n",
    "        pressure=[0, 5e2 * units.GPa],  # 500 GPa / 10 ps = 50 GPa / 1 ps\n",
    "        traj_file=f\"{REGISTRY[model.name]['family']}/{model.name}_{atoms.info.get('material_id', 'random')}_{atoms.get_chemical_formula()}_npt.traj\",\n",
    "        traj_interval=10,\n",
    "    )\n",
    "\n",
    "    return result\n",
    "\n",
    "\n",
    "@flow\n",
    "def compress():\n",
    "    futures = []\n",
    "    # To download the database automatically, `huggingface_hub login` or provide HF_TOKEN\n",
    "    for atoms in get_atoms_from_db(\"random-mixture.db\", force_download=False):\n",
    "        for model in MLIPEnum:\n",
    "            if model.name not in selected_models:\n",
    "                continue\n",
    "\n",
    "            if \"stability\" not in REGISTRY[model.name][\"gpu-tasks\"]:\n",
    "                continue\n",
    "\n",
    "            try:\n",
    "                future = run_one.with_options(\n",
    "                    timeout_seconds=600, retries=2, refresh_cache=False\n",
    "                ).submit(atoms.copy(), model)\n",
    "                futures.append(future)\n",
    "            except:\n",
    "                continue\n",
    "\n",
    "    return [future.result(raise_on_failure=False) for future in futures]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compress.with_options(\n",
    "    task_runner=DaskTaskRunner(address=client.scheduler.address), log_prints=True\n",
    ")()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "NERSC Python",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
