{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "import equinox as eqx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import exponax as ex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TaylorGreenVorticity(eqx.Module):\n",
    "    nu: float\n",
    "\n",
    "    def __init__(self, domain_extent, diffusivity):\n",
    "        if domain_extent != (2 * jnp.pi):\n",
    "            raise ValueError(\"Domain extent must be 2 * pi\")\n",
    "        self.nu = diffusivity\n",
    "\n",
    "    def __call__(self, t, x):\n",
    "        f_term = jnp.exp(-2 * self.nu * t)\n",
    "        vorticity = 2 * jnp.sin(x[0:1]) * jnp.cos(x[1:2]) * f_term\n",
    "\n",
    "        return vorticity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-20 10:14:47.179970: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.52). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
     ]
    }
   ],
   "source": [
    "grid = ex.make_grid(2, 2 * jnp.pi, 60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "tg = TaylorGreenVorticity(2 * jnp.pi, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "ic = tg(0.0, grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ns_stepper = ex.stepper.NavierStokesVorticity(\n",
    "    2,\n",
    "    2 * jnp.pi,\n",
    "    60,\n",
    "    0.1,\n",
    "    diffusivity=0.1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rel_error(pred, ref):\n",
    "    diff_norm = jnp.linalg.norm(pred - ref)\n",
    "    ref_norm = jnp.linalg.norm(ref)\n",
    "    return diff_norm / ref_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(2.3546949e-07, dtype=float32)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rel_error(ns_stepper(ic), tg(0.1, grid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(1.3551877e-06, dtype=float32)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rel_error(ex.repeat(ns_stepper, 10)(ic), tg(1.0, grid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(1.2082977e-05, dtype=float32)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rel_error(ex.repeat(ns_stepper, 100)(ic), tg(10.0, grid))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax_fresh",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
