{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nnxTmN9WC3FI"
      },
      "source": [
        "* If you're running this on Google Colab, please uncomment and run the cell below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kWXMRG5BC3FJ"
      },
      "outputs": [],
      "source": [
        "# !pip install optax\n",
        "# !pip install flax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "KoC-hLN4Oliv"
      },
      "outputs": [],
      "source": [
        "%matplotlib inline\n",
        "import os\n",
        "import time\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import optax\n",
        "import matplotlib.pyplot as plt\n",
        "from tqdm import trange\n",
        "from jax import jvp, value_and_grad, random, jit, vmap, vjp\n",
        "from flax import linen as nn\n",
        "from typing import Sequence\n",
        "from functools import partial\n",
        "from jax.nn.initializers import  normal, ones, constant, uniform #truncated_normal,\n",
        "from functools import partial\n",
        "from typing import Any, Callable, Sequence, Tuple, Optional, Union, Dict\n",
        "\n",
        "from flax import linen as nn\n",
        "from flax.core.frozen_dict import freeze\n",
        "from pprint import pprint\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lOqc5OoSN_5M"
      },
      "source": [
        "## 1. PIG"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 46,
      "metadata": {
        "id": "hxn8cP_6DQFx"
      },
      "outputs": [],
      "source": [
        "\n",
        "class Gaussian3d(nn.Module):\n",
        "    num_gaussian: int = 100\n",
        "    grid_range: int = 2\n",
        "    sigmas_range : float = 0.5\n",
        "    mlp_dim: int= 4\n",
        "\n",
        "    def setup(self):\n",
        "        self.mu_x = self.param(\"mean_x\", uniform(self.grid_range), (self.mlp_dim, self.num_gaussian))\n",
        "        self.mu_y = self.param(\"mean_y\", uniform(self.grid_range), (self.mlp_dim, self.num_gaussian))\n",
        "        self.mu_z = self.param(\"mean_z\", uniform(self.grid_range), (self.mlp_dim, self.num_gaussian))\n",
        "        self.sigmas = self.param(\"sigmas\", constant(self.sigmas_range), (self.mlp_dim, self.num_gaussian, 3))\n",
        "        self.weight = self.param(\"weight\", normal(), (self.mlp_dim, self.num_gaussian, 1))\n",
        "\n",
        "    @nn.compact\n",
        "    def __call__(self, x, y, z):\n",
        "        sigmas_x = self.sigmas[:,:,0]\n",
        "        sigmas_y = self.sigmas[:,:,1]\n",
        "        sigmas_z = self.sigmas[:,:,2]\n",
        "\n",
        "        x = (x/10.) * self.grid_range\n",
        "        y = ((y+1.)/2.) * self.grid_range\n",
        "        z = ((z+1.)/2.) * self.grid_range\n",
        "\n",
        "        pdf = 0.5 * (((x[None,:,:] - self.mu_x[:,None,:])/sigmas_x[:,None,:])**2 + ((y[None,:,:] - self.mu_y[:,None,:])/sigmas_y[:,None,:])**2 + ((z[None,:,:] - self.mu_z[:,None,:])/sigmas_z[:,None,:])**2)\n",
        "        pdf = jnp.exp(-pdf)\n",
        "\n",
        "        rasterized_color_primes = pdf * self.weight.squeeze()[:, None, :]\n",
        "\n",
        "        output = rasterized_color_primes.sum(2)\n",
        "\n",
        "        return output.T#.reshape(-1, 1)\n",
        "\n",
        "\n",
        "class PINN3d(nn.Module):\n",
        "    features: Sequence[int]\n",
        "    num_gaussian: int = 100\n",
        "    grid_range: int = 2\n",
        "    sigmas_range : float = 0.5\n",
        "    mlp_dim: int= 4\n",
        "\n",
        "    @nn.compact\n",
        "    def __call__(self, x, y, z):\n",
        "        X = Gaussian3d(self.num_gaussian, self.grid_range, self.sigmas_range, self.mlp_dim)(x,y,z)\n",
        "\n",
        "        init = nn.initializers.glorot_normal()\n",
        "        for fs in self.features[:-1]:\n",
        "            X = nn.Dense(fs, kernel_init=init)(X)\n",
        "            X = nn.activation.tanh(X)\n",
        "        X = nn.Dense(self.features[-1], kernel_init=init)(X)\n",
        "\n",
        "        return X\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 47,
      "metadata": {
        "id": "3lmf86_ON_5N"
      },
      "outputs": [],
      "source": [
        "# hessian-vector product\n",
        "def hvp_fwdrev(f, primals, tangents, return_primals=False):\n",
        "    g = lambda primals: vjp(f, primals)[1](tangents[0])[0]\n",
        "    primals_out, tangents_out = jvp(g, primals, tangents)\n",
        "    if return_primals:\n",
        "        return primals_out, tangents_out\n",
        "    else:\n",
        "        return tangents_out\n",
        "\n",
        "\n",
        "# loss function\n",
        "def pinn_loss_klein_gordon3d(apply_fn, *train_data):\n",
        "    def residual_loss(params, t, x, y, source_term):\n",
        "        # compute u\n",
        "        u = apply_fn(params, t, x, y)\n",
        "        # tangent vector du/du\n",
        "        v = jnp.ones(u.shape)\n",
        "        # 2nd derivatives of u\n",
        "        utt = hvp_fwdrev(lambda t: apply_fn(params, t, x, y), (t,), (v,))\n",
        "        uxx = hvp_fwdrev(lambda x: apply_fn(params, t, x, y), (x,), (v,))\n",
        "        uyy = hvp_fwdrev(lambda y: apply_fn(params, t, x, y), (y,), (v,))\n",
        "        return jnp.mean((utt - uxx - uyy + u**2 - source_term)**2)\n",
        "\n",
        "    def initial_boundary_loss(params, t, x, y, u):\n",
        "        return jnp.mean((apply_fn(params, t, x, y) - u)**2)\n",
        "\n",
        "    # unpack data\n",
        "    tc, xc, yc, uc, ti, xi, yi, ui, tb, xb, yb, ub = train_data\n",
        "\n",
        "    # isolate loss function from redundant arguments\n",
        "    fn = lambda params: 0.1 * residual_loss(params, tc, xc, yc, uc) + \\\n",
        "                        initial_boundary_loss(params, ti, xi, yi, ui) + \\\n",
        "                        initial_boundary_loss(params, tb, xb, yb, ub)\n",
        "\n",
        "    return fn\n",
        "\n",
        "\n",
        "# optimizer step function\n",
        "@partial(jax.jit, static_argnums=(0,))\n",
        "def update_model(optim, gradient, params, state):\n",
        "    updates, state = optim.update(gradient, state)\n",
        "    params = optax.apply_updates(params, updates)\n",
        "    # print(params)\n",
        "    return params, state"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y3OErz7bN_5O"
      },
      "source": [
        "## 2. Data generator"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 48,
      "metadata": {
        "id": "VVY7wtfBN_5O"
      },
      "outputs": [],
      "source": [
        "# 2d time-dependent klein-gordon exact u\n",
        "def _klein_gordon3d_exact_u(t, x, y):\n",
        "    return (x + y) * jnp.cos(2*t) + (x * y) * jnp.sin(2*t)\n",
        "\n",
        "\n",
        "# 2d time-dependent klein-gordon source term\n",
        "def _klein_gordon3d_source_term(t, x, y):\n",
        "    u = _klein_gordon3d_exact_u(t, x, y)\n",
        "    return u**2 - 4*u\n",
        "\n",
        "\n",
        "# train data\n",
        "def pinn_train_generator_klein_gordon3d(nc, ni, nb, key):\n",
        "    keys = jax.random.split(key, 13)\n",
        "    # collocation points\n",
        "    tc = jax.random.uniform(keys[0], (nc, 1), minval=0., maxval=10.)\n",
        "    xc = jax.random.uniform(keys[1], (nc, 1), minval=-1., maxval=1.)\n",
        "    yc = jax.random.uniform(keys[2], (nc, 1), minval=-1., maxval=1.)\n",
        "    uc = _klein_gordon3d_source_term(tc, xc, yc)\n",
        "    # initial points\n",
        "    ti = jnp.zeros((ni, 1))\n",
        "    xi = jax.random.uniform(keys[3], (ni, 1), minval=-1., maxval=1.)\n",
        "    yi = jax.random.uniform(keys[4], (ni, 1), minval=-1., maxval=1.)\n",
        "    ui = _klein_gordon3d_exact_u(ti, xi, yi)\n",
        "    # boundary points (hard-coded)\n",
        "    tb = [\n",
        "        jax.random.uniform(keys[5], (nb, 1), minval=0., maxval=10.),\n",
        "        jax.random.uniform(keys[6], (nb, 1), minval=0., maxval=10.),\n",
        "        jax.random.uniform(keys[7], (nb, 1), minval=0., maxval=10.),\n",
        "        jax.random.uniform(keys[8], (nb, 1), minval=0., maxval=10.)\n",
        "    ]\n",
        "    xb = [\n",
        "        jnp.array([[-1.]]*nb),\n",
        "        jnp.array([[1.]]*nb),\n",
        "        jax.random.uniform(keys[9], (nb, 1), minval=-1., maxval=1.),\n",
        "        jax.random.uniform(keys[10], (nb, 1), minval=-1., maxval=1.)\n",
        "    ]\n",
        "    yb = [\n",
        "        jax.random.uniform(keys[11], (nb, 1), minval=-1., maxval=1.),\n",
        "        jax.random.uniform(keys[12], (nb, 1), minval=-1., maxval=1.),\n",
        "        jnp.array([[-1.]]*nb),\n",
        "        jnp.array([[1.]]*nb)\n",
        "    ]\n",
        "    ub = []\n",
        "    for i in range(4):\n",
        "        ub += [_klein_gordon3d_exact_u(tb[i], xb[i], yb[i])]\n",
        "    tb = jnp.concatenate(tb)\n",
        "    xb = jnp.concatenate(xb)\n",
        "    yb = jnp.concatenate(yb)\n",
        "    ub = jnp.concatenate(ub)\n",
        "    return tc, xc, yc, uc, ti, xi, yi, ui, tb, xb, yb, ub\n",
        "\n",
        "\n",
        "# test data\n",
        "def pinn_test_generator_klein_gordon3d(nc_test):\n",
        "    t = jnp.linspace(0, 10, nc_test)\n",
        "    x = jnp.linspace(-1, 1, nc_test)\n",
        "    y = jnp.linspace(-1, 1, nc_test)\n",
        "    t = jax.lax.stop_gradient(t)\n",
        "    x = jax.lax.stop_gradient(x)\n",
        "    y = jax.lax.stop_gradient(y)\n",
        "    tm, xm, ym = jnp.meshgrid(t, x, y, indexing='ij')\n",
        "    u_gt = _klein_gordon3d_exact_u(tm, xm, ym)\n",
        "    t = tm.reshape(-1, 1)\n",
        "    x = xm.reshape(-1, 1)\n",
        "    y = ym.reshape(-1, 1)\n",
        "    u_gt = u_gt.reshape(-1, 1)\n",
        "    return t, x, y, u_gt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wEWeH3ZFN_5P"
      },
      "source": [
        "## 3. Utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 49,
      "metadata": {
        "id": "cLX1oaDUN_5P"
      },
      "outputs": [],
      "source": [
        "def relative_l2(u, u_gt):\n",
        "    return jnp.linalg.norm(u-u_gt) / jnp.linalg.norm(u_gt)\n",
        "\n",
        "def plot_klein_gordon3d(t, x, y, u, error_list):\n",
        "    fig = plt.figure(figsize=(6, 6))\n",
        "    ax = fig.add_subplot(111, projection='3d')\n",
        "    ax.scatter(t, x, y, c=u, s=0.5, cmap='seismic')\n",
        "    ax.set_title('U(t, x, y)', fontsize=20, pad=-5)\n",
        "    ax.set_xlabel('t', fontsize=18, labelpad=10)\n",
        "    ax.set_ylabel('x', fontsize=18, labelpad=10)\n",
        "    ax.set_zlabel('y', fontsize=18, labelpad=10)\n",
        "    plt.show()\n",
        "\n",
        "    fig = plt.figure(figsize= (6,6))\n",
        "    plt.plot(error_list, c= 'r')\n",
        "    plt.yscale('log')\n",
        "    plt.show()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9Q3WgLq_N_5P"
      },
      "source": [
        "## 4. Main function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 53,
      "metadata": {
        "id": "VHtJazHuN_5Q"
      },
      "outputs": [],
      "source": [
        "def main(N_LAYERS, FEATURES, NC, NI, NB, NC_TEST, SEED, LR, EPOCHS, num_gaussian, grid_range, sigmas_range, mlp_dim, LOG_ITER):\n",
        "    # force jax to use one device\n",
        "    os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
        "    os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n",
        "\n",
        "    # random key\n",
        "    key = jax.random.PRNGKey(SEED)\n",
        "    key, subkey = jax.random.split(key, 2)\n",
        "\n",
        "    # feature sizes\n",
        "    feat_sizes = tuple([FEATURES for _ in range(N_LAYERS - 1)] + [1])\n",
        "\n",
        "    # make & init model\n",
        "    model = PINN3d(feat_sizes, num_gaussian, grid_range, sigmas_range, mlp_dim)\n",
        "    # model = PINN(feat_sizes)\n",
        "    params = model.init(subkey, jnp.ones((NC, 1)), jnp.ones((NC, 1)), jnp.ones((NC, 1)))\n",
        "    # optimizer\n",
        "    optim = optax.adam(LR)\n",
        "    state = optim.init(params)\n",
        "\n",
        "    # dataset\n",
        "    key, subkey = jax.random.split(key, 2)\n",
        "    train_data = pinn_train_generator_klein_gordon3d(NC, NI, NB, subkey)\n",
        "    t, x, y, u_gt = pinn_test_generator_klein_gordon3d(NC_TEST)\n",
        "\n",
        "    # forward & loss function\n",
        "    apply_fn = jax.jit(model.apply)\n",
        "    loss_fn = pinn_loss_klein_gordon3d(apply_fn, *train_data)\n",
        "\n",
        "    @jax.jit\n",
        "    def train_one_step(params, state):\n",
        "        # compute loss and gradient\n",
        "        loss, gradient = value_and_grad(loss_fn)(params)\n",
        "        # update state\n",
        "        params, state = update_model(optim, gradient, params, state)\n",
        "        return loss, params, state\n",
        "    error_list = []\n",
        "    start = time.time()\n",
        "    for e in trange(1, EPOCHS+1):\n",
        "        # single run\n",
        "        loss, params, state = train_one_step(params, state)\n",
        "        error = relative_l2(apply_fn(params, t, x, y), u_gt)\n",
        "        if e % LOG_ITER == 0 or e==1:\n",
        "            print(f'Epoch: {e}/{EPOCHS} --> loss: {loss:.8f}, error: {error:.8f}')\n",
        "        error_list.append(error)\n",
        "    end = time.time()\n",
        "    print(f'Runtime: {((end-start)/EPOCHS*1000):.2f} ms/iter.')\n",
        "\n",
        "    print('Solution:')\n",
        "    u = apply_fn(params, t, x, y)\n",
        "    plot_klein_gordon3d(t, x, y, u, error_list)\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eCdzoogAN_5Q"
      },
      "source": [
        "## 5. Run!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "j-DGXwqYN_5Q",
        "outputId": "114eeb46-b72c-4b33-990d-0c14e30d5c11"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "  0%|          | 12/50000 [00:06<5:47:03,  2.40it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 1/50000 --> loss: 2.22182012, error: 1.01501584\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "  2%|▏         | 1013/50000 [00:19<09:59, 81.72it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 1000/50000 --> loss: 0.00078357, error: 0.04038550\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "  4%|▍         | 2013/50000 [00:32<10:05, 79.26it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 2000/50000 --> loss: 0.00042715, error: 0.02812523\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "  6%|▌         | 3015/50000 [00:45<10:21, 75.59it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 3000/50000 --> loss: 0.00031845, error: 0.01923372\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "  8%|▊         | 4015/50000 [00:59<10:14, 74.82it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 4000/50000 --> loss: 0.00020300, error: 0.01641631\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            " 10%|█         | 5015/50000 [01:12<09:52, 75.93it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 5000/50000 --> loss: 0.00040753, error: 0.01412683\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            " 12%|█▏        | 6013/50000 [01:25<09:06, 80.42it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 6000/50000 --> loss: 0.00009787, error: 0.01000255\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            " 14%|█▍        | 7013/50000 [01:38<08:56, 80.08it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 7000/50000 --> loss: 0.00008948, error: 0.00983457\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            " 16%|█▌        | 8013/50000 [01:51<08:40, 80.64it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 8000/50000 --> loss: 0.00027596, error: 0.01631125\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            " 18%|█▊        | 9013/50000 [02:04<08:33, 79.85it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 9000/50000 --> loss: 0.00004546, error: 0.00785934\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            " 19%|█▊        | 9311/50000 [02:08<08:52, 76.38it/s]"
          ]
        }
      ],
      "source": [
        "main(N_LAYERS=2, FEATURES=16, NC=16**3, NI=16**2, NB=16**2, NC_TEST=100, SEED=444, LR= 0.01, EPOCHS=50000, num_gaussian=100, grid_range=2, sigmas_range = 0.5, mlp_dim=4, LOG_ITER=1000)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VLCGWyVNC3FP"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.6"
    },
    "vscode": {
      "interpreter": {
        "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}