{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"projectile_motion.ipynb","provenance":[],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3 (ipykernel)","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.6"}},"cells":[{"cell_type":"markdown","metadata":{"id":"hklzPsgw693D"},"source":["# InImNet Solution of Projectile motion"]},{"cell_type":"markdown","metadata":{"id":"KwevE66669T-"},"source":["## Imports and directories"]},{"cell_type":"code","metadata":{"id":"hqRtdOKv3vD4"},"source":["# Quick parameters\n","npoints = 4 #32  # Number of p-layers\n","nsamples = 4 #100  # = No. training samples = No. validation samples\n","nlayers = 2  # Dense layers in the Phi-network\n","inflate_width = 4  # Inflate the hidden network width\n","gravity = 9.81  # Vertical decelaration constant\n","hmax = 10.  # Max height fixed for plotting\n","load_samples = False  # Requires existance of files else computes samples\n","save_samples = True  # Only when loading samples fails\n","\n","# Import modules\n","import os\n","import sys\n","from matplotlib import rcParams, cycler\n","import matplotlib.pyplot as plt  # %matplotlib notebook\n","import numpy as np\n","import random\n","from tqdm import tqdm\n","from tqdm.notebook import tqdm_notebook\n","\n","import tensorflow as tf\n","import tensorflow_datasets as tfds\n","import tensorflow_probability as tfp  # !pip install tensorflow_probability\n","\n","# pgrid defining InIm layers: pgrid[-1] = 1 = q\n","pgrid = tf.linspace(0., 1., npoints + 1)\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"U93C4uEuhSxu"},"source":["from google.colab import drive\n","drive.mount('/content/gdrive', force_remount=True)\n","\n","pmotion_path = '/content/gdrive/MyDrive/Colab/pmotion'\n","if os.getcwd() != pmotion_path: \n","    os.chdir(pmotion_path)\n","print('Current working directory: ', pmotion_path)\n","sys.path.insert(0, pmotion_path)\n","data_path = 'data/p' + str(npoints) + 'b' + str(nsamples)\n","graphics_path = 'graphics/p' + str(npoints) + 'b' + str(nsamples)\n","if not os.path.exists(data_path):\n","    os.makedirs(data_path)\n","    print('New data directory.')\n","if not os.path.exists(graphics_path):\n","    os.makedirs(graphics_path)\n","    print('New graphics directory.')\n","\n","ptrain_path = data_path + '/ptrain{p_i:03d}'\n","pvalid_path = data_path + '/pvalid{p_i:03d}'\n","vtrain_path = data_path + '/vtrain'\n","vvalid_path = data_path + '/vvalid'\n","inimsave_path = data_path + '/inimsol'\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"P-zgc8joRgCk"},"source":["%%script false\n","# TPU initialisation\n","resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')\n","tf.config.experimental_connect_to_cluster(resolver)\n","tf.tpu.experimental.initialize_tpu_system(resolver)\n","print(\"All devices: \", tf.config.list_logical_devices('TPU'))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ifZbeQ0Y3rhV"},"source":["# GPU info\n","print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))\n","gpu_info = !nvidia-smi\n","gpu_info = '\\n'.join(gpu_info)\n","if gpu_info.find('failed') >= 0:\n","  print('Not connected to a GPU')\n","else:\n","   print(gpu_info)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"2tMBAXqY7SQc"},"source":["## Generate Samples"]},{"cell_type":"markdown","metadata":{"id":"Zs7vFh75PSiX"},"source":["### Projectile ODE"]},{"cell_type":"code","metadata":{"id":"0esicgvG48n3"},"source":["class ProjectileODE(tfp.math.ode.DormandPrince):\n","    \"\"\"Subclass of tfp.math.ode.Solve facilitates self.solve method.\"\"\"\n","\n","    def __init__(self,\n","                 batch_dim,\n","                 gravity=9.81,\n","                 c=0,\n","                 ):\n","        super(ProjectileODE, self).__init__()\n","\n","        # ODE matrix multiplier\n","        self.A = tf.constant([[0, 1],\n","                              [0, c]], dtype=tf.float32)\n","        self.A = tf.expand_dims(self.A, axis=0)\n","        self.A = tf.repeat(self.A, batch_dim, axis=0)\n","\n","        # ODE shift \n","        self.b = tf.constant([0, -gravity], dtype=tf.float32)\n","        self.b = tf.expand_dims(self.b, axis=0)\n","        self.b = tf.repeat(self.b, batch_dim, axis=0)\n","\n","    def dynamics(self, t, z):\n","        \"\"\"Implements (A@z + b)/c for A=[[0, 1], [0, 0]] and b=[0, -gravity].\"\"\"\n","        return tf.linalg.matvec(self.A, z) + self.b\n","\n","    def inimdynamics(self, z, p, x, gtape):\n","        \"\"\"Invariant Imbedding p-gradient of z(0; p, x).\"\"\"\n","        dzdx = gtape.batch_jacobian(z, x)\n","        phi = self.dynamics(t=p, z=x)\n","        return -tf.linalg.matvec(dzdx, phi)\n","\n","    def inimsolve(self, pgrid, x, gtape):\n","        \"\"\"Euler-step ODE solver to solve dzdp = inimdynamics() over pgrid.\"\"\"\n","        z = [x]\n","        q = pgrid[-1]\n","        for p in tqdm_notebook(pgrid[-2::-1]):  # reverse pgrid, ommit pgrid[-1]\n","            delta = p - q\n","            z.append(z[-1] + delta * self.inimdynamics(z[-1], p, x, gtape))\n","            q = p\n","        z.reverse()  # Reorder solution according to pgrid\n","        return tf.stack(z)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7OJF0YKGPW93"},"source":["### Generate samples"]},{"cell_type":"code","metadata":{"id":"29zCFcDpioA3"},"source":["def get_projectile_samples(pgrid,\n","                           nsamples,\n","                           saving=False,\n","                           gravity=9.81,\n","                           ):\n","    \"\"\"Integrates state z(t; p, x) = [h, v] from p to pgrid[-1].\n","\n","    - Completes this for each p in pgrid[:-1] with t-solns at pgrid[i:].\n","    - Also computes the InIm integral analytically over pgrid[::-1].\n","    \"\"\"\n","\n","    if not all(p < q for p, q in zip(pgrid, pgrid[1:])):\n","        raise ValueError(\"Tensor `pgrid` should be strictly increasing.\")\n","\n","    # Adjust sample parameters (vertical speed) for suitable plots\n","    vmin = gravity * (pgrid[-1] - pgrid[0]) / 2\n","    vmax = tf.math.sqrt(2 * gravity * hmax)\n","    print('vmax: ', vmax)\n","    print('vmin: ', vmin)\n","\n","    # Initial input values for z(p; p, x) = x = [h=0, v=vinit]\n","    vinit = tf.random.uniform(shape=[2 * nsamples, 1],\n","                              minval=vmin,\n","                              maxval=vmax,\n","                              dtype=tf.dtypes.float32,\n","                              )\n","    hinit = tf.zeros_like(vinit)\n","    x = tf.concat([hinit, vinit], axis=-1)\n","\n","    # ODE dynamics objects\n","    ode = ProjectileODE(2*nsamples, gravity)\n","\n","    # Iterate t-integral for inital values p\n","    ptrain = []\n","    pvalid = []\n","    for i in tqdm_notebook(range(len(pgrid) - 1)):\n","        # Integrate t-solution\n","        tsol = ode.solve(ode.dynamics, pgrid[i], x, pgrid[i:])\n","        # i-th tstates[v, t, j] = z(t; p_i, x=[0, v])_j\n","        tstates = tf.transpose(\n","            tsol.states, perm=[1, 0, 2], name = 'z(t; p_{}, x)'.format(i+1),\n","            )   \n","        # Split training/validation data 50:50\n","        ttrain = tf.data.Dataset.from_tensors(tstates[:nsamples])\n","        tvalid = tf.data.Dataset.from_tensors(tstates[nsamples:])\n","        if saving:\n","            # Save p-th t-state\n","            tf.data.experimental.save(ttrain, ptrain_path.format(p_i=i+1))\n","            tf.data.experimental.save(tvalid, pvalid_path.format(p_i=i+1))\n","        # Accumulate as list\n","        ptrain.append(ttrain.get_single_element())\n","        pvalid.append(tvalid.get_single_element())\n","\n","    # Save initial values\n","    vtrain = tf.data.Dataset.from_tensors(vinit[:nsamples])\n","    vvalid = tf.data.Dataset.from_tensors(vinit[nsamples:])\n","    if saving:\n","        tf.data.experimental.save(vtrain, vtrain_path)\n","        tf.data.experimental.save(vvalid, vvalid_path)\n","\n","    # Compute InIm dynamics\n","    with tf.GradientTape(persistent=True, watch_accessed_variables=False) as g:\n","        g.watch(x)\n","        # Integrate InImODE: inimsoln[i, v, j] = z(q; p_i, x=[0, v])_j\n","        inimsoln = ode.inimsolve(pgrid, x, g)\n","    if saving:\n","        tf.data.experimental.save(tf.data.Dataset.from_tensors(inimsoln),\n","                                  inimsave_path,\n","                                  )\n","    return (ptrain,\n","            pvalid,\n","            tf.squeeze(vtrain.get_single_element()),\n","            tf.squeeze(vvalid.get_single_element()),\n","            inimsoln,\n","            )\n","\n","# Load training data\n","if load_samples and os.path.exists(ptrain_path.format(p_i=len(pgrid) - 1)):\n","    print('\\nLoading samples from file...')\n","    ptrain = []\n","    pvalid = []\n","    for i in tqdm_notebook(range(len(pgrid) - 1)):\n","        ttrain = tf.data.experimental.load(ptrain_path.format(p_i=i+1))\n","        tvalid = tf.data.experimental.load(pvalid_path.format(p_i=i+1))\n","        ptrain.append(ttrain.get_single_element())\n","        pvalid.append(tvalid.get_single_element())\n","    vtrain = tf.data.experimental.load(vtrain_path).get_single_element()\n","    vvalid = tf.data.experimental.load(vvalid_path).get_single_element()\n","    inimsoln = tf.data.experimental.load(inimsave_path).get_single_element()\n","    print('Complete.')\n","else:\n","    print('\\nComputing samples...')\n","    samples = get_projectile_samples(pgrid,\n","                                     nsamples,\n","                                     saving=save_samples,\n","                                     gravity=gravity,\n","                                     )\n","    ptrain, pvalid, vtrain, vvalid, inimsoln = samples\n","    print('\\nComplete.')\n","\n","# inimsoln[i, v, j] = z(q; p_i, [0, v])_j with len(v) = 2*nsamples\n","inimtrain = inimsoln[:, :nsamples, :]\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1qRiPEJYfvdg"},"source":["### Cultivate data"]},{"cell_type":"code","metadata":{"id":"rqaXuN2D5Atw"},"source":["# Training data\n","vtrain = tf.expand_dims(vtrain, axis=-1)\n","htrain = tf.zeros_like(vtrain)\n","xtrain = tf.concat([htrain, vtrain], axis=-1)\n","ytrain = ptrain  # ytrain[i][v, t, j] = z(t; p_i, [0, v])_j & len(v) = nsamples\n","\n","print('vtrain.shape: ', vtrain.shape)\n","print('htrain.shape: ', htrain.shape)\n","print('xtrain.shape: ', xtrain.shape)\n","print('ytrain[0].shape: ', ytrain[0].shape)\n","\n","# Validation data\n","#vvalid = tf.zeros_like(tf.expand_dims(vvalid, axis=-1))\n","hvalid = tf.zeros_like(vvalid)\n","xvalid = tf.concat([hvalid, vvalid], axis=-1)\n","yvalid = pvalid  # yvalid[i][v, t, j] = z(t; p_i, [0, v])_j & len(v) = nsamples"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"-yXnm6glKJUH"},"source":["print(vtrain)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"NBHH6DnB7XPc"},"source":["### Plot $t$-varying projectiles"]},{"cell_type":"code","metadata":{"id":"R00eRJ52--ZO"},"source":["## Plot t-varying output\n","# ptrain[i][v, t, j] = z(t; p_i, [0, v])_j\n","print('No. of p-states: ', len(ptrain))\n","print('No. of v-samples: ', len(vtrain))\n","print('Shape of typical p_states: ', ptrain[0].shape)\n","\n","print('First input x = [h, v] in batch:')\n","print('h(p) = ', ptrain[0][0][0, 0])\n","print('v(p) = ', ptrain[0][0][0, 1])\n","print('\\n')\n","\n","# Create figure\n","plt.style.use('seaborn-whitegrid')\n","cmap = plt.cm.autumn\n","rcParams['axes.prop_cycle'] = cycler(\n","    color=cmap(np.linspace(0, 1, len(pgrid) - 1))\n","    )\n","fig, ax = plt.subplots(figsize=[10, 8])\n","ax.set_xlim(pgrid[0], pgrid[-1])\n","ax.set_xlabel('p', fontsize=16)\n","ax.set_ylim(0, hmax)\n","ax.set_ylabel('h(t; p, x=[0, v(p)])', fontsize=16)\n","ax.set_title('Projectile path sample set.', fontsize=20)\n","\n","# Plot (blue) lines at p = pgrid[0] for 5 random values of v\n","cstr = ['b', 'aquamarine', 'xkcd:sky blue', 'r']\n","knum = 4\n","ipt = 0\n","msize = 100\n","for k in range(knum):\n","    ax.plot(pgrid[ipt:],\n","            ptrain[ipt][k][:, 0],\n","            cstr[k],\n","            label='$v(p)$ = {:.2f}'.format(float(vtrain[k])),\n","            )\n","    ax.scatter(pgrid[-1],\n","               #inimtrain[ipt, k, 0],\n","               ptrain[0][k][-1, 0],  \n","               c=cstr[k],\n","               s=msize,\n","               marker=8,\n","               )\n","\n","# Plot (red) lines at each p = pgrid for the final values of v=vtrain[vidx[-1]]\n","for i in range(0, len(pgrid) - 1):\n","    ax.plot(pgrid[i:], ptrain[i][knum-1][:, 0])\n","ax.legend(loc='upper left')\n","#plt.show()\n","fig.savefig(graphics_path + '/tsol.png')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"NOfahb0M7oYS"},"source":["## Define model"]},{"cell_type":"code","metadata":{"id":"yH5mfp9Et3ip"},"source":["class InImNet(tf.keras.Model):\n","    \"\"\"Time series regression with InImNet.\"\"\"\n","\n","    def __init__(self,\n","                 input_dim: int=2,\n","                 num_layers: int=2,\n","                 width_mult: int=2,\n","                 activation_in='relu',\n","                 activation_out=None,\n","                 cost: str='mse',\n","                 bias_in: bool=True,\n","                 bias_out: bool=False,\n","                 t_grad_on: bool=False,\n","                 name=\"InImNet\"):\n","        super(InImNet, self).__init__(name=name)\n","        if cost == 'mse':\n","            self.cost_fn = tf.keras.losses.MeanSquaredError(\n","                                reduction=tf.keras.losses.Reduction.NONE)\n","        else:\n","            raise ValueError('Please set new cost_fn')\n","        self.t_grad_on = t_grad_on\n","        self.phi = tf.keras.Sequential()\n","        self.phi.add(tf.keras.Input(shape=(input_dim,)))\n","        for layer in range(num_layers - 1):\n","            self.phi.add(tf.keras.layers.Dense(input_dim * width_mult,\n","                                               activation=activation_in,\n","                                               use_bias=bias_in,\n","                                               name='HiddenLayer{}'\n","                                                    .format(layer),\n","                                               ))\n","        self.phi.add(tf.keras.layers.Dense(input_dim,\n","                                           activation=activation_out,\n","                                           use_bias=bias_out,\n","                                           name='OutputLayer',\n","                                           ))\n","\n","    def inim_dynamics(self, z, p, x, gtape):\n","        \"\"\"Invariant Imbedding p-gradient of z(0; p, x).\"\"\"\n","        return -tf.linalg.matvec(gtape.batch_jacobian(z, x), self.phi(x))\n","\n","    def call(self, pgrid, x):\n","        \"\"\"Fwd-Euler integration of p-gradient of z(q; p, x) over pgrid.\n","\n","        Args:\n","            pgrid: Points at which to evaluate state z, all less than q.\n","            x: Tensor, input to the system for any given p in pgrid\n","\n","        Returns: 3D Tensor `out`, such that out[i, v, j] = z(q; p_i, x=[0, v])_j\n","        \"\"\"\n","\n","        if not all(p < q for p, q in zip(pgrid, pgrid[1:])):\n","            raise ValueError(\"The list 'pgrid' should be strictly increasing.\")\n","\n","        with tf.GradientTape(\n","            persistent=True, watch_accessed_variables=False) as g:\n","            g.watch(x)\n","            z = [x]\n","            q = pgrid[-1]\n","            for p in tqdm_notebook(pgrid[-2::-1]):  # p = pgrid[-2], ..., [0]\n","                delta = p - q\n","                z.append(z[-1] + delta * self.inim_dynamics(z[-1], p, x, g))\n","                q = p\n","        z.reverse()\n","        return tf.stack(z)\n","\n","    def aug_adjoint_dynamics(self, aa, p, x, gtape):\n","        \"\"\"Implements the p-derivative of the augmented adjoint (aa) state.\n","\n","        Args:\n","            aa: The Augmented Adjoint, a length 2 or 3 tuple of tensors with\n","                aa[0] = z derivative of the adjoint\n","                aa[1] = parameter derivative of the adjoint\n","                aa[2] = t derivative of the adjoint [optional]\n","            p: Scalar tensor\n","            x: Tensor, InIm input of shape (batch_size, 2)\n","            gx: GradientTape().watch(x) \n","                \n","        \"\"\"\n","\n","        # Watch x in phi(x, var) for Jacobian computation\n","        with tf.GradientTape(\n","            persistent=True, watch_accessed_variables=False) as g:\n","            g.watch(x)\n","            for var in self.variables:\n","                g.watch(var)\n","            phi = self.phi(x)\n","        dphi = g.batch_jacobian(phi, x)\n","\n","        # Gradient wrt. state z\n","        grad_z = - tf.linalg.matvec(gtape.batch_jacobian(aa[0], x), phi) \\\n","                 - tf.linalg.matvec(tf.transpose(dphi, perm=[0, 2, 1]), aa[0])\n","\n","        # Gradient wrt. params\n","        jacobians_wrt_params = []\n","        for var in self.variables:\n","            # vjp(z=self.phi(p, x), x=param, v_like_z=lam_aug[0]).view(-1))\n","            jacobians_wrt_params.append(tf.linalg.matvec(\n","                tf.transpose(tf.reshape(g.jacobian(phi, var),\n","                                        shape = phi.shape.as_list() + [-1,],\n","                                        ),\n","                             perm=[0, 2, 1],\n","                             ),\n","                aa[0],\n","                ))\n","\n","        # grad_lam_th = - jvp(z=lam_aug[1], x=x, v_like_x=self.phi(p, x)) \\\n","                    #   - torch.cat(vjps_wrt_params)\n","        grad_params = - tf.linalg.matvec(gtape.batch_jacobian(aa[1], x), phi) \\\n","                      - tf.concat(jacobians_wrt_params, axis=-1)\n","\n","        # Activate the t-component of the lambda derivative\n","        if self.t_grad_on:\n","            # grad_lam_t = - jvp(z=lam_aug[2], x=x, v_like_x=self.phi(p, x)) \\\n","                        #  - torch.bmm(jvp(self.phi(p, x), x, self.phi(p, x)),\n","                                    #  lam_aug[0],\n","                                    #  )\n","            grad_t = - tf.linalg.matvec(gtape.batch_jacobian(aa[2], x), phi) \\\n","                     - tf.reduce_sum(tf.multiply(tf.linalg.matvec(dphi, phi),\n","                                                 aa[0],\n","                                                 ),\n","                                     axes=1,\n","                                     keepdims=True,\n","                                     )\n","        else:\n","            grad_t = None\n","\n","        return grad_z, grad_params, grad_t\n","\n","    def call_aug_grads(self, pgrid, x, y):\n","        \"\"\"Integrates the p-gradient of the augmented adjoint (aa) state.\n","        \n","        Args:\n","            pgrid: Points at which to evaluate state z, all less than q.\n","            x: Tensor, input to the system for any given p in pgrid\n","            y: Tensor, \n","\n","        Returns: 3D Tensor `out`, such that out[i, v, j] = z(q; p_i, x=[0, v])_j\n","        \"\"\"\n","\n","\n","        if not all(p < q for p, q in zip(pgrid, pgrid[1:])):\n","            raise ValueError(\"The list 'pgrid' should be strictly increasing.\")\n","\n","        # List to store solutions tuples of solutions for each p in pgrid\n","        aa = []\n","        z = []\n","\n","        with tf.GradientTape(\n","            persistent=True, watch_accessed_variables=False) as gx:\n","            gx.watch(x)\n","\n","            # Initial value of InIm solution\n","            z.append(x)\n","\n","            # Initial values of the augmented adjoint\n","            with tf.GradientTape(\n","                persistent=True, watch_accessed_variables=False) as gz:\n","                gz.watch(z[0])\n","                init_rloss = tf.expand_dims(self.cost_fn(z[0], x), -1)\n","            aa.append((tf.squeeze(\n","                tf.linalg.matmul(gz.batch_jacobian(init_rloss, z[0]),\n","                                 gx.batch_jacobian(z[0], x)))\n","            ,))          \n","            aa[0] = aa[0] + (tf.math.multiply(\n","                tf.zeros(shape=(x.shape[0], sum(tf.math.reduce_prod(var.shape)\n","                                                for var in self.variables))),\n","                tf.reduce_sum(x)),)  # Mult by x so that JVP wrt x is 0 not None\n","            if self.t_grad_on:\n","                aa[0] = aa[0] + (torch.bmm(aa[0][0], self.phi(p_points[0], x)),)\n","            \n","            # Integrate over pgrid\n","            q = pgrid[-1]\n","            ii = 0  # Count the y data points in reverse p-order\n","            for p in tqdm_notebook(pgrid[-2::-1]):  # p = pgrid[-2], ..., [0]\n","                delta = p - q\n","                z.append(z[-1] + delta * self.inim_dynamics(z[-1], p, x, gx))\n","                ii += 1\n","                yi = tf.transpose(y[-ii], perm=[1, 0, 2])[-1]\n","                with tf.GradientTape(\n","                    persistent=True, watch_accessed_variables=False) as gz:\n","                    gz.watch(z[-1])\n","                    r_loss = tf.expand_dims(self.cost_fn(z[-1], yi), -1)\n","                r_grad = tf.squeeze(\n","                    tf.linalg.matmul(gz.batch_jacobian(r_loss, z[-1]),\n","                                     gx.batch_jacobian(z[-1], x)))\n","                aa_dynamics = self.aug_adjoint_dynamics(aa[-1], p, x, gx)\n","                aa.append((\n","                    aa[-1][0] + r_grad + delta * aa_dynamics[0],\n","                    aa[-1][1] + delta * aa_dynamics[1],\n","                    ))\n","                if self.t_grad_on:\n","                    aa[-1] = aa[-1] + (aa[-1][2] + delta * aa_dynamics[2],)\n","                q = p\n","        z.reverse()\n","        aa.reverse()\n","        return z, aa\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"aNwEXPWGeytS"},"source":["### Instantiation"]},{"cell_type":"code","metadata":{"id":"cgQUwe-FeYG0"},"source":["# Instantiate model\n","inimodel = InImNet(input_dim = 2,\n","                   num_layers=nlayers,\n","                   width_mult=inflate_width)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"fGymwDGEPkbs"},"source":["## Training Phase"]},{"cell_type":"code","metadata":{"id":"s8h2WoMWcYyL"},"source":["epochs = 10\n","learning_rate = 0.001\n","\n","# Performance record\n","record_training_loss = np.zeros((epochs,))\n","\n","# Train\n","for ep in tqdm_notebook(range(epochs)):\n","    print('Epoch ', ep)\n","\n","    # Backward pass model\n","    ztrain, aatrain = inimodel.call_aug_grads(pgrid, xtrain, ytrain)\n","    print('Check ztrain epoch {}: '.format(ep), ztrain)\n","    param_grads = tf.math.reduce_sum(aatrain[0][1], axis=0)\n","    print('param_grads.shape: ', param_grads.shape)\n","    new_weights = []\n","    start = 0\n","    for w in inimodel.get_weights():\n","        end = start + tf.math.reduce_prod(w.shape)\n","        wloss = tf.reshape(param_grads[start:end], shape=w.shape)\n","        new_weights.append(w - learning_rate * wloss)\n","        start = end\n","    inimodel.set_weights(new_weights)\n","\n","    # Record loss\n","    num = 0\n","    losses = []\n","    mse = tf.keras.losses.MeanSquaredError()\n","    for z, y in zip(ztrain[:-1], ytrain):\n","        num += 1\n","        losses.append(mse(z, tf.transpose(y, perm=[1, 0, 2])[-1]))\n","    record_training_loss[ep] = sum(losses)/num\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sOC5p0A97fRH"},"source":["## Plot results"]},{"cell_type":"code","metadata":{"id":"Pf4gcJypoGSr"},"source":["    # Model evalutated on training data, shape (npoints+1, batch_dim, dim=2)\n","    ztrain = tf.stack(ztrain)\n","    yptrain = [tf.transpose(y, perm=[1, 0, 2])[-1] for y in ytrain]\n","    yptrain.append(xtrain)\n","    yptrain = tf.stack(yptrain)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"5z9zR_L17Fs-"},"source":["# Normalise loss by untrained parameters\n","training_loss = np.array(record_training_loss)\n","norm_loss = [loss / training_loss[0] for loss in training_loss]\n","\n","# Plot loss and accuracy graphs on one plot\n","fig, ax = plt.subplots()\n","plt.title('Loss versus epoch')\n","loss_lines = ax.plot(norm_loss)\n","ax.set_ylabel('Training loss')\n","ax.set_xlabel('Number of epochs')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"wsMAAaKSx9tX"},"source":["## Plot Learnt output from training data at t=q (for varying p)\n","\n","print('First input x = [h, v] in batch:')\n","print('h(q; q, x=[0, v]) = ', ztrain[-1, 0, 0])\n","print('v(q; q, x=[0, v]) = ', ztrain[-1, 0, 1])\n","print('\\n')\n","\n","#print('h(q; q, x=[0, v]) = ', yptrain[-1, 0, 0])\n","#print('v(q; q, x=[0, v]) = ', yptrain[-1, 0, 1])\n","#print('\\n')\n","\n","# Create figure\n","plt.style.use('seaborn-whitegrid')\n","cmap = plt.cm.autumn\n","rcParams['axes.prop_cycle'] = cycler(\n","    color=cmap(np.linspace(0, 1, 5))\n","    )\n","fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[20, 8])\n","\n","## Axis 1\n","ipt = 0\n","ax1.set_xlim(pgrid[0], pgrid[-1])\n","ax1.set_xlabel('t', fontsize=24)\n","ax1.set_ylim(0, hmax)\n","ax1.set_ylabel('h(t; p, x=[0, v])', fontsize=24)\n","ax1.set_title('Fixed p = {:.2f}, varying initial velocity $v$'\n","              .format(0), fontsize=24)  # pgrid[ipt]\n","xlabs = [0, 0.25, 0.5, 0.75, 1]\n","ylabs = [0, 2, 4, 6, 8, 10]\n","ax1.set_xticklabels(xlabs, fontsize=24)\n","ax1.set_yticklabels(ylabs, fontsize=24)\n","\n","# Plot (blue) lines at p = pgrid[0] for 5 random values of v\n","cstr = ['b', 'aquamarine', 'xkcd:sky blue', 'r']\n","knum = 4\n","msize = 100\n","for k in range(knum):\n","    ax1.plot(pgrid[ipt:],\n","             ytrain[ipt][k][:, 0],\n","             cstr[k],\n","             label='$v$ = {:.2f}'.format(float(vtrain[k])),\n","             )\n","    ax1.scatter(pgrid[-1],\n","                ztrain[ipt, k, 0],\n","                c=cstr[k],\n","                s=msize,\n","                marker=8,\n","                )\n","ax1.legend(loc='upper left', fontsize=20, framealpha=1)\n","\n","## Axis 2\n","ax2.set_xlim(pgrid[0], pgrid[-1])\n","ax2.set_xlabel('t', fontsize=24)\n","ax2.set_ylim(0, hmax)\n","#ax2.set_ylabel('h(t; p, x=[0, v(p)])', fontsize=24)\n","ax2.set_title('Fixed initial velocity $v$, varying $p$', fontsize=24)\n","xlabs = ax1.get_xticklabels()\n","ax2.set_xticklabels(xlabs, fontsize=24)\n","ax2.set_yticklabels([])\n","\n","# Plot (red) lines at each p = pgrid for the final values of v=vtrain[vidx[-1]]\n","for i in range(0, len(pgrid) - 1):\n","    if i == 0:\n","        ax2.plot(pgrid[i:], ytrain[i][knum-1][:, 0],\n","                 label='$v$ = {:.2f}'.format(float(vtrain[k])),\n","                 )\n","    else:\n","        ax2.plot(pgrid[i:], ptrain[i][knum-1][:, 0])\n","    ax2.scatter(pgrid[-1],\n","                ztrain[i, knum-1, 0],\n","                s=msize,\n","                marker=8,\n","                )\n","ax2.legend(loc='upper center', fontsize=20)\n","plt.show()\n","fig.savefig(graphics_path + '/psol.pdf')\n"],"execution_count":null,"outputs":[]}]}