{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "ICLR 2021 CelebA.ipynb",
      "provenance": [
        {
          "file_id": "1OplIUXiVWHjmeXr8A0FMe4kH8NLN1JcD",
          "timestamp": 1601336479922
        },
        {
          "file_id": "1mvp25SrzH1w4f8qrAnTHFZXJHZkL_JuN",
          "timestamp": 1589231089092
        }
      ],
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      }
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6_jij6FTtcxn"
      },
      "source": [
        "### CelebA Experiments for Information Transfer In Multi-Task Learning ICLR 2021 Submission \n",
        "\n",
        "Licensed under the Apache License, Version 2.0\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r14fs3V2XLOl"
      },
      "source": [
        "import itertools\n",
        "import pickle\n",
        "import time\n",
        "import copy\n",
        "import math\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "from collections import namedtuple, OrderedDict\n",
        "from tqdm import tqdm\n",
        "\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras.initializers import glorot_uniform\n",
        "from tensorflow.keras.layers import Activation, Add, AveragePooling2D, BatchNormalization, Conv2D, Dense, Flatten, MaxPooling2D, Lambda\n",
        "import scipy.integrate as it\n",
        "\n",
        "from absl import flags"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ooXuDeQVe_MH",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1601344955777,
          "user_tz": 420,
          "elapsed": 13299,
          "user": {
            "displayName": "Christopher Fifty",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiYYrGujgM-yL4Ol5R-LyAzmYiPF4BIVzCq-g4=s64",
            "userId": "11483318465600365808"
          }
        },
        "outputId": "c61e9e29-2f49-4a39-f547-9f3c6451309c"
      },
      "source": [
        "def del_all_flags(FLAGS):\n",
        "    flags_dict = FLAGS._flags()    \n",
        "    keys_list = [keys for keys in flags_dict]    \n",
        "    for keys in keys_list:\n",
        "        FLAGS.__delattr__(keys)\n",
        "del_all_flags(flags.FLAGS)\n",
        "\n",
        "FLAGS = flags.FLAGS\n",
        "\n",
        "flags.DEFINE_integer('steps', 100, 'Number of epoch to train.')\n",
        "flags.DEFINE_integer('batch_size', 256, 'Number of examples in a minibatch.')\n",
        "flags.DEFINE_integer('order', -1, 'Order of permutations to consider.')\n",
        "flags.DEFINE_enum('eval', 'test', ['valid', 'test'], 'The eval dataset.')\n",
        "flags.DEFINE_enum('method', 'mtl', ['mtl', 'smart_mtl'],'Multitask Training Method.')\n",
        "flags.DEFINE_bool('eval_every_step', True, \"Whether or not to run eval every step.\")\n",
        "flags.DEFINE_list('tasks', ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', \n",
        "                   'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',\n",
        "                   'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', \n",
        "                   'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', \n",
        "                   'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open',\n",
        "                   'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin',\n",
        "                   'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', \n",
        "                   'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', \n",
        "                   'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', \n",
        "                   'Wearing_Necktie', 'Young'], \"The attributes to predict in CelebA.\")"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<absl.flags._flagvalues.FlagHolder at 0x7f2881958898>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jTH-4Vc5OX-F"
      },
      "source": [
        "# Define some fun constants.\n",
        "SEED = 0\n",
        "METRICS_AVERAGE = 5\n",
        "EPSILON = 0.001\n",
        "TRAIN_SIZE = 162770\n",
        "VALID_SIZE = 19867\n",
        "TEST_SIZE = 19962"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nArMmed5iZmC"
      },
      "source": [
        "class ResBlock(tf.keras.Model):\n",
        "\n",
        "  def __init__(self, filters, kernel_size, strides, name):\n",
        "    super(ResBlock, self).__init__()\n",
        "    self.conv1 = Conv2D(\n",
        "        filters=filters[0],\n",
        "        kernel_size=kernel_size[0],\n",
        "        strides=strides,\n",
        "        name='conv{}_1'.format(name),\n",
        "        kernel_initializer=glorot_uniform(seed=SEED),\n",
        "        padding='same',\n",
        "        use_bias=False)\n",
        "    self.bn1 = BatchNormalization(axis=3, name='bn{}_1'.format(name))\n",
        "    self.conv2 = Conv2D(\n",
        "        filters=filters[1],\n",
        "        kernel_size=kernel_size[1],\n",
        "        strides=(1,1),\n",
        "        name='conv{}_2'.format(name),\n",
        "        kernel_initializer=glorot_uniform(seed=SEED),\n",
        "        padding='same',\n",
        "        use_bias=False)\n",
        "    self.bn2 = BatchNormalization(axis=3, name='bn{}_2'.format(name))\n",
        "\n",
        "    if strides == (1,1):\n",
        "      self.shortcut = Lambda(lambda x : x)\n",
        "    else:\n",
        "      self.shortcut = tf.keras.Sequential()\n",
        "      shortcut_conv = Conv2D(filters=filters[1], \n",
        "                             kernel_size=1, \n",
        "                             strides=(2,2), \n",
        "                             name='skip_conv{}_1'.format(name), \n",
        "                             kernel_initializer=glorot_uniform(seed=SEED),\n",
        "                             padding='valid',\n",
        "                             use_bias=False)\n",
        "      shortcut_bn = BatchNormalization(axis=3, name='skip_bn{}_1'.format(name))\n",
        "      self.shortcut.add(shortcut_conv)\n",
        "      self.shortcut.add(shortcut_bn)\n",
        "\n",
        "  def call(self, inputs):\n",
        "    x = inputs\n",
        "    x = Activation('relu')(self.bn1(self.conv1(x)))\n",
        "    x = self.bn2(self.conv2(x))\n",
        "    x = Add()([x, self.shortcut(inputs)])\n",
        "    return Activation('relu')(x)\n",
        "\n",
        "class ResNet18(tf.keras.Model):\n",
        "  def __init__(self):\n",
        "    super(ResNet18, self).__init__()\n",
        "    self.conv1_1 = Conv2D(\n",
        "        filters=64,\n",
        "        kernel_size=3,\n",
        "        strides=(1, 1),\n",
        "        name='conv1_1',\n",
        "        kernel_initializer=glorot_uniform(seed=SEED),\n",
        "        padding='same',\n",
        "        use_bias=False)\n",
        "    self.bn1_1 = BatchNormalization(axis=3, name='bn1_1')\n",
        "    self.resblock_2 = ResBlock([64, 64], [3, 3], (1, 1), '1')\n",
        "\n",
        "  def call(self, inputs):\n",
        "    x = inputs\n",
        "    x = Activation('relu')(self.bn1_1(self.conv1_1(x)))\n",
        "    x = MaxPooling2D((3,3), strides=(2,2))(x)\n",
        "    x = self.resblock_2(x)\n",
        "    return x\n",
        "\n",
        "class AttributeDecoder(tf.keras.Model):\n",
        "  def __init__(self):\n",
        "    super(AttributeDecoder, self).__init__()\n",
        "    self.fc_a = Dense(1000, kernel_initializer=glorot_uniform(seed=SEED))\n",
        "    self.fc_b = Dense(500, kernel_initializer=glorot_uniform(seed=SEED))\n",
        "    self.fc1 = Dense(2, kernel_initializer=glorot_uniform(seed=SEED))\n",
        "\n",
        "  def call(self, inputs):\n",
        "    x = inputs\n",
        "    x = AveragePooling2D((2,2), name='avg_pool')(x)\n",
        "    x = Flatten()(x)\n",
        "    x = self.fc_a(x)\n",
        "    x = self.fc_b(x)\n",
        "    x = self.fc1(x)\n",
        "    return x"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pf5H7GvHWVAN"
      },
      "source": [
        "def res_block_step(inputs, base_updated):\n",
        "  conv1 = tf.nn.conv2d(inputs, base_updated[0], strides=(2,2), padding=\"SAME\")\n",
        "  mean1, variance1 = tf.nn.moments(conv1, axes=[0,1,2])\n",
        "  gamma1, beta1 = base_updated[1], base_updated[2]\n",
        "  bn_conv1 = tf.nn.batch_normalization(conv1, mean1, variance1, offset=beta1, scale=gamma1, variance_epsilon=EPSILON)\n",
        "  relu1 = tf.nn.relu(bn_conv1)\n",
        "  \n",
        "  conv2 = tf.nn.conv2d(relu1, base_updated[3], strides=(1,1), padding=\"SAME\")\n",
        "  mean2, variance2 = tf.nn.moments(conv2, axes=[0,1,2])\n",
        "  gamma2, beta2 = base_updated[4], base_updated[5]\n",
        "  bn_conv2 = tf.nn.batch_normalization(conv2, mean2, variance2, offset=beta2, scale=gamma2, variance_epsilon=EPSILON)\n",
        "\n",
        "  skip_conv = tf.nn.conv2d(inputs, base_updated[6], strides=(2,2), padding=\"VALID\")\n",
        "  skip_mean, skip_variance = tf.nn.moments(skip_conv, axes=[0,1,2])\n",
        "  skip_gamma, skip_beta = base_updated[7], base_updated[8]\n",
        "  skip_bn = tf.nn.batch_normalization(skip_conv, skip_mean, skip_variance, offset=skip_beta, scale=skip_gamma, variance_epsilon=EPSILON)\n",
        "  \n",
        "  res_block = tf.nn.relu(bn_conv2 + skip_bn)\n",
        "  return res_block\n",
        "\n",
        "def base_step(inputs, base_updated):\n",
        "  # ResNet Block 1 Output.\n",
        "  conv1_1 = tf.nn.conv2d(inputs, base_updated[0], strides=(1,1), padding=\"SAME\")\n",
        "  mean1_1, variance1_1 = tf.nn.moments(conv1_1, axes=[0,1,2], keepdims=True) # normalize across the channel dimension for spacial batch norm.\n",
        "  gamma1_1, beta1_1 = base_updated[1], base_updated[2]\n",
        "  bn_conv1_1 = tf.nn.batch_normalization(conv1_1, mean1_1, variance1_1, offset=beta1_1, scale=gamma1_1, variance_epsilon=EPSILON)\n",
        "  res_block_1 = tf.nn.max_pool2d(tf.nn.relu(bn_conv1_1), ksize=[1,3,3,1], strides=[1,2,2,1], padding=\"VALID\")\n",
        "\n",
        "  # ResNet Block 2\n",
        "  conv2_1 = tf.nn.conv2d(res_block_1, base_updated[3], strides=(1,1), padding=\"SAME\")\n",
        "  mean2_1, variance2_1 = tf.nn.moments(conv2_1, axes=[0,1,2])\n",
        "  gamma2_1, beta2_1 = base_updated[4], base_updated[5]\n",
        "  bn_conv2_1 = tf.nn.batch_normalization(conv2_1, mean2_1, variance2_1, offset=beta2_1, scale=gamma2_1, variance_epsilon=EPSILON)\n",
        "  res_block2_1 = tf.nn.relu(bn_conv2_1)\n",
        "  \n",
        "  conv2_2 = tf.nn.conv2d(res_block2_1, base_updated[6], strides=(1,1), padding=\"SAME\")\n",
        "  mean2_2, variance2_2 = tf.nn.moments(conv2_2, axes=[0,1,2])\n",
        "  gamma2_2, beta2_2 = base_updated[7], base_updated[8]\n",
        "  bn_conv2_2 = tf.nn.batch_normalization(conv2_2, mean2_2, variance2_2, offset=beta2_2, scale=gamma2_2, variance_epsilon=EPSILON)\n",
        "  res_block_2 = tf.nn.relu(bn_conv2_2 + res_block_1)\n",
        "  \n",
        "  return res_block_2"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5yJVRK0cWaYK"
      },
      "source": [
        "def permute(losses):\n",
        "  \"\"\"Returns all permutations of losses in the loss dictionary.\"\"\"\n",
        "  losses = OrderedDict(sorted(losses.items()))\n",
        "  rtn = {}\n",
        "  for task,loss in losses.items():\n",
        "    tmp_dict = {task:loss}\n",
        "    for saved_task, saved_loss in rtn.items():\n",
        "      if FLAGS.order == 1:\n",
        "        continue # Skip higher than first-order permutations if order == 1.\n",
        "      new_task = \"{}_{}\".format(saved_task, task)\n",
        "      new_loss = loss + saved_loss \n",
        "      tmp_dict[new_task] = new_loss \n",
        "    rtn.update(tmp_dict)\n",
        "  \n",
        "  if FLAGS.order == 1:\n",
        "    rtn[\"_\".join(losses.keys())] = sum(losses.values())\n",
        "  return rtn\n",
        "\n",
        "\n",
        "def permute_list(lst):\n",
        "  \"\"\"Returns all permutations of tasks in the task list.\"\"\"\n",
        "  lst.sort()\n",
        "  rtn = []\n",
        "  for task in lst:\n",
        "    tmp_lst = [task]\n",
        "    for saved_task in rtn:\n",
        "      if FLAGS.order == 1:\n",
        "        continue\n",
        "      new_task = \"{}_{}\".format(saved_task, task)\n",
        "      tmp_lst.append(new_task)\n",
        "    rtn += tmp_lst\n",
        "  \n",
        "  if FLAGS.order == 1:\n",
        "    rtn.append(\"_\".join(lst))\n",
        "  return rtn\n",
        "\n",
        "def decay_lr(step, optimizer):\n",
        "  if (step + 1) % 30 == 0:\n",
        "    optimizer.lr = optimizer.lr / 2.\n",
        "    print('Decreasing the learning rate by 1/2. New Learning Rate: {}'.format(optimizer.lr))\n",
        "\n",
        "def decay_pcgrad_lr(step, lr_var):\n",
        "  if (step + 1) % 30 == 0:\n",
        "    lr_var.assign(lr_var / 2.)\n",
        "    print('Decreasing the learning rate by 1/2.')\n",
        "\n",
        "def add_average(lst, metrics_dict, n):\n",
        "  if len(lst) < n:\n",
        "    lst.append(metrics_dict)\n",
        "  elif len(lst) == n:\n",
        "    lst.pop(0)\n",
        "    lst.append(metrics_dict)\n",
        "  elif len(lst) > n:\n",
        "    raise Exception('List size is greater than n. This should never happen.')\n",
        "\n",
        "def compute_average(metrics_list, n):\n",
        "  if not metrics_list:\n",
        "    return {}\n",
        "  rtn = {task:0. for task in metrics_list[0]}\n",
        "  for metric in metrics_list:\n",
        "    for task in metric:\n",
        "      rtn[task] += metric[task] / float(n)\n",
        "  return rtn \n",
        "\n",
        "def load_dataset(batch_size):\n",
        "  train = tfds.load('celeb_a', split='train')\n",
        "  resized_train = train.map(\n",
        "      lambda d: (d['attributes'], tf.image.resize(tf.image.convert_image_dtype(d['image'], tf.float32), [64, 64])))\n",
        "  final_train = resized_train.shuffle(\n",
        "      buffer_size=TRAIN_SIZE, seed=SEED,\n",
        "      reshuffle_each_iteration=True).batch(batch_size)\n",
        "  # Save a bit of time by selectively loading the eval dataset\n",
        "  if FLAGS.eval == 'valid':\n",
        "    valid = tfds.load('celeb_a', split='validation')\n",
        "    resized_valid = valid.map(\n",
        "      lambda d: (d['attributes'], tf.image.resize(tf.image.convert_image_dtype(d['image'], tf.float32), [64, 64])))\n",
        "    final_valid = resized_valid.batch(batch_size)\n",
        "    final_test = None\n",
        "  elif FLAGS.eval == 'test':\n",
        "    final_valid = None \n",
        "    test = tfds.load('celeb_a', split='test')\n",
        "    resized_test = test.map(lambda d: (d['attributes'], tf.image.resize(tf.image.convert_image_dtype(d['image'], tf.float32), [64, 64])))\n",
        "    final_test = resized_test.batch(batch_size)\n",
        "\n",
        "  Dataset = namedtuple('Dataset', ['train', 'valid', 'test'])\n",
        "  return Dataset(final_train, final_valid, final_test)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jSJ7yW8HX8IU"
      },
      "source": [
        "def train(params):\n",
        "  print(params)\n",
        "\n",
        "  ResBase = ResNet18()\n",
        "  ResTowers = {task:AttributeDecoder() for task in FLAGS.tasks}\n",
        "\n",
        "  dataset = load_dataset(FLAGS.batch_size)\n",
        "  global_step = tf.Variable(0, trainable=False)\n",
        "  optimizer = tf.keras.optimizers.SGD(params.lr, momentum=0.9)\n",
        "\n",
        "  @tf.function()\n",
        "  def train_step(input, labels, first_step=False):\n",
        "    with tf.GradientTape(persistent=True) as tape:\n",
        "      rep = ResBase(input, training=True)\n",
        "      preds = {task:model(rep, training=True) for (task, model) in ResTowers.items()}\n",
        "      losses = {task: tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(\n",
        "                      labels=labels[task], \n",
        "                      logits=preds[task]))\n",
        "                for task in labels}\n",
        "      loss = tf.add_n(list(losses.values()))\n",
        "\n",
        "      # Compute the gradient of the task-specific loss w.r.t. the shared base.\n",
        "      task_gains = {}\n",
        "      task_permutations = permute(losses)\n",
        "      combined_task_gradients = [(combined_task, tape.gradient(task_permutations[combined_task], ResBase.trainable_weights)) for combined_task in task_permutations]\n",
        "\n",
        "    for combined_task, task_gradient in combined_task_gradients:\n",
        "      if first_step:\n",
        "        base_update = [optimizer.lr*grad for grad in task_gradient]\n",
        "        base_updated = [param - update for param,update in zip(ResBase.trainable_weights, base_update)]\n",
        "      else:\n",
        "        base_update = [(optimizer._momentum*optimizer.get_slot(param, 'momentum') - optimizer.lr*grad) \n",
        "                        for param, grad in zip(ResBase.trainable_weights, task_gradient)]\n",
        "        base_updated = [param + update for param, update in zip(ResBase.trainable_weights, base_update)]\n",
        "      task_update_rep = base_step(input, base_updated)\n",
        "      task_update_preds = {task:model(task_update_rep, training=True) for (task, model) in ResTowers.items()}\n",
        "      task_update_losses = {task: tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(\n",
        "                                  labels=labels[task], \n",
        "                                  logits=task_update_preds[task]))\n",
        "                            for task in labels}\n",
        "      task_gain = {task:(1.0 - task_update_losses[task]/losses[task])/optimizer.lr for task in FLAGS.tasks}\n",
        "      task_gains[combined_task] = task_gain\n",
        "\n",
        "    # DO NOT apply Nesterov in normal mtl training.\n",
        "    for task,model in ResTowers.items():\n",
        "      task_grads = tape.gradient(loss, model.trainable_weights)\n",
        "      optimizer.apply_gradients(zip(task_grads, model.trainable_weights))\n",
        "    \n",
        "    # Apply the traditional MTL update since this is a normal train step.\n",
        "    base_grads = tape.gradient(loss, ResBase.trainable_weights)\n",
        "    optimizer.apply_gradients(zip(base_grads, ResBase.trainable_weights))\n",
        "\n",
        "    global_step.assign_add(1)\n",
        "    return losses, task_gains\n",
        "\n",
        "  @tf.function()\n",
        "  def eval_step(input, labels):\n",
        "    rep = ResBase(input)\n",
        "    preds = {task:ResTowers[task](rep) for (task, model) in ResTowers.items()}\n",
        "    int_preds = {task:tf.math.argmax(preds[task], 1, tf.dtypes.int32) for task in labels}\n",
        "    int_labels = {task:tf.math.argmax(labels[task], 1, tf.dtypes.int32) for task in labels}\n",
        "    losses = {task: tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(\n",
        "                labels=tf.cast(labels[task], tf.float32), \n",
        "                logits=preds[task])) \n",
        "              for task in labels}\n",
        "    accuracies = {task:tf.math.count_nonzero(tf.equal(int_preds[task], int_labels[task])) for task in labels}\n",
        "    Eval = namedtuple('Eval', ['losses', 'accuracies'])\n",
        "    return Eval(losses, accuracies)\n",
        "\n",
        "  # Training Loop.\n",
        "  metrics = {'train_loss': [], 'eval_loss': [], 'eval_acc': []}\n",
        "  gradient_metrics = {task:[] for task in permute_list(FLAGS.tasks)}\n",
        "  final_metrics = {'train_loss': [], 'eval_loss': [], 'eval_acc': []}\n",
        "  smart_metrics = {task:[0 for _ in range(FLAGS.steps)] for task in permute_list(FLAGS.tasks)}\n",
        "\n",
        "  for step in range(FLAGS.steps):\n",
        "    print('epoch: {}'.format(step))\n",
        "    decay_lr(step, optimizer) # Halve the learning rate every 30 steps.\n",
        "    batch_train_loss = {task:0. for task in FLAGS.tasks}\n",
        "    batch_grad_metrics = {combined_task:{task:0. for task in FLAGS.tasks} for combined_task in gradient_metrics}\n",
        "    for labels, img in dataset.train:\n",
        "      labels = {task:tf.keras.utils.to_categorical(labels[task], num_classes=2) for task in labels if task in FLAGS.tasks}\n",
        "      if FLAGS.method == 'mtl':\n",
        "         losses, task_gains = train_step(img, labels, first_step=(len(optimizer.variables()) == 0))\n",
        "      else:\n",
        "        raise Exception(\"Unrecognized Method Selected.\")\n",
        "      \n",
        "      # Record batch-level training and gradient metrics. \n",
        "      for combined_task,task_gain_map in task_gains.items():\n",
        "        for task,gain in task_gain_map.items():\n",
        "          batch_grad_metrics[combined_task][task] += gain.numpy() / (math.ceil(TRAIN_SIZE / FLAGS.batch_size))\n",
        "      for task,loss in losses.items():\n",
        "        batch_train_loss[task] += loss.numpy() / (math.ceil(TRAIN_SIZE / FLAGS.batch_size))\n",
        "\n",
        "    # Record epoch-level training and gradient metrics.\n",
        "    add_average(metrics['train_loss'], batch_train_loss, METRICS_AVERAGE)\n",
        "    for combined_task,task_gain_map in batch_grad_metrics.items():\n",
        "      gradient_metrics[combined_task].append(task_gain_map)\n",
        "\n",
        "    if FLAGS.eval_every_step or step == FLAGS.steps - 1:\n",
        "      batch_eval_loss = {task:0. for task in FLAGS.tasks}\n",
        "      batch_eval_acc = {task:0. for task in FLAGS.tasks}\n",
        "      for labels, img in dataset.test if FLAGS.eval == 'test' else dataset.valid:\n",
        "        labels = {task:tf.keras.utils.to_categorical(labels[task], num_classes=2) for task in labels if task in FLAGS.tasks}\n",
        "        eval_metrics = eval_step(img, labels)\n",
        "        for task in FLAGS.tasks:\n",
        "          EVAL_SIZE = TEST_SIZE if FLAGS.eval == 'test' else VALID_SIZE\n",
        "          batch_eval_loss[task] += eval_metrics.losses[task].numpy() / (math.ceil(EVAL_SIZE / FLAGS.batch_size))\n",
        "          batch_eval_acc[task] += eval_metrics.accuracies[task].numpy() / EVAL_SIZE\n",
        "      add_average(metrics['eval_loss'], batch_eval_loss, METRICS_AVERAGE)\n",
        "      add_average(metrics['eval_acc'], batch_eval_acc, METRICS_AVERAGE)\n",
        "\n",
        "    for metric in metrics:\n",
        "      final_metrics[metric].append(compute_average(metrics[metric], METRICS_AVERAGE))\n",
        "\n",
        "    print_train_loss = \"\\n\".join([\"{}: {:.4f}\".format(task, metric) for task, metric in final_metrics['train_loss'][-1].items()])\n",
        "    print(\"Train Loss:\\n{}\\n\".format(print_train_loss))\n",
        "\n",
        "    print(\"grad metrics for fun: {}\".format(gradient_metrics))\n",
        "    print(\"smart metrics for fun: {}\".format(smart_metrics))\n",
        "\n",
        "    if FLAGS.eval_every_step or step == FLAGS.steps - 1:\n",
        "      print_eval_loss = \"\\n\".join([\"{}: {:.4f}\".format(task, metric) for task, metric in final_metrics['eval_loss'][-1].items()])\n",
        "      print(\"Eval Loss:\\n{}\\n\".format(print_eval_loss))\n",
        "      print_eval_acc = \"\\n\".join([\"{}: {:.2f}\".format(task, 100.0*metric) for task, metric in final_metrics['eval_acc'][-1].items()])\n",
        "      print(\"Eval Accuracy:\\n{}\\n\".format(print_eval_acc))\n",
        "      print(\"\\n-------------\\n\")\n",
        "\n",
        "  return final_metrics, gradient_metrics, smart_metrics"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cImVvSUgcGhD"
      },
      "source": [
        "Params = namedtuple(\"Params\", ['lr'])\n",
        "params = Params(lr=0.001)\n",
        "FLAGS.steps = 40\n",
        "FLAGS.batch_size = 256\n",
        "FLAGS.eval = 'test'\n",
        "FLAGS.method = 'mtl'\n",
        "FLAGS.order = 1\n",
        "FLAGS.eval_every_step = True\n",
        "FLAGS.tasks = ['5_o_Clock_Shadow', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Goatee', 'Mustache', 'No_Beard', 'Rosy_Cheeks', 'Wearing_Hat'] # 9 out of 40 attributes."
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "urok0MhIvUER"
      },
      "source": [
        "# %%capture\n",
        "# run the model 1 time\n",
        "tf.compat.v1.reset_default_graph()\n",
        "eval_metrics, gradient_metrics, smart_metrics = train(params)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "796UoYFkxtXt"
      },
      "source": [
        "# Print out the best total accuracy.\n",
        "max_acc = 0\n",
        "best_eval_acc = {}\n",
        "step = -1\n",
        "for index,eval_acc_dict in enumerate(eval_metrics['eval_acc']):\n",
        "  eval_acc = sum(eval_acc_dict.values())\n",
        "  if eval_acc > max_acc:\n",
        "    best_eval_acc = eval_acc_dict \n",
        "    step = index\n",
        "    max_acc = eval_acc\n",
        "print(\"step: {}\".format(step))\n",
        "print(best_eval_acc)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NtwY_fFHvFjt"
      },
      "source": [
        "plt.rcParams['figure.figsize'] = [10,10]\n",
        "x_axis = np.arange(FLAGS.steps)\n",
        "task_train_losses = {task:[] for task in FLAGS.tasks}\n",
        "for metrics in eval_metrics['train_loss']:\n",
        "  for task,val in metrics.items():\n",
        "    task_train_losses[task].append(val)\n",
        "\n",
        "for task in task_train_losses:\n",
        "  plt.plot(x_axis, task_train_losses[task], label=\"TRAIN LOSS ({})\".format(task))\n",
        "\n",
        "task_eval_losses = {task:[] for task in FLAGS.tasks}\n",
        "for metrics in eval_metrics['eval_loss']:\n",
        "  for task,val in metrics.items():\n",
        "    task_eval_losses[task].append(val)\n",
        "\n",
        "for task in task_eval_losses:\n",
        "  plt.plot(x_axis, task_eval_losses[task], label=\"EVAL LOSS ({})\".format(task))\n",
        "\n",
        "plt.legend(loc='upper right')\n",
        "plt.xlabel('epoch')\n",
        "plt.ylabel('cross entropy loss')\n",
        "plt.title('Comparison of Train and Eval Loss')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VR8b1fT7UBJt"
      },
      "source": [
        "plt.rcParams['figure.figsize'] = [10,10]\n",
        "x_axis = np.arange(FLAGS.steps)\n",
        "task_eval_acc = {task:[] for task in FLAGS.tasks}\n",
        "for metrics in eval_metrics['eval_acc']:\n",
        "  for task,val in metrics.items():\n",
        "    task_eval_acc[task].append(val)\n",
        "\n",
        "for task in task_eval_acc:\n",
        "  plt.plot(x_axis, task_eval_acc[task], label=\"EVAL ACC ({})\".format(task))\n",
        "\n",
        "plt.legend(loc='best')\n",
        "plt.xlabel('epoch')\n",
        "plt.ylabel('accuracy')\n",
        "plt.title('Eval Accuracy')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "77gR_mkqggTP"
      },
      "source": [
        "from matplotlib.patches import Patch\n",
        "from matplotlib.lines import Line2D\n",
        "\n",
        "legend_elements = []\n",
        "color_legend_set = {}\n",
        "marker_legend_set = {}\n",
        "integrals = {}\n",
        "\n",
        "[Line2D([0], [0], marker='o', color='k', lw=4, label='Line'),\n",
        "                   Line2D([0], [0], marker='o', color='w', label='Scatter',\n",
        "                          markerfacecolor='g', markersize=15),\n",
        "                   Patch(facecolor='orange', edgecolor='r',\n",
        "                         label='Color Patch')]\n",
        "\n",
        "plt.rcParams['figure.figsize'] = [15,15]\n",
        "x_axis = np.arange(FLAGS.steps)\n",
        "markers = ['*', '^', 'o', 'v', '2', 'X', 'D', '*', '^', 'o', 'v', '2', 'X', 'D']\n",
        "colors = ['b', 'g', 'r', 'orange', 'k', 'pink', 'brown', 'b', 'g', 'r', 'orange', 'k', 'pink', 'brown']\n",
        "\n",
        "for i, combined_task in enumerate(gradient_metrics):\n",
        "  combined_task_grad_dict = {task:[] for task in FLAGS.tasks}\n",
        "  for combined_task_grad_metrics in gradient_metrics[combined_task]:\n",
        "    for task,gain in combined_task_grad_metrics.items():\n",
        "      combined_task_grad_dict[task].append(gain)\n",
        "  \n",
        "  plt.title(\"{} Gradient Effect on Model Performance\".format(combined_task))\n",
        "  for c_i, task in enumerate(combined_task_grad_dict):\n",
        "    if task not in color_legend_set:\n",
        "      color_legend_set[task] = True\n",
        "      legend_elements.append(Line2D([0], [0], color=colors[c_i], lw=4, label='{}'.format(task)))\n",
        "    plt.plot(x_axis, combined_task_grad_dict[task], label=\"{}\".format(task), marker=markers[i], color=colors[c_i], markersize=5.0)\n",
        "    if task not in integrals:\n",
        "      integrals[task] = {}\n",
        "    integrals[task][combined_task] = it.cumtrapz(combined_task_grad_dict[task])[-1]\n",
        "  if combined_task not in marker_legend_set:\n",
        "    marker_legend_set[combined_task] = True\n",
        "    legend_elements.append(Line2D([0], [0], marker=markers[i], color='k', label='{}'.format(combined_task)))\n",
        "\n",
        "plt.legend(handles=legend_elements, loc='best')\n",
        "plt.xlabel('epoch')\n",
        "plt.ylabel('Gain')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "F5yE9bpW_z53"
      },
      "source": [
        "revised_integrals = {}\n",
        "for task in integrals:\n",
        "  revised_integrals[task] = {}\n",
        "  for gradient in integrals[task]:\n",
        "    if gradient.count('_') < 5: # Remove combined gradient updates on a given task.\n",
        "      revised_integrals[task][gradient] = integrals[task][gradient] \n",
        "\n",
        "# Normalize each value by the task-gradient on itself..\n",
        "final_integrals = {}\n",
        "for task in revised_integrals:\n",
        "  final_integrals[task] = {}\n",
        "  self_task = revised_integrals[task][task]\n",
        "  for gradient in revised_integrals[task]:\n",
        "    final_integrals[task][gradient] = (self_task + revised_integrals[task][gradient])/self_task - 1.0\n",
        "    \n",
        "# Zero out same task gradient for radar chart.\n",
        "for task in final_integrals:\n",
        "  final_integrals[task][task] = -0.0"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "P9kQ3n8UCnkm"
      },
      "source": [
        "import numpy as np\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib.patches import Circle, RegularPolygon\n",
        "from matplotlib.path import Path\n",
        "from matplotlib.projections.polar import PolarAxes\n",
        "from matplotlib.projections import register_projection\n",
        "from matplotlib.spines import Spine\n",
        "from matplotlib.transforms import Affine2D\n",
        "from matplotlib import rc\n",
        "\n",
        "plt.rcParams['font.family'] = 'Roboto'\n",
        "\n",
        "def radar_factory(num_vars, frame='circle'):\n",
        "    \"\"\"Create a radar chart with `num_vars` axes.\n",
        "\n",
        "    This function creates a RadarAxes projection and registers it.\n",
        "\n",
        "    Parameters\n",
        "    ----------\n",
        "    num_vars : int\n",
        "        Number of variables for radar chart.\n",
        "    frame : {'circle' | 'polygon'}\n",
        "        Shape of frame surrounding axes.\n",
        "\n",
        "    \"\"\"\n",
        "    # calculate evenly-spaced axis angles\n",
        "    theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)\n",
        "\n",
        "    class RadarAxes(PolarAxes):\n",
        "\n",
        "        name = 'radar'\n",
        "        # use 1 line segment to connect specified points\n",
        "        RESOLUTION = 1\n",
        "\n",
        "        def __init__(self, *args, **kwargs):\n",
        "            super().__init__(*args, **kwargs)\n",
        "            # rotate plot such that the first axis is at the top\n",
        "            self.set_theta_zero_location('N')\n",
        "\n",
        "        def fill(self, *args, closed=True, **kwargs):\n",
        "            \"\"\"Override fill so that line is closed by default\"\"\"\n",
        "            return super().fill(closed=closed, *args, **kwargs)\n",
        "\n",
        "        def plot(self, *args, **kwargs):\n",
        "            \"\"\"Override plot so that line is closed by default\"\"\"\n",
        "            lines = super().plot(*args, **kwargs)\n",
        "            for line in lines:\n",
        "                self._close_line(line)\n",
        "\n",
        "        def _close_line(self, line):\n",
        "            x, y = line.get_data()\n",
        "            # FIXME: markers at x[0], y[0] get doubled-up\n",
        "            if x[0] != x[-1]:\n",
        "                x = np.concatenate((x, [x[0]]))\n",
        "                y = np.concatenate((y, [y[0]]))\n",
        "                line.set_data(x, y)\n",
        "\n",
        "        def set_varlabels(self, labels):\n",
        "            self.set_thetagrids(np.degrees(theta), labels)\n",
        "\n",
        "        def _gen_axes_patch(self):\n",
        "            # The Axes patch must be centered at (0.5, 0.5) and of radius 0.5\n",
        "            # in axes coordinates.\n",
        "            if frame == 'circle':\n",
        "                return Circle((0.5, 0.5), 0.5)\n",
        "            elif frame == 'polygon':\n",
        "                return RegularPolygon((0.5, 0.5), num_vars,\n",
        "                                      radius=.5, edgecolor=\"k\")\n",
        "            else:\n",
        "                raise ValueError(\"unknown value for 'frame': %s\" % frame)\n",
        "\n",
        "        def _gen_axes_spines(self):\n",
        "            if frame == 'circle':\n",
        "                return super()._gen_axes_spines()\n",
        "            elif frame == 'polygon':\n",
        "                # spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.\n",
        "                spine = Spine(axes=self,\n",
        "                              spine_type='circle',\n",
        "                              path=Path.unit_regular_polygon(num_vars))\n",
        "                # unit_regular_polygon gives a polygon of radius 1 centered at\n",
        "                # (0, 0) but we want a polygon of radius 0.5 centered at (0.5,\n",
        "                # 0.5) in axes coordinates.\n",
        "                spine.set_transform(Affine2D().scale(.5).translate(.5, .5)\n",
        "                                    + self.transAxes)\n",
        "                return {'polar': spine}\n",
        "            else:\n",
        "                raise ValueError(\"unknown value for 'frame': %s\" % frame)\n",
        "\n",
        "    register_projection(RadarAxes)\n",
        "    return theta\n",
        "\n",
        "\n",
        "def example_data():\n",
        "    data = [\n",
        "        [task for task in final_integrals],\n",
        "        ('Transference in CelebA', [[final_integrals[task][grad] for task in final_integrals] for grad in final_integrals[list(final_integrals.keys())[0]]])\n",
        "         ]\n",
        "    return data\n",
        "\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    N = len(final_integrals.keys())\n",
        "    theta = radar_factory(N, frame='polygon')\n",
        "\n",
        "    data = example_data()\n",
        "    spoke_labels = data.pop(0)\n",
        "    colors = ['b', 'r', 'g', 'k', 'magenta', 'gold', 'saddlebrown', 'mediumspringgreen', 'orange']\n",
        "    title, case_data = data[0][0], data[0][1]\n",
        "\n",
        "    fig, ax = plt.subplots(figsize=(9, 9), nrows=1, ncols=1,\n",
        "                          subplot_kw=dict(projection='radar'))\n",
        "    ax.tick_params(pad=20)\n",
        "\n",
        "    ax.set_rgrids([-1.0, -0.75, -0.5, -0.25, 0.0], angle=38.5, fontweight='heavy', color='black')\n",
        "    ax.set_title(title, weight='bold', size='medium', position=(0.5, 1.1),\n",
        "                  horizontalalignment='center', verticalalignment='center')\n",
        "    for d, color in zip(case_data, colors):\n",
        "        ax.plot(theta, d, color=color)\n",
        "        ax.fill(theta, d, facecolor=color, alpha=0.25)\n",
        "    ax.set_varlabels(spoke_labels)\n",
        "\n",
        "\n",
        "    data = example_data()\n",
        "    spoke_labels = data.pop(0)\n",
        "\n",
        "    labels = [grad for grad in final_integrals[list(final_integrals.keys())[0]]]\n",
        "    legend = ax.legend(labels, loc=(0.9, .95),\n",
        "                       labelspacing=0.1, fontsize='medium', title='Gradient Update')\n",
        "\n",
        "    plt.tight_layout()\n",
        "    plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JwQU1I08MXrY"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}