{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"_Discrete_inimnet_rotating_mnist_coupled_layers.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","metadata":{"id":"Mk98tCrKfOfD"},"source":["from google.colab import drive\n","drive.mount('/gdrive')\n","%cd /gdrive/My Drive/InImNet/inim_code/"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"byftFOmFetBP"},"source":["import numpy as np\n","import os\n","import tensorflow as tf\n","#tf.get_logger().setLevel('ERROR')\n","import argparse\n","import matplotlib as mpl\n","import matplotlib.pyplot as plt\n","import random\n","\n","GPU_ID=0\n","ELL_MAX = 4\n","ELL_MAX_EVAL = 8\n","NTOTAL = 50\n","INFLATION_FACTOR = 2\n","ACTIVATION = tf.keras.activations.relu\n","NUM_INTERNAL_LAYERS = 3\n","BATCH_SIZE = 25\n","I_EVAL = 4\n","N_MASK = 3\n","NUM_IMAGES_TOTAL = 16 \n","NUM_IMAGES_CONTEXT = 1\n","DIM_LATENT = 16 *  64\n","INNER_DIM_LATENT=20\n","DROPOUT_VALUE=0.3\n","LR_INIM = 0.001\n","NUM_EPOCHS = 500\n","REPORT_INTERVAL_EPOCHS = 5\n","t = np.linspace(0, 1, NUM_IMAGES_TOTAL).astype(np.float32)\n","t_eval = np.linspace(0, 1, NUM_IMAGES_TOTAL).astype(np.float32)\n","print ('t: ', t)\n","print ('t_eval: ', t_eval)\n","\n","\n","mpl.rcParams['figure.figsize'] = (8, 6)\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"YcIA8ZpMe4Ol"},"source":["class InputAutoencoder (tf.keras.Model):\n","    def __init__(self, represent_dim_in=50, represent_dim_out=50):\n","        super(InputAutoencoder, self).__init__()\n","        \n","        self.encoder_input = []\n","        self.decoder_input = []\n","\n","\n","        self.encoder_input.append(tf.keras.layers.Conv2D(16, (5, 5), strides=2, padding='same'))\n","        self.encoder_input.append(tf.keras.layers.BatchNormalization())\n","        self.encoder_input.append(tf.keras.layers.ReLU())\n","        self.encoder_input.append(tf.keras.layers.Conv2D(32, (5, 5), strides=2, padding='same'))\n","        self.encoder_input.append(tf.keras.layers.BatchNormalization())\n","        self.encoder_input.append(tf.keras.layers.ReLU())\n","        self.encoder_input.append(tf.keras.layers.Conv2D(64, (5, 5), strides=2, padding='same'))\n","        self.encoder_input.append(tf.keras.layers.ReLU())\n","        #self.encoder_input.append(tf.keras.layers.BatchNormalization())\n","        #self.encoder_input.append(tf.keras.layers.Conv2D(4, (5, 5), strides=1, activation='relu', padding='same'))\n","        #self.encoder_input.append(tf.keras.layers.Flatten())\n","\n","\n","        self.dim_reducer = tf.keras.layers.Dense (represent_dim_in)\n","        self.dim_reconstructor = tf.keras.layers.Dense (represent_dim_out)\n","        \n","        #self.max_pool_output_shape = [6, 6, 8]\n","        #self.flattened_max_pool_output_shape = 288\n","\n","        #self.decoder_input.append(tf.keras.layers.Dense(self.flattened_max_pool_output_shape, activation='relu')) \n","        #self.decoder_input.append(tf.keras.layers.Reshape(self.max_pool_output_shape))\n","        self.decoder_input.append(tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=2, padding='same'))\n","        self.decoder_input.append(tf.keras.layers.BatchNormalization())\n","        self.decoder_input.append(tf.keras.layers.ReLU())\n","        self.decoder_input.append(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=2, padding='same'))\n","        self.decoder_input.append(tf.keras.layers.BatchNormalization())\n","        self.decoder_input.append(tf.keras.layers.ReLU())\n","        self.decoder_input.append(tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=2, padding='same'))\n","        self.decoder_input.append(tf.keras.layers.BatchNormalization())\n","        self.decoder_input.append(tf.keras.layers.ReLU())\n","        self.decoder_input.append(tf.keras.layers.Conv2DTranspose(1, (5, 5), padding='same', activation='sigmoid'))\n","\n","\n","    def call_encoder(self, x, training):\n","        z = x               \n","        for i in range (len(self.encoder_input)):       \n","            z = self.encoder_input[i](z, training=training)   \n","            #print ('encoder['+str(i)+'].shape: ', z.shape)\n","        return z\n","\n","    def call_dim_reducer(self, x, training):\n","        return self.dim_reducer (x, training=training) \n","\n","    def call_dim_reconstructor(self, x, training):\n","        return self.dim_reconstructor (x, training=training) \n","\n","    def call_decoder (self, x, training):\n","        z = x               \n","        for i in range (len(self.decoder_input)):       \n","            z = self.decoder_input[i](z, training=training)   \n","            #print ('decoder['+str(i)+'].shape: ', z.shape)\n","        z = z[:, 2:-2, 2:-2, :]\n","        return z\n","    \n","\n","\n","class AutoDiffInImNet(tf.keras.Model):\n","    \"\"\"Choose output method for model.\"\"\"\n","\n","    def __init__(self, dim: int, represent_dim_out: int,\n","                 num_resnet_layers: int,\n","                 activation,\n","                 num_internal_layers: int,\n","                 num_eval_resnet_layers: int,\n","                 bias_on: bool = True,\n","                 mult: int = 1,               \n","                 use_batch_norm: bool = False,\n","                 dropout=0.1,\n","                 lr_inim = 0.001,\n","                 approx_jacobian = True,\n","                 weight_regularisation_alpha = 0.0, \n","                 n_mask = 3,\n","                 i_mask = 4):\n","        super(AutoDiffInImNet, self).__init__()\n","        self.num_internal_layers = num_internal_layers\n","        self.num_resnet_layers = num_resnet_layers\n","        self.num_eval_resnet_layers = num_eval_resnet_layers\n","        self.activation = activation\n","        self.dim = dim\n","        self.n_mask = n_mask\n","        self.i_mask = i_mask\n","        self.lr_inim = lr_inim#tf.keras.optimizers.schedules.PiecewiseConstantDecay (boundaries=[128], values=[lr_inim/50, lr_inim])#tf.keras.optimizers.schedules.ExponentialDecay(lr_inim, decay_steps=15625, decay_rate=0.5, staircase=True)\n","        self.represent_dim_out = represent_dim_out\n","        self.weight_regularisation_alpha = weight_regularisation_alpha\n","        self.t_reshaped_mask = None\n","        self.row_val = None\n","\n","        self.fc_network = []\n","        for i in range(num_resnet_layers):\n","            self.fc_network.append([])\n","            for j in range (num_internal_layers-1):\n","                self.fc_network[i].append(tf.keras.layers.Dense(dim*mult, activation=activation))\n","                #self.fc_network[i].append(tf.keras.layers.BatchNormalization())\n","                #self.fc_network[i].append(tf.keras.layers.Activation(activation))\n","                self.fc_network[i].append(tf.keras.layers.Dropout(dropout))\n","                \n","            self.fc_network[i].append(tf.keras.layers.Dense(dim+1, activation=activation))\n","        self.optimiser = []\n","        for i in range (self.num_resnet_layers):\n","            self.optimiser.append(tf.keras.optimizers.Adam(learning_rate=self.lr_inim))\n","        self.approx_jacobian = approx_jacobian\n","        self.autoencoder =  InputAutoencoder (represent_dim_in=dim, represent_dim_out=represent_dim_out)\n","\n","        \n","    def call_phi (self, ell, x, training):\n","        z = x        \n","        ell = 0       \n","        for i in range (len(self.fc_network[ell])):       \n","            z = self.fc_network[ell][i](z, training=training)    \n","        return z\n","\n","    #@tf.function\n","    def optimise (self, x, y_in, t, loss): \n","        self.output_dim = y_in.shape[2]\n","        #print ('optimise')\n","\n","        with tf.GradientTape(persistent=True) as g: \n","            z, out_mask = self (x, t, training=True, use_mask = True)\n","            y_in_shape = self._infer_shape (y_in)\n","            y = tf.reshape (y_in, [y_in_shape[0] * y_in_shape[1], y_in_shape[2], y_in_shape[3]])\n","            print ('y.shape: ', y.shape)\n","            print ('z[0].shape: ', z[0].shape)\n","            print ('out_mask.shape: ', out_mask.shape)\n","            y = tf.boolean_mask(y, out_mask)\n","            losses = []\n","            total_loss = 0\n","            for i in range (self.num_resnet_layers):\n","                #print (y[:, i, :])\n","                curr_loss = tf.math.reduce_mean(loss(z[i+1], y))\n","                losses.append(curr_loss)\n","                total_loss = total_loss + curr_loss\n","            print ('losses: ', losses)\n","            #for i in range (self.num_resnet_layers):\n","            #    for fc_ij in self.fc_network[i]:\n","            #        print ('fc_ij.trainable_weights: ', fc_ij.trainable_weights)\n","\n","            #   print ('[fc_ij.trainable_weights for fc_ij in self.encoder[i]+[self.decoder[i]]]: ', [fc_ij.trainable_weights for fc_ij in self.fc_network[i]])\n","            #    #if i < self.num_resnet_layers-1:\n","            #    self.optimiser[i].minimize (losses[i],[fc_ij.trainable_weights for fc_ij in self.fc_network[i]], tape=g)\n","                #else:\n","                #self.optimiser[i].minimize (losses[i],[fc_ij.trainable_weights for fc_ij in self.fc_network[i]+self.autoencoder.encoder_input+self.autoencoder.decoder_input+[self.autoencoder.dim_reducer]+[self.autoencoder.dim_reconstructor]], tape=g)\n","            #    print ('Minimisation finished...')\n","\n","            norm_w = []\n","            for i in range (self.num_resnet_layers):\n","                for fc_ij in self.fc_network[i]:\n","                    norm_w.extend (fc_ij.trainable_weights)\n","            print ('norm_w: ', norm_w)\n","            norms = [0.5 * tf.reduce_sum(tf.square(fc_ij_weights)) for fc_ij_weights in norm_w]\n","            print ('norms: ', norms)\n","            w_regularisation = tf.reduce_sum(norms)\n","            print ('w_regularisation: ', w_regularisation)\n","            final_loss = losses[-1]+self.weight_regularisation_alpha * w_regularisation\n","\n","        w = []\n","        for i in range (self.num_resnet_layers):\n","            w.extend ([fc_ij.trainable_weights for fc_ij in self.fc_network[i]])  \n","        w.extend ([fc_ij.trainable_weights for fc_ij in self.autoencoder.encoder_input+self.autoencoder.decoder_input+[self.autoencoder.dim_reducer]+[self.autoencoder.dim_reconstructor]])\n","        self.optimiser[0].minimize (final_loss, w, tape=g)\n","\n","        #for i in range (len(losses)):\n","        #   self.autoencoder.optimise_autoencoder (losses[i], g)\n","        \n","        \n","        return final_loss\n","                  \n","    def jacobian (self, z, x, g):\n","        return g.batch_jacobian(z, x) \n","\n","    def _infer_shape(self, x):\n","        x = tf.convert_to_tensor(x)\n","\n","        # If unknown rank, return dynamic shape\n","        if x.shape.dims is None:\n","            return tf.shape(x)\n","\n","        static_shape = x.shape.as_list()\n","        dynamic_shape = tf.shape(x)\n","\n","        ret = []\n","        for i in range(len(static_shape)):\n","            dim = static_shape[i]\n","            if dim is None:\n","                dim = dynamic_shape[i]\n","            ret.append(dim)\n","\n","        return ret\n","\n","    #@tf.function\n","    def call(self, x, t, training=False, use_mask = False):\n","        x_old_shape =  self._infer_shape(x)\n","\n","        #inferred_shape = x_old_shape\n","        #x = tf.reshape (x, [inferred_shape[0] * x.shape[1], x.shape[2], x.shape[3]])\n","        #x = tf.expand_dims (x, -1)\n","        x = tf.transpose (x, [0, 2, 3, 1])\n","        encoded_images = self.autoencoder.call_encoder (x, training=training)\n","    \n","        shape_encoded = self._infer_shape(encoded_images)\n","        dim = tf.reduce_prod(shape_encoded[1:])\n","        #encoded_images = tf.reshape(encoded_images, [-1, dim])\n","        #encoded_images_shape = self._infer_shape(encoded_images)\n","        #encoded_images = tf.reshape (encoded_images, [x_old_shape[0], x_old_shape[1], encoded_images_shape[1]])\n","        #encoded_images_shape = self._infer_shape(encoded_images)\n","        encoded_images = tf.reshape (encoded_images, [shape_encoded[0], dim])\n","        #print (encoded_images.shape)\n","        encoded_images = self.autoencoder.call_dim_reducer (encoded_images, training=training)\n","\n","        t_shape = self._infer_shape(t)\n","        print ('t_shape: ', t_shape)\n","  \n","        t_reshaped = tf.tile(tf.expand_dims(t, axis=0), [x_old_shape[0], 1])  \n","        \n","        if use_mask:\n","            #if self.t_reshaped_mask is None:\n","            #    self.t_reshaped_mask = tf.Variable(tf.zeros_like (t_reshaped))\n","            #else:\n","            #    self.t_reshaped_mask.assign(tf.zeros_like (t_reshaped))\n","            #t_reshaped_mask = self.t_reshaped_mask\n","            t_reshaped_mask = []\n","            for ind_val in range(x_old_shape[0]): \n","                i_val = tf.range(t_shape[0]-1)\n","                i_val = tf.where(i_val >=self.i_mask, i_val+1, i_val)\n","                i_val = tf.random.shuffle(i_val)\n","                ind_rm=i_val[:self.n_mask]\n","                if self.row_val is None:\n","                    self.row_val = tf.Variable(1-tf.one_hot([self.i_mask], t_shape[0]))\n","                else:              \n","                    self.row_val.assign (1-tf.one_hot([self.i_mask], t_shape[0]))\n","                #print ('self.row_val: ', self.row_val)\n","                row_val = self.row_val\n","                #row_val [self.i_mask].assign(0) \n","                for v in range(self.n_mask):\n","                    row_val.assign_add (1-tf.one_hot([ind_rm[v]], t_shape[0]))# [ind_rm[v]].assign(0)\n","                \n","                t_reshaped_mask.append (row_val)\n","            t_reshaped_mask = tf.stack (t_reshaped_mask)\n","            t_reshaped_mask = tf.reshape(t_reshaped_mask, [t_shape[0]*x_old_shape[0]])\n","            t_reshaped = tf.reshape(t_reshaped, [t_shape[0]*x_old_shape[0], 1])\n","            t_reshaped = t_reshaped[t_reshaped_mask > 0]\n","            x_reshaped = tf.repeat (encoded_images, t_shape[0], axis=0)\n","            x_reshaped = tf.boolean_mask(x_reshaped, t_reshaped_mask > 0)\n","        else:\n","            t_reshaped = tf.reshape(t_reshaped, [t_shape[0]*x_old_shape[0], 1])\n","            x_reshaped = tf.repeat (encoded_images, t_shape[0], axis=0)\n","        #x_reshaped = tf.reshape(x_reshaped, [x_reshaped.shape[0], x_reshaped.shape[1]*x_reshaped.shape[2]])\n","        x = tf.concat([x_reshaped, t_reshaped], axis=1)\n","        \n","        z = [x]\n","        print ('x.shape: ', x.shape)\n","        print (z)\n","        with tf.GradientTape(persistent=True) as g:\n","          g.watch (x)\n","          for ell in range (self.num_eval_resnet_layers):\n","              print ('ell: ', ell)\n","              #dim = tf.reduce_prod(tf.shape(x)[1:])\n","              #x_flattened = tf.reshape(x, [-1, dim])\n","              #z_flattened = tf.reshape(z[-1], [-1, dim])\n","              with g.stop_recording():\n","                  if not self.approx_jacobian:\n","                      jacobian_z_x = self.jacobian(z[-1], x, g) \n","                  else:\n","                      if ell == 0:\n","                          jacobian_z_x = self.jacobian(x, x, g) \n","                      else:\n","                          jacobian_z_x +=  self.jacobian(phi_curr, x, g) \n","                  #print (jacobian_z_x.shape)\n","              phi_curr = self.call_phi (ell, x, training=training)\n","              delta = jacobian_z_x @  tf.expand_dims(phi_curr, -1)\n","              delta = tf.squeeze(delta, -1)\n","              #print(delta.shape)\n","              z.append(z[-1] +  delta)\n","\n","        for ell in range (self.num_eval_resnet_layers+1): \n","            print (ell)  \n","            z[ell] = z[ell][:, :-1]\n","            z[ell] = self.autoencoder.call_dim_reconstructor (z[ell], training=training)\n","            z_ell_shape = self._infer_shape(z[ell])\n","            z[ell] = tf.reshape (z[ell], [z_ell_shape[0], shape_encoded[1], shape_encoded[2], shape_encoded[3]])\n","            z[ell] = self.autoencoder.call_decoder (z[ell],training=training)\n","            z[ell] = z[ell] [:, :, :, 0]\n","        \n","        print ('call finished')\n","        if use_mask:\n","            return z, t_reshaped_mask > 0\n","        else: \n","            return z"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"rwPYE61td9xp"},"source":["import matplotlib.pyplot as plt\n","from scipy.io import loadmat\n","import os\n","import shutil\n","def load_data(file_name = 'data/rotated_mnist/rot-mnist-3s.mat', N=500, T=16):\n","    X = loadmat(file_name)['X'].squeeze() # (N, 16, 784)\n","    print (X.shape)\n","    X_train = X[:N, :, :].astype(np.float32)\n","    X_test = X[N:, :, :].astype(np.float32)\n","    X_train   = X_train.reshape([N,T,28,28])\n","    X_test = X_test.reshape([-1,T,28,28])\n","    return X_train, X_test\n","\n","\n","def do_training (task_id):\n","    LOG_DIR = 'logs_rotating_mnist'+str(task_id)    \n","    if os.path.isdir(LOG_DIR):\n","       shutil.rmtree(LOG_DIR, ignore_errors=True)\n","    os.mkdir (LOG_DIR)\n","    trajectories_train, trajectories_test = load_data ()\n","\n","    try:\n","        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')\n","        tf.config.experimental_connect_to_cluster(resolver)\n","        # This is the TPU initialization code that has to be at the beginning.\n","        tf.tpu.experimental.initialize_tpu_system(resolver)\n","        print(\"All devices: \", tf.config.list_logical_devices('TPU'))\n","        strategy = tf.distribute.TPUStrategy(resolver)\n","    except:\n","        strategy = None\n","    autodiff_inimnet =  AutoDiffInImNet(dim=INNER_DIM_LATENT, represent_dim_out=DIM_LATENT, \n","                num_resnet_layers=ELL_MAX,\n","                num_eval_resnet_layers=ELL_MAX_EVAL,\n","                activation=ACTIVATION,\n","                num_internal_layers=NUM_INTERNAL_LAYERS,\n","                bias_on=True,\n","                mult=INFLATION_FACTOR,\n","                dropout=DROPOUT_VALUE,\n","                lr_inim = tf.keras.optimizers.schedules.ExponentialDecay (LR_INIM, 30 * trajectories_train.shape[0] // BATCH_SIZE, 0.5, staircase=True),\n","                weight_regularisation_alpha = 0.0,#0.00001\n","                n_mask=N_MASK,\n","                i_mask=I_EVAL\n","                )\n","    losses = []\n","\n","    #images = trajectories_train[indices, :NUM_IMAGES_TOTAL, :, :]\n","    #x = images [:, :NUM_IMAGES_CONTEXT, :, :]\n","    dataset = tf.data.Dataset.from_tensor_slices((trajectories_train))\n","    dataset = dataset.shuffle(512).batch(BATCH_SIZE)\n","    iter_dataset = iter(dataset)\n","    if strategy is not None:\n","        dataset = strategy.experimental_distribute_dataset(dataset)\n","\n","    dataset_test = tf.data.Dataset.from_tensor_slices((trajectories_test))\n","    dataset_test = dataset_test.batch(BATCH_SIZE)\n","    iter_dataset_test = iter(dataset_test)\n","    if strategy is not None:\n","        dataset_test = strategy.experimental_distribute_dataset(dataset_test)\n","\n","    loss_train_x = []\n","    loss_train_y = []\n","    loss_valid_x = []\n","    loss_valid_y = []\n","    loss_test_x = []\n","    loss_test_y = []\n","    losses_per_epoch = []\n","\n","    @tf.function\n","    def distribute_train_step(data, async_exec=True):\n","        def replica_fn (d):\n","            d_im = d[:, :NUM_IMAGES_TOTAL, :, :]\n","            d_x = d[:, :NUM_IMAGES_CONTEXT, :, :]\n","            print ('d_x.shape: ', d_x.shape)\n","            print ('d_im.shape: ', d_im.shape)\n","            return autodiff_inimnet.optimise (d_x, d_im, t, tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE))\n","              \n","        if strategy is not None:         \n","            if async_exec:\n","                per_replica_result = [strategy.run(replica_fn, args=(data,))]\n","                print ('per_replica_result: ', per_replica_result)\n","                return strategy.gather(per_replica_result, axis=0)\n","            else:\n","                results = replica_fn (strategy.gather(data, axis=0))\n","                print (results)\n","                return results\n","        else:\n","            results = replica_fn (data)\n","            return [results]\n","\n","    @tf.function\n","    def distribute_test_step (data):\n","        def replica_fn_test (d):\n","            d_im = d[:, :, :, :]\n","            d_x = d_im[:, :NUM_IMAGES_CONTEXT, :, :]\n","            print ('d_x.shape: ', d_x.shape)\n","            print ('d_im.shape: ', d_im.shape)\n","            return d_im, autodiff_inimnet (d_x, t_eval, training=False)\n","        if strategy is not None: \n","            return strategy.gather(strategy.run(replica_fn_test, args=(data,)), axis=0)\n","        else:\n","            return replica_fn_test(data)\n","          \n","    @tf.function\n","    def distribute_test_step_loss (data, async_exec=True):\n","        def replica_fn_test_loss (d):\n","            d_im = d[:, :, :, :]\n","            d_x = d[:, :NUM_IMAGES_CONTEXT, :, :]\n","            print ('d_x.shape: ', d_x.shape)\n","            print ('d_im.shape: ', d_im.shape)\n","            result = autodiff_inimnet (d_x, t_eval, training=False)\n","            #d_im = tf.reshape (d_im, [d_im.shape[0] * d_im.shape[1], d_im.shape[2], d_im.shape[3]])\n","            result = [tf.reshape (r, [ tf.shape(d_im)[0], d_im.shape[1], r.shape[1], r.shape[2]]) for r in result]\n","                \n","            return [tf.math.reduce_mean(tf.square(r - d_im), axis=[2, 3]) for r in result]\n","        if strategy is not None: \n","            if async_exec:\n","                per_replica_result = strategy.run(replica_fn_test_loss, args=(data,))\n","                results = [strategy.gather(single_result, axis=0) for single_result in per_replica_result]\n","                return [tf.reduce_sum(r, axis=0) for r in results]\n","            else:\n","                results = replica_fn_test_loss (strategy.gather(data, axis=0))\n","                return [tf.reduce_sum(r, axis=0) for r in results]\n","        else:\n","            results = replica_fn_test_loss (data)\n","            return [tf.reduce_sum(r, axis=0) for r in results]\n","\n","    best_value = 1e9\n","    for epoch in range(NUM_EPOCHS):\n","        running_losses = []\n","        for x in dataset:\n","            curr_loss = distribute_train_step(x)\n","            losses.extend(curr_loss)\n","            running_losses.extend(curr_loss)\n","            x_last = x\n","        print ('training finished')\n","        losses_per_epoch.append(np.array(tf.reduce_mean(running_losses)))\n","           \n","        if epoch % REPORT_INTERVAL_EPOCHS == 0:\n","            print (epoch)\n","            curr_loss_train = None\n","            curr_loss_test = None\n","            curr_loss_valid = None\n","             \n","            for x_train in dataset:\n","                if curr_loss_train is None:\n","                    curr_loss_train = np.array(distribute_test_step_loss (x_train))\n","                else:\n","                    curr_loss_train += np.array(distribute_test_step_loss (x_train))\n","            curr_loss_train /= trajectories_train.shape[0]\n","           \n","            for x_test in dataset_test:          \n","                curr_loss_value = distribute_test_step_loss (x_test)\n","                curr_loss_value = np.array(curr_loss_value)\n","                if curr_loss_test is None:\n","                    curr_loss_test = curr_loss_value\n","                    x_last_test = x_test\n","                else:\n","                    if not len(curr_loss_value.shape)  == len(curr_loss_test.shape):\n","                        curr_loss_value = np.array(distribute_test_step_loss (x_test, False))\n","                \n","                    curr_loss_test += curr_loss_value\n","                \n","            curr_loss_test /= trajectories_test.shape[0]\n","\n","            loss_train_x.append(epoch)\n","            loss_train_y.append(curr_loss_train)\n","            loss_test_x.append(epoch)\n","            loss_test_y.append(curr_loss_test)\n","\n","            for layer_id in range (1, curr_loss_train.shape[0]):\n","                #plt.plot(loss_train_x, np.mean(np.array(loss_train_y)[:, layer_id, :], axis=1), label='train_loss (l' + str(layer_id)+')')\n","                plt.plot(loss_test_x, np.mean(np.array(loss_test_y)[:, layer_id, :], axis=1), label='test_loss (l' + str(layer_id)+')')        \n","                #plt.plot(loss_valid_x, np.mean(np.array(loss_valid_y)[:, layer_id, :], axis=1), label='valid_loss(l' + str(layer_id)+')')\n","                plt.legend()\n","\n","            plt.show()#.savefig(LOG_DIR+'/'+str('losses_history_'+str(epoch)+'.png'))\n","\n","            for layer_id in range (1, curr_loss_train.shape[0]):\n","                #plt.plot(np.array(loss_train_y)[-1, layer_id, :], label='train_loss (l' + str(layer_id)+', '+str(np.mean(np.array(loss_train_y)[-1, layer_id, :]))+')')\n","                plt.plot(np.array(loss_test_y)[-1, layer_id, :], label='test_loss (l' + str(layer_id)+', '+str(np.mean(np.array(loss_test_y)[-1, layer_id, :]))+')')        \n","                #plt.plot(np.array(loss_valid_y)[-1, layer_id, :], label='valid_loss(l' + str(layer_id)+', '+str(np.mean(np.array(loss_valid_y)[-1, layer_id, :]))+')')\n","            plt.legend() \n","            plt.show()#.savefig(LOG_DIR+'/'+str('losses_'+str(epoch)+'.png'))\n","            \n","            print ('Loss (frame_id), training:')\n","            for layer_id in range (1, curr_loss_train.shape[0]):\n","                print  (np.array(loss_train_y)[-1, layer_id, :])\n","            print ('Loss (frame_id), testing:')\n","            for layer_id in range (1, curr_loss_train.shape[0]):\n","                print  (np.array(loss_test_y)[-1, layer_id, :])\n","            if np. mean (np.array(loss_train_y)[-1, ELL_MAX, :]) < best_value:\n","                best_value = np. mean (np.array(loss_train_y)[-1, ELL_MAX, :]) \n","                best_value_test_ind = epoch\n","                best_value_test = np.array(loss_test_y)[-1, :, :]            \n","            print('current best value: ', best_value)\n","            print ('best_value_test_ind: ', best_value_test_ind)\n","            print ('best_value_test: ', best_value_test)\n","            print ('best_value_test_mean:', np.mean (best_value_test, axis=1))\n","            print ('test_')        \n","            #plt.plot (losses_per_epoch)\n","            #plt.show()\n","            #with open(LOG_DIR+'/log_train'+'.txt', 'a+') as f:\n","            #     print('epoch: ', epoch, '; ', np.array(loss_train_y)[-1, 1:, :].flatten(), file=f)\n","            #with open(LOG_DIR+'/log_test'+'.txt', 'a+') as f:\n","            #     print('epoch: ', epoch, '; ', np.array(loss_test_y)[-1, 1:, :].flatten(), file=f)\n","            \n","            d_im, pred_train  =  distribute_test_step (x_last)\n","            print ('Data sample (training)')\n","            print ('GT:')\n","            fig, axes = plt.subplots(1, d_im.shape[1])\n","            \n","            print ('axes.shape[0]: ', axes.shape[0])\n","            for i in  range(1):\n","                for j in  range(d_im.shape[1]):\n","                    axes[j].get_xaxis().set_visible(False)\n","                    axes[j].get_yaxis().set_visible(False)\n","                    axes[j].imshow(d_im[0, i*d_im.shape[1]+j, :, :])\n","            plt.show()#.savefig(LOG_DIR+'/data_sample_gt_training_'+str(epoch)+'.png')\n","\n","            print ('Prediction: ')\n","            for layer_id in range(len(pred_train)): \n","                print ('Layer ' + (str(layer_id)))\n","                fig, axes = plt.subplots(1, d_im.shape[1])\n","                for i in  range(1):\n","                     for j in  range(d_im.shape[1]):\n","                         axes[j].get_xaxis().set_visible(False)\n","                         axes[j].get_yaxis().set_visible(False)\n","                         axes[j].imshow(pred_train[layer_id][i*d_im.shape[1]+j, :, :])\n","                plt.show()#.savefig(LOG_DIR+'/data_sample_pred_training_'+str(epoch)+'_'+str(layer_id)+'.png')\n","\n","            d_im_test, pred_test  =  distribute_test_step (x_last_test)\n","            #print ('Data sample (testing)')\n","            #print ('GT:')\n","            fig, axes = plt.subplots(1, d_im_test.shape[1])\n","            print ('axes.shape[0]: ', d_im_test.shape[1])\n","            for i in  range(1):\n","                for j in  range(d_im_test.shape[1]):\n","                    axes[j].get_xaxis().set_visible(False)\n","                    axes[j].get_yaxis().set_visible(False)\n","                    axes[j].imshow(d_im_test[0, i*d_im_test.shape[1]+j, :, :])\n","            plt.show()#savefig(LOG_DIR+'/data_sample_gt_testing_'+str(epoch)+'.png')\n","\n","            #print ('Prediction: ')\n","            for layer_id in range(len(pred_test)): \n","                print ('Layer ' + (str(layer_id)))\n","                fig, axes = plt.subplots(1, d_im_test.shape[1])\n","                for i in  range(1):\n","                    for j in  range(d_im_test.shape[1]):\n","                       axes[j].get_xaxis().set_visible(False)\n","                       axes[j].get_yaxis().set_visible(False)\n","                       axes[j].imshow(pred_test[layer_id][i*d_im_test.shape[1]+j, :, :])\n","                plt.show()#savefig(LOG_DIR+'/data_sample_pred_testing_'+str(epoch)+'_'+str(layer_id)+'.png')\n","      \n","      \n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"V7b2YlX8eIyF"},"source":["seed = 1\n","task_id=1\n","random.seed(seed)\n","np.random.seed(seed)\n","tf.random.set_seed(seed)\n","    \n","do_training (task_id)"],"execution_count":null,"outputs":[]}]}