{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from ase import units\n",
    "from dask.distributed import Client\n",
    "from dask_jobqueue import SLURMCluster\n",
    "from dotenv import load_dotenv\n",
    "from prefect import flow, task\n",
    "from prefect_dask import DaskTaskRunner\n",
    "\n",
    "from mlip_arena.models import REGISTRY, MLIPEnum\n",
    "from mlip_arena.tasks.md import run as MD\n",
    "from mlip_arena.tasks.stability.input import get_atoms_from_db\n",
    "\n",
    "load_dotenv()\n",
    "\n",
    "HF_TOKEN = os.environ.get(\"HF_TOKEN\", None)\n",
    "MP_API_KEY = os.environ.get(\"MP_API_KEY\", None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "nodes_per_alloc = 1\n",
    "gpus_per_alloc = 4\n",
    "ntasks = 1\n",
    "\n",
    "cluster_kwargs = dict(\n",
    "    cores=1,\n",
    "    memory=\"64 GB\",\n",
    "    processes=1,\n",
    "    shebang=\"#!/bin/bash\",\n",
    "    account=\"matgen\",\n",
    "    walltime=\"04:00:00\",\n",
    "    # job_cpu=128,\n",
    "    job_mem=\"0\",\n",
    "    job_script_prologue=[\n",
    "        \"source ~/.bashrc\",\n",
    "        \"module load python\",\n",
    "        \"source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena\",\n",
    "    ],\n",
    "    job_directives_skip=[\"-n\", \"--cpus-per-task\", \"-J\"],\n",
    "    job_extra_directives=[\n",
    "        \"-J stability-nvt\",\n",
    "        \"-q preempt\",\n",
    "        \"--time-min=00:30:00\",\n",
    "        \"--comment=12:00:00\",\n",
    "        f\"-N {nodes_per_alloc}\",\n",
    "        \"-C gpu\",\n",
    "        f\"-G {gpus_per_alloc}\",\n",
    "    ],\n",
    ")\n",
    "\n",
    "cluster = SLURMCluster(**cluster_kwargs)\n",
    "print(cluster.job_script())\n",
    "cluster.adapt(minimum_jobs=10, maximum_jobs=50)\n",
    "client = Client(cluster)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from prefect.cache_policies import INPUTS, TASK_SOURCE\n",
    "from prefect.futures import wait\n",
    "\n",
    "from mlip_arena.tasks.utils import get_calculator\n",
    "\n",
    "selected_models = [\n",
    "    \"MACE-MP(M)\",\n",
    "    \"CHGNet\",\n",
    "    \"M3GNet\",\n",
    "    \"MatterSim\",\n",
    "    \"eqV2(OMat)\",\n",
    "    \"MACE-MPA\",\n",
    "    \"ORBv2\",\n",
    "    \"SevenNet\",\n",
    "    \"ALIGNN\",\n",
    "]\n",
    "\n",
    "\n",
    "@task(cache_policy=TASK_SOURCE + INPUTS)\n",
    "def run_one(\n",
    "    atoms,\n",
    "    model,\n",
    "):\n",
    "    try:\n",
    "        result = MD.with_options(\n",
    "            # timeout_seconds=600,\n",
    "            # retries=1,\n",
    "            refresh_cache=True\n",
    "        )(\n",
    "            atoms=atoms,\n",
    "            calculator=get_calculator(\n",
    "                model.name,\n",
    "                calculator_kwargs=None,\n",
    "            ),\n",
    "            ensemble=\"nvt\",\n",
    "            dynamics=\"nose-hoover\",\n",
    "            time_step=None,\n",
    "            dynamics_kwargs=dict(\n",
    "                ttime=25 * units.fs,\n",
    "                # pfactor=((75 * units.fs) ** 2) * 1e2 * units.GPa\n",
    "            ),\n",
    "            total_time=1e4,  # 5e4, # fs\n",
    "            temperature=[300, 3000],\n",
    "            pressure=None,\n",
    "            traj_file=f\"{REGISTRY[model.name]['family']}/{model.name}_{atoms.info.get('material_id', 'random')}_{atoms.get_chemical_formula()}_nvt.traj\",\n",
    "            traj_interval=10,\n",
    "        )\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        return e\n",
    "\n",
    "    return result\n",
    "\n",
    "\n",
    "@flow\n",
    "def heat():\n",
    "    futures = []\n",
    "    # To download the database automatically, `huggingface_hub login` or provide HF_TOKEN\n",
    "    for atoms in get_atoms_from_db(\"random-mixture.db\", force_download=False):\n",
    "        for model in MLIPEnum:\n",
    "            if model.name not in selected_models:\n",
    "                continue\n",
    "\n",
    "            future = run_one.with_options(\n",
    "                timeout_seconds=600, retries=2, refresh_cache=False\n",
    "            ).submit(atoms.copy(), model)\n",
    "            futures.append(future)\n",
    "\n",
    "    wait(futures)\n",
    "\n",
    "    return [\n",
    "        f.result(timeout=None, raise_on_failure=False)\n",
    "        for f in futures\n",
    "        if f.state.is_completed()\n",
    "    ]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "heat.with_options(\n",
    "    task_runner=DaskTaskRunner(address=client.scheduler.address), log_prints=True\n",
    ")()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlip-arena",
   "language": "python",
   "name": "mlip-arena"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
