{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bf94d59d",
   "metadata": {
    "id": "bf94d59d"
   },
   "source": [
    "Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "09088a1a",
   "metadata": {
    "id": "09088a1a"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[CudaDevice(id=0)]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "import flax\n",
    "import optax\n",
    "from jax import lax, random, numpy as jnp\n",
    "from jax import random, grad, vmap, hessian, jacfwd, jit\n",
    "from jax import config\n",
    "from flax import linen as nn\n",
    "from evojax.util import get_params_format_fn\n",
    "from scipy import io\n",
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "from jax.scipy.linalg import solve\n",
    "# choose GPU\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "jax.config.update(\"jax_enable_x64\", True)\n",
    "n_gpu = len(jax.devices())\n",
    "jax.devices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "61479f56",
   "metadata": {},
   "outputs": [],
   "source": [
    "# read data\n",
    "x_train = jnp.load(os.path.join('vbgs_input.npy'))\n",
    "y_train_all = jnp.load(os.path.join('vbgs_label.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a076465d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0 1.0 -1.0 1.0\n",
      "(6579, 2) (50, 6579, 1) (229, 2) (50, 229, 1)\n"
     ]
    }
   ],
   "source": [
    "dert_x = 1\n",
    "dert_t = 1\n",
    "seed_data = 10  # Random seed for train–test split\n",
    "seed_model = 10 # Random seed for model parameter initialization\n",
    "n_step = 6   # Nonlinear iteration steps (n_step = 1 for linear PDEs)\n",
    "n_step_test = n_step\n",
    "n_all_task = y_train_all.shape[0]\n",
    "n_train = 16    # Number of training tasks\n",
    "n_test = n_all_task - n_train\n",
    "r_task = 8\n",
    "n_nodes = 128 # Number of nodes in each neural network layer\n",
    "max_iters = 1000 # Total training (meta-learning) epoch\n",
    "max_lr = 5e-3 # Learning rate\n",
    "\n",
    "x_step = 129\n",
    "t_step = 51\n",
    "nu_list = jnp.linspace(0.001, 0.05, 50)\n",
    "n_task = len(nu_list)\n",
    "\n",
    "x_train = x_train.T.reshape(2, t_step, -1)[:, ::dert_t, ::dert_x].reshape(2, -1).T\n",
    "y_train_all = y_train_all.reshape(n_task , t_step, -1)[:, ::dert_t, ::dert_x].reshape(n_task , -1, 1)\n",
    "\n",
    "t_l, t_u, x_l, x_u = np.min(x_train[:,1]), np.max(x_train[:,1]), np.min(x_train[:,0]), np.max(x_train[:,0])\n",
    "ext = [t_l, t_u, x_l, x_u]\n",
    "print (t_l, t_u, x_l, x_u)\n",
    "\n",
    "x_dim = math.floor((x_step - 1) / dert_x) + 1 #251   63   126\n",
    "t_dim = math.floor((t_step - 1) / dert_t) + 1 #251   63   126\n",
    "\n",
    "# split into BC / IC data\n",
    "bic = (x_train[:,1] == 0) | (x_train[:,0] == x_l) | (x_train[:,0] == x_u)\n",
    "bcl = (x_train[:,0] == x_l)\n",
    "bcu = (x_train[:,0] == x_u)\n",
    "ic = (x_train[:,1] == 0)\n",
    "data_X_BIC, data_Y_BIC = x_train[bic], y_train_all[:,bic]\n",
    "print (x_train.shape, y_train_all.shape, data_X_BIC.shape, data_Y_BIC.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "349d0aad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert to jnp\n",
    "x_train, y_train_all, data_X_BIC, data_Y_BIC, nu_list = jnp.array(x_train), jnp.array(y_train_all), jnp.array(data_X_BIC), jnp.array(data_Y_BIC)\\\n",
    "    ,jnp.array(nu_list)\n",
    "\n",
    "# Perform a fully random train–test split\n",
    "key, rng = random.split(random.PRNGKey(seed_data))\n",
    "key, rng = random.split(rng) # update random generator\n",
    "n_total = y_train_all.shape[0]\n",
    "\n",
    "all_indices = random.permutation(key, jnp.arange(n_total))\n",
    "test_task_list = all_indices[:n_test]\n",
    "train_task_list = all_indices[n_test:(n_test + n_train)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1389a078-0613-44df-a52a-e5136a61f390",
   "metadata": {
    "id": "1389a078-0613-44df-a52a-e5136a61f390"
   },
   "outputs": [],
   "source": [
    "class PINN(nn.Module):\n",
    "    \"\"\"PINNs\"\"\"\n",
    "    def setup(self):\n",
    "        self.layers = [nn.Dense(n_nodes, kernel_init = jax.nn.initializers.he_uniform()),\n",
    "                       jnp.sin,\n",
    "                       nn.Dense(n_nodes, kernel_init = jax.nn.initializers.he_uniform()),\n",
    "                       jnp.sin,\n",
    "                       nn.Dense(n_nodes, kernel_init = jax.nn.initializers.he_uniform()),\n",
    "                       jnp.sin,\n",
    "                       nn.Dense(n_nodes, kernel_init = jax.nn.initializers.he_uniform()),\n",
    "                       jnp.sin]\n",
    "        #self.last_layer = nn.Dense(1, kernel_init = jax.nn.initializers.he_uniform(), use_bias=False)\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, inputs):\n",
    "        # split the two variables, probably just by slicing\n",
    "        x, t = inputs[:,0:1], inputs[:,1:2]\n",
    "        def get_u(x, t):\n",
    "            f = jnp.hstack([x, t])\n",
    "            fs = []\n",
    "            for i, lyr in enumerate(self.layers):\n",
    "                f = lyr(f)\n",
    "                if (i == 0):\n",
    "                    f = 2 *jnp.pi *f\n",
    "                if (i%2 != 0):\n",
    "                    fs.append(f)\n",
    "                # print(f.shape)\n",
    "            f = jnp.hstack(fs)\n",
    "            #u = self.last_layer(f)\n",
    "            return f\n",
    "\n",
    "        f = get_u(x, t)\n",
    "\n",
    "        # obtain f_xx, f_x, f_t\n",
    "        def get_f_dir(get_u, x, t):\n",
    "            f_xx = jacfwd(jacfwd(get_u))(x, t)\n",
    "            f_x, f_t = jacfwd(get_u)(x, t), jacfwd(get_u, argnums=1)(x, t)\n",
    "            return f_xx, f_x, f_t\n",
    "\n",
    "        f_dir_vmap = vmap(get_f_dir, in_axes=(None, 0, 0))\n",
    "        f_xx, f_x, f_t = f_dir_vmap(get_u, x, t)\n",
    "        f_xx, f_x, f_t = f_xx[:,:,0,0], f_x[:,:,0], f_t[:,:,0]\n",
    "\n",
    "        outputs = jnp.hstack([f, f_xx, f_x, f_t])\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "71970aed-3347-4317-895e-9a606971888b",
   "metadata": {
    "id": "71970aed-3347-4317-895e-9a606971888b",
    "outputId": "21abc408-c265-46d2-b447-3719d5f0dab1"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.int64(49920)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# choose seed\n",
    "key, rng = random.split(random.PRNGKey(seed_model))\n",
    "\n",
    "# dummy input\n",
    "a = random.normal(key, [1,2])\n",
    "\n",
    "# initialization call\n",
    "model = PINN()\n",
    "params = model.init(key, a)\n",
    "num_params, format_params_fn = get_params_format_fn(params)\n",
    "\n",
    "# flatten initial params\n",
    "params = jax.flatten_util.ravel_pytree(params)[0]\n",
    "num_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cbb3592f-f9e7-4d5e-afa0-d3bebc81d460",
   "metadata": {
    "id": "cbb3592f-f9e7-4d5e-afa0-d3bebc81d460",
    "outputId": "61eeac25-4a7d-490a-e4e0-ea7ba9b1bde7"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(49922,)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hyper-parameter \"lambs\" via SGD - append params to include 2 more parameters\n",
    "key, rng = random.split(rng) # update random generator\n",
    "ini_a = random.normal(key, [1,2])\n",
    "params = jnp.append(params, ini_a)\n",
    "params.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "36486f02",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([20, 36,  1,  4,  5, 31, 16, 41, 33, 43,  9, 17, 22, 48, 21, 49],      dtype=int64)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_task_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bee134fb-897e-4fb7-984c-27340e2ffc06",
   "metadata": {
    "id": "bee134fb-897e-4fb7-984c-27340e2ffc06"
   },
   "outputs": [],
   "source": [
    "# loss function——Pichard linearization\n",
    "def eval_loss(params, task):\n",
    "    inputs, labels, gamma = x_train, y_train_all[task], nu_list[task]\n",
    "    pred = model.apply(format_params_fn(params[:-2]), inputs)\n",
    "    f, f_xx, f_x, f_t = jnp.split(pred, 4, axis=1)\n",
    "    # initial u\n",
    "    u = jnp.tile(y_train_all[task, ic], (t_dim, 1)).reshape(-1, 1)  #jnp.zeros((f.shape[0], 1))\n",
    "    f_bc = f[bcl] - f[bcu] # enforce periodic BC\n",
    "    p_bc = jnp.zeros((f_bc.shape[0], 1))\n",
    "    fx_bc = f_x[bcl] - f_x[bcu] # enforce periodic BC\n",
    "    px_bc = jnp.zeros((fx_bc.shape[0], 1))\n",
    "    lmbda = 10**(nn.sigmoid(params[-2]) *8 - 6) # transform into (1e-6, 1e2)\n",
    "    lamb = 10**(nn.sigmoid(params[-1]) *16 - 14) # transform into (1e-14, 1e2)\n",
    "    # to implement loop\n",
    "    for i in range(n_step):\n",
    "        pde = f_t + u*f_x - gamma*f_xx\n",
    "        # construct least square problem - populate A & b\n",
    "        A = jnp.vstack([pde * lmbda, f_bc, fx_bc, f[ic]])\n",
    "        b = jnp.vstack([labels * 0, p_bc, px_bc, labels[ic]])\n",
    "        # alternative solve (n_sample >> n_node)\n",
    "        #w = jnp.linalg.inv(lamb*jnp.eye(A.shape[1]) + (A.T@A))@A.T@b\n",
    "        As, bs = lamb*jnp.eye(A.shape[1]) + (A.T@A), A.T@b\n",
    "        w = solve(As, bs)\n",
    "        u = f @ w\n",
    "    # ssr, mse, rl2\n",
    "    ssr = np.sum((b - A @ w)**2)\n",
    "    mse = jnp.mean(jnp.square(labels.reshape(51, -1)[::dert_t, ::dert_x] - u.reshape(51, -1)[::dert_t, ::dert_x]))\n",
    "    rl2 = jnp.linalg.norm(labels - u) / jnp.linalg.norm(labels)\n",
    "    loss = mse\n",
    "    return loss, (ssr, mse, rl2)\n",
    "\n",
    "loss_grad = jax.jit(jax.value_and_grad(eval_loss, has_aux=True))\n",
    "# loss_grad = jax.vmap(loss_grad)\n",
    "# loss_grad = jax.pmap(loss_grad)  # span multi GPUs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7f930650",
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def update(params, opt_state, key):\n",
    "    loss_all, ssr_all, mse_all, rl2_all, grad_all = 0, 0, 0, 0, 0\n",
    "    batch_train_task = random.choice(key, train_task_list, (r_task,), replace=False)\n",
    "    for task in batch_train_task:\n",
    "        (loss, (ssr, mse, rl2)), grad = loss_grad(params, task)\n",
    "        loss_all += loss\n",
    "        ssr_all += ssr\n",
    "        mse_all += mse\n",
    "        rl2_all += rl2\n",
    "        grad_all += grad\n",
    "    loss_all /= r_task\n",
    "    ssr_all /= r_task\n",
    "    rl2_all /= r_task\n",
    "    mse_all /= r_task\n",
    "    grad_all /= r_task\n",
    "    updates, opt_state = optimizer.update(grad_all, opt_state)   ### adamw:  params=params\n",
    "    params = optax.apply_updates(params, updates)\n",
    "    return params, opt_state, loss_all, ssr_all, mse_all, rl2_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "fe019e01-0083-46fc-9e41-c82d2dc3cda7",
   "metadata": {
    "id": "fe019e01-0083-46fc-9e41-c82d2dc3cda7"
   },
   "outputs": [],
   "source": [
    "# optimizer\n",
    "lr_scheduler = optax.warmup_cosine_decay_schedule(init_value=max_lr, peak_value=max_lr, warmup_steps=int(max_iters*.4),\n",
    "                                                  decay_steps=max_iters, end_value=1e-6)\n",
    "optimizer = optax.adam(learning_rate=lr_scheduler) # Choose the method\n",
    "opt_state = optimizer.init(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5b49b20-d73f-4925-be58-b930bf26d0c8",
   "metadata": {
    "id": "e5b49b20-d73f-4925-be58-b930bf26d0c8",
    "outputId": "ce826088-137a-4420-ee93-77f829c0d1d2"
   },
   "outputs": [],
   "source": [
    "# training iteration\n",
    "runtime = 0\n",
    "train_iters = 0\n",
    "\n",
    "store = []\n",
    "\n",
    "while (train_iters <= max_iters) and (runtime < 1200):\n",
    "    # mini-batch update\n",
    "    start = time.time()\n",
    "    key, rng = random.split(rng) # update random generator\n",
    "    params, opt_state, loss, ssr, mse, rl2 = update(params, opt_state, key)\n",
    "    end = time.time()\n",
    "    runtime += (end-start)\n",
    "    # append weights\n",
    "    if (train_iters % 100 == 0):\n",
    "        print ('iter. = %05d,  time = %03ds,  loss = %.2e  |  ssr = %.2e,  mse = %.2e,  rl2 = %.2e'%(train_iters, runtime, loss, ssr, mse, rl2))\n",
    "    store.append([train_iters, runtime, loss, ssr, mse, rl2])\n",
    "    train_iters += 1\n",
    "\n",
    "store = jnp.array(store)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c50be89e",
   "metadata": {},
   "outputs": [],
   "source": [
    "##   save weight\n",
    "import pickle\n",
    "import jax\n",
    "\n",
    "params_np = jax.device_get(params)\n",
    "store_np = jax.device_get(store)\n",
    "save_dict = {\n",
    "    \"params\": params_np,\n",
    "    \"store\": store_np,\n",
    "    \"seed\": seed_model\n",
    "}\n",
    "filename = f\"PINet--Burgers.pkl\"\n",
    "\n",
    "with open(filename, \"wb\") as f:\n",
    "    pickle.dump(save_dict, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dd9e524",
   "metadata": {},
   "outputs": [],
   "source": [
    "###   load weight\n",
    "import pickle\n",
    "import jax\n",
    "\n",
    "with open(\"PINet--Burgers.pkl\", \"rb\") as f:\n",
    "    data = pickle.load(f)\n",
    "params = jax.numpy.asarray(data[\"params\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41f1c6a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Performance of PINet on test tasks\n",
    "from jax.scipy.linalg import solve\n",
    "import time\n",
    "\n",
    "# n_task = y_train_all.shape[0]\n",
    "u_all = []\n",
    "label_all = []\n",
    "mse_all = []\n",
    "rl2_all = []\n",
    "runtime = 0.0\n",
    "\n",
    "start = time.time()\n",
    "inputs = x_train\n",
    "pred = model.apply(format_params_fn(params[:-2]), inputs)\n",
    "f, f_xx, f_x, f_t = jnp.split(pred, 4, axis=1)\n",
    "f_bc = f[bcl] - f[bcu] # enforce periodic BC\n",
    "p_bc = jnp.zeros((f_bc.shape[0], 1))\n",
    "fx_bc = f_x[bcl] - f_x[bcu] # enforce periodic BC\n",
    "px_bc = jnp.zeros((fx_bc.shape[0], 1))\n",
    "    \n",
    "lmbda = 10**(nn.sigmoid(params[-2]) *8 - 6) # transform into (1e-6, 1e2)\n",
    "lamb = 10**(nn.sigmoid(params[-1]) *16 - 14) # transform into (1e-14, 1e2)\n",
    "for task in test_task_list:\n",
    "    labels = y_train_all[task]\n",
    "    u = jnp.tile(y_train_all[task, ic], (t_dim, 1)).reshape(-1, 1)  #jnp.zeros((f.shape[0], 1))\n",
    "    u_x = u\n",
    "    gamma = nu_list[task]\n",
    "\n",
    "    for i in range(n_step):\n",
    "        pde = f_t + u*f_x - gamma*f_xx\n",
    "        # construct least square problem - populate A & b\n",
    "        A = jnp.vstack([pde * lmbda, f_bc, fx_bc, f[ic]])\n",
    "        b = jnp.vstack([0 * labels, p_bc, px_bc, labels[ic]])\n",
    "        # alternative solve (n_sample >> n_node)\n",
    "        #w = jnp.linalg.inv(lamb*jnp.eye(A.shape[1]) + (A.T@A))@A.T@b\n",
    "        As, bs = lamb*jnp.eye(A.shape[1]) + (A.T@A), A.T@b\n",
    "        w = solve(As, bs)\n",
    "        u = f @ w\n",
    "        # u_x = f_x @ w\n",
    "    u_all.append(u)\n",
    "    label_all.append(labels)\n",
    "    mse_all.append(jnp.mean((labels - u) ** 2))\n",
    "    rl2_all.append(jnp.linalg.norm(labels - u) / jnp.linalg.norm(labels))\n",
    "end = time.time()\n",
    "runtime = end - start\n",
    "u_all = jnp.stack(u_all)\n",
    "label_all = jnp.stack(label_all)\n",
    "mse_all = jnp.array(mse_all)\n",
    "rl2_all = jnp.array(rl2_all)\n",
    "print(f\"Total runtime: {runtime:.2f} s\")\n",
    "print(\"Mean MSE:\", mse_all.mean())\n",
    "print(\"Mean RL2:\", rl2_all.mean())\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "star",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
