{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bf94d59d",
   "metadata": {
    "id": "bf94d59d"
   },
   "source": [
    "Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "09088a1a",
   "metadata": {
    "id": "09088a1a"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[CudaDevice(id=0)]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "import flax\n",
    "import optax\n",
    "from jax import lax, random, numpy as jnp\n",
    "from jax import random, grad, vmap, hessian, jacfwd, jit\n",
    "from jax import config\n",
    "from flax import linen as nn\n",
    "from evojax.util import get_params_format_fn\n",
    "from scipy import io\n",
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "from jax.scipy.linalg import solve\n",
    "# choose GPU\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "jax.config.update(\"jax_enable_x64\", True)\n",
    "n_gpu = len(jax.devices())\n",
    "jax.devices()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31fe584f",
   "metadata": {
    "id": "31fe584f"
   },
   "source": [
    "    1D Burgers' equation\n",
    "\n",
    "        u_t + u*u_x = v*u_xx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "744088e9",
   "metadata": {
    "id": "744088e9"
   },
   "outputs": [],
   "source": [
    "# read data\n",
    "x_train = jnp.load(os.path.join('vbgs_input.npy'))\n",
    "y_train_all = jnp.load(os.path.join('vbgs_label.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7dd30db-7f00-4a44-b9e8-5906311ddac9",
   "metadata": {
    "id": "e7dd30db-7f00-4a44-b9e8-5906311ddac9",
    "outputId": "e73d21b2-d4bf-490e-d570-156479d6b5a6"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0 1.0 -1.0 1.0\n",
      "(6579, 2) (50, 6579, 1) (229, 2) (50, 229, 1)\n"
     ]
    }
   ],
   "source": [
    "dert_x = 1\n",
    "dert_t = 1\n",
    "seed_data = 10  ## fixed\n",
    "seed_model = 10  ## tunning   10, 20, 30, 40, 50\n",
    "\n",
    "n_all_task = y_train_all.shape[0]\n",
    "n_train = 16  ###  5,10,20,40,60,80,100  default 100\n",
    "n_test = n_all_task - n_train\n",
    "r_task = 8  ###  8 for single; 32 for window\n",
    "n_nodes = 128\n",
    "max_iters = 5000\n",
    "max_lr = 5e-3\n",
    "\n",
    "x_step = 129\n",
    "t_step = 51\n",
    "nu_list = jnp.linspace(0.001, 0.05, 50)\n",
    "n_task = len(nu_list)\n",
    "\n",
    "x_train = x_train.T.reshape(2, t_step, -1)[:, ::dert_t, ::dert_x].reshape(2, -1).T\n",
    "y_train_all = y_train_all.reshape(n_task , t_step, -1)[:, ::dert_t, ::dert_x].reshape(n_task , -1, 1)\n",
    "\n",
    "t_l, t_u, x_l, x_u = np.min(x_train[:,1]), np.max(x_train[:,1]), np.min(x_train[:,0]), np.max(x_train[:,0])\n",
    "ext = [t_l, t_u, x_l, x_u]\n",
    "print (t_l, t_u, x_l, x_u)\n",
    "\n",
    "x_dim = math.floor((x_step - 1) / dert_x) + 1 #251   63   126\n",
    "t_dim = math.floor((t_step - 1) / dert_t) + 1 #251   63   126\n",
    "\n",
    "# split into BC / IC data\n",
    "bic = (x_train[:,1] == 0) | (x_train[:,0] == x_l) | (x_train[:,0] == x_u)\n",
    "bcl = (x_train[:,0] == x_l)\n",
    "bcu = (x_train[:,0] == x_u)\n",
    "ic = (x_train[:,1] == 0)\n",
    "data_X_BIC, data_Y_BIC = x_train[bic], y_train_all[:,bic]\n",
    "print (x_train.shape, y_train_all.shape, data_X_BIC.shape, data_Y_BIC.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "200ecf39-64a7-40b6-8167-820ca4473ad8",
   "metadata": {
    "id": "200ecf39-64a7-40b6-8167-820ca4473ad8",
    "outputId": "72f158e9-f5de-46bb-df4d-8997d80eaae1"
   },
   "outputs": [],
   "source": [
    "# convert to jnp\n",
    "x_train, y_train_all, data_X_BIC, data_Y_BIC, nu_list = jnp.array(x_train), jnp.array(y_train_all), jnp.array(data_X_BIC), jnp.array(data_Y_BIC)\\\n",
    "    ,jnp.array(nu_list)\n",
    "\n",
    "key, rng = random.split(random.PRNGKey(seed_data))\n",
    "key, rng = random.split(rng) # update random generator\n",
    "n_total = y_train_all.shape[0]\n",
    "\n",
    "all_indices = random.permutation(key, jnp.arange(n_total))\n",
    "test_task_list = all_indices[:n_test]\n",
    "train_task_list = all_indices[n_test:(n_test + n_train)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8WUvS8L9EeZ5",
   "metadata": {
    "id": "8WUvS8L9EeZ5"
   },
   "outputs": [],
   "source": [
    "# for derivatives in used in PI prediction\n",
    "class PINN(nn.Module):\n",
    "    \"\"\"PINNs\"\"\"\n",
    "    def setup(self):\n",
    "        self.layers = [nn.Dense(n_nodes, kernel_init = jax.nn.initializers.he_uniform()),\n",
    "                       jnp.tanh,\n",
    "                       nn.Dense(n_nodes, kernel_init = jax.nn.initializers.he_uniform()),\n",
    "                       jnp.tanh,\n",
    "                       nn.Dense(n_nodes, kernel_init = jax.nn.initializers.he_uniform()),\n",
    "                       jnp.tanh,\n",
    "                       nn.Dense(n_nodes, kernel_init = jax.nn.initializers.he_uniform()),\n",
    "                       jnp.tanh]\n",
    "        self.last_layer = nn.Dense(1, kernel_init = jax.nn.initializers.he_uniform(), use_bias=False)\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, inputs):\n",
    "        # split the two variables, probably just by slicing\n",
    "        x, t, v = inputs[:,0:1], inputs[:,1:2], inputs[:,2:3]\n",
    "        def get_u(x, t, v):\n",
    "            f = jnp.hstack([x, t, v /5.])\n",
    "            fs = []\n",
    "            for i, lyr in enumerate(self.layers):\n",
    "                f = lyr(f)\n",
    "                if (i == 0):\n",
    "                    f = 2 *jnp.pi *f\n",
    "            #     if (i%2 != 0):\n",
    "            #         fs.append(f)\n",
    "            # f = jnp.hstack(fs)\n",
    "            u = self.last_layer(f)\n",
    "            return u, f\n",
    "\n",
    "        u, f = get_u(x, t, v)\n",
    "\n",
    "        def get_f_dir(get_u, x, t, v):\n",
    "            _, f_xx = jacfwd(jacfwd(get_u))(x, t, v)\n",
    "            _, f_x = jacfwd(get_u)(x, t, v)\n",
    "            _, f_t = jacfwd(get_u, argnums=1)(x, t, v)\n",
    "            return f_xx, f_x, f_t\n",
    "\n",
    "        f_dir_vmap = vmap(get_f_dir, in_axes=(None, 0, 0, 0))\n",
    "        f_xx, f_x, f_t = f_dir_vmap(get_u, x, t, v)\n",
    "        f_xx, f_x, f_t = f_xx[:,:,0,0], f_x[:,:,0], f_t[:,:,0]\n",
    "\n",
    "        outputs = jnp.hstack([u, f, f_xx, f_x, f_t])\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "1389a078-0613-44df-a52a-e5136a61f390",
   "metadata": {
    "id": "1389a078-0613-44df-a52a-e5136a61f390"
   },
   "outputs": [],
   "source": [
    "# for training and direct prediction\n",
    "class DNN(nn.Module):\n",
    "    \"\"\"DNNs\"\"\"\n",
    "    def setup(self):\n",
    "        self.layers = [nn.Dense(n_nodes, kernel_init = jax.nn.initializers.he_uniform()),\n",
    "                       jnp.tanh,\n",
    "                       nn.Dense(n_nodes, kernel_init = jax.nn.initializers.he_uniform()),\n",
    "                       jnp.tanh,\n",
    "                       nn.Dense(n_nodes, kernel_init = jax.nn.initializers.he_uniform()),\n",
    "                       jnp.tanh,\n",
    "                       nn.Dense(n_nodes, kernel_init = jax.nn.initializers.he_uniform()),\n",
    "                       jnp.tanh]\n",
    "        self.last_layer = nn.Dense(1, kernel_init = jax.nn.initializers.he_uniform(), use_bias=False)\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, inputs):\n",
    "        # split the two variables, probably just by slicing\n",
    "        x, t, v = inputs[:,0:1], inputs[:,1:2], inputs[:,2:3]\n",
    "        def get_u(x, t, v):\n",
    "            f = jnp.hstack([x, t, v /5.])\n",
    "            fs = []\n",
    "            for i, lyr in enumerate(self.layers):\n",
    "                f = lyr(f)\n",
    "                if (i == 0):\n",
    "                    f = 2 *jnp.pi *f\n",
    "            #     if (i%2 != 0):\n",
    "            #         fs.append(f)\n",
    "            # f = jnp.hstack(fs)\n",
    "            u = self.last_layer(f)\n",
    "            return u\n",
    "\n",
    "        u = get_u(x, t, v)\n",
    "\n",
    "        outputs = u\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "71970aed-3347-4317-895e-9a606971888b",
   "metadata": {
    "id": "71970aed-3347-4317-895e-9a606971888b",
    "outputId": "c63e689c-f0b3-4f3a-bb29-5d673b0f245d"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.int64(50176)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# choose seed\n",
    "key, rng = random.split(random.PRNGKey(seed_model))\n",
    "\n",
    "# dummy input\n",
    "a = random.normal(key, [1,3])\n",
    "\n",
    "# initialization call\n",
    "model = DNN()\n",
    "params = model.init(key, a)\n",
    "num_params, format_params_fn = get_params_format_fn(params)\n",
    "\n",
    "# flatten initial params\n",
    "params = jax.flatten_util.ravel_pytree(params)[0]\n",
    "num_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "cbb3592f-f9e7-4d5e-afa0-d3bebc81d460",
   "metadata": {
    "id": "cbb3592f-f9e7-4d5e-afa0-d3bebc81d460",
    "outputId": "ba12603f-816e-4d9d-c59e-39ad4e37afe2"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50178,)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hyper-parameter \"lambs\" via SGD - append params to include 2 more parameters\n",
    "key, rng = random.split(rng) # update random generator\n",
    "ini_a = random.normal(key, [1,2])\n",
    "params = jnp.append(params, ini_a)\n",
    "params.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "bee134fb-897e-4fb7-984c-27340e2ffc06",
   "metadata": {
    "id": "bee134fb-897e-4fb7-984c-27340e2ffc06"
   },
   "outputs": [],
   "source": [
    "# loss function\n",
    "def eval_loss(params, task):\n",
    "    inputs, labels, v = x_train, y_train_all[task], nu_list[task]\n",
    "    v = jnp.tile(v.reshape(-1, 1), (1, len(inputs))).T\n",
    "    inputs = jnp.hstack([inputs, v])\n",
    "    u = model.apply(format_params_fn(params[:-2]), inputs)[:,0:1]\n",
    "    mse = jnp.mean(jnp.square(labels - u))\n",
    "    rl2 = jnp.linalg.norm(labels - u) / jnp.linalg.norm(labels)\n",
    "    loss = mse\n",
    "    return loss, (mse, mse, rl2)\n",
    "\n",
    "loss_grad = jax.jit(jax.value_and_grad(eval_loss, has_aux=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "296cd196-4065-4e56-8863-fd3d1e9fc0f2",
   "metadata": {
    "id": "296cd196-4065-4e56-8863-fd3d1e9fc0f2"
   },
   "outputs": [],
   "source": [
    "# weights update\n",
    "@jit\n",
    "def update(params, opt_state, key):\n",
    "    loss_all, ssr_all, mse_all, rl2_all, grad_all = 0, 0, 0, 0, 0\n",
    "    batch_train_task = random.choice(key, train_task_list, (r_task,), replace=False)\n",
    "    for task in batch_train_task:\n",
    "        (loss, (ssr, mse, rl2)), grad = loss_grad(params, task) # calculate for different tasks and sum grad\n",
    "        loss_all += loss\n",
    "        ssr_all += ssr\n",
    "        mse_all += mse\n",
    "        rl2_all += rl2\n",
    "        grad_all += grad\n",
    "    loss_all /= r_task\n",
    "    ssr_all /= r_task\n",
    "    mse_all /= r_task\n",
    "    rl2_all /= r_task\n",
    "    grad_all /= r_task\n",
    "    updates, opt_state = optimizer.update(grad_all, opt_state)\n",
    "    params = optax.apply_updates(params, updates)\n",
    "    return params, opt_state, loss_all, ssr_all, mse_all, rl2_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "fe019e01-0083-46fc-9e41-c82d2dc3cda7",
   "metadata": {
    "id": "fe019e01-0083-46fc-9e41-c82d2dc3cda7"
   },
   "outputs": [],
   "source": [
    "# optimizer\n",
    "max_iters = 5000\n",
    "max_lr = 5e-2 # 1e-2 at first\n",
    "lr_scheduler = optax.warmup_cosine_decay_schedule(init_value=max_lr, peak_value=max_lr, warmup_steps=int(max_iters*.4),\n",
    "                                                  decay_steps=max_iters, end_value=1e-6)\n",
    "optimizer = optax.adam(learning_rate=lr_scheduler) # Choose the method\n",
    "opt_state = optimizer.init(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "CNdJda_qF7UO",
   "metadata": {
    "id": "CNdJda_qF7UO"
   },
   "source": [
    "Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "e5b49b20-d73f-4925-be58-b930bf26d0c8",
   "metadata": {
    "id": "e5b49b20-d73f-4925-be58-b930bf26d0c8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter. = 00000,  time = 002s,  loss = 4.27e-01  |  ssr = 4.27e-01,  mse = 4.27e-01,  rl2 = 1.15e+00\n",
      "iter. = 01000,  time = 011s,  loss = 6.48e-03  |  ssr = 6.48e-03,  mse = 6.48e-03,  rl2 = 1.38e-01\n",
      "iter. = 02000,  time = 019s,  loss = 2.93e-03  |  ssr = 2.93e-03,  mse = 2.93e-03,  rl2 = 9.39e-02\n",
      "iter. = 03000,  time = 028s,  loss = 2.97e-04  |  ssr = 2.97e-04,  mse = 2.97e-04,  rl2 = 3.11e-02\n",
      "iter. = 04000,  time = 036s,  loss = 1.28e-04  |  ssr = 1.28e-04,  mse = 1.28e-04,  rl2 = 1.92e-02\n",
      "iter. = 05000,  time = 044s,  loss = 1.35e-04  |  ssr = 1.35e-04,  mse = 1.35e-04,  rl2 = 2.01e-02\n"
     ]
    }
   ],
   "source": [
    "# training iteration\n",
    "runtime = 0\n",
    "train_iters = 0\n",
    "\n",
    "store = []\n",
    "\n",
    "while (train_iters <= max_iters) and (runtime < 1200):\n",
    "    # mini-batch update\n",
    "    start = time.time()\n",
    "    key, rng = random.split(rng) # update random generator\n",
    "    params, opt_state, loss, ssr, mse, rl2 = update(params, opt_state, key)\n",
    "    end = time.time()\n",
    "    runtime += (end-start)\n",
    "    # append weights\n",
    "    if (train_iters % 1000 == 0):\n",
    "        print ('iter. = %05d,  time = %03ds,  loss = %.2e  |  ssr = %.2e,  mse = %.2e,  rl2 = %.2e'%(train_iters, runtime, loss, ssr, mse, rl2))\n",
    "        store.append([train_iters, runtime, loss, ssr, mse, rl2])\n",
    "    train_iters += 1\n",
    "\n",
    "store = jnp.array(store)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d802e0c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "##   save weight\n",
    "import pickle\n",
    "import jax\n",
    "\n",
    "params_np = jax.device_get(params)\n",
    "# store_np = jax.device_get(store)\n",
    "with open(f\"DNN model/Fixed_IC_v_50_mse_n_train_{n_train}_seed_{seed_model}.pkl\", \"wb\") as f:\n",
    "    pickle.dump(params_np, f)\n",
    "# with open(\"try-Burgurs_newton_20_train_store.pkl\", \"wb\") as f:\n",
    "#     pickle.dump(store_np, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "863044cb-ed9d-4f51-8da7-7c41717fab6e",
   "metadata": {
    "id": "863044cb-ed9d-4f51-8da7-7c41717fab6e"
   },
   "source": [
    "Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "825b7706-28a9-42c0-80ea-c296c9870e38",
   "metadata": {
    "id": "825b7706-28a9-42c0-80ea-c296c9870e38"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Task 025:  MSE = 8.06e-05  RL2 = 1.59e-02\n",
      "Task 045:  MSE = 1.25e-04  RL2 = 2.12e-02\n",
      "Task 012:  MSE = 8.80e-05  RL2 = 1.59e-02\n",
      "Task 028:  MSE = 7.73e-05  RL2 = 1.57e-02\n",
      "Task 011:  MSE = 9.34e-05  RL2 = 1.63e-02\n",
      "Task 008:  MSE = 1.10e-04  RL2 = 1.75e-02\n",
      "Task 038:  MSE = 7.39e-05  RL2 = 1.59e-02\n",
      "Task 047:  MSE = 1.54e-04  RL2 = 2.37e-02\n",
      "Task 027:  MSE = 7.86e-05  RL2 = 1.58e-02\n",
      "Task 023:  MSE = 8.20e-05  RL2 = 1.59e-02\n",
      "Task 039:  MSE = 7.73e-05  RL2 = 1.63e-02\n",
      "Task 007:  MSE = 1.12e-04  RL2 = 1.76e-02\n",
      "Task 046:  MSE = 1.39e-04  RL2 = 2.24e-02\n",
      "Task 003:  MSE = 1.86e-04  RL2 = 2.24e-02\n",
      "Task 034:  MSE = 6.99e-05  RL2 = 1.53e-02\n",
      "Task 013:  MSE = 8.38e-05  RL2 = 1.55e-02\n",
      "Task 006:  MSE = 1.12e-04  RL2 = 1.75e-02\n",
      "Task 026:  MSE = 7.97e-05  RL2 = 1.59e-02\n",
      "Task 029:  MSE = 7.56e-05  RL2 = 1.56e-02\n",
      "Task 000:  MSE = 3.82e-04  RL2 = 3.18e-02\n",
      "Task 042:  MSE = 9.45e-05  RL2 = 1.83e-02\n",
      "Task 037:  MSE = 7.15e-05  RL2 = 1.56e-02\n",
      "Task 015:  MSE = 7.95e-05  RL2 = 1.52e-02\n",
      "Task 030:  MSE = 7.38e-05  RL2 = 1.55e-02\n",
      "Task 002:  MSE = 2.42e-04  RL2 = 2.55e-02\n",
      "Task 044:  MSE = 1.13e-04  RL2 = 2.01e-02\n",
      "Task 014:  MSE = 8.09e-05  RL2 = 1.53e-02\n",
      "Task 032:  MSE = 7.11e-05  RL2 = 1.53e-02\n",
      "Task 010:  MSE = 9.97e-05  RL2 = 1.68e-02\n",
      "Task 018:  MSE = 8.03e-05  RL2 = 1.55e-02\n",
      "Task 035:  MSE = 6.97e-05  RL2 = 1.53e-02\n",
      "Task 024:  MSE = 8.15e-05  RL2 = 1.59e-02\n",
      "Task 040:  MSE = 8.17e-05  RL2 = 1.69e-02\n",
      "Task 019:  MSE = 8.05e-05  RL2 = 1.56e-02\n",
      "MEDIAN :  MSE = 8.16e-05  RL2 = 1.59e-02\n",
      "MEAN   :  MSE = 1.06e-04  RL2 = 1.77e-02\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "for task in test_task_list:\n",
    "    inputs, labels, v = x_train, y_train_all[task], nu_list[task]\n",
    "    v = jnp.tile(v.reshape(-1, 1), (1, len(inputs))).T\n",
    "    inputs = jnp.hstack([inputs, v])\n",
    "    u = model.apply(format_params_fn(params[:-2]), inputs)[:,0:1]\n",
    "    mse = jnp.mean(jnp.square(labels - u))\n",
    "    rl2 = jnp.linalg.norm(labels - u) / jnp.linalg.norm(labels)\n",
    "    # print ('Task %03d  %s :  MSE = %.2e  RL2 = %.2e'%(task, y_task_all[task], mse, rl2))\n",
    "    print ('Task %03d:  MSE = %.2e  RL2 = %.2e'%(task, mse, rl2))\n",
    "    results.append([task, mse, rl2])\n",
    "\n",
    "results = jnp.array(results)\n",
    "\n",
    "print ('MEDIAN :  MSE = %.2e  RL2 = %.2e'%(jnp.median(results[:,1:2]), jnp.median(results[:,2:3])))\n",
    "print ('MEAN   :  MSE = %.2e  RL2 = %.2e'%(jnp.mean(results[:,1:2]), jnp.mean(results[:,2:3])))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "star",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
