{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HZkjHDQ_7Xte"
      },
      "source": [
        "# M-layer Robustness\n",
        "\n",
        "This notebook analyzes model robustness and validates robustness claims stated in \"Intelligent Matrix Exponentiation\"."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "usGerxU97S58"
      },
      "outputs": [],
      "source": [
        "# Imports and set-up.\n",
        "\n",
        "import collections\n",
        "import contextlib\n",
        "import hashlib\n",
        "import math\n",
        "import numbers\n",
        "import operator\n",
        "import os\n",
        "\n",
        "from matplotlib import pyplot\n",
        "import numpy\n",
        "import opt_einsum  # numpy.einsum() cannot handle fancy-type arrays, this can.\n",
        "import scipy.special\n",
        "\n",
        "import logging\n",
        "logging.getLogger('tensorflow').disabled = True\n",
        "\n",
        "import tensorflow.compat.v1.keras as keras\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mYwcwavVuQ43"
      },
      "source": [
        "Check integrity of  `cifar10_model.npy`. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BjKrFDTctHf6"
      },
      "outputs": [],
      "source": [
        "with open('cifar10_model.npy', 'rb') as h:\n",
        "        fp = hashlib.sha256(h.read()).hexdigest()\n",
        "        assert fp == ('abe15f55bffd5f26df664c87bd9a4db2',\n",
        "                      'd2cbc05dd086bc39c860cc66b7dc1d25'),\n",
        "                      \"Corrupted model file detected.\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XEQwed0G-tbw"
      },
      "outputs": [],
      "source": [
        "# Auxiliary Affine Arithmetic Python code (basic AA implementation).\n",
        "\n",
        "# Stack of dynamic \"environmentally active\" AAContext contexts.\n",
        "# See AANum.__repr__ for an explanation.\n",
        "_DynamicAAContexts = []  # pylint:disable=invalid-name\n",
        "\n",
        "\n",
        "def _merge_ess(linop, ess1, ess2):\n",
        "  \"\"\"Merges weighted-error-symbol collections.\"\"\"\n",
        "  ess = {}\n",
        "  for sym in set(ess1) | set(ess2):\n",
        "    coeff = linop(ess1.get(sym, 0), ess2.get(sym, 0))\n",
        "    if coeff:\n",
        "      ess[sym] = coeff\n",
        "  return ess\n",
        "\n",
        "\n",
        "class AAContext(contextlib.AbstractContextManager):\n",
        "  \"\"\"Affine Arithmetic Context.\n",
        "\n",
        "  Every AANumber refers to a context, which takes care of managing error\n",
        "  symbols. Only numbers from the same AAContext can be combined.\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self, name=None,\n",
        "               collapsing_threshold=0,\n",
        "               thorough=True,\n",
        "               max_num_symbols=float('inf')):\n",
        "    \"\"\"Initializes the instance.\n",
        "\n",
        "    Args:\n",
        "      name: The name of the context, for debugging.\n",
        "      collapsing_threshold: Error symbols smaller than\n",
        "        {largest error symbol} * {this factor} will get collected\n",
        "        into a new error symbol.\n",
        "      thorough: If True, work hard to properly keep track of products\n",
        "        of error symbols. If False, only do this for the 1st level\n",
        "        of products.\n",
        "      max_num_symbols: The maximal number of error symbols to keep\n",
        "        on one quantity.\n",
        "    \"\"\"\n",
        "    # Error symbol expansions.\n",
        "    # Key = Error symbol number.\n",
        "    # Value = `True` for fundamental error-symbols, or a sorted tuple of\n",
        "    # (symbol_id, power) of fundamental symbols and their power that\n",
        "    # the symbol was obtained from.\n",
        "    self._esym_expansions = {}\n",
        "    # The reverse mapping for higher-order error-symbols, to find an\n",
        "    # already-known error symbol given its expansion.\n",
        "    self._esym_by_expansion = {}\n",
        "    self._collapsing_threshold = float(collapsing_threshold)\n",
        "    self._max_num_symbols = float(max_num_symbols)\n",
        "    self._name = name\n",
        "    self._thorough = thorough\n",
        "    self._num_syms = 0\n",
        "    self._num_mul = 0\n",
        "\n",
        "  def __repr__(self):\n",
        "    return '\u003cAAContext name=%s\u003e' % self._name\n",
        "\n",
        "  def __enter__(self):\n",
        "    _DynamicAAContexts.append(self)\n",
        "    return self\n",
        "\n",
        "  def __exit__(self, exc_type, exc_value, traceback):\n",
        "    del exc_type, exc_value, traceback  # Unused.\n",
        "    _DynamicAAContexts.pop()\n",
        "\n",
        "  def product_symbol(self, sym_x, sym_y):\n",
        "    \"\"\"Finds an error symbol for the product of two error symbols.\"\"\"\n",
        "    expansion_x = self._esym_expansions.get(sym_x, False)\n",
        "    expansion_y = self._esym_expansions.get(sym_y, False)\n",
        "    if not self._thorough and not expansion_x is expansion_y is True:\n",
        "      # Non-thorough operation, higher product symbol.\n",
        "      ret = self._num_syms\n",
        "      self._num_syms += 1\n",
        "      return ret\n",
        "    #\n",
        "    if isinstance(expansion_x, bool):\n",
        "      expansion_x = ((sym_x, 1),)\n",
        "    if isinstance(expansion_y, bool):\n",
        "      expansion_y = ((sym_y, 1),)\n",
        "    accum_xy = {sym: power for sym, power in expansion_x}\n",
        "    for esym_y, power in expansion_y:\n",
        "      accum_xy[esym_y] = accum_xy.get(esym_y, 0) + power\n",
        "    expansion_xy = tuple(sorted(accum_xy.items()))\n",
        "    sym_xy = self._esym_by_expansion.get(expansion_xy)\n",
        "    if sym_xy is not None:\n",
        "      return sym_xy\n",
        "    # Otherwise, make a new entry.\n",
        "    sym_xy = self._num_syms\n",
        "    self._num_syms += 1\n",
        "    self._esym_expansions[sym_xy] = expansion_xy\n",
        "    self._esym_by_expansion[expansion_xy] = sym_xy\n",
        "    return sym_xy\n",
        "\n",
        "  def new_symbol(self, fundamental=False):\n",
        "    \"\"\"Generates a new independent error symbol, e.g. for collapsing.\"\"\"\n",
        "    sym = self._num_syms\n",
        "    self._num_syms += 1\n",
        "    if fundamental:\n",
        "      self._esym_expansions[sym] = True\n",
        "    return sym\n",
        "\n",
        "  def collapse_ess(self, ess):\n",
        "    \"\"\"Modifies ess by collapsing symbols with 'small' coefficients into one.\"\"\"\n",
        "    collapsing_threshold = self._collapsing_threshold\n",
        "    if not ess:\n",
        "      return  # Short-cut.\n",
        "    if collapsing_threshold \u003e 0:\n",
        "      threshold_abs_coeff = collapsing_threshold * max(\n",
        "          abs(x) for x in ess.values())\n",
        "      to_collapse = {sym: x for sym, x in ess.items()\n",
        "                     if abs(x) \u003c threshold_abs_coeff}\n",
        "    else:\n",
        "      to_collapse = set()\n",
        "    if len(ess) - len(to_collapse) + 1 \u003e self._max_num_symbols:\n",
        "      # Collapsing the symbols listed above will not get us to an acceptable\n",
        "      # number of symbols.\n",
        "      syms_by_coeff_magnitude = sorted(\n",
        "          (sym_x for sym_x in ess.items() if sym_x[0] not in to_collapse),\n",
        "          key=lambda sym_x: abs(sym_x[1]))\n",
        "      for sym, c in syms_by_coeff_magnitude[:-int(self._max_num_symbols)]:\n",
        "        to_collapse[sym] = c\n",
        "    if to_collapse:\n",
        "      collapsed_coeff = sum(map(abs, to_collapse.values()))\n",
        "      for sym in to_collapse:\n",
        "        del ess[sym]  # Removed by collapsing.\n",
        "      ess[self.new_symbol()] = collapsed_coeff\n",
        "\n",
        "  def num(self, val, radius=0):\n",
        "    \"\"\"Produces a new AANum with fresh error-symbol.\"\"\"\n",
        "    esym = self.new_symbol(fundamental=True)\n",
        "    return AANum(val, (esym, radius), context=self)\n",
        "\n",
        "\n",
        "class AANum(numbers.Number):\n",
        "  \"\"\"Affine-Arithmetic Number.\"\"\"\n",
        "\n",
        "  def __init__(self, val, *seq_esym_escale, context=None):\n",
        "    \"\"\"Initializes the instance.\"\"\"\n",
        "    if context is None:\n",
        "      if not _DynamicAAContexts:\n",
        "        raise RuntimeError('No AAContext available for initializing AANum.')\n",
        "      context = _DynamicAAContexts[-1]\n",
        "    self._context = context\n",
        "    self._val = val\n",
        "    self._ess = {esym: escale for esym, escale in seq_esym_escale}\n",
        "\n",
        "  @property\n",
        "  def radius(self):\n",
        "    return sum(map(abs, self._ess.values()))\n",
        "\n",
        "  def __str__(self):\n",
        "    return '\u003cAA %g +/- %g\u003e' % (self._val, self.radius)\n",
        "\n",
        "  def __repr__(self):\n",
        "    # Implementing __repr__ is tricky, since we cannot also serialize\n",
        "    # the context: The context must be identical for all the different\n",
        "    # AANum instances in an arithmetic expression.\n",
        "    # We resolve this by making __repr__ produce a representation that refers\n",
        "    # to the 'top dynamic context' which is set by using AAContext as\n",
        "    # a context-manager.\n",
        "    return 'AANum(%r, %s)' % (\n",
        "        self._val,\n",
        "        ', '.join(map(repr, sorted(self._ess.items()))))\n",
        "\n",
        "  def _linop(self, linop, other):\n",
        "    \"\"\"Lifts a linear operator to Affine Arithmetic.\"\"\"\n",
        "    if isinstance(other, AANum):\n",
        "      if self._context is not other._context:\n",
        "        raise ValueError('Cannot combine AANums from different contexts.')\n",
        "      result = AANum(linop(self._val, other._val), context=self._context)\n",
        "      ess = _merge_ess(linop, self._ess, other._ess)\n",
        "      self._context.collapse_ess(ess)\n",
        "      result._ess = ess\n",
        "      return result\n",
        "    else:\n",
        "      # Case: AANum +- {non-AA number}.\n",
        "      result = AANum(linop(self._val, other), context=self._context)\n",
        "      result._ess = _merge_ess(linop, self._ess, {})\n",
        "      return result\n",
        "\n",
        "  def __add__(self, other):\n",
        "    return self._linop(operator.add, other)\n",
        "\n",
        "  def __radd__(self, other):\n",
        "    return self.__add__(other)\n",
        "\n",
        "  def __sub__(self, other):\n",
        "    return self._linop(operator.sub, other)\n",
        "\n",
        "  def __rsub__(self, other):\n",
        "    return self._linop(lambda x, y: y - x, other)\n",
        "\n",
        "  def __mul__(self, other):\n",
        "    context = self._context\n",
        "    context._num_mul += 1\n",
        "    if not isinstance(other, AANum):\n",
        "      # Case: AANum * {non-AA number}.\n",
        "      result = AANum(self._val * other, context=context)\n",
        "      if other != 0:\n",
        "        result._ess = {sym: x * other for sym, x in self._ess.items()}\n",
        "      return result\n",
        "    # Otherwise, `other` is also an AANum.\n",
        "    if context is not other._context:\n",
        "      raise ValueError('Cannot multiply AANums from different contexts.')\n",
        "    xval = self._val\n",
        "    yval = other._val\n",
        "    # Scaling error symbol coefficients with xval and yval.\n",
        "    ess = {sym: x * yval for sym, x in self._ess.items()}\n",
        "    for sym, y in other._ess.items():\n",
        "      ess[sym] = ess.get(sym, 0) + xval * y\n",
        "    # Adding error symbol coefficients for products of error-symbols.\n",
        "    for sym_x, coeff_x in self._ess.items():\n",
        "      for sym_y, coeff_y in other._ess.items():\n",
        "        sym_xy = context.product_symbol(sym_x, sym_y)\n",
        "        ess[sym_xy] = ess.get(sym_xy, 0) + coeff_x * coeff_y\n",
        "    # Pruning\n",
        "    for sym_to_delete in [sym for sym, x in ess.items() if x == 0]:\n",
        "      del ess[sym_to_delete]  # Remove coefficient-zero entries.\n",
        "    # Collapsing small-coefficient symbols into one.\n",
        "    context.collapse_ess(ess)\n",
        "    result = AANum(xval * yval, context=context)\n",
        "    result._ess = ess\n",
        "    return result\n",
        "\n",
        "  def __rmul__(self, other):\n",
        "    return self.__mul__(other)\n",
        "\n",
        "  def __truediv__(self, other):\n",
        "    if not isinstance(other, AANum):\n",
        "      # Case: AANum / {non-AA number}.\n",
        "      result = AANum(self._val / other, context=self._context)\n",
        "      if other != 0:\n",
        "        result._ess = {sym: x / other for sym, x in self._ess.items()}\n",
        "      return result\n",
        "    # Dividing an AANum by another AANum is not implemented yet.\n",
        "    return NotImplemented\n",
        "\n",
        "  def __float__(self):\n",
        "    return float(self._val)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XwEYzEsb_CZY"
      },
      "outputs": [],
      "source": [
        "# Checking robustness claims from and around \"Intelligent Matrix Exponentiation\"\n",
        "\n",
        "NUM_CLASSES = 10\n",
        "\n",
        "\n",
        "def plot_loghist(xs, bins, filename=None, show=False):\n",
        "  \"\"\"Plots a histogram with logarithmic x-axis.\"\"\"\n",
        "  _, bins = numpy.histogram(xs, bins=bins)\n",
        "  log_bins = numpy.logspace(numpy.log10(bins[0]),\n",
        "                            numpy.log10(bins[-1]),\n",
        "                            len(bins))\n",
        "  fig = pyplot.figure()\n",
        "  axes = fig.gca()\n",
        "  axes.hist(xs, bins=log_bins, histtype='step')\n",
        "  axes.set_xscale('log')\n",
        "  axes.grid()\n",
        "  axes.set_title('Certified-Robustness Bounds Distribution')\n",
        "  if filename is not None:\n",
        "    fig.savefig(filename)\n",
        "  if show or filename is None:\n",
        "    pyplot.show()\n",
        "    fig.show()\n",
        "  return fig\n",
        "\n",
        "\n",
        "def sorted_xys(xs, ys):\n",
        "  \"\"\"Sorts features/labels into well-defined order.\n",
        "\n",
        "  tfds.load('cifar10') does not give us a guaranteed order of examples.\n",
        "  \"\"\"\n",
        "  fp_weights = numpy.random.RandomState(seed=0).uniform(\n",
        "      size=xs.size // xs.shape[0])\n",
        "  xys = sorted(zip(xs, ys),\n",
        "               key=lambda xy: numpy.dot(fp_weights, xy[0].reshape(-1)))\n",
        "  return (numpy.stack([x for x, _ in xys], axis=0),\n",
        "          numpy.stack([y for _, y in xys], axis=0))\n",
        "\n",
        "\n",
        "def load_cifar10():\n",
        "  \"\"\"Loads the CIFAR-10 dataset.\"\"\"\n",
        "  train = tfds.load('cifar10', split='train', with_info=False, batch_size=-1)\n",
        "  test = tfds.load('cifar10', split='test', with_info=False, batch_size=-1)\n",
        "  train_np = tfds.as_numpy(train)\n",
        "  test_np = tfds.as_numpy(test)\n",
        "  x_train, y_train = sorted_xys(train_np['image'], train_np['label'])\n",
        "  x_test, y_test = sorted_xys(test_np['image'], test_np['label'])\n",
        "  print(f'x_train shape: {x_train.shape}, x_test shape: {x_test.shape}')\n",
        "  y_train_cat = keras.utils.to_categorical(y_train, NUM_CLASSES)\n",
        "  y_test_cat = keras.utils.to_categorical(y_test, NUM_CLASSES)\n",
        "  x_train_range1 = x_train.astype('float32') / 255\n",
        "  x_test_range1 = x_test.astype('float32') / 255\n",
        "  return ((x_train_range1, y_train_cat), (x_test_range1, y_test_cat))\n",
        "\n",
        "\n",
        "def load_the_cifar10_model(filename='cifar10_model.npy'):\n",
        "  return np_load_arrays_from_file(\n",
        "      [('mb', (20, 20)), ('mk', (35, 20, 20)),\n",
        "       ('sb', (10,)), ('sk', (20, 20, 10)),\n",
        "       ('ub', (35,)), ('uk', (32, 32, 3, 35))],\n",
        "      filename)\n",
        "\n",
        "\n",
        "def _get_taylor_strategy(n_max, eye, m, prod=numpy.dot):\n",
        "  \"\"\"Finds out how to build x**N with low depth, given all the lower powers.\"\"\"\n",
        "  depth_and_tensor_power_by_exponent = [None] * (n_max + 1)\n",
        "  depth_and_tensor_power_by_exponent[0] = (0, eye)\n",
        "  depth_and_tensor_power_by_exponent[1] = (0, m)\n",
        "  for n in range(2, n_max + 1):\n",
        "    best_depth, best_k = min(\n",
        "        (1 + max(depth_and_tensor_power_by_exponent[k][0],\n",
        "                 depth_and_tensor_power_by_exponent[n - k][0]),\n",
        "         k) for k in range(1, n))\n",
        "    depth_and_tensor_power_by_exponent[n] = (\n",
        "        best_depth, prod(depth_and_tensor_power_by_exponent[best_k][1],\n",
        "                         depth_and_tensor_power_by_exponent[n - best_k][1]))\n",
        "  return depth_and_tensor_power_by_exponent\n",
        "\n",
        "\n",
        "def _expm_taylor(m, max_pow):\n",
        "  \"\"\"Matrix exponentiation via Taylor series.\"\"\"\n",
        "  m_id = numpy.eye(m.shape[0])\n",
        "  powers = _get_taylor_strategy(max_pow, m_id, m)\n",
        "  fact = 1\n",
        "  accum = m_id\n",
        "  for n in range(1, max_pow):\n",
        "    fact *= n\n",
        "    accum = accum + powers[n][1] / fact\n",
        "  return accum\n",
        "\n",
        "\n",
        "def expm(m, max_taylor_pow=8, max_abs_eigenvalue=0.5):\n",
        "  \"\"\"Approximate matrix exponentiation.\"\"\"\n",
        "  spectral_radius = max(abs(ev) for ev in numpy.linalg.eigvals(m.astype(float)))\n",
        "  num_halvings = max(\n",
        "      0, math.ceil(math.log(spectral_radius / max_abs_eigenvalue, 2)))\n",
        "  m_small = m * 0.5**num_halvings\n",
        "  ret = _expm_taylor(m_small, max_taylor_pow)\n",
        "  for _ in range(num_halvings):\n",
        "    ret = numpy.dot(ret, ret)\n",
        "  return ret\n",
        "\n",
        "\n",
        "def np_load_arrays_from_file(names_and_shapes, filename):\n",
        "  \"\"\"Loads a collection of numpy-arrays from a saved vector.\"\"\"\n",
        "  with open(filename, 'rb') as h:\n",
        "    all_data = numpy.load(h)\n",
        "  offset = 0\n",
        "  ret = {}\n",
        "  for name, shape in names_and_shapes:\n",
        "    a = numpy.zeros(shape)\n",
        "    a.flat = all_data[offset: offset + a.size]\n",
        "    offset += a.size\n",
        "    ret[name] = a\n",
        "  if offset != all_data.size:\n",
        "    raise ValueError('Data size mismatch.')\n",
        "  return ret\n",
        "\n",
        "\n",
        "def process_example(tensors, ex_img, expm=expm):\n",
        "  \"\"\"Processes an example.\"\"\"\n",
        "  v_lin_angles = opt_einsum.contract('yxca,yxc-\u003ea', tensors['uk'], ex_img)\n",
        "  v_aff_angles = v_lin_angles + tensors['ub']\n",
        "  m_lin_gen = opt_einsum.contract('amn,a-\u003emn', tensors['mk'], v_aff_angles)\n",
        "  m_aff_gen = m_lin_gen + tensors['mb']  # We could also absorb ub into mb.\n",
        "  m_exp = expm(m_aff_gen)\n",
        "  v_lin_exp = opt_einsum.contract('mn,mnp-\u003ep', m_exp, tensors['sk'])\n",
        "  v_aff_exp = v_lin_exp + tensors['sb']\n",
        "  float_weights = v_aff_exp.astype(float)\n",
        "  float_probabilities = numpy.exp(float_weights) / sum(numpy.exp(float_weights))\n",
        "  return {k: v for k, v in locals().items()}\n",
        "\n",
        "\n",
        "def fuzzify(aacontext, array, delta):\n",
        "  \"\"\"Fuzzifies a feature-array by adding uncertainty to every parameter.\"\"\"\n",
        "  return numpy.array([aacontext.num(x, radius=delta)\n",
        "                      for x in array.flat]).reshape(array.shape)\n",
        "\n",
        "\n",
        "def aa_process_example(tensors, ex_img, img_fuzz=1e-5,\n",
        "                       collapsing_threshold=1e-3,\n",
        "                       max_num_symbols=10,\n",
        "                       thorough=False):\n",
        "  \"\"\"Processes an example with Affine Arithmetic.\"\"\"\n",
        "  aac0 = AAContext(name='aac0',\n",
        "                   thorough=thorough,\n",
        "                   collapsing_threshold=collapsing_threshold,\n",
        "                   max_num_symbols=max_num_symbols)\n",
        "  fuzzed_img = fuzzify(aac0, ex_img, img_fuzz)\n",
        "  v_lin_angles = opt_einsum.contract('yxca,yxc-\u003ea', tensors['uk'], fuzzed_img)\n",
        "  aac1 = AAContext(name='aac1',\n",
        "                   thorough=thorough,\n",
        "                   collapsing_threshold=collapsing_threshold,\n",
        "                   max_num_symbols=max_num_symbols)\n",
        "  v_lin_angles_aa = numpy.array(\n",
        "      [aac1.num(float(x), x.radius) for x in v_lin_angles])\n",
        "  v_aff_angles_aa = v_lin_angles + tensors['ub']\n",
        "  m_lin_gen = opt_einsum.contract('amn,a-\u003emn', tensors['mk'], v_aff_angles_aa)\n",
        "  for aaz in m_lin_gen.flat:\n",
        "    aaz._context.collapse_ess(aaz._ess)\n",
        "  m_aff_gen = m_lin_gen + tensors['mb']  # We could also absorb ub into mb.\n",
        "  m_exp = expm(m_aff_gen)\n",
        "  v_lin_exp = opt_einsum.contract('mn,mnp-\u003ep', m_exp, tensors['sk'])\n",
        "  v_aff_exp = v_lin_exp + tensors['sb']\n",
        "  float_weights = v_aff_exp.astype(float)\n",
        "  float_probabilities = numpy.exp(float_weights) / sum(numpy.exp(float_weights))\n",
        "  return {k: v for k, v in locals().items()}\n",
        "\n",
        "\n",
        "def mspace_transform_tensors(tensors, mright_transform=None):\n",
        "  \"\"\"Applies M_right coordinate transformation to tensors.\"\"\"\n",
        "  uk = tensors['uk']\n",
        "  ub = tensors['ub']\n",
        "  mk = tensors['mk']\n",
        "  mb = tensors['mb']\n",
        "  sk = tensors['sk']\n",
        "  sb = tensors['sb']\n",
        "  imr_transform = numpy.linalg.inv(mright_transform)\n",
        "  mk = numpy.einsum('amn,nR,Qm-\u003eaQR',\n",
        "                    mk, mright_transform, imr_transform, optimize='greedy')\n",
        "  mb = numpy.einsum('mn,nR,Qm-\u003eQR',\n",
        "                    mb, mright_transform, imr_transform, optimize='greedy')\n",
        "  sk = numpy.einsum('mnp,Rn,mQ-\u003eQRp',\n",
        "                    sk, imr_transform, mright_transform, optimize='greedy')\n",
        "  return dict(uk=uk, ub=ub, mk=mk, mb=mb, sk=sk, sb=sb)\n",
        "\n",
        "\n",
        "def determine_robustness_aa(model_tensors,\n",
        "                            ex_img,\n",
        "                            Linf_bounds_to_check,\n",
        "                            coordinate_transform_to_example=False,\n",
        "                            **aa_kwargs):\n",
        "  \"\"\"Determines robustness bounds via AA.\"\"\"\n",
        "  if coordinate_transform_to_example:\n",
        "    processed = process_example(model_tensors, ex_img)\n",
        "    m_gen = processed['m_aff_gen']\n",
        "    _, e0g_eigvecsT = numpy.linalg.eig(m_gen)\n",
        "    model_tensors = mspace_transform_tensors(model_tensors, e0g_eigvecsT)\n",
        "  for fuzz in sorted(Linf_bounds_to_check, reverse=True):\n",
        "    aa_processed = aa_process_example(model_tensors, ex_img, img_fuzz=fuzz,\n",
        "                                      **aa_kwargs)\n",
        "    v_evidence = aa_processed['v_aff_exp']\n",
        "    m_margins = [[vj - vk for vk in v_evidence] for vj in v_evidence]\n",
        "    if any(all(float(v) \u003e= v.radius for v in row) for row in m_margins):\n",
        "      return fuzz  # For this fuzz, we have guaranteed robustness.\n",
        "  return 0\n",
        "\n",
        "\n",
        "def check_robustness_claims():\n",
        "  \"\"\"Checks claims about robustness from the paper.\"\"\"\n",
        "  (x_train, y_train), (x_test, y_test) = load_cifar10()\n",
        "  model_tensors = load_the_cifar10_model()\n",
        "  # Check that the model classifies a small set of examples as expected.\n",
        "  # Major problems (e.g. having loaded the wrong model file) would\n",
        "  # show up here.\n",
        "  probs = [\n",
        "      tuple(numpy.round(process_example(\n",
        "          model_tensors,\n",
        "          x_train[k],\n",
        "      )['float_probabilities'], 3))\n",
        "      for k in range(5)]\n",
        "  expected_probs = (\n",
        "      [(0.017, 0.035, 0.191, 0.072, 0.032, 0.013, 0.6, 0.026, 0.009, 0.006),\n",
        "       (0.003, 0.008, 0.123, 0.04, 0.703, 0.017, 0.046, 0.055, 0.001, 0.003),\n",
        "       (0.037, 0.036, 0.301, 0.041, 0.074, 0.012, 0.453, 0.024, 0.015, 0.009),\n",
        "       (0.266, 0.239, 0.345, 0.053, 0.001, 0.026, 0.026, 0.014, 0.014, 0.015),\n",
        "       (0.013, 0.174, 0.094, 0.076, 0.316, 0.034, 0.199, 0.028, 0.045, 0.02)])\n",
        "  assert probs == expected_probs, \\\n",
        "         'Model did not classify a small sample of examples as expected.'\n",
        "  #\n",
        "  test_weights = [\n",
        "      tuple(process_example(\n",
        "          model_tensors,\n",
        "          img,\n",
        "      )['float_weights'])\n",
        "      for img in x_test]\n",
        "  n_weights_of_correct_classifications = [\n",
        "      (n, weights) for n, (weights, y) in enumerate(zip(test_weights, y_test))\n",
        "      if y[numpy.argmax(weights)] == 1]\n",
        "  accuracy = len(n_weights_of_correct_classifications) / len(y_test)\n",
        "  assert accuracy \u003e 0.52  # Implicit claim: Accuracy of this small model is OK.\n",
        "  #\n",
        "  tu = numpy.einsum('yxca,amn-\u003eyxcmn',\n",
        "                    model_tensors['uk'],\n",
        "                    model_tensors['mk'])\n",
        "  m = numpy.einsum('yxcmn-\u003emn', abs(tu))\n",
        "  delta_in = numpy.linalg.svd(m @ m.T)[1][0]**.5\n",
        "  assert numpy.round(delta_in, 3) == 213.598, \\\n",
        "         'Claim L.376R: delta_in ~ 200'\n",
        "  s_2_norm = numpy.linalg.svd(model_tensors['sk'].reshape(-1, 10))[1][0]\n",
        "  assert numpy.round(s_2_norm, 3) == 3.153, 'Claim L.376R, |S|_2 ~ 3'\n",
        "  ms = numpy.einsum('yxcmn,byxc-\u003ebmn', tu, x_test)\n",
        "  m_2_norms = sorted([max(abs(ev) for ev in numpy.linalg.eigvals(m))\n",
        "                      for m in ms])\n",
        "  assert m_2_norms[int(0.98 * len(m_2_norms))] \u003c 4, \\\n",
        "         'Claim L.377R: |M|_2 \u003c 4 for 98% of test cases.'\n",
        "  # We actually can strengthen this claim.\n",
        "  top_evidences = [sorted(ws, reverse=True)[:2]\n",
        "                   for _, ws in n_weights_of_correct_classifications]\n",
        "  evidence_margins = sorted([first - second for first, second in top_evidences],\n",
        "                            reverse=True)\n",
        "  assert evidence_margins[int(0.63 * len(evidence_margins))] \u003e 1, \\\n",
        "         ('Claim L.380R: Evidence margins at least 1 for \u003e63% of '\n",
        "          'correctly-classified cases.')\n",
        "  #\n",
        "  def get_Linf_bound(ex_img):\n",
        "    \"\"\"Determines a L_infinity bound.\"\"\"\n",
        "    processed = process_example(model_tensors, ex_img)\n",
        "    m_gen = processed['m_aff_gen']\n",
        "    m_radius = numpy.linalg.svd(m_gen @ m_gen.T)[1][0]**.5\n",
        "    sensitivity_factor = (\n",
        "        m_gen.shape[0]**.5 * math.exp(m_radius) * s_2_norm)\n",
        "    weights = sorted(processed['float_weights'], reverse=True)\n",
        "    evidence_margin = 0.5 * (weights[0] - weights[1])\n",
        "    ret = scipy.special.lambertw(evidence_margin /\n",
        "                                 sensitivity_factor) / delta_in\n",
        "    assert abs(ret.imag) \u003c 1e-8\n",
        "    return ret.real\n",
        "  #\n",
        "  linf_bounds = [get_Linf_bound(x_test[n])\n",
        "                 for n, _ in n_weights_of_correct_classifications]\n",
        "  plot_loghist(linf_bounds, 50, filename=None)\n",
        "\n",
        "\n",
        "def check_analytic_formula_aligns_with_aa(num_samples=100):\n",
        "  \"\"\"Checks that the analytic formula compares well with AA.\n",
        "\n",
        "  This function asserts that it is not easy to outperform the\n",
        "  analytic result on guaranteed robustness with AA.\n",
        "  \"\"\"\n",
        "  (x_train, y_train), (x_test, y_test) = load_cifar10()\n",
        "  model_tensors = load_the_cifar10_model()\n",
        "  test_weights = [\n",
        "      tuple(process_example(\n",
        "          model_tensors,\n",
        "          img,\n",
        "      )['float_weights'])\n",
        "      for img in x_test]\n",
        "  n_weights_of_correct_classifications = [\n",
        "      (n, weights) for n, (weights, y) in enumerate(zip(test_weights, y_test))\n",
        "      if y[numpy.argmax(weights)] == 1]\n",
        "  #\n",
        "  rng = numpy.random.RandomState(seed=0)  # Make reproducible.\n",
        "  ns_correct = [n for n, _ in n_weights_of_correct_classifications]\n",
        "  rng.shuffle(ns_correct)\n",
        "  ns_correct_samples = ns_correct[:num_samples]\n",
        "  robustnesses = [\n",
        "      determine_robustness_aa(\n",
        "          model_tensors, x_test[n],\n",
        "          [1e-7,  # Skip: 3e-7\n",
        "           1e-6, 3e-6,\n",
        "           1e-5, 3e-5,\n",
        "           1e-4],\n",
        "          # Using more AA error-symbols has been found to only slightly\n",
        "          # improve bounds, while having a strong negative impact on\n",
        "          # runtime.\n",
        "          max_num_symbols=3, thorough=False)\n",
        "      for n in ns_correct_samples]\n",
        "  if num_samples == 100:\n",
        "    assert (sorted(collections.Counter(robustnesses).items()) ==\n",
        "            [(0, 1), (1e-07, 3), (1e-06, 13), (3e-06, 32), \n",
        "             (1e-05, 47), (3e-05, 4)\n",
        "             ]), 'AA-Robustnesses guarantees are not as expected.'\n",
        "  return list(zip(ns_correct, robustnesses))\n",
        "\n",
        "\n",
        "print('=== Checking Robustness Claims ===')\n",
        "check_robustness_claims()\n",
        "print('=== Formula/AA Comparison ===')\n",
        "check_analytic_formula_aligns_with_aa()\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "M-Layer Robustness",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
