{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.multiprocessing as mp\n",
    "\n",
    "from burgers.approximator import BurgersApproximator\n",
    "from GNRK.experiment import run\n",
    "from GNRK.hyperparameter import get_hp\n",
    "from GNRK.path import DATA_DIR\n",
    "from heat.approximator import HeatApproximator\n",
    "from kuramoto.approximator import KuramotoApproximator\n",
    "from rossler.approximator import RosslerApproximator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hp = get_hp(\n",
    "    [\n",
    "        # ----------------- Data -----------------\n",
    "        \"--equation=burgers\",\n",
    "        \"--dataset=burgers_dataset1\",\n",
    "        # ------------------ NN ------------------\n",
    "        \"--rk=RK4\",\n",
    "        \"--approximator_state_embedding\", \"8\",\n",
    "        # \"--approximator_node_embedding\", \"16\",\n",
    "        \"--approximator_edge_embedding\", \"8\",\n",
    "        \"--approximator_glob_embedding\", \"8\",\n",
    "        \"--approximator_edge_hidden=32\",\n",
    "        \"--approximator_node_hidden=32\",\n",
    "        \"--approximator_activation=gelu\",\n",
    "        \"--approximator_dropout=0.0\",\n",
    "        \"--approximator_bn_momentum=-1.0\",\n",
    "        # --------------- Schedular --------------\n",
    "        \"--scheduler_name=step\",\n",
    "        \"--scheduler_lr=0.0001\",\n",
    "        \"--scheduler_lr_max=0.004\",\n",
    "        \"--scheduler_lr_max_mult=0.5\",\n",
    "        \"--scheduler_period=20\",\n",
    "        \"--scheduler_period_mult=1.5\",\n",
    "        \"--scheduler_warmup=0\",\n",
    "        # -------------- Early Stop --------------\n",
    "        # \"--earlystop_patience=60\",\n",
    "        # \"--earlystop_delta=0.0\",\n",
    "        # ------------ Train config --------------\n",
    "        \"--weight_decay=0.0\",\n",
    "        \"--device\", \"0\", \"1\", \"2\", \"3\",\n",
    "        # \"--seed=0\",\n",
    "        \"--port=3184\",\n",
    "        \"--epochs=2\",\n",
    "        \"--batch_size=64\",\n",
    "        # \"--rollout_batch_size=256\",\n",
    "        \"--tqdm\",\n",
    "        # \"--wandb\",\n",
    "        \"--amp\",\n",
    "    ]\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.perf_counter()\n",
    "train_df = pd.read_pickle(DATA_DIR / f\"{hp.dataset}_train.pkl\")\n",
    "val_df = pd.read_pickle(DATA_DIR / f\"{hp.dataset}_val.pkl\")\n",
    "print(f\"Reading data took {time.perf_counter()-start} seconds\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create governing equation approximator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "match hp.equation:\n",
    "    case \"burgers\":\n",
    "        approximator = BurgersApproximator.from_hp(hp.approximator)\n",
    "    case \"heat\":\n",
    "        approximator = HeatApproximator.from_hp(hp.approximator)\n",
    "    case \"kuramoto\":\n",
    "        approximator = KuramotoApproximator.from_hp(hp.approximator)\n",
    "    case \"rossler\":\n",
    "        approximator = RosslerApproximator.from_hp(hp.approximator)\n",
    "    case _:\n",
    "        raise NotImplementedError(f\"No such equation {hp.equation}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.perf_counter()\n",
    "save = False\n",
    "\n",
    "if len(hp.device) == 1:\n",
    "    run(0, hp, approximator, train_df, val_df, save)\n",
    "else:\n",
    "    os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
    "    os.environ[\"MASTER_PORT\"] = f\"{hp.port}\"\n",
    "\n",
    "    mp.spawn(  # type:ignore\n",
    "        run,\n",
    "        args=(hp, approximator, train_df, val_df, save),\n",
    "        nprocs=len(hp.device),\n",
    "        join=True,\n",
    "    )\n",
    "\n",
    "print(f\"Training took {time.perf_counter()-start} seconds\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
