{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Discrete_InImNet_tensorflow_time_series_bouncing_balls_conv.ipynb","provenance":[],"collapsed_sections":[],"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","name":"python3"},"accelerator":"GPU"},"cells":[{"cell_type":"code","metadata":{"id":"JNiPg5Ygf6pX"},"source":["\n","from google.colab import drive\n","drive.mount('/gdrive')\n","%cd /gdrive/My Drive/InImNet/"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kyJMnDVZGX_q"},"source":["import os\n","import tensorflow as tf\n","#tf.get_logger().setLevel('ERROR')\n","\n","import matplotlib as mpl\n","import matplotlib.pyplot as plt\n","import logging\n","\n","mpl.rcParams['figure.figsize'] = (8, 6)\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Iex1qdCJxUdK"},"source":["class InputAutoencoder (tf.keras.Model):\n","    def __init__(self, autoenc_lr: float = 0.001, represent_dim_in=50, represent_dim_out=50):\n","        super(InputAutoencoder, self).__init__()\n","        \n","        self.encoder_input = []\n","        self.decoder_input = []\n","\n","\n","        self.encoder_input.append(tf.keras.layers.Conv2D(16, (5, 5), strides=2, padding='same'))\n","        self.encoder_input.append(tf.keras.layers.BatchNormalization())\n","        self.encoder_input.append(tf.keras.layers.ReLU())\n","        self.encoder_input.append(tf.keras.layers.Conv2D(32, (5, 5), strides=2, padding='same'))\n","        self.encoder_input.append(tf.keras.layers.BatchNormalization())\n","        self.encoder_input.append(tf.keras.layers.ReLU())\n","        self.encoder_input.append(tf.keras.layers.Conv2D(64, (5, 5), strides=2, padding='same'))\n","        self.encoder_input.append(tf.keras.layers.ReLU())\n","        #self.encoder_input.append(tf.keras.layers.BatchNormalization())\n","        #self.encoder_input.append(tf.keras.layers.Conv2D(4, (5, 5), strides=1, activation='relu', padding='same'))\n","        #self.encoder_input.append(tf.keras.layers.Flatten())\n","\n","\n","        self.dim_reducer = tf.keras.layers.Dense (represent_dim_in)\n","        self.dim_reconstructor = tf.keras.layers.Dense (represent_dim_out)\n","        \n","        #self.max_pool_output_shape = [6, 6, 8]\n","        #self.flattened_max_pool_output_shape = 288\n","\n","        #self.decoder_input.append(tf.keras.layers.Dense(self.flattened_max_pool_output_shape, activation='relu')) \n","        #self.decoder_input.append(tf.keras.layers.Reshape(self.max_pool_output_shape))\n","        self.decoder_input.append(tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=2, padding='same'))\n","        self.decoder_input.append(tf.keras.layers.BatchNormalization())\n","        self.decoder_input.append(tf.keras.layers.ReLU())\n","        self.decoder_input.append(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=2, padding='same'))\n","        self.decoder_input.append(tf.keras.layers.BatchNormalization())\n","        self.decoder_input.append(tf.keras.layers.ReLU())\n","        self.decoder_input.append(tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=2, padding='same'))\n","        self.decoder_input.append(tf.keras.layers.BatchNormalization())\n","        self.decoder_input.append(tf.keras.layers.ReLU())\n","        self.decoder_input.append(tf.keras.layers.Conv2DTranspose(4, (5, 5), padding='same'))\n","        lr_decayed_fn = autoenc_lr#tf.keras.optimizers.schedules.CosineDecay(autoenc_lr, 32)\n","        self.optimiser_autoencoder = tf.keras.optimizers.Adam(learning_rate=lr_decayed_fn)  \n","\n","\n","    def call_encoder(self, x, training):\n","        z = x               \n","        for i in range (len(self.encoder_input)):       \n","            z = self.encoder_input[i](z, training=training)   \n","            print ('encoder['+str(i)+'].shape: ', z.shape)\n","        return z\n","\n","    def call_dim_reducer(self, x, training):\n","        return self.dim_reducer (x, training=training) \n","\n","    def call_dim_reconstructor(self, x, training):\n","        return self.dim_reconstructor (x, training=training) \n","\n","    def call_decoder (self, x, training):\n","        z = x               \n","        for i in range (len(self.decoder_input)):       \n","            z = self.decoder_input[i](z, training=training)   \n","            print ('decoder['+str(i)+'].shape: ', z.shape)\n","        return z\n","    \n","    def optimise_autoencoder (self, loss, g): \n","        self.optimiser_autoencoder.minimize (loss, [w.trainable_weights for w in self.encoder_input+self.decoder_input+[self.dim_reducer]+[self.dim_reconstructor]], tape=g)\n","        return 0"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kbpUoMeHeyqa"},"source":["class AutoDiffInImNet(tf.keras.Model):\n","    \"\"\"Choose output method for model.\"\"\"\n","\n","    def __init__(self, dim: int, represent_dim_out: int,\n","                 num_resnet_layers: int,\n","                 activation,\n","                 num_internal_layers: int,\n","                 bias_on: bool = True,\n","                 mult: int = 1,               \n","                 use_batch_norm: bool = False,\n","                 dropout=0.1,\n","                 lr_inim = 0.001,\n","                 lr_autoencoder = 0.001,\n","                 approx_jacobian = True,\n","                 weight_regularisation_alpha = 0.0):\n","        super(AutoDiffInImNet, self).__init__()\n","        self.num_internal_layers = num_internal_layers\n","        self.num_resnet_layers = num_resnet_layers\n","        self.activation = activation\n","        self.dim = dim\n","        self.lr_inim = lr_inim#tf.keras.optimizers.schedules.PiecewiseConstantDecay (boundaries=[128], values=[lr_inim/50, lr_inim])#tf.keras.optimizers.schedules.ExponentialDecay(lr_inim, decay_steps=15625, decay_rate=0.5, staircase=True)\n","        self.represent_dim_out = represent_dim_out\n","        self.weight_regularisation_alpha = weight_regularisation_alpha\n"," \n","        self.optimiser = []\n","        for i in range (self.num_resnet_layers):\n","            self.optimiser.append(tf.keras.optimizers.Adam(learning_rate=self.lr_inim))\n","        self.approx_jacobian = approx_jacobian\n","        self.autoencoder =  []\n","        for i in range (self.num_resnet_layers):\n","            self.autoencoder.append(InputAutoencoder (lr_autoencoder, represent_dim_in=dim, represent_dim_out=represent_dim_out))\n","\n","        \n","    def call_phi (self, ell, x, training):\n","        z = x          \n","        #for i in range (len(self.fc_network[ell])):  \n","        encoded_images = self.autoencoder[ell].call_encoder (x, training=training)  \n","        z = self.autoencoder[ell].call_decoder (encoded_images, training=training)          \n","            #z = self.fc_network[ell][i](z, training=training)    \n","        return z\n","\n","    #@tf.function\n","    def optimise (self, x, y_in, t, loss): \n","        self.output_dim = y_in.shape[2]\n","        #print ('optimise')\n","\n","        with tf.GradientTape(persistent=True) as g:\n","            z = self (x, t, training=True)\n","            y_in_shape = self._infer_shape (y_in)\n","            y = tf.reshape (y_in, [y_in_shape[0] * y_in_shape[1], y_in_shape[2], y_in_shape[3]])\n","            losses = []\n","            total_loss = 0\n","            for i in range (self.num_resnet_layers):\n","                #print (y[:, i, :])\n","                curr_loss = tf.math.reduce_mean(loss(z[i+1], y))\n","                losses.append(curr_loss)\n","                total_loss = total_loss + curr_loss\n","            print ('losses: ', losses)\n","            final_loss = losses[-1]\n","\n","        w = []\n","        for i in range (self.num_resnet_layers):\n","            w.extend ([fc_ij.trainable_weights for fc_ij in self.autoencoder[i].encoder_input+self.autoencoder[i].decoder_input])\n","        self.optimiser[0].minimize (final_loss, w, tape=g)\n","\n","        #for i in range (len(losses)):\n","        #   self.autoencoder.optimise_autoencoder (losses[i], g)\n","        \n","        \n","        return final_loss\n","                  \n","    def jacobian (self, z, x, g):\n","        return g.batch_jacobian(z, x) \n","\n","    def _infer_shape(self, x):\n","        x = tf.convert_to_tensor(x)\n","\n","        # If unknown rank, return dynamic shape\n","        if x.shape.dims is None:\n","            return tf.shape(x)\n","\n","        static_shape = x.shape.as_list()\n","        dynamic_shape = tf.shape(x)\n","\n","        ret = []\n","        for i in range(len(static_shape)):\n","            dim = static_shape[i]\n","            if dim is None:\n","                dim = dynamic_shape[i]\n","            ret.append(dim)\n","\n","        return ret\n","\n","    #@tf.function\n","    def call(self, x, t, training=False):\n","        x_old_shape =  self._infer_shape(x)\n","        print ('x_old_shape: ', x_old_shape)\n","        #inferred_shape = x_old_shape\n","        #x = tf.reshape (x, [x_old_shape[0] * x.shape[1], x.shape[2], x.shape[3]])\n","        #x = tf.expand_dims (x, -1)\n","        x = tf.transpose (x, [0, 2, 3, 1])\n","\n","        x_reshaped = tf.repeat (x, t.shape[0], axis=0)\n","        t_reshaped = tf.tile(tf.reshape(t, [t.shape[0], 1, 1, 1]), [x_old_shape[0], x.shape[1], x.shape[2], 1])\n","        x = tf.concat((x_reshaped[:, :, :, :], t_reshaped), axis=3)\n","        #encoded_images = self.autoencoder.call_encoder (x, training=training)\n","    \n","\n","        #shape_encoded = self._infer_shape(encoded_images)\n","        #dim = tf.reduce_prod(shape_encoded[1:])\n","        #encoded_images = tf.reshape(encoded_images, [-1, dim])\n","        #encoded_images_shape = self._infer_shape(encoded_images)\n","        #encoded_images = tf.reshape (encoded_images, [x_old_shape[0], x_old_shape[1], encoded_images_shape[1]])\n","        #encoded_images_shape = self._infer_shape(encoded_images)\n","        #encoded_images = tf.reshape (encoded_images, [shape_encoded[0], dim])\n","        #print (encoded_images.shape)\n","        #encoded_images = self.autoencoder.call_dim_reducer (encoded_images, training=training)\n","\n","        #t_shape = self._infer_shape(t)\n","        #print ('t_shape: ', t_shape)\n","        #t_reshaped = tf.tile(tf.reshape(t, [t_shape[0], 1]), [x_old_shape[0], 1])\n","        #x_reshaped = tf.repeat (encoded_images, t_shape[0], axis=0)\n","        #x_reshaped = tf.reshape(x_reshaped, [x_reshaped.shape[0], x_reshaped.shape[1]*x_reshaped.shape[2]])\n","        #x = tf.concat([x_reshaped, t_reshaped], axis=1)\n","        \n","        z = [x]\n","        print ('x.shape: ', x.shape)\n","        print (z)\n","        with tf.GradientTape(persistent=True) as g:\n","          g.watch (x)\n","          for ell in range (self.num_resnet_layers):\n","              print ('ell: ', ell)\n","              dim = tf.reduce_prod(tf.shape(x)[1:])\n","              x_flattened = tf.reshape(x, [-1, dim])\n","              z_flattened = tf.reshape(z[-1], [-1, dim])\n","              print ('tf.shape(x): ', tf.shape(x))\n","              with g.stop_recording():\n","                  if not self.approx_jacobian:\n","                      jacobian_z_x = self.jacobian(z_flattened[-1], x_flattened, g) \n","                  else:\n","                      if ell == 0:\n","                          jacobian_z_x = self.jacobian(x_flattened, x_flattened, g) \n","                      else:\n","                          jacobian_z_x +=  self.jacobian(phi_curr_flattened, x_flattened, g) \n","                  #print (jacobian_z_x.shape)\n","              phi_curr = self.call_phi (ell, x, training=training)\n","              print ('phi_curr.shape: ', phi_curr.shape)\n","              phi_curr_flattened = tf.reshape(phi_curr, [-1, tf.reduce_prod(phi_curr.shape[1:])])\n","              delta = jacobian_z_x @  tf.expand_dims(phi_curr_flattened, -1)  \n","              delta = tf.reshape(delta, z[-1].shape)\n","              #print(delta.shape)\n","              z.append(z[-1] +  delta)\n","\n","        for ell in range (self.num_resnet_layers+1): \n","        #    print (ell)  \n","        #    z[ell] = z[ell][:, :-1]\n","        #    z[ell] = self.autoencoder.call_dim_reconstructor (z[ell], training=training)\n","        #    z_ell_shape = self._infer_shape(z[ell])\n","        #    z[ell] = tf.reshape (z[ell], [z_ell_shape[0], shape_encoded[1], shape_encoded[2], shape_encoded[3]])\n","        #    z[ell] = self.autoencoder.call_decoder (z[ell],training=training)\n","            z[ell] = z[ell] [:, :, :, 0]\n","        \n","        print ('call finished')\n","        return z"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"lElxZyFOWS9D"},"source":["# Activate InImNet (True) or ResNet (False) during training/testing\n","inim_on_training = True\n","inim_on_testing = True\n","\n","# Plotting options\n","save_plots = True\n","view_plots = False\n","initial_data_plot = False"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"cNZ7_FZhWa5g"},"source":["import numpy as np\n","import math \n","# Logging initiation\n","logger = logging.getLogger()\n","logger.setLevel(logging.INFO)  # DEBUG, INFO, WARNING, ERROR, or CRITICAL\n","logger.addHandler(logging.StreamHandler())\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Na8dcSp3Juvc"},"source":["import matplotlib.pyplot as plt\n","import h5py\n","\n","print ('Training data')\n","with h5py.File('data/bouncing_ball_data/training.hkl', 'r') as f:\n","    print (f.keys())\n","    trajectories_train = np.array(f['data_0'])\n","    trajectories_train = np.reshape (trajectories_train,\\\n","                                     [trajectories_train.shape[0], trajectories_train.shape[1], 32, 32])\n","    #trajectories_train = trajectories_train[:, :, ::2, ::2]\n","    print (trajectories_train.shape)\n","plt.imshow (trajectories_train[0, 0, :, :])\n","plt.show()\n","\n","print ('Testing data')\n","with h5py.File('data/bouncing_ball_data/test.hkl', 'r') as f:\n","    print (f.keys())\n","    trajectories_test = np.array(f['data_0'])\n","    trajectories_test = np.reshape (trajectories_test,\\\n","                                     [trajectories_test.shape[0], trajectories_test.shape[1], 32, 32])\n","    \n","    #trajectories_test = trajectories_test[:, :, ::2, ::2]\n","    print (trajectories_test.shape)\n","plt.imshow (trajectories_test[0, 0, :, :])\n","plt.show()\n","\n","print ('Validation data')\n","with h5py.File('data/bouncing_ball_data/val.hkl', 'r') as f:\n","    print (f.keys())\n","    trajectories_val = np.array(f['data_0'])\n","    trajectories_val = np.reshape (trajectories_val,\\\n","                                     [trajectories_val.shape[0], trajectories_val.shape[1], 32, 32])\n","    \n","    #trajectories_val = trajectories_val[:, :, ::2, ::2]\n","    print (trajectories_val.shape)\n","plt.imshow (trajectories_val[0, 0, :, :])\n","plt.show()\n","\n","#test_data = h5F.load('data/bouncing_ball_data/test.hkl')\n","#val_data = hkl.load('data/bouncing_ball_data/val.hkl')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"xMSCH9li4NCX"},"source":["import matplotlib.animation\n","import matplotlib.pyplot as plt\n","import numpy as np\n","plt.rcParams[\"animation.html\"] = \"jshtml\"\n","plt.rcParams['figure.dpi'] = 150  \n","plt.ioff()\n","\n","fig, ax = plt.subplots()\n","x= np.linspace(0,10,100)\n","def animate(t):\n","    plt.cla()\n","    img = np.array(trajectories_train[0, t, :, :])\n","    plt.imshow(img)\n","\n","matplotlib.animation.FuncAnimation(fig, animate, frames=trajectories_train.shape[1])\n","        \n","     #print (im['label'])"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"b0fJmx9YWnJV"},"source":["#num_internal_layers = 3\n","#batch_size = 1\n","#autoenc_lr = 0.0001\n","#losses = []\n","#t = np.linspace(0, 1, ntotal).astype(np.float32)\n","\n","#autoencoder =  InputAutoencoder(autoenc_lr = autoenc_lr )\n","\n","#for epoch in range(num_epochs):\n","#    indices = np.random.choice (trajectories_train.shape[0], size=batch_size)\n","#    images = trajectories_train[indices, :, :, :]    \n","#    images = tf.reshape (images, [images.shape[0] * images.shape[1], images.shape[2], images.shape[3]])\n","#    images = tf.expand_dims (images, -1)\n","#    loss = autoencoder.optimise_autoencoder (images, tf.keras.losses.MSE)\n","#    if epoch % 1000 == 0:\n","#        print (epoch)\n","#        print (loss)\n","#        pred_train = autoencoder.call_decoder(autoencoder.call_encoder (images))\n","#        print ('pred_train.shape: ', pred_train.shape)\n","#        fig, axes = plt.subplots(5, images.shape[0]//5)\n","#        print ('axes.shape[0]: ', axes.shape[0])\n","#        for i in  range(axes.shape[0]):\n","#            for j in  range(axes.shape[1]):\n","#                axes[i][j].get_xaxis().set_visible(False)\n","#                axes[i][j].get_yaxis().set_visible(False)\n","#                axes[i][j].imshow(images[i*axes.shape[0]+j, :, :, 0])\n","#        plt.show()\n","#        print ('animation ended')    \n","\n","#        print ('animation started:l1')\n","#        fig, axes = plt.subplots(5, pred_train.shape[0]//5)\n","#        for i in  range(axes.shape[0]):\n","#            for j in  range(axes.shape[1]):\n","#                axes[i][j].get_xaxis().set_visible(False)\n","#                axes[i][j].get_yaxis().set_visible(False)\n","#                axes[i][j].imshow(pred_train[i*axes.shape[0]+j, :, :, 0])\n","#        plt.show()\n","#        print ('animation ended: l1')\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"w1M-iwYThtbc"},"source":["#NUM_IMAGES_CONTEXT = 3\n","#indices = np.random.choice (trajectories_train.shape[0], size=batch_size)\n","#images = trajectories_train[indices, :, :, :]\n","#x = images [:, :NUM_IMAGES_CONTEXT, :, :]\n","#x_old_shape = x.shape\n","#x = tf.reshape (x, [x.shape[0] * x.shape[1], x.shape[2], x.shape[3]])\n","#x = tf.expand_dims (x, -1)\n","#encoded_images = autoencoder.call_encoder (x)\n","#shape_encoded = encoded_images.shape\n","#dim = tf.reduce_prod(tf.shape(encoded_images)[1:])\n","#encoded_images = tf.reshape(encoded_images, [-1, dim])\n","#encoded_images = tf.reshape (encoded_images, [x_old_shape[0], x_old_shape[1], encoded_images.shape[1]])\n","#encoded_images = tf.reshape (encoded_images, [encoded_images.shape[0], encoded_images.shape[1]*encoded_images.shape[2]])\n","#print (encoded_images.shape)\n","\n","#decoded_images = tf.reshape (encoded_images[:, :dim], [x_old_shape[0], shape_encoded[1], shape_encoded[2], shape_encoded[3]])\n","#decoded_images = autoencoder.call_decoder (decoded_images)\n","#plt.imshow(decoded_images[0, :, :, 0])\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"eU-g7sfQtd0z"},"source":["ell_max = 1\n","ntotal = 50\n","double_mlp_on = True\n","triple_mlp_on = True\n","bias_on = True\n","inflation_factor = 2\n","test_activation = tf.keras.activations.relu\n","num_internal_layers = 3\n","batch_size = 2\n","NUM_IMAGES_TOTAL = 20\n","NUM_IMAGES_EVAL = 13\n","NUM_IMAGES_CONTEXT = 3\n","DIM_LATENT = 16 * 64\n","INNER_DIM_LATENT=50\n","DROPOUT_VALUE=0.3\n","lr_inim = 0.001\n","lr_autoencoder = 0.0004\n","num_epochs = 500\n","t = np.linspace(0, 1, NUM_IMAGES_TOTAL).astype(np.float32)\n","t_eval = np.linspace(0, 1, NUM_IMAGES_TOTAL).astype(np.float32) [:NUM_IMAGES_EVAL]\n","print ('t: ', t)\n","print ('t_eval: ', t_eval)\n","\n","try:\n","    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')\n","    tf.config.experimental_connect_to_cluster(resolver)\n","    # This is the TPU initialization code that has to be at the beginning.\n","    tf.tpu.experimental.initialize_tpu_system(resolver)\n","    print(\"All devices: \", tf.config.list_logical_devices('TPU'))\n","    strategy = tf.distribute.TPUStrategy(resolver)\n","except:\n","    strategy = None\n","autodiff_inimnet =  AutoDiffInImNet(dim=INNER_DIM_LATENT, represent_dim_out=DIM_LATENT, \n","                            num_resnet_layers=ell_max,\n","                            activation=test_activation,\n","                            num_internal_layers=num_internal_layers,\n","                            bias_on=bias_on,\n","                            mult=inflation_factor,\n","                            dropout=DROPOUT_VALUE,\n","                            lr_inim = tf.keras.optimizers.schedules.ExponentialDecay (lr_inim, 30 * trajectories_train.shape[0] // batch_size, 0.5, staircase=True),#tf.keras.optimizers.schedules.CosineDecay(lr_inim, num_epochs * trajectories_train.shape[0] // batch_size),\n","                            lr_autoencoder = lr_autoencoder,\n","                            weight_regularisation_alpha = 0.0#0.00001\n","                            )\n","losses = []\n","\n","#images = trajectories_train[indices, :NUM_IMAGES_TOTAL, :, :]\n","#x = images [:, :NUM_IMAGES_CONTEXT, :, :]\n","dataset = tf.data.Dataset.from_tensor_slices((trajectories_train))\n","dataset = dataset.shuffle(512).batch(batch_size)\n","iter_dataset = iter(dataset)\n","if strategy is not None:\n","    dataset = strategy.experimental_distribute_dataset(dataset)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"zA-VWJCQvgP1"},"source":["dataset_test = tf.data.Dataset.from_tensor_slices((trajectories_test))\n","dataset_test = dataset_test.batch(batch_size)\n","iter_dataset_test = iter(dataset_test)\n","if strategy is not None:\n","  dataset_test = strategy.experimental_distribute_dataset(dataset_test)\n","\n","dataset_valid = tf.data.Dataset.from_tensor_slices((trajectories_val))\n","dataset_valid = dataset_valid.batch(batch_size)\n","iter_dataset_valid = iter(dataset_valid)\n","if strategy is not None:\n","  dataset_valid = strategy.experimental_distribute_dataset(dataset_valid)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"YQtZmsRjrO83"},"source":["UPDATE_FREQ = 1\n","SHOW_PLOTS = True\n","loss_train_x = []\n","loss_train_y = []\n","loss_valid_x = []\n","loss_valid_y = []\n","loss_test_x = []\n","loss_test_y = []\n","losses_per_epoch = []\n","\n","@tf.function\n","def distribute_train_step(data):\n","    def replica_fn (d):\n","        d_im = d[:, :NUM_IMAGES_TOTAL, :, :]\n","        d_x = d[:, :NUM_IMAGES_CONTEXT, :, :]\n","        print ('d_x.shape: ', d_x.shape)\n","        print ('d_im.shape: ', d_im.shape)\n","        return autodiff_inimnet.optimise (d_x, d_im, t, tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE))\n","                      \n","    if strategy is not None:         \n","        per_replica_result = [strategy.run(replica_fn, args=(data,))]\n","        print ('per_replica_result: ', per_replica_result)\n","        return strategy.gather(per_replica_result, axis=0)\n","    else:\n","        results = replica_fn (data)\n","        return [results]\n","\n","@tf.function\n","def distribute_test_step (data):\n","    def replica_fn_test (d):\n","        d_im = d[:, :NUM_IMAGES_EVAL, :, :]\n","        d_x = d_im[:, :NUM_IMAGES_CONTEXT, :, :]\n","        print ('d_x.shape: ', d_x.shape)\n","        print ('d_im.shape: ', d_im.shape)\n","        return d_im, autodiff_inimnet (d_x, t_eval, training=False)\n","    if strategy is not None: \n","        return strategy.gather(strategy.run(replica_fn_test, args=(data,)), axis=0)\n","    else:\n","        return replica_fn_test(data)\n","              \n","@tf.function\n","def distribute_test_step_loss (data, async_exec=True):\n","    def replica_fn_test_loss (d):\n","        d_im = d[:, NUM_IMAGES_CONTEXT:NUM_IMAGES_EVAL, :, :]\n","        d_x = d[:, :NUM_IMAGES_CONTEXT, :, :]\n","        print ('d_x.shape: ', d_x.shape)\n","        print ('d_im.shape: ', d_im.shape)\n","        result = autodiff_inimnet (d_x, t_eval[NUM_IMAGES_CONTEXT:], training=False)\n","        #d_im = tf.reshape (d_im, [d_im.shape[0] * d_im.shape[1], d_im.shape[2], d_im.shape[3]])\n","        result = [tf.reshape (r, [ tf.shape(d_im)[0], d_im.shape[1], r.shape[1], r.shape[2]]) for r in result]\n","                        \n","        return [tf.math.reduce_mean(tf.square(r - d_im), axis=[2, 3]) for r in result]\n","    if strategy is not None: \n","        if async_exec:\n","            per_replica_result = strategy.run(replica_fn_test_loss, args=(data,))\n","            results = [strategy.gather(single_result, axis=0) for single_result in per_replica_result]\n","            return [tf.reduce_sum(r, axis=0) for r in results]\n","        else:\n","            results = replica_fn_test_loss (strategy.gather(data, axis=0))\n","            return [tf.reduce_sum(r, axis=0) for r in results]\n","    else:\n","        results = replica_fn_test_loss (data)\n","        return [tf.reduce_sum(r, axis=0) for r in results]\n","\n","for epoch in range(num_epochs):\n","    import time\n","\n","    start_time = time.time()\n","    running_losses = []\n","    for x in dataset:\n","        curr_loss = distribute_train_step(x)\n","        losses.extend(curr_loss)\n","        running_losses.extend(curr_loss)\n","        x_last = x\n","    print ('training finished')\n","    losses_per_epoch.append(np.array(tf.reduce_mean(running_losses)))\n","\n","    if epoch % UPDATE_FREQ == 0:\n","        print (epoch)\n","    if epoch % UPDATE_FREQ == 0:\n","        print (epoch)\n","        curr_loss_train = None\n","        curr_loss_test = None\n","        curr_loss_valid = None\n","         \n","        for x_train in dataset:\n","            if curr_loss_train is None:\n","                curr_loss_train = np.array(distribute_test_step_loss (x_train))\n","            else:\n","                curr_loss_train += np.array(distribute_test_step_loss (x_train))\n","        curr_loss_train /= trajectories_train.shape[0]\n","       \n","        for x_test in dataset_test:          \n","            curr_loss_value = distribute_test_step_loss (x_test)\n","            curr_loss_value = np.array(curr_loss_value)\n","            if curr_loss_test is None:\n","                curr_loss_test = curr_loss_value\n","            else:\n","                if not len(curr_loss_value.shape)  == len(curr_loss_test.shape):\n","                    curr_loss_value = np.array(distribute_test_step_loss (x_test, False))\n","                else: \n","                    x_last_test = x_test\n","                curr_loss_test += curr_loss_value\n","                \n","        curr_loss_test /= trajectories_test.shape[0]\n","\n","        for x_valid in dataset_valid:          \n","            curr_loss_value = distribute_test_step_loss (x_valid)\n","            curr_loss_value = np.array(curr_loss_value)\n","            if curr_loss_valid is None:\n","                curr_loss_valid = curr_loss_value\n","            else:\n","                if not len(curr_loss_value.shape)  == len(curr_loss_valid.shape):\n","                    curr_loss_value = np.array(distribute_test_step_loss (x_valid, False))\n","                curr_loss_valid += curr_loss_value\n","                \n","        curr_loss_valid /= trajectories_val.shape[0]\n","\n","        loss_train_x.append(epoch)\n","        loss_train_y.append(curr_loss_train)\n","        loss_test_x.append(epoch)\n","        loss_test_y.append(curr_loss_test)\n","        loss_valid_x.append(epoch)\n","        loss_valid_y.append(curr_loss_valid)\n","\n","        end_time = time.time()\n","        print('Time per epoch: ', end_time - start_time)\n","        if not SHOW_PLOTS:\n","            continue\n","        for layer_id in range (1, curr_loss_train.shape[0]):\n","            plt.plot(loss_train_x, np.mean(np.array(loss_train_y)[:, layer_id, :], axis=1), label='train_loss (l' + str(layer_id)+')')\n","            plt.plot(loss_test_x, np.mean(np.array(loss_test_y)[:, layer_id, :], axis=1), label='test_loss (l' + str(layer_id)+')')        \n","            plt.plot(loss_valid_x, np.mean(np.array(loss_valid_y)[:, layer_id, :], axis=1), label='valid_loss(l' + str(layer_id)+')')\n","            plt.legend()\n","\n","        plt.show()\n","\n","        for layer_id in range (1, curr_loss_train.shape[0]):\n","            plt.plot(np.array(loss_train_y)[-1, layer_id, :], label='train_loss (l' + str(layer_id)+', '+str(np.mean(np.array(loss_train_y)[-1, layer_id, :]))+')')\n","            plt.plot(np.array(loss_test_y)[-1, layer_id, :], label='test_loss (l' + str(layer_id)+', '+str(np.mean(np.array(loss_test_y)[-1, layer_id, :]))+')')        \n","            plt.plot(np.array(loss_valid_y)[-1, layer_id, :], label='valid_loss(l' + str(layer_id)+', '+str(np.mean(np.array(loss_valid_y)[-1, layer_id, :]))+')')\n","            plt.legend() \n","\n","        plt.show()\n","        print ('Loss (frame_id), training:')\n","        for layer_id in range (1, curr_loss_train.shape[0]):\n","            print ('layer_id: ', layer_id)\n","            print  (np.array(loss_train_y)[-1, layer_id, :])\n","        print ('Loss (frame_id), validation:')\n","        for layer_id in range (1, curr_loss_train.shape[0]):\n","            print ('layer_id: ', layer_id)\n","            print  (np.array(loss_valid_y)[-1, layer_id, :])\n","        print ('Loss (frame_id), testing:')\n","        for layer_id in range (1, curr_loss_train.shape[0]):\n","            print ('layer_id: ', layer_id)\n","            print  (np.array(loss_test_y)[-1, layer_id, :])\n","\n","        plt.plot (losses_per_epoch)\n","        plt.show()\n","\n","        \n","        d_im, pred_train  =  distribute_test_step (x_last)\n","        print ('Data sample (training)')\n","        print ('GT:')\n","        fig, axes = plt.subplots(1, d_im.shape[1])\n","        \n","        #print ('axes.shape[0]: ', axes.shape[0])\n","        for i in  range(1):\n","            for j in  range(axes.shape[0]):\n","                axes[j].get_xaxis().set_visible(False)\n","                axes[j].get_yaxis().set_visible(False)\n","                axes[j].imshow(d_im[0, i*axes.shape[0]+j, :, :])\n","        plt.show()\n","\n","        print ('Prediction: ')\n","        for layer_id in range(len(pred_train)): \n","            print ('Layer ' + (str(layer_id)))\n","            fig, axes = plt.subplots(1, d_im.shape[1])\n","            for i in  range(1):\n","                for j in  range(axes.shape[0]):\n","                    axes[j].get_xaxis().set_visible(False)\n","                    axes[j].get_yaxis().set_visible(False)\n","                    axes[j].imshow(pred_train[layer_id][i*axes.shape[0]+j, :, :])\n","            plt.show()\n","\n","        d_im_test, pred_test  =  distribute_test_step (x_last_test)\n","        print ('Data sample (testing)')\n","        print ('GT:')\n","        fig, axes = plt.subplots(1, d_im_test.shape[1])\n","        print ('axes.shape[0]: ', axes.shape[0])\n","        for i in  range(1):\n","            for j in  range(axes.shape[0]):\n","                axes[j].get_xaxis().set_visible(False)\n","                axes[j].get_yaxis().set_visible(False)\n","                axes[j].imshow(d_im_test[0, i*axes.shape[0]+j, :, :])\n","        plt.show()\n","\n","        print ('Prediction: ')\n","        for layer_id in range(len(pred_test)): \n","            print ('Layer ' + (str(layer_id)))\n","            fig, axes = plt.subplots(1, d_im_test.shape[1])\n","            for i in  range(1):\n","                for j in  range(axes.shape[0]):\n","                    axes[j].get_xaxis().set_visible(False)\n","                    axes[j].get_yaxis().set_visible(False)\n","                    axes[j].imshow(pred_test[layer_id][i*axes.shape[0]+j, :, :])\n","            plt.show()\n","      \n","  \n","plt.plot(np.array(losses))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"0YWhEPcJzLe4"},"source":["plt.figure(figsize=(10, 10))\n","cumsum_value = np.array(losses)#np.cumsum (np.array(losses), axis = 1)\n","interp = 10\n","cumsum_value = cumsum_value[:cumsum_value.shape[0]-cumsum_value.shape[0]%interp]\n","x_cumsum = np.array (range (cumsum_value.shape[0]))\n","x_downsampled = np.array([np.min (arr, axis=0) for arr in np.split (x_cumsum, interp, axis=0)])\n","cumsum_value = np.array([np.mean (arr, axis=0) for arr in np.split (cumsum_value, interp, axis=0)])\n","for i in range (cumsum_value.shape[1]):\n","    plt.plot(x_downsampled, cumsum_value[:, i], label = 'Layer' + str(i+1))\n","\n","plt.legend()\n","plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"C-VNHebj3E1S"},"source":["x_test = next(iter_dataset_test)\n","\n","\n","if strategy is not None:\n","    d_im, pred_test  =  distribute_test_step (x_test)\n","else:\n","    d_im = x_test[:, :NUM_IMAGES_TOTAL, :, :]\n","    d_x = x_test[:, :NUM_IMAGES_CONTEXT, :, :]\n","    pred_train = autodiff_inimnet (d_x, t) [-1]\n","#pred_test = autodiff_inimnet (d_x, t)[-1]\n","\n","print ('pred_train.shape: ', pred_test)\n","print ('animation started')\n","fig, axes = plt.subplots(5, d_im.shape[1]//5)\n","print ('axes.shape[0]: ', axes.shape[0])\n","for i in  range(axes.shape[0]):\n","    for j in  range(axes.shape[1]):\n","        axes[i][j].get_xaxis().set_visible(False)\n","        axes[i][j].get_yaxis().set_visible(False)\n","        axes[i][j].imshow(d_im[0, i*axes.shape[1]+j, :, :])\n","plt.show()\n","print ('animation ended')\n","\n","print ('animation started:l1')\n","fig, axes = plt.subplots(5, d_im.shape[1]//5)\n","for i in  range(axes.shape[0]):\n","    for j in  range(axes.shape[1]):\n","        axes[i][j].get_xaxis().set_visible(False)\n","        axes[i][j].get_yaxis().set_visible(False)\n","        axes[i][j].imshow(pred_test[i*axes.shape[1]+j, :, :])\n","plt.show()\n","print ('animation ended: l1')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"OUow-wEc1m88"},"source":["print (running_losses)"],"execution_count":null,"outputs":[]}]}