{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "ICLR MultiMNIST/Fashion.ipynb",
      "provenance": [
        {
          "file_id": "1GsqZjyFqzMOuAYjsY_wYbIaVsM6-P-o7",
          "timestamp": 1601325520976
        },
        {
          "file_id": "1xrFS3S5OPJjk8TmbvGjoc2GtnD8yfA-t",
          "timestamp": 1599684513781
        },
        {
          "file_id": "1mvp25SrzH1w4f8qrAnTHFZXJHZkL_JuN",
          "timestamp": 1589231089092
        }
      ],
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      }
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "14QJtPJWs5xa"
      },
      "source": [
        "### MultiMNIST/Fashion Experiments for Information Transfer In Multi-Task Learning ICLR 2021 Submission \n",
        "\n",
        "Licensed under the Apache License, Version 2.0"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r14fs3V2XLOl"
      },
      "source": [
        "import itertools\n",
        "import pickle\n",
        "import time\n",
        "import copy\n",
        "import math\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "\n",
        "from collections import namedtuple\n",
        "from tqdm import tqdm\n",
        "\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers\n",
        "\n",
        "from absl import app\n",
        "from absl import flags"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vXMZI0KtEipW"
      },
      "source": [
        "# Adapted from https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py\n",
        "GATE_OP = 1\n",
        "\n",
        "class PCGrad(tf.compat.v1.train.Optimizer):\n",
        "  \"\"\"PCGrad. https://arxiv.org/pdf/2001.06782.pdf.\"\"\"\n",
        "\n",
        "  def __init__(self, opt, use_locking=False, name=\"PCGrad\"):\n",
        "    \"\"\"optimizer: the optimizer being wrapped.\"\"\"\n",
        "    super(PCGrad, self).__init__(use_locking, name)\n",
        "    self.optimizer = opt\n",
        "\n",
        "  def compute_gradients(self, loss, var_list=None,\n",
        "                        gate_gradients=GATE_OP,\n",
        "                        aggregation_method=None,\n",
        "                        colocate_gradients_with_ops=False,\n",
        "                        grad_loss=None):\n",
        "    assert isinstance(loss, list)\n",
        "    num_tasks = len(loss)\n",
        "    loss = tf.stack(loss)\n",
        "    tf.random.shuffle(loss)\n",
        "\n",
        "    # Compute per-task gradients.\n",
        "    grads_task = tf.vectorized_map(lambda x: tf.concat(\n",
        "        [tf.reshape(grad, [-1,]) for grad in tf.gradients(\n",
        "            x, var_list) if grad is not None], axis=0), loss)\n",
        "\n",
        "    # Compute gradient projections.\n",
        "    def proj_grad(grad_task):\n",
        "      for k in range(num_tasks):\n",
        "        inner_product = tf.reduce_sum(grad_task*grads_task[k])\n",
        "        proj_direction = inner_product / tf.reduce_sum(\n",
        "            grads_task[k]*grads_task[k])\n",
        "        grad_task = grad_task - tf.minimum(proj_direction, 0.) * grads_task[k]\n",
        "      return grad_task\n",
        "\n",
        "    proj_grads_flatten = tf.vectorized_map(proj_grad, grads_task)\n",
        "\n",
        "    # Unpack flattened projected gradients back to their original shapes.\n",
        "    proj_grads = []\n",
        "    for j in range(num_tasks):\n",
        "      start_idx = 0\n",
        "      for idx, var in enumerate(var_list):\n",
        "        grad_shape = var.get_shape()\n",
        "        flatten_dim = np.prod(\n",
        "            [grad_shape.dims[i].value for i in range(len(grad_shape.dims))])\n",
        "        proj_grad = proj_grads_flatten[j][start_idx:start_idx+flatten_dim]\n",
        "        proj_grad = tf.reshape(proj_grad, grad_shape)\n",
        "        if len(proj_grads) < len(var_list):\n",
        "          proj_grads.append(proj_grad)\n",
        "        else:\n",
        "          proj_grads[idx] += proj_grad\n",
        "        start_idx += flatten_dim\n",
        "    grads_and_vars = list(zip(proj_grads, var_list))\n",
        "    return grads_and_vars\n",
        "\n",
        "  def _create_slots(self, var_list):\n",
        "    self.optimizer._create_slots(var_list)\n",
        "\n",
        "  def _prepare(self):\n",
        "    self.optimizer._prepare()\n",
        "\n",
        "  def _apply_dense(self, grad, var):\n",
        "    return self.optimizer._apply_dense(grad, var)\n",
        "\n",
        "  def _resource_apply_dense(self, grad, var):\n",
        "    return self.optimizer._resource_apply_dense(grad, var)\n",
        "\n",
        "  def _apply_sparse_shared(self, grad, var, indices, scatter_add):\n",
        "    return self.optimizer._apply_sparse_shared(grad, var, indices, scatter_add)\n",
        "\n",
        "  def _apply_sparse(self, grad, var):\n",
        "    return self.optimizer._apply_sparse(grad, var)\n",
        "\n",
        "  def _resource_scatter_add(self, x, i, v):\n",
        "    return self.optimizer._resource_scatter_add(x, i, v)\n",
        "\n",
        "  def _resource_apply_sparse(self, grad, var, indices):\n",
        "    return self.optimizer._resource_apply_sparse(grad, var, indices)\n",
        "\n",
        "  def _finish(self, update_ops, name_scope):\n",
        "    return self.optimizer._finish(update_ops, name_scope)\n",
        "\n",
        "  def _call_if_callable(self, param):\n",
        "    \"\"\"Call the function if param is callable.\"\"\"\n",
        "    return param() if callable(param) else param"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ooXuDeQVe_MH"
      },
      "source": [
        "def del_all_flags(FLAGS):\n",
        "    flags_dict = FLAGS._flags()    \n",
        "    keys_list = [keys for keys in flags_dict]    \n",
        "    for keys in keys_list:\n",
        "        FLAGS.__delattr__(keys)\n",
        "del_all_flags(flags.FLAGS)\n",
        "\n",
        "FLAGS = flags.FLAGS\n",
        "\n",
        "flags.DEFINE_enum('dataset', 'MultiMNIST', ['MultiMNIST', 'MultiFashion'],\n",
        "                  'Dataset: MNIST or Fashion MNIST.')\n",
        "flags.DEFINE_integer('steps', 100, 'Number of epoch to train.')\n",
        "flags.DEFINE_enum('eval', 'valid', ['valid', 'test'], 'The eval dataset.')\n",
        "flags.DEFINE_enum('method', 'mtl', ['mtl', 'it_mtl', 'uncertainty', 'gradnorm', 'pcgrad', 'mgda', \n",
        "                                    'it_uncertainty', 'uncertainty_pcgrad'],\n",
        "                  'Multitask Training Method.')\n",
        "flags.DEFINE_float('lr', 0.001, 'Learning rate.')\n",
        "flags.DEFINE_float('alpha', 0.5, 'Weight for first-pass task 1 and task 2.')\n",
        "flags.DEFINE_integer('batch_size', 256, 'Training data batch size')\n",
        "flags.DEFINE_bool('eval_every_step', True, 'Whether or not to run eval every step or just the last step.')\n",
        "flags.DEFINE_bool('nesterov', False, 'Whether or not to update task-specific parameters before looking ahead.')\n",
        "\n",
        "TRAIN_DATASET_SIZE = 100000\n",
        "EVAL_DATASET_SIZE = 20000\n",
        "METRICS_AVERAGE = 5\n",
        "\n",
        "# Globs for gradnorm and uncertainty weighing.\n",
        "l_weight = r_weight = l_l0_loss = r_l0_loss = l_uncertainty = grad_masks = r_uncertainty = l_fp_uncertainty = r_fp_uncertainty = l_la_uncertainty = r_la_uncertainty = None"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nArMmed5iZmC"
      },
      "source": [
        "class LeNetBase(tf.keras.Model):\n",
        "\n",
        "  def __init__(self):\n",
        "    super(LeNetBase, self).__init__()\n",
        "    self.conv1 = layers.Conv2D(\n",
        "        filters=10,\n",
        "        kernel_size=5,\n",
        "        strides=(1, 1),\n",
        "        padding='valid',\n",
        "        activation='relu')\n",
        "    self.conv2 = layers.Conv2D(\n",
        "        filters=20,\n",
        "        kernel_size=5,\n",
        "        strides=(1, 1),\n",
        "        padding='valid',\n",
        "        activation='relu')\n",
        "\n",
        "  def call(self, inputs, mask=None):\n",
        "    x = inputs\n",
        "    x = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(self.conv1(x))\n",
        "    x = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(self.conv2(x))\n",
        "    x = layers.Flatten()(x)\n",
        "    return x\n",
        "\n",
        "\n",
        "class LeNetTower(tf.keras.Model):\n",
        "\n",
        "  def __init__(self):\n",
        "    super(LeNetTower, self).__init__()\n",
        "    self.fc1 = layers.Dense(50, input_shape=(720,), activation='relu')\n",
        "    self.fc2 = layers.Dense(50, input_shape=(50,), activation='relu')\n",
        "    self.fc3 = layers.Dense(10, input_shape=(50,), activation=None)\n",
        "\n",
        "  def call(self, inputs, mask=None):\n",
        "    x = inputs\n",
        "    x = self.fc1(x)\n",
        "    x = self.fc2(x)\n",
        "    x = self.fc3(x)\n",
        "    return x"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yNULkMbTXnER"
      },
      "source": [
        "def decay_lr(step, optimizer):\n",
        "  if (step + 1) % 30 == 0:\n",
        "    optimizer.lr = optimizer.lr / 2.\n",
        "    print('Decreasing the learning rate by 1/2. New Learning Rate: {}'.format(optimizer.lr))\n",
        "\n",
        "def decay_pcgrad_lr(step, lr_var):\n",
        "  if (step + 1) % 30 == 0:\n",
        "    lr_var.assign(lr_var / 2.)\n",
        "    print('Decreasing the learning rate by 1/2.')\n",
        "\n",
        "def load_dataset(batch_size, dataset='MultiMNIST'):\n",
        "  if dataset == 'MultiMNIST':\n",
        "    path = 'LOCAL PATH HERE/multi_mnist.pickle'\n",
        "    if 'LOCAL PATH HERE' in path:\n",
        "      raise Exception(\"Please change multi_mnist directory path to your local download. Datasets can be downloaded from ParetoMTL (https://github.com/Xi-L/ParetoMTL).\")\n",
        "  elif dataset == 'MultiFashion':\n",
        "    path = 'LOCAL PATH HERE/multi_fashion.pickle'\n",
        "    if 'LOCAL PATH HERE' in path:\n",
        "      raise Exception(\"Please change multi_fashion directory path to your local download. Datasets can be downloaded from ParetoMTL (https://github.com/Xi-L/ParetoMTL).\")\n",
        "\n",
        "  with open(path, 'rb') as f:\n",
        "    trainX, trainY, testX, testY = pickle.load(f)\n",
        "  trainX, testX = np.expand_dims(\n",
        "      trainX, axis=-1).astype(np.float32), np.expand_dims(\n",
        "          testX, axis=-1).astype(np.float32)\n",
        "\n",
        "  valid_indices = [i * 6 for i in range(20000)]\n",
        "  validX = trainX[valid_indices]\n",
        "  validY = trainY[valid_indices]\n",
        "  trainX = np.delete(trainX, valid_indices, axis=0)\n",
        "  trainY = np.delete(trainY, valid_indices, axis=0)\n",
        "\n",
        "  trainYL = tf.keras.utils.to_categorical(trainY[:, 0], num_classes=10)\n",
        "  trainYR = tf.keras.utils.to_categorical(trainY[:, 1], num_classes=10)\n",
        "  testYL = tf.keras.utils.to_categorical(testY[:, 0], num_classes=10)\n",
        "  testYR = tf.keras.utils.to_categorical(testY[:, 1], num_classes=10)\n",
        "  validYL = tf.keras.utils.to_categorical(validY[:, 0], num_classes=10)\n",
        "  validYR = tf.keras.utils.to_categorical(validY[:, 1], num_classes=10)\n",
        "\n",
        "  train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainYL, trainYR))\n",
        "  train_dataset = train_dataset.shuffle(\n",
        "      buffer_size=trainX.shape[0], seed=0,\n",
        "      reshuffle_each_iteration=True).batch(batch_size)\n",
        "  test_dataset = tf.data.Dataset.from_tensor_slices(\n",
        "      (testX, testYL, testYR)).batch(batch_size)\n",
        "  valid_dataset = tf.data.Dataset.from_tensor_slices(\n",
        "      (validX, validYL, validYR)).batch(batch_size)\n",
        "\n",
        "  Dataset = namedtuple('Dataset', ['train', 'valid', 'test'])\n",
        "  return Dataset(train_dataset, valid_dataset, test_dataset)\n",
        "\n",
        "# Frank-Wolfe Solver to find alphas.\n",
        "def min_norm_solver(vecs, dps):\n",
        "  \"\"\"Find the minimum norm solution as a combination of the two points.\"\"\"\n",
        "  dmin = 1e8\n",
        "  for i in range(len(vecs)):\n",
        "    for j in range(i+1, len(vecs)):\n",
        "      if (i,j) not in dps:\n",
        "        dps[(i,j)] = 0.0\n",
        "        for k in range(len(vecs[i])):\n",
        "          dps[(i,j)] += tf.tensordot(vecs[i][k], vecs[j][k], axes=1)\n",
        "        dps[(j,i)] = dps[(i,j)]\n",
        "      if (i,i) not in dps:\n",
        "        dps[(i,i)] = 0.0\n",
        "        for k in range(len(vecs[i])):\n",
        "          dps[(i,i)] += tf.tensordot(vecs[i][k], vecs[i][k], axes=1)\n",
        "      if (j,j) not in dps:\n",
        "        dps[(j,j)] = 0.0\n",
        "        for k in range(len(vecs[i])):\n",
        "          dps[(j,j)] += tf.tensordot(vecs[j][k], vecs[j][k], axes=1)\n",
        "      c,d = min_norm_element(dps[(i,i)], dps[(i,j)], dps[(j,j)])\n",
        "      if d < dmin:\n",
        "        dmin = d\n",
        "        sol = [(i,j), c, d]\n",
        "    \n",
        "    # Initial sol is good enough with tasks = 2 for convex hull.\n",
        "    sol_vec = [0., 0.]\n",
        "    sol_vec[sol[0][0]] = sol[1]\n",
        "    sol_vec[sol[0][1]] = 1 - sol[1]\n",
        "    return sol_vec\n",
        "\n",
        "# Closed form solution for min_{c} |cx_1 + (1-c)x_2|_2^2\n",
        "def min_norm_element(v1v1, v1v2, v2v2):\n",
        "  if v1v2 >= v1v1: # Fig. 1, third column\n",
        "    gamma = 0.99\n",
        "    cost = v1v1\n",
        "    return gamma, cost\n",
        "  if v1v2 >= v2v2: # Fig. 1, first column\n",
        "    gamma = 0.001\n",
        "    cost = v2v2\n",
        "    return gamma, cost\n",
        "  gamma = -1.0*((v1v2-v2v2) / (v1v1 + v2v2 - 2*v1v2))\n",
        "  cost = v2v2 + gamma*(v1v2 - v2v2)\n",
        "  return gamma, cost"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zOrJoED1E8Ib"
      },
      "source": [
        "def add_average(lst, val, n):\n",
        "  if len(lst) < n:\n",
        "    lst.append(val)\n",
        "  elif len(lst) == n:\n",
        "    lst.pop(0)\n",
        "    lst.append(val)\n",
        "  elif len(lst) > n:\n",
        "    raise Exception('List size is greater than n. This should never happen.')\n",
        "\n",
        "def base_step(trainX, labelL, labelR, base_updated, l_params, r_params):\n",
        "  # Make a forward pass with the updated parameters.\n",
        "  conv1 = tf.nn.conv2d(trainX, base_updated[0], strides=(1,1), padding='VALID')\n",
        "  conv1_b = tf.nn.bias_add(conv1, base_updated[1])\n",
        "  conv1_out = tf.nn.max_pool2d(tf.nn.relu(conv1_b), ksize=[1,2,2,1], strides=[1,2,2,1], padding=\"VALID\")\n",
        "  conv2 = tf.nn.conv2d(conv1_out, base_updated[2], strides=(1,1), padding='VALID')\n",
        "  conv2_b = tf.nn.bias_add(conv2, base_updated[3])\n",
        "  conv2_out = tf.nn.max_pool2d(tf.nn.relu(conv2_b), ksize=[1,2,2,1], strides=[1,2,2,1], padding=\"VALID\")\n",
        "  base_out = tf.reshape(conv2_out, [trainX.shape[0], -1])\n",
        "\n",
        "  # Compute the r look-ahead loss using the sb-params updated w.r.t. l_fp_loss.\n",
        "  r_fc1 = tf.nn.relu(tf.matmul(base_out, r_params[0]) + r_params[1])\n",
        "  r_fc2 = tf.nn.relu(tf.matmul(r_fc1, r_params[2]) + r_params[3])\n",
        "  r_fc3 = tf.matmul(r_fc2, r_params[4]) + r_params[5]\n",
        "  r_la_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labelR, logits=r_fc3))\n",
        "\n",
        "  # Compute the l look-ahead loss using sb-params updated w.r.t. l_fp_loss.\n",
        "  l_fc1 = tf.nn.relu(tf.matmul(base_out, l_params[0]) + l_params[1])\n",
        "  l_fc2 = tf.nn.relu(tf.matmul(l_fc1, l_params[2]) + l_params[3])\n",
        "  l_fc3 = tf.matmul(l_fc2, l_params[4]) + l_params[5]\n",
        "  l_la_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labelL, logits=l_fc3))\n",
        "  return l_la_loss, r_la_loss\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jSJ7yW8HX8IU"
      },
      "source": [
        "def train(params):\n",
        "  print(params)\n",
        "  global l_uncertainty\n",
        "  global r_uncertainty\n",
        "  l_uncertainty = None\n",
        "  l_uncertainty = None\n",
        "  LeBase = LeNetBase()\n",
        "  LeDigitR = LeNetTower()\n",
        "  LeDigitL = LeNetTower()\n",
        "  global_step = tf.Variable(0, trainable=False)\n",
        "  dataset = load_dataset(params.batch_size, FLAGS.dataset)\n",
        "\n",
        "  optimizer = tf.keras.optimizers.SGD(params.lr, momentum=0.9)\n",
        "  if 'pcgrad' in FLAGS.method:\n",
        "    lr_var = tf.Variable(params.lr)\n",
        "    optimizer = PCGrad(tf.compat.v1.train.MomentumOptimizer(lr_var, momentum=0.9))\n",
        "    uw_optimizer = tf.compat.v1.train.MomentumOptimizer(lr_var, momentum=0.9)\n",
        "\n",
        "  @tf.function()\n",
        "  def train_step(trainX, labelL, labelR, first_step=False):\n",
        "    rep = LeBase(trainX)\n",
        "    with tf.GradientTape(persistent=True) as tape:\n",
        "      rep = LeBase(trainX)\n",
        "      outL = LeDigitL(rep)\n",
        "      outL_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labelL, logits=outL)\n",
        "      outR = LeDigitR(rep)\n",
        "      outR_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labelR, logits=outR)\n",
        "      l_loss = (1. - params.alpha)*tf.reduce_mean(outL_loss)\n",
        "      r_loss = params.alpha*tf.reduce_mean(outR_loss)\n",
        "      l_params = [param for param in LeDigitL.trainable_weights]\n",
        "      r_params = [param for param in LeDigitR.trainable_weights]\n",
        "\n",
        "      l_fp_grads = tape.gradient(l_loss, LeDigitL.trainable_weights)\n",
        "      r_fp_grads = tape.gradient(r_loss, LeDigitR.trainable_weights)\n",
        "\n",
        "      optimizer.apply_gradients(zip(tape.gradient(l_loss + r_loss, LeBase.trainable_weights), LeBase.trainable_weights))\n",
        "      optimizer.apply_gradients(zip(l_fp_grads, LeDigitL.trainable_weights))\n",
        "      optimizer.apply_gradients(zip(r_fp_grads, LeDigitR.trainable_weights))\n",
        "\n",
        "    global_step.assign_add(1)\n",
        "    l_grad_gain = r_grad_gain = s_grad_gain = 0.\n",
        "    return tf.reduce_sum(outL_loss), tf.reduce_sum(outR_loss)\n",
        "\n",
        "  @tf.function()\n",
        "  def train_it_step(trainX, labelL, labelR, first_step=False):\n",
        "    rep = LeBase(trainX)\n",
        "    with tf.GradientTape(persistent=True) as tape:\n",
        "      rep = LeBase(trainX)\n",
        "      outL = LeDigitL(rep)\n",
        "      outL_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labelL, logits=outL)\n",
        "      outR = LeDigitR(rep)\n",
        "      outR_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labelR, logits=outR)\n",
        "      l_loss = (1. - params.alpha)*tf.reduce_mean(outL_loss)\n",
        "      l_fp_grads = tape.gradient(l_loss, LeDigitL.trainable_weights)\n",
        "      r_loss = params.alpha*tf.reduce_mean(outR_loss)\n",
        "      r_fp_grads = tape.gradient(r_loss, LeDigitR.trainable_weights)\n",
        "\n",
        "      if FLAGS.nesterov:\n",
        "        optimizer.apply_gradients(zip(l_fp_grads, LeDigitL.trainable_weights))\n",
        "        optimizer.apply_gradients(zip(r_fp_grads, LeDigitR.trainable_weights))\n",
        "\n",
        "      l_params = [param for param in LeDigitL.trainable_weights]\n",
        "      r_params = [param for param in LeDigitR.trainable_weights]\n",
        "\n",
        "      # Compute the left lookahead gradients.\n",
        "      l_base_gradients = tape.gradient(l_loss, LeBase.trainable_weights)\n",
        "      if first_step:\n",
        "        l_base_update = [optimizer.lr*grads for param, grads in zip(LeBase.trainable_weights, l_base_gradients)]\n",
        "        l_base_updated = [param - update for param, update in zip(LeBase.trainable_weights, l_base_update)]\n",
        "      else:\n",
        "        l_base_update = [(optimizer._momentum*optimizer.get_slot(param, 'momentum') - optimizer.lr*grads) for param, grads in zip(LeBase.trainable_weights, l_base_gradients)]\n",
        "        l_base_updated = [param + update for param, update in zip(LeBase.trainable_weights, l_base_update)]\n",
        "\n",
        "      l_la_loss, r_la_loss = base_step(trainX, labelL, labelR, l_base_updated, l_params, r_params)\n",
        "      l_grad_gain = tf.cast(1.0 - (r_la_loss / tf.reduce_mean(outR_loss)) + 1.0 - (l_la_loss / tf.reduce_mean(outL_loss)), tf.float32)\n",
        "      \n",
        "      # Compute the right lookahead gradients.\n",
        "      r_base_gradients = tape.gradient(r_loss, LeBase.trainable_weights)\n",
        "      if first_step:\n",
        "        r_base_update = [optimizer.lr*grads for param, grads in zip(LeBase.trainable_weights, r_base_gradients)]\n",
        "        r_base_updated = [param - update for param, update in zip(LeBase.trainable_weights, r_base_update)]\n",
        "      else:\n",
        "        r_base_update = [(optimizer._momentum*optimizer.get_slot(param, 'momentum') - optimizer.lr*grads) for param, grads in zip(LeBase.trainable_weights, r_base_gradients)]\n",
        "        r_base_updated = [param + update for param, update in zip(LeBase.trainable_weights, r_base_update)]\n",
        "\n",
        "      l_la_loss, r_la_loss = base_step(trainX, labelL, labelR, r_base_updated, l_params, r_params)\n",
        "      r_grad_gain = tf.cast(1.0 - (r_la_loss / tf.reduce_mean(outR_loss)) + 1.0 - (l_la_loss / tf.reduce_mean(outL_loss)), tf.float32)\n",
        "\n",
        "      # Compute the shared lookahead gradients.\n",
        "      s_base_gradients = tape.gradient(r_loss + l_loss, LeBase.trainable_weights)\n",
        "      if first_step:\n",
        "        s_base_update = [optimizer.lr*grads for param, grads in zip(LeBase.trainable_weights, s_base_gradients)]\n",
        "        s_base_updated = [param - update for param, update in zip(LeBase.trainable_weights, s_base_update)]\n",
        "      else:\n",
        "        s_base_update = [(optimizer._momentum*optimizer.get_slot(param, 'momentum') - optimizer.lr*grads) for param, grads in zip(LeBase.trainable_weights, s_base_gradients)]\n",
        "        s_base_updated = [param + update for param, update in zip(LeBase.trainable_weights, s_base_update)]\n",
        "\n",
        "      l_la_loss, r_la_loss = base_step(trainX, labelL, labelR, s_base_updated, l_params, r_params)\n",
        "      s_grad_gain = tf.cast(1.0 - (r_la_loss / tf.reduce_mean(outR_loss)) + 1.0 - (l_la_loss / tf.reduce_mean(outL_loss)), tf.float32)\n",
        "\n",
        "      if s_grad_gain > l_grad_gain and s_grad_gain > r_grad_gain: \n",
        "        optimizer.apply_gradients(zip(tape.gradient(l_loss + r_loss, LeBase.trainable_weights), LeBase.trainable_weights))\n",
        "      elif l_grad_gain > s_grad_gain and l_grad_gain > r_grad_gain:\n",
        "        optimizer.apply_gradients(zip(tape.gradient(l_loss, LeBase.trainable_weights), LeBase.trainable_weights))\n",
        "      else:\n",
        "        optimizer.apply_gradients(zip(tape.gradient(r_loss, LeBase.trainable_weights), LeBase.trainable_weights))\n",
        "\n",
        "      # If nesterov is enabled, then we applied the update to the task-specific params earlier.\n",
        "      if not FLAGS.nesterov:\n",
        "        optimizer.apply_gradients(zip(l_fp_grads, LeDigitL.trainable_weights))\n",
        "        optimizer.apply_gradients(zip(r_fp_grads, LeDigitR.trainable_weights))\n",
        "\n",
        "    global_step.assign_add(1)\n",
        "    return tf.reduce_sum(outL_loss), tf.reduce_sum(outR_loss)\n",
        "\n",
        "  @tf.function()\n",
        "  def train_uncertainty_step(trainX, labelL, labelR):\n",
        "    with tf.GradientTape(persistent=True) as tape:\n",
        "      rep = LeBase(trainX)\n",
        "      outL = LeDigitL(rep)\n",
        "      outL_loss = tf.nn.softmax_cross_entropy_with_logits(\n",
        "          labels=labelL, logits=outL)\n",
        "      outR = LeDigitR(rep)\n",
        "      outR_loss = tf.nn.softmax_cross_entropy_with_logits(\n",
        "          labels=labelR, logits=outR)\n",
        "\n",
        "      l_loss = tf.reduce_mean(outL_loss)\n",
        "\n",
        "      global l_uncertainty\n",
        "      global r_uncertainty\n",
        "      if l_uncertainty is None:\n",
        "        l_uncertainty = tf.Variable(1.0)\n",
        "      if r_uncertainty is None:\n",
        "        r_uncertainty = tf.Variable(1.0)\n",
        "\n",
        "      l_clip_uncertainty = tf.clip_by_value(l_uncertainty, 0.01, 10.0)\n",
        "      l_loss = l_loss / tf.exp(2 * l_clip_uncertainty) + l_clip_uncertainty\n",
        "      r_loss = tf.reduce_mean(outR_loss)\n",
        "      r_clip_uncertainty = tf.clip_by_value(r_uncertainty, 0.01, 10.0)\n",
        "      r_loss = r_loss / tf.exp(2 * r_clip_uncertainty) + r_clip_uncertainty\n",
        "\n",
        "      loss = l_loss + r_loss\n",
        "\n",
        "      base_gradients = tape.gradient(loss, LeBase.trainable_weights)\n",
        "      l_digit_gradients = tape.gradient(loss, LeDigitL.trainable_weights)\n",
        "      r_digit_gradients = tape.gradient(loss, LeDigitR.trainable_weights)\n",
        "      uncertainty_gradients = tape.gradient(\n",
        "          loss, [l_uncertainty, r_uncertainty])\n",
        "\n",
        "      optimizer.apply_gradients(zip(base_gradients, LeBase.trainable_weights))\n",
        "      optimizer.apply_gradients(\n",
        "          zip(l_digit_gradients, LeDigitL.trainable_weights))\n",
        "      optimizer.apply_gradients(\n",
        "          zip(r_digit_gradients, LeDigitR.trainable_weights))\n",
        "      optimizer.apply_gradients(\n",
        "          zip(uncertainty_gradients, [l_uncertainty, r_uncertainty]))\n",
        "    return tf.reduce_sum(outL_loss), tf.reduce_sum(outR_loss)\n",
        "\n",
        "\n",
        "  @tf.function()\n",
        "  def train_it_uncertainty_step(trainX, labelL, labelR, first_step=False):\n",
        "    with tf.GradientTape(persistent=True) as tape:\n",
        "      rep = LeBase(trainX)\n",
        "      outL = LeDigitL(rep)\n",
        "      outL_loss = tf.nn.softmax_cross_entropy_with_logits(\n",
        "          labels=labelL, logits=outL)\n",
        "      outR = LeDigitR(rep)\n",
        "      outR_loss = tf.nn.softmax_cross_entropy_with_logits(\n",
        "          labels=labelR, logits=outR)\n",
        "\n",
        "      l_loss = tf.reduce_mean(outL_loss)\n",
        "\n",
        "      global l_uncertainty\n",
        "      global r_uncertainty\n",
        "      if l_uncertainty is None:\n",
        "        l_uncertainty = tf.Variable(1.0)\n",
        "      if r_uncertainty is None:\n",
        "        r_uncertainty = tf.Variable(1.0)\n",
        "\n",
        "      l_clip_uncertainty = tf.clip_by_value(l_uncertainty, 0.01, 10.0)\n",
        "      l_loss = l_loss / tf.exp(2 * l_clip_uncertainty) + l_clip_uncertainty\n",
        "      r_loss = tf.reduce_mean(outR_loss)\n",
        "      r_clip_uncertainty = tf.clip_by_value(r_uncertainty, 0.01, 10.0)\n",
        "      r_loss = r_loss / tf.exp(2 * r_clip_uncertainty) + r_clip_uncertainty\n",
        "\n",
        "      loss = l_loss + r_loss\n",
        "      l_fp_grads = tape.gradient(l_loss, LeDigitL.trainable_weights)\n",
        "      r_fp_grads = tape.gradient(r_loss, LeDigitR.trainable_weights)\n",
        "\n",
        "      if FLAGS.nesterov:\n",
        "        optimizer.apply_gradients(zip(l_fp_grads, LeDigitL.trainable_weights))\n",
        "        optimizer.apply_gradients(zip(r_fp_grads, LeDigitR.trainable_weights))\n",
        "\n",
        "      l_params = [param for param in LeDigitL.trainable_weights]\n",
        "      r_params = [param for param in LeDigitR.trainable_weights]\n",
        "\n",
        "      # Compute the left lookahead gradients.\n",
        "      l_base_gradients = tape.gradient(l_loss, LeBase.trainable_weights)\n",
        "      if first_step:\n",
        "        l_base_update = [optimizer.lr*grads for param, grads in zip(LeBase.trainable_weights, l_base_gradients)]\n",
        "        l_base_updated = [param - update for param, update in zip(LeBase.trainable_weights, l_base_update)]\n",
        "      else:\n",
        "        l_base_update = [(optimizer._momentum*optimizer.get_slot(param, 'momentum') - optimizer.lr*grads) for param, grads in zip(LeBase.trainable_weights, l_base_gradients)]\n",
        "        l_base_updated = [param + update for param, update in zip(LeBase.trainable_weights, l_base_update)]\n",
        "\n",
        "      l_la_loss, r_la_loss = base_step(trainX, labelL, labelR, l_base_updated, l_params, r_params)\n",
        "      l_grad_gain = tf.cast(1.0 - (r_la_loss / tf.reduce_mean(outR_loss)) + 1.0 - (l_la_loss / tf.reduce_mean(outL_loss)), tf.float32)\n",
        "      \n",
        "      # Compute the right lookahead gradients.\n",
        "      r_base_gradients = tape.gradient(r_loss, LeBase.trainable_weights)\n",
        "      if first_step:\n",
        "        r_base_update = [optimizer.lr*grads for param, grads in zip(LeBase.trainable_weights, r_base_gradients)]\n",
        "        r_base_updated = [param - update for param, update in zip(LeBase.trainable_weights, r_base_update)]\n",
        "      else:\n",
        "        r_base_update = [(optimizer._momentum*optimizer.get_slot(param, 'momentum') - optimizer.lr*grads) for param, grads in zip(LeBase.trainable_weights, r_base_gradients)]\n",
        "        r_base_updated = [param + update for param, update in zip(LeBase.trainable_weights, r_base_update)]\n",
        "\n",
        "      l_la_loss, r_la_loss = base_step(trainX, labelL, labelR, r_base_updated, l_params, r_params)\n",
        "      r_grad_gain = tf.cast(1.0 - (r_la_loss / tf.reduce_mean(outR_loss)) + 1.0 - (l_la_loss / tf.reduce_mean(outL_loss)), tf.float32)\n",
        "\n",
        "      # Compute the shared lookahead gradients.\n",
        "      s_base_gradients = tape.gradient(r_loss + l_loss, LeBase.trainable_weights)\n",
        "      if first_step:\n",
        "        s_base_update = [optimizer.lr*grads for param, grads in zip(LeBase.trainable_weights, s_base_gradients)]\n",
        "        s_base_updated = [param - update for param, update in zip(LeBase.trainable_weights, s_base_update)]\n",
        "      else:\n",
        "        s_base_update = [(optimizer._momentum*optimizer.get_slot(param, 'momentum') - optimizer.lr*grads) for param, grads in zip(LeBase.trainable_weights, s_base_gradients)]\n",
        "        s_base_updated = [param + update for param, update in zip(LeBase.trainable_weights, s_base_update)]\n",
        "\n",
        "      l_la_loss, r_la_loss = base_step(trainX, labelL, labelR, s_base_updated, l_params, r_params)\n",
        "      s_grad_gain = tf.cast(1.0 - (r_la_loss / tf.reduce_mean(outR_loss)) + 1.0 - (l_la_loss / tf.reduce_mean(outL_loss)), tf.float32)\n",
        "\n",
        "      if s_grad_gain > l_grad_gain and s_grad_gain > r_grad_gain: \n",
        "        optimizer.apply_gradients(zip(tape.gradient(l_loss + r_loss, LeBase.trainable_weights), LeBase.trainable_weights))\n",
        "      elif l_grad_gain > s_grad_gain and l_grad_gain > r_grad_gain:\n",
        "        optimizer.apply_gradients(zip(tape.gradient(l_loss, LeBase.trainable_weights), LeBase.trainable_weights))\n",
        "      else:\n",
        "        optimizer.apply_gradients(zip(tape.gradient(r_loss, LeBase.trainable_weights), LeBase.trainable_weights))\n",
        "\n",
        "      # If nesterov is enabled, then we applied the update to the task-specific params earlier.\n",
        "      if not FLAGS.nesterov:\n",
        "        optimizer.apply_gradients(zip(l_fp_grads, LeDigitL.trainable_weights))\n",
        "        optimizer.apply_gradients(zip(r_fp_grads, LeDigitR.trainable_weights))\n",
        "\n",
        "      uncertainty_gradients = tape.gradient(loss, [l_uncertainty, r_uncertainty])\n",
        "      optimizer.apply_gradients(zip(uncertainty_gradients, [l_uncertainty, r_uncertainty]))\n",
        "    return tf.reduce_sum(outL_loss), tf.reduce_sum(outR_loss)\n",
        "\n",
        "  @tf.function()\n",
        "  def train_gradnorm_step(trainX, labelL, labelR):\n",
        "    with tf.GradientTape(persistent=True) as tape:\n",
        "      rep = LeBase(trainX)\n",
        "      outL = LeDigitL(rep)\n",
        "      outL_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labelL, logits=outL)\n",
        "      outR = LeDigitR(rep)\n",
        "      outR_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labelR, logits=outR)\n",
        "\n",
        "      global l_weight\n",
        "      global r_weight\n",
        "      if l_weight is None:\n",
        "        l_weight = tf.Variable(1.)\n",
        "      if r_weight is None:\n",
        "        r_weight = tf.Variable(1.)\n",
        "\n",
        "      l_loss = tf.reduce_mean(outL_loss) * l_weight\n",
        "      r_loss = tf.reduce_mean(outR_loss) * r_weight\n",
        "      loss = l_loss + r_loss\n",
        "      l_gradnorm = tf.norm(tape.gradient(l_loss, LeBase.trainable_weights[-2]), ord=2)\n",
        "      r_gradnorm = tf.norm(tape.gradient(r_loss, LeBase.trainable_weights[-2]), ord=2)\n",
        "      gradnorm = (l_gradnorm + r_gradnorm) / 2.\n",
        "\n",
        "      global l_l0_loss\n",
        "      global r_l0_loss\n",
        "      if l_l0_loss is None:\n",
        "        l_l0_loss = tf.Variable(tf.reduce_mean(outL_loss))\n",
        "      if r_l0_loss is None:\n",
        "        r_l0_loss = tf.Variable(tf.reduce_mean(outR_loss))\n",
        "      l_li_loss = tf.reduce_mean(outL_loss) / l_l0_loss\n",
        "      r_li_loss = tf.reduce_mean(outR_loss) / r_l0_loss\n",
        "      li_expected = (l_li_loss + r_li_loss) / 2.\n",
        "      l_ri_loss = tf.math.pow(l_li_loss/li_expected, params.alpha)\n",
        "      r_ri_loss = tf.math.pow(r_li_loss/li_expected, params.alpha)\n",
        "      l_gradnorm_loss = tf.norm(l_gradnorm - tf.stop_gradient(gradnorm*l_ri_loss), ord=1)\n",
        "      r_gradnorm_loss = tf.norm(r_gradnorm - tf.stop_gradient(gradnorm*r_ri_loss), ord=1)\n",
        "      gradnorm_loss = l_gradnorm_loss + r_gradnorm_loss\n",
        "\n",
        "      l_weight_gradient = tape.gradient(gradnorm_loss, l_weight)\n",
        "      r_weight_gradient = tape.gradient(gradnorm_loss, r_weight)\n",
        "      optimizer.apply_gradients(zip([l_weight_gradient, r_weight_gradient], [l_weight, r_weight]))\n",
        "      scale = 2. / (l_weight + r_weight)\n",
        "      l_weight.assign(scale*l_weight)\n",
        "      r_weight.assign(scale*r_weight)\n",
        "\n",
        "      base_gradients = tape.gradient(loss, LeBase.trainable_weights)\n",
        "      l_digit_gradients = tape.gradient(loss, LeDigitL.trainable_weights)\n",
        "      r_digit_gradients = tape.gradient(loss, LeDigitR.trainable_weights)\n",
        "\n",
        "      optimizer.apply_gradients(zip(base_gradients, LeBase.trainable_weights))\n",
        "      optimizer.apply_gradients(zip(l_digit_gradients, LeDigitL.trainable_weights))\n",
        "      optimizer.apply_gradients(zip(r_digit_gradients, LeDigitR.trainable_weights))\n",
        "    return tf.reduce_sum(outL_loss), tf.reduce_sum(outR_loss)\n",
        "\n",
        "  @tf.function()\n",
        "  def train_it_gradnorm_step(trainX, labelL, labelR):\n",
        "    with tf.GradientTape(persistent=True) as tape:\n",
        "      rep = LeBase(trainX)\n",
        "      outL = LeDigitL(rep)\n",
        "      outL_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labelL, logits=outL)\n",
        "      outR = LeDigitR(rep)\n",
        "      outR_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labelR, logits=outR)\n",
        "\n",
        "      global l_weight\n",
        "      global r_weight\n",
        "      if l_weight is None:\n",
        "        l_weight = tf.Variable(1.)\n",
        "      if r_weight is None:\n",
        "        r_weight = tf.Variable(1.)\n",
        "\n",
        "      l_loss = tf.reduce_mean(outL_loss) * l_weight\n",
        "      r_loss = tf.reduce_mean(outR_loss) * r_weight\n",
        "      loss = l_loss + r_loss\n",
        "      l_gradnorm = tf.norm(tape.gradient(l_loss, LeBase.trainable_weights[-2]), ord=2)\n",
        "      r_gradnorm = tf.norm(tape.gradient(r_loss, LeBase.trainable_weights[-2]), ord=2)\n",
        "      gradnorm = (l_gradnorm + r_gradnorm) / 2.\n",
        "\n",
        "      global l_l0_loss\n",
        "      global r_l0_loss\n",
        "      if l_l0_loss is None:\n",
        "        l_l0_loss = tf.Variable(tf.reduce_mean(outL_loss))\n",
        "      if r_l0_loss is None:\n",
        "        r_l0_loss = tf.Variable(tf.reduce_mean(outR_loss))\n",
        "      l_li_loss = tf.reduce_mean(outL_loss) / l_l0_loss\n",
        "      r_li_loss = tf.reduce_mean(outR_loss) / r_l0_loss\n",
        "      li_expected = (l_li_loss + r_li_loss) / 2.\n",
        "      l_ri_loss = tf.math.pow(l_li_loss/li_expected, params.alpha)\n",
        "      r_ri_loss = tf.math.pow(r_li_loss/li_expected, params.alpha)\n",
        "      l_gradnorm_loss = tf.norm(l_gradnorm - tf.stop_gradient(gradnorm*l_ri_loss), ord=1)\n",
        "      r_gradnorm_loss = tf.norm(r_gradnorm - tf.stop_gradient(gradnorm*r_ri_loss), ord=1)\n",
        "      gradnorm_loss = l_gradnorm_loss + r_gradnorm_loss\n",
        "\n",
        "      l_weight_gradient = tape.gradient(gradnorm_loss, l_weight)\n",
        "      r_weight_gradient = tape.gradient(gradnorm_loss, r_weight)\n",
        "      optimizer.apply_gradients(zip([l_weight_gradient, r_weight_gradient], [l_weight, r_weight]))\n",
        "      scale = 2. / (l_weight + r_weight)\n",
        "      l_weight.assign(scale*l_weight)\n",
        "      r_weight.assign(scale*r_weight)\n",
        "\n",
        "      l_fp_grads = tape.gradient(l_loss, LeDigitL.trainable_weights)\n",
        "      r_fp_grads = tape.gradient(r_loss, LeDigitR.trainable_weights)\n",
        "      if FLAGS.nesterov:\n",
        "        optimizer.apply_gradients(zip(l_fp_grads, LeDigitL.trainable_weights))\n",
        "        optimizer.apply_gradients(zip(r_fp_grads, LeDigitR.trainable_weights))\n",
        "\n",
        "      l_params = [param for param in LeDigitL.trainable_weights]\n",
        "      r_params = [param for param in LeDigitR.trainable_weights]\n",
        "\n",
        "      # Compute the left lookahead gradients.\n",
        "      l_base_gradients = tape.gradient(l_loss, LeBase.trainable_weights)\n",
        "      if first_step:\n",
        "        l_base_update = [optimizer.lr*grads for param, grads in zip(LeBase.trainable_weights, l_base_gradients)]\n",
        "        l_base_updated = [param - update for param, update in zip(LeBase.trainable_weights, l_base_update)]\n",
        "      else:\n",
        "        l_base_update = [(optimizer._momentum*optimizer.get_slot(param, 'momentum') - optimizer.lr*grads) for param, grads in zip(LeBase.trainable_weights, l_base_gradients)]\n",
        "        l_base_updated = [param + update for param, update in zip(LeBase.trainable_weights, l_base_update)]\n",
        "\n",
        "      l_la_loss, r_la_loss = base_step(trainX, labelL, labelR, l_base_updated, l_params, r_params)\n",
        "      l_grad_gain = tf.cast(1.0 - (r_la_loss / tf.reduce_mean(outR_loss)) + 1.0 - (l_la_loss / tf.reduce_mean(outL_loss)), tf.float32)\n",
        "      \n",
        "      # Compute the right lookahead gradients.\n",
        "      r_base_gradients = tape.gradient(r_loss, LeBase.trainable_weights)\n",
        "      if first_step:\n",
        "        r_base_update = [optimizer.lr*grads for param, grads in zip(LeBase.trainable_weights, r_base_gradients)]\n",
        "        r_base_updated = [param - update for param, update in zip(LeBase.trainable_weights, r_base_update)]\n",
        "      else:\n",
        "        r_base_update = [(optimizer._momentum*optimizer.get_slot(param, 'momentum') - optimizer.lr*grads) for param, grads in zip(LeBase.trainable_weights, r_base_gradients)]\n",
        "        r_base_updated = [param + update for param, update in zip(LeBase.trainable_weights, r_base_update)]\n",
        "\n",
        "      l_la_loss, r_la_loss = base_step(trainX, labelL, labelR, r_base_updated, l_params, r_params)\n",
        "      r_grad_gain = tf.cast(1.0 - (r_la_loss / tf.reduce_mean(outR_loss)) + 1.0 - (l_la_loss / tf.reduce_mean(outL_loss)), tf.float32)\n",
        "\n",
        "      # Compute the shared lookahead gradients.\n",
        "      s_base_gradients = tape.gradient(r_loss + l_loss, LeBase.trainable_weights)\n",
        "      if first_step:\n",
        "        s_base_update = [optimizer.lr*grads for param, grads in zip(LeBase.trainable_weights, s_base_gradients)]\n",
        "        s_base_updated = [param - update for param, update in zip(LeBase.trainable_weights, s_base_update)]\n",
        "      else:\n",
        "        s_base_update = [(optimizer._momentum*optimizer.get_slot(param, 'momentum') - optimizer.lr*grads) for param, grads in zip(LeBase.trainable_weights, s_base_gradients)]\n",
        "        s_base_updated = [param + update for param, update in zip(LeBase.trainable_weights, s_base_update)]\n",
        "\n",
        "      l_la_loss, r_la_loss = base_step(trainX, labelL, labelR, s_base_updated, l_params, r_params)\n",
        "      s_grad_gain = tf.cast(1.0 - (r_la_loss / tf.reduce_mean(outR_loss)) + 1.0 - (l_la_loss / tf.reduce_mean(outL_loss)), tf.float32)\n",
        "\n",
        "      if s_grad_gain > l_grad_gain and s_grad_gain > r_grad_gain: \n",
        "        optimizer.apply_gradients(zip(tape.gradient(l_loss + r_loss, LeBase.trainable_weights), LeBase.trainable_weights))\n",
        "      elif l_grad_gain > s_grad_gain and l_grad_gain > r_grad_gain:\n",
        "        optimizer.apply_gradients(zip(tape.gradient(l_loss, LeBase.trainable_weights), LeBase.trainable_weights))\n",
        "      else:\n",
        "        optimizer.apply_gradients(zip(tape.gradient(r_loss, LeBase.trainable_weights), LeBase.trainable_weights))\n",
        "\n",
        "      # If nesterov is enabled, then we applied the update to the task-specific params earlier.\n",
        "      if not FLAGS.nesterov:\n",
        "        optimizer.apply_gradients(zip(l_fp_grads, LeDigitL.trainable_weights))\n",
        "        optimizer.apply_gradients(zip(r_fp_grads, LeDigitR.trainable_weights))\n",
        "    return tf.reduce_sum(outL_loss), tf.reduce_sum(outR_loss)\n",
        "\n",
        "  @tf.function()\n",
        "  def train_pcgrad_step(trainX, labelL, labelR):\n",
        "    rep = LeBase(trainX)\n",
        "    outL = LeDigitL(rep)\n",
        "    outL_loss = tf.nn.softmax_cross_entropy_with_logits(\n",
        "        labels=labelL, logits=outL)\n",
        "    outR = LeDigitR(rep)\n",
        "    outR_loss = tf.nn.softmax_cross_entropy_with_logits(\n",
        "        labels=labelR, logits=outR)\n",
        "\n",
        "    l_loss = (1. - params.alpha)*tf.reduce_mean(outL_loss)\n",
        "    r_loss = params.alpha*tf.reduce_mean(outR_loss)\n",
        "    loss =  l_loss + r_loss\n",
        "\n",
        "    base_gradvars = optimizer.compute_gradients([l_loss, r_loss], LeBase.trainable_weights)\n",
        "    l_digit_gradvars = optimizer.compute_gradients([l_loss], LeDigitL.trainable_weights)\n",
        "    r_digit_gradvars = optimizer.compute_gradients([r_loss], LeDigitR.trainable_weights)\n",
        "\n",
        "    optimizer.apply_gradients(base_gradvars)\n",
        "    optimizer.apply_gradients(l_digit_gradvars)\n",
        "    optimizer.apply_gradients(r_digit_gradvars)\n",
        "    return tf.reduce_sum(outL_loss), tf.reduce_sum(outR_loss)\n",
        "    \n",
        "  @tf.function()\n",
        "  def train_uncertainty_pcgrad_step(trainX, labelL, labelR):\n",
        "    global l_uncertainty\n",
        "    global r_uncertainty\n",
        "    if l_uncertainty is None:\n",
        "      l_uncertainty = tf.Variable(1.0)\n",
        "    if r_uncertainty is None:\n",
        "      r_uncertainty = tf.Variable(1.0)\n",
        "\n",
        "    rep = LeBase(trainX)\n",
        "    outL = LeDigitL(rep)\n",
        "    outL_loss = tf.nn.softmax_cross_entropy_with_logits(\n",
        "        labels=labelL, logits=outL)\n",
        "    outR = LeDigitR(rep)\n",
        "    outR_loss = tf.nn.softmax_cross_entropy_with_logits(\n",
        "        labels=labelR, logits=outR)\n",
        "\n",
        "    l_loss = tf.reduce_mean(outL_loss)\n",
        "    r_loss = tf.reduce_mean(outR_loss)\n",
        "\n",
        "    l_clip_uncertainty = tf.clip_by_value(l_uncertainty, 0.01, 10.0)\n",
        "    l_loss = l_loss / tf.exp(2 * l_clip_uncertainty) + l_clip_uncertainty\n",
        "    r_clip_uncertainty = tf.clip_by_value(r_uncertainty, 0.01, 10.0)\n",
        "    r_loss = r_loss / tf.exp(2 * r_clip_uncertainty) + r_clip_uncertainty\n",
        "\n",
        "    base_gradvars = optimizer.compute_gradients([l_loss, r_loss], LeBase.trainable_weights)\n",
        "    l_digit_gradvars = optimizer.compute_gradients([l_loss], LeDigitL.trainable_weights)\n",
        "    r_digit_gradvars = optimizer.compute_gradients([r_loss], LeDigitR.trainable_weights)\n",
        "    l_uncertainty_gradvars = uw_optimizer.compute_gradients(l_loss, l_uncertainty)\n",
        "    r_uncertainty_gradvars = uw_optimizer.compute_gradients(r_loss, r_uncertainty)\n",
        "\n",
        "    optimizer.apply_gradients(base_gradvars)\n",
        "    optimizer.apply_gradients(l_digit_gradvars)\n",
        "    optimizer.apply_gradients(r_digit_gradvars)\n",
        "    optimizer.apply_gradients(l_uncertainty_gradvars)\n",
        "    optimizer.apply_gradients(r_uncertainty_gradvars)\n",
        "    return tf.reduce_sum(outL_loss), tf.reduce_sum(outR_loss)\n",
        "\n",
        "  def train_mgda_step(trainX, labelL, labelR, first_step=False):\n",
        "    with tf.GradientTape(persistent=True) as tape:\n",
        "      rep = LeBase(trainX)\n",
        "      outL = LeDigitL(rep)\n",
        "      outL_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labelL, logits=outL)\n",
        "      outR = LeDigitR(rep)\n",
        "      outR_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labelR, logits=outR)\n",
        "      l_loss = tf.reduce_mean(outL_loss)\n",
        "      r_loss = tf.reduce_mean(outR_loss)\n",
        "\n",
        "      # Update task-specific parameters.\n",
        "      optimizer.apply_gradients(zip(tape.gradient(l_loss, LeDigitL.trainable_weights), LeDigitL.trainable_weights))\n",
        "      optimizer.apply_gradients(zip(tape.gradient(r_loss, LeDigitR.trainable_weights), LeDigitR.trainable_weights))\n",
        "\n",
        "      # Compute gradient with respect to rep.\n",
        "      l_rep_grads = tape.gradient(l_loss, rep)\n",
        "      r_rep_grads = tape.gradient(r_loss, rep)\n",
        "\n",
        "    # Scaled back propagation.\n",
        "    alphas = min_norm_solver([l_rep_grads, r_rep_grads], {})\n",
        "    l_grad_updates = [alphas[0]*grad for grad in tape.gradient(l_loss, LeBase.trainable_weights)]\n",
        "    r_grad_updates = [alphas[1]*grad for grad in tape.gradient(r_loss, LeBase.trainable_weights)]\n",
        "    s_grad_updates = [l_grad_update + r_grad_update for l_grad_update, r_grad_update in zip(l_grad_updates, r_grad_updates)]\n",
        "    optimizer.apply_gradients(zip(s_grad_updates, LeBase.trainable_weights))\n",
        "\n",
        "    global_step.assign_add(1)\n",
        "    return tf.reduce_sum(outL_loss), tf.reduce_sum(outR_loss)\n",
        "\n",
        "  @tf.function()\n",
        "  def eval_step(evalX, evalYL, evalYR):\n",
        "    rep = LeBase(evalX)\n",
        "    outL = LeDigitL(rep)\n",
        "    outR = LeDigitR(rep)\n",
        "    l_loss = tf.reduce_sum(\n",
        "        tf.nn.softmax_cross_entropy_with_logits(labels=evalYL, logits=outL))\n",
        "    r_loss = tf.reduce_sum(\n",
        "        tf.nn.softmax_cross_entropy_with_logits(labels=evalYR, logits=outR))\n",
        "\n",
        "    l_pred = tf.math.argmax(outL, axis=1, output_type=tf.dtypes.int32)\n",
        "    r_pred = tf.math.argmax(outR, axis=1, output_type=tf.dtypes.int32)\n",
        "    l_acc = tf.math.count_nonzero(\n",
        "        tf.equal(l_pred,\n",
        "                 tf.math.argmax(evalYL, axis=1, output_type=tf.dtypes.int32)))\n",
        "    r_acc = tf.math.count_nonzero(\n",
        "        tf.equal(r_pred,\n",
        "                 tf.math.argmax(evalYR, axis=1, output_type=tf.dtypes.int32)))\n",
        "    Eval = namedtuple('Eval', ['l_loss', 'r_loss', 'l_acc', 'r_acc'])\n",
        "    return Eval(l_loss, r_loss, l_acc, r_acc)\n",
        "\n",
        "  eval_metrics = {'l_loss': [], 'r_loss': [], 'l_acc': [], 'r_acc': []}\n",
        "  final_eval = {'l_loss': [], 'r_loss': [], 'l_acc': [], 'r_acc': []}\n",
        "  for step in range(FLAGS.steps):\n",
        "    print('epoch: {}'.format(step))\n",
        "    if \"pcgrad\" not in FLAGS.method:\n",
        "      decay_lr(step, optimizer) # Halve the learning rate every 30 steps.\n",
        "    else:\n",
        "      decay_pcgrad_lr(step, lr_var)\n",
        "    epoch_l_loss = epoch_r_loss = batch_l_loss = batch_r_loss = 0.\n",
        "    for trainX, trainYL, trainYR in dataset.train:\n",
        "      if FLAGS.method == 'mtl':\n",
        "        batch_l_loss, batch_r_loss = train_step(trainX, trainYL, trainYR, first_step=(len(optimizer.variables()) == 0))\n",
        "      elif FLAGS.method == 'it_mtl':\n",
        "        batch_l_loss, batch_r_loss = train_it_step(trainX, trainYL, trainYR, first_step=(len(optimizer.variables()) == 0))\n",
        "      elif FLAGS.method == 'uncertainty':\n",
        "        batch_l_loss, batch_r_loss = train_uncertainty_step(\n",
        "            trainX, trainYL, trainYR)\n",
        "      elif FLAGS.method == 'it_uncertainty':\n",
        "        batch_l_loss, batch_r_loss = train_it_uncertainty_step(\n",
        "            trainX, trainYL, trainYR, first_step=(len(optimizer.variables()) == 0))\n",
        "      elif FLAGS.method == 'gradnorm':\n",
        "        batch_l_loss, batch_r_loss = train_gradnorm_step(\n",
        "            trainX, trainYL, trainYR)\n",
        "      elif FLAGS.method == 'it_gradnorm':\n",
        "        batch_l_loss, batch_r_loss = train_gradnorm_step(\n",
        "            trainX, trainYL, trainYR)\n",
        "      elif FLAGS.method == 'pcgrad':\n",
        "        batch_l_loss, batch_r_loss = train_pcgrad_step(\n",
        "            trainX, trainYL, trainYR)\n",
        "      elif FLAGS.method == 'mgda':\n",
        "        batch_l_loss, batch_r_loss = train_mgda_step(trainX, trainYL, trainYR, first_step=(len(optimizer.variables()) == 0))\n",
        "      elif FLAGS.method == 'uncertainty_pcgrad':\n",
        "        batch_l_loss, batch_r_loss = train_uncertainty_pcgrad_step(\n",
        "            trainX, trainYL, trainYR)\n",
        "      else:\n",
        "        raise Exception('Unknown method chosen.')\n",
        "      epoch_l_loss += batch_l_loss / TRAIN_DATASET_SIZE\n",
        "      epoch_r_loss += batch_r_loss / TRAIN_DATASET_SIZE\n",
        "\n",
        "    print(\n",
        "        'total train loss: {:.4f} || left digit loss: {:.4f} || right digit loss: {:.4f}'\n",
        "        .format(((epoch_l_loss + epoch_r_loss)), epoch_l_loss, epoch_r_loss))\n",
        "\n",
        "    epoch_eval = {'l_loss': 0, 'r_loss': 0, 'l_acc': 0, 'r_acc': 0}\n",
        "    if FLAGS.eval_every_step or step == FLAGS.steps - 1:\n",
        "      if FLAGS.eval == 'valid':\n",
        "        eval_dataset = dataset.valid\n",
        "      elif FLAGS.eval == 'test':\n",
        "        eval_dataset = dataset.test\n",
        "      for evalX, evalYL, evalYR in eval_dataset:\n",
        "        eval_data = eval_step(evalX, evalYL, evalYR)\n",
        "        for key, val in zip(eval_data._fields, eval_data):\n",
        "          epoch_eval[key] += val\n",
        "      for key in epoch_eval:\n",
        "        epoch_eval[key] /= EVAL_DATASET_SIZE\n",
        "        add_average(eval_metrics[key], epoch_eval[key], METRICS_AVERAGE)\n",
        "        epoch_eval[key] = np.mean(eval_metrics[key])\n",
        "        final_eval[key].append(epoch_eval[key])\n",
        "      print('total eval loss: {:.4f}'.format(epoch_eval['l_loss'] +\n",
        "                                                      epoch_eval['r_loss']))\n",
        "      print(\n",
        "          'left digit accuracy: {:.2f} || right digit accuracy: {:.2f}'.format(\n",
        "              100.*epoch_eval['l_acc'], 100.*epoch_eval['r_acc']))\n",
        "\n",
        "  epoch_eval['global_step'] = step\n",
        "  return final_eval"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cImVvSUgcGhD"
      },
      "source": [
        "Params = namedtuple(\"Params\", ['batch_size', 'lr', 'alpha'])\n",
        "params = Params(batch_size=256, lr=0.001, alpha=0.5)\n",
        "FLAGS.dataset = 'MultiFashion'\n",
        "FLAGS.method = 'pcgrad'\n",
        "FLAGS.steps = 100\n",
        "FLAGS.eval = 'test'\n",
        "FLAGS.eval_every_step = False\n",
        "FLAGS.nesterov = True"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "urok0MhIvUER"
      },
      "source": [
        "# %%capture\n",
        "# run the model\n",
        "tf.compat.v1.reset_default_graph()\n",
        "tf.compat.v1.enable_eager_execution()\n",
        "\n",
        "global l_uncertainty\n",
        "global r_uncertainty\n",
        "l_uncertainty = r_uncertainty = None\n",
        "\n",
        "global l_weight\n",
        "global r_weight\n",
        "l_weight = r_weight = None\n",
        "\n",
        "global l_l0_loss\n",
        "global r_l0_loss\n",
        "l_l0_loss = r_l0_loss = None\n",
        "\n",
        "test = train(params)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "U_-sSgXj959Y"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}