install.packages(c('tensorflow','keras','reticulate','coro','tfprobability','httr'))
require(keras)
install_keras(version='gpu', extra_packages=c('tensorflow_probability==0.16','tensorflow-datasets','matplotlib'))

We use R reticulate package to bring together R and Python APIs. Internally, install_keras creates an r-miniconda environment at ~/.local/share/ and puts packages there.

In [ ]:
require(tensorflow)
require(keras)
require(reticulate)
Sys.setenv("TF_XLA_FLAGS" = "--tf_xla_enable_xla_devices")
# R is not supported in colab. install own R.
#use_python("/usr/local/bin/python")
require(tfprobability)
require(coro)
Loading required package: tensorflow

Loading required package: keras

Loading required package: reticulate

Loading required package: tfprobability

Loading required package: coro


Attaching package: ‘coro’


The following object is masked from ‘package:reticulate’:

    as_iterator


Duality between message passing and back propagation¶

In this section, we show that we can back-propagate the smoothing densities of a Bayesian neural network through message passing, and we can also back propagate the gradient of loss with respect to filter density parameters. These two back propagation methods give the same gradient of loss with respect to weight parameters. We use this example for two purposes:

  • to sanity-check Theorem 1 in the paper, and
  • to illustrate that autodiff can compute not only gradients, but also probability distributions.

We use tf.custom_gradient to pass smoothing probability distributions backward. We compare the smoothing densities of latent features from message passing (implemented in layer_dense_Bayesian) and from the gradient of cross entropy loss with respect to the mean and variance of latent feature smoothing densities (implemented through tensorflow's autodiff in layer_dense_Bayesian_autodiff). The conversion is in the following code snippet, where x.update and P.update are smoothing densities computed from gradients.

x.update = lapply(1:4, function(n) y_predBayesian_autodiff[[n]][[1]] + tf$einsum('bij,bi->bj', y_predBayesian_autodiff[[n]][[2]], sensitivityBayesian_autodiff[[n]][[1]]))

P.update = lapply(1:4, function(n) y_predBayesian_autodiff[[n]][[2]] -
  tf$einsum('bij,bjk,blk->bil', y_predBayesian_autodiff[[n]][[2]], sensitivityBayesian_autodiff[[n]][[2]], y_predBayesian_autodiff[[n]][[2]]) -
  tf$einsum('bi,bj->bij', x.update[[n]] - y_predBayesian_autodiff[[n]][[1]], x.update[[n]] - y_predBayesian_autodiff[[n]][[1]]))
In [ ]:
mnist <- dataset_mnist()
In [ ]:
CE.mvnorm = function(m0, Q0, m1, Q1){
  pinv.Q1 = tf$linalg$pinv(Q1)
  (tf$reduce_sum( pinv.Q1 * Q0, axis=c(1L, 2L)) + tf$einsum('bi,bij,bj->b', m1-m0, pinv.Q1, m1-m0) + tf$linalg$logdet(2 * pi * Q1)) * (-.5)
}

layer_dense_Bayesian <- Layer(
  classname = "BayesDense",
  inherit = keras$layers$Layer,
  initialize = function(units = 7, f = tf$sigmoid, ...){
    super$initialize(...)
    self$units = as.integer(units)
    self$f = f
  },
  build = function(input_shape){
    #print(input_shape)
    super$build(input_shape)
    self$A = self$add_weight(name='A',shape=list(input_shape[[1]][[2]], self$units), initializer='normal', trainable=TRUE)
    self$B = self$add_weight(name='B',shape=list(input_shape[[1]][[2]], self$units), initializer='uniform', trainable=TRUE)
    self$W = tf$zeros_like(self$B[,1])
    self$`__forward__` = tf$custom_gradient(function(inputs){
      message('trace layer_dense_Bayesian forward ....')
      x = inputs[[1]]
      P0 = inputs[[2]]
      with(tf$GradientTape(persistent=TRUE) %as% g2, {
        with(tf$GradientTape(persistent=TRUE) %as% g, {
          g$watch(list(x,self$W))
          B.W = tf$einsum('ij,i->ij',self$B, self$W)
          h = self$f(tf$matmul(x, (self$A + B.W)))
          `__h__a` = self$f(tf$matmul(x, (self$A + B.W)))
          `__h__b` = self$f(tf$matmul(tf$stop_gradient(x), (self$A + B.W)))
        })

        dhdW = g$jacobian(`__h__b`, self$W)
        dhdx = tf$stop_gradient(g$batch_jacobian(`__h__a`, x))
        P = tf$matmul(tf$matmul(dhdx, P0), dhdx, transpose_b=TRUE) + tf$matmul(dhdW, dhdW, transpose_b=TRUE)
        P = P + tf$linalg$diag(1e-3*tf$ones_like(tf$linalg$diag_part(P[1,,])))
      })

      grad_fn = function(dh, dP, variables){
        message('trace layer_dense_Bayesian backward ....')
        with(g2, {
          h.update = dh
          P.update = dP
          G0 = tf$stop_gradient(tf$einsum('bip,bqp,bqj->bij', P0, dhdx, tf$linalg$pinv(P)))
          x.update = tf$stop_gradient(x + tf$einsum('bij,bj->bi',G0, h.update - h))
          P0.update = tf$stop_gradient(P0 - tf$einsum('bip,bpq,bjq->bij', G0, P - P.update, G0))

          nce = - tf$reduce_mean(CE.mvnorm(h.update - h - tf$einsum('bij,bj->bi',dhdx, x.update-x),
            tf$einsum('bip,bpq,bjq->bij',dhdx, P0.update, dhdx) + P.update - tf$einsum('bim,bqm,bjq->bij', P.update,G0,dhdx) - tf$einsum('bim,bqm,bjq->bji', P.update,G0,dhdx),
            tf$zeros_like(h),
            tf$matmul(dhdW, dhdW, transpose_b=TRUE) + tf$linalg$diag(1e-3*tf$ones_like(P[1,1,]))))
        })

        list(list(x.update, P0.update), g2$gradient(nce, variables))
      }
      list(list(h, P), grad_fn)
    })
  },
  call = function(inputs){
    self$`__forward__`(inputs)
  },
  compute_output_shape = function(input_shape){
    super()$compute_output_shape(input_shape)
  }
)

CategoricalObsModel = PyClass(classname = 'CategoricalObs', defs = list(
  inputs = NULL,
  shape = NULL,
  dtype=tf$float32,
  `__log.prob__` = NULL,
  `__init__` = function(self, inputs) {
    self$inputs = inputs
    self$shape = inputs[[1]]$shape
    self$`__log.prob__` = tf$custom_gradient(f = function(h, P, y){
      message('trace CategoricalObsModel forward ....')
      inv.Z = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), tf$einsum('bi->b',tf$math$exp(h)))
      p = tf$einsum('bi,b->bi', tf$math$exp(h), inv.Z) # prediction y
      dydx = tf$stop_gradient(tf$linalg$diag(p) - tf$einsum('bi,bj->bij', p, p))
      S = tf$einsum('bip,bpq,bjq->bij',dydx, P, dydx) + tf$einsum('b,ij->bij',tf$ones_like(inv.Z), diag(rep(.14,10))) #+ tf$linalg$diag( (tf$sign(.2 - tf$abs(y-p))+1)/2*.14 ) + diag(rep(.01, 10))# + diag(rep(.14,10))
      log.prob =   tf_probability()$distributions$MultivariateNormalFullCovariance( loc=p, covariance_matrix= S )$log_prob( y ) #
      grad_fn = function(dy, variables=NULL){
        message('trace CategoricalObsModel backward ....')
        # tf$assert_less(tf$abs(dy + 1.0), tf$constant(1e-3, dtype=tf$float32))
        K = tf$einsum('bip,bqp,bqj->bij', P, dydx, tf$linalg$pinv(S))
        h.update = tf$stop_gradient(h + tf$einsum('bij,bj->bi',K, y - p))
        P.update = tf$stop_gradient(P - tf$einsum('bip,bpq,bqj->bij', K, dydx, P))
        list(h.update, P.update, tf$ones_like(y))
      }

      list(log.prob, grad_fn)
    })
    NULL
  },
  `value` = function(self){
    h = tf_probability()$distributions$MultivariateNormalFullCovariance(loc=self$inputs[[1]], covariance_matrix=self$inputs[[2]])$sample()
    inv.Z = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), tf$einsum('bi->b',tf$math$exp(h)))
    tf$einsum('bi,b->bi', tf$math$exp(h), inv.Z) # prediction y
  },
  loss = function(self, y){
    - self$`__log.prob__`(self$inputs[[1]], self$inputs[[2]], y)
  }
))

tf$register_tensor_conversion_function(CategoricalObsModel, function(x,...)x$`value`())
tf$keras$`__internal__`$utils$register_symbolic_tensor_type(CategoricalObsModel)
In [ ]:
layer_dense_Bayesian_autodiff <- Layer(
  classname = "BayesDense",
  inherit = keras$layers$Layer,
  initialize = function(units = 7, f = tf$sigmoid, ...){
    super$initialize(...)
    self$units = as.integer(units)
    self$f = f
  },
  build = function(input_shape){
    #print(input_shape)
    super$build(input_shape)
    self$A = self$add_weight(name='A',shape=list(input_shape[[1]][[2]], self$units), initializer='normal', trainable=TRUE)
    self$B = self$add_weight(name='B',shape=list(input_shape[[1]][[2]], self$units), initializer='uniform', trainable=TRUE)
    self$W = tf$zeros_like(self$B[,1])
    self$`__forward__` = function(inputs){
      message('trace layer_dense_Bayesian_autodiff forward ....')
      x = inputs[[1]]
      P0 = inputs[[2]]
      with(tf$GradientTape(persistent=TRUE) %as% g2, {
        with(tf$GradientTape(persistent=TRUE) %as% g, {
          g$watch(list(x,self$W))
          B.W = tf$einsum('ij,i->ij',self$B, self$W)
          h = self$f(tf$matmul(x, (self$A + B.W)))
          `__h__a` = self$f(tf$matmul(x, (self$A + B.W)))
          `__h__b` = self$f(tf$matmul(tf$stop_gradient(x), (self$A + B.W)))
        })

        dhdW = g$jacobian(`__h__b`, self$W)
        dhdx = tf$stop_gradient(g$batch_jacobian(`__h__a`, x))
        P = tf$matmul(tf$matmul(dhdx, P0), dhdx, transpose_b=TRUE) + tf$matmul(dhdW, dhdW, transpose_b=TRUE)
        P = P + tf$linalg$diag(1e-3*tf$ones_like(tf$linalg$diag_part(P[1,,])))
      })

      list(h, P)
    }
  },
  call = function(inputs){
    self$`__forward__`(inputs)
  },
  compute_output_shape = function(input_shape){
    super()$compute_output_shape(input_shape)
  }
)

CategoricalObsModel_autodiff = PyClass(classname = 'CategoricalObs', defs = list(
  inputs = NULL,
  shape = NULL,
  dtype=tf$float32,
  `__log.prob__` = NULL,
  `__init__` = function(self, inputs) {
    self$inputs = inputs
    self$shape = inputs[[1]]$shape
    self$`__log.prob__` = function(h, P, y){
      message('trace CategoricalObsModel_autodiff forward ....')
      inv.Z = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), tf$einsum('bi->b',tf$math$exp(h)))
      p = tf$einsum('bi,b->bi', tf$math$exp(h), inv.Z) # prediction y
      dydx = tf$stop_gradient(tf$linalg$diag(p) - tf$einsum('bi,bj->bij', p, p))
      S = tf$einsum('bip,bpq,bjq->bij',dydx, P, dydx) + tf$einsum('b,ij->bij',tf$ones_like(inv.Z), diag(rep(.14,10))) #+ tf$linalg$diag( (tf$sign(.2 - tf$abs(y-p))+1)/2*.14 ) + diag(rep(.01, 10))# + diag(rep(.14,10))
      log.prob =   tf_probability()$distributions$MultivariateNormalFullCovariance( loc=p, covariance_matrix= S )$log_prob( y ) #

      log.prob
    }
    NULL
  },
  `value` = function(self){
    h = tf_probability()$distributions$MultivariateNormalFullCovariance(loc=self$inputs[[1]], covariance_matrix=self$inputs[[2]])$sample()
    inv.Z = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), tf$einsum('bi->b',tf$math$exp(h)))
    tf$einsum('bi,b->bi', tf$math$exp(h), inv.Z) # prediction y
  },
  loss = function(self, y){
    - self$`__log.prob__`(self$inputs[[1]], self$inputs[[2]], y)
  }
))#, inherit = tf$Module)

#a = CategoricalObsModel(layer_dense_Bayesian(units=10)(inputs))
#a$value()

tf$register_tensor_conversion_function(CategoricalObsModel_autodiff, function(x,...)x$`value`())
tf$keras$`__internal__`$utils$register_symbolic_tensor_type(CategoricalObsModel_autodiff)
In [ ]:
inputsBayesian <- list(layer_input(shape = 784L), layer_input(shape=tuple(784L)))
predictionsBayesian <- inputsBayesian %>%
  layer_dense_Bayesian(units=7) %>%
  layer_dense_Bayesian(units=7) %>%
  layer_dense_Bayesian(units=10, f = tf$identity) %>%
  layer_lambda(function(inputs) CategoricalObsModel(inputs))
modelBayesian <- keras_model(inputs = inputsBayesian, outputs = predictionsBayesian)

mylossBayesian = function(y, y_hat){
    y_hat$loss(y)
}
modelBayesian %>% compile(loss=mylossBayesian, optimizer=tf$keras$optimizers$SGD(learning_rate=.03), metrics='accuracy')

predictionsBayesian_autodiff = inputsBayesian %>%
  layer_dense_Bayesian_autodiff(units=7) %>%
  layer_dense_Bayesian_autodiff(units=7) %>%
  layer_dense_Bayesian_autodiff(units=10, f = tf$identity) %>%
  layer_lambda(function(inputs) CategoricalObsModel_autodiff(inputs))
modelBayesian_autodiff = keras_model(inputs = inputsBayesian, outputs = predictionsBayesian_autodiff)
modelBayesian_autodiff %>% compile(loss=mylossBayesian, optimizer=tf$keras$optimizers$SGD(learning_rate=.03), metrics='accuracy')
invisible(lapply(1:6, function(n) modelBayesian_autodiff$trainable_variables[[n]]$assign(modelBayesian$trainable_variables[[n]]) ))
tracking layer_dense_Bayesian forward ....

tracking layer_dense_Bayesian forward ....

tracking layer_dense_Bayesian forward ....

tracking layer_dense_Bayesian forward ....

tracking layer_dense_Bayesian forward ....

tracking layer_dense_Bayesian forward ....

In [ ]:
n = 32L
indices = sample(dim(mnist$train$x)[1], size = n)
x = tf$constant(matrix(mnist$train$x[indices, , ] / 255.0, nrow = n))
y = tf$one_hot(mnist$train$y[indices], depth = 10L)
P0 = tf$einsum('b,ij->bij', tf$ones_like(x[,1]), tf$constant(diag(rep(1e-4, 784)), dtype=tf$float64) )

with(tf$GradientTape() %as% tape, {
    y_predBayesian = modelBayesian$call(list(x,P0))
    lossBayesian = tf$reduce_mean(mylossBayesian(y, y_predBayesian))
})
gradientsBayesian = tape$gradient(lossBayesian, modelBayesian$trainable_variables)

with(tf$GradientTape() %as% tape, {
  y_predBayesian_autodiff = modelBayesian_autodiff$call(list(x,P0))
  lossBayesian_autodiff = tf$reduce_mean(mylossBayesian(y, y_predBayesian_autodiff))
})
gradientsBayesian_autodiff = tape$gradient(lossBayesian_autodiff, modelBayesian_autodiff$trainable_variables)
tracking layer_dense_Bayesian forward ....

tracking layer_dense_Bayesian forward ....

tracking layer_dense_Bayesian forward ....

trace CategoricalObsModel forward ....

trace CategoricalObsModel backwards ....

tracking layer_dense_Bayesian backward ....

tracking layer_dense_Bayesian backward ....

tracking layer_dense_Bayesian backward ....

tracking layer_dense_Bayesian forward ....

tracking layer_dense_Bayesian forward ....

tracking layer_dense_Bayesian forward ....

trace CategoricalObsModel forward ....

In [ ]:
par('mfrow')
options('repr.plot.width')
options('repr.plot.height')
  1. 1
  2. 1
$repr.plot.width = 7
$repr.plot.height = 7

The following shows that no matter whether we use Tensorflow's autodiff or we back propagate smoothing densities, we get the same gradient of loss with respect to the variational mean and covariance of weight parameters.

In [ ]:
#str(gradientsBayesian_autodiff)
#str(gradientsBayesian)

require(repr)
repr.plot.opt = options(repr.plot.width=11, repr.plot.height=16)
par.opt = par(mfrow=c(4,3))

for(n in 1:3){
  plot(gradientsBayesian_autodiff[[n*2-1]]$numpy(), gradientsBayesian[[n*2-1]]$numpy(), xlab='autodiff', ylab='custom_gradient', main=sprintf('d W_loc (layer %d)', n),asp=1)
  abline(coef=c(0,1),col='red')
}
for(n in 1:3){
  plot(gradientsBayesian_autodiff[[n*2]]$numpy(), gradientsBayesian_autodiff[[n*2]]$numpy(), xlab='autodiff', ylab='custom_gradient', main=sprintf('d W_scale (layer %d)', n),asp=1)
  abline(coef=c(0,1),col='red')
}

for(n in 1:3){
  image(x = seq(dim(gradientsBayesian[[n*2-1]])[1]*2), y=seq(dim(gradientsBayesian[[n*2-1]])[2]), z = rbind(gradientsBayesian[[n*2-1]]$numpy(), gradientsBayesian_autodiff[[n*2-1]]$numpy()), main=sprintf('d W_loc (layer %d)',n), xlab='1:dim(input)', ylab='1:dim(output)', xaxt='n')
  abline(v=dim(gradientsBayesian[[n*2-1]])[1]+.5,col='black')
  axis(side=1, at = seq(dim(gradientsBayesian[[n*2-1]])[1]*2), label=rep(seq(dim(gradientsBayesian[[n*2-1]])[1]), times=2))
}

for(n in 1:3){
  image(x = seq(dim(gradientsBayesian[[n*2]])[1]*2), y=seq(dim(gradientsBayesian[[n*2]])[2]), z = rbind(gradientsBayesian[[n*2]]$numpy(), gradientsBayesian_autodiff[[n*2]]$numpy()), main=sprintf('d W_scale (layer %d)',n), xlab='1:dim(input)', ylab='1:dim(output)', xaxt='n')
  abline(v=dim(gradientsBayesian[[n*2]])[1]+.5,col='black')
  axis(side=1, at = seq(dim(gradientsBayesian[[n*2]])[1]*2), label=rep(seq(dim(gradientsBayesian[[n*2]])[1]), times=2))
}

par(par.opt)
options(repr.plot.opt)
Image

While we can find the gradient w.r.t. parameters of weight varioual posterior from the smoothing distribution of weight, treating weight as latent features, we compute the gradient directly from smoothing densities of latent features in our layer_dense_Bayesian, with a hope to provide some deeper insights between back propagating sensitivity wrt functions vs. sensitivity wrst distributions.

nce = - tf$reduce_mean(CE.mvnorm(h.update - h - tf$einsum('bij,bj->bi',dhdx, x.update-x),
  tf$einsum('bip,bpq,bjq->bij',dhdx, P0.update, dhdx) + P.update - tf$einsum('bim,bqm,bjq->bij', P.update,G0,dhdx) - tf$einsum('bim,bqm,bjq->bji', P.update,G0,dhdx),
  tf$zeros_like(h),
  tf$matmul(dhdW, dhdW, transpose_b=TRUE) + tf$linalg$diag(1e-3*tf$ones_like(P[1,1,]))))
})

# gradient wrt weight distributions parameters is `tape$gradient(nce, variables)`

The derivation is below, where we treat a deep neural network as a Gaussian linear process, and use Kalman filter/smoother to backpropagate sensitivity over filter distributions.

  • $x_t = F_t x_{t-1} + w_t$, where $w_t\sim {\cal N}(\vec 0, Q_t)$
  • $z_t = H_t x_t + v_t$, where $v_t\sim {\cal N}(\vec 0, R_t)$

Lower bound of (latent variable) log likelihood is negative cross entropy + entropy over some proposal distribution. $\log p(y) = \log \sum_x \log p(x,y) \ge \mathbf E_{q(x)} \log p(x,y) - \mathbf E_{q(x)} \log q(x)$. The bound is tight because equality holds when $q(x)=p(x|y)$.

The cross-entropy between two probability distributions $p$ and $q$ measures the average number of bits needed to identify an event drawn from the true distribution $p$ with the coding scheme optimized for an estimated probability distribution $q$. $H(p, q) = - \mathbf E_p(\log q) = H(p)+D_{\mathrm{KL}}(p \| q)$. The K-L divergence between two multivariate normal distributions is $D_{\mathrm{KL}}\left(\mathcal{N}(\mu_{0},\Sigma_0) \| \mathcal{N}(\mu_{1},\Sigma_1)\right)=\frac{1}{2}\left\{\operatorname{tr}\left(\boldsymbol{\Sigma}_{1}^{-1} \boldsymbol{\Sigma}_{0}\right)+\left(\boldsymbol{\mu}_{1}-\boldsymbol{\mu}_{0}\right)^{\mathrm{T}} \boldsymbol{\Sigma}_{1}^{-1}\left(\boldsymbol{\mu}_{1}-\boldsymbol{\mu}_{0}\right)-k+\ln \frac{\left|\boldsymbol{\Sigma}_{1}\right|}{\left|\boldsymbol{\Sigma}_{0}\right|}\right\}$. The entropy of a multivariate normal distribution is $\frac{1}{2} \ln \operatorname{det}(2 \pi e \Sigma)$. So

$$\begin{aligned} H\left(\mathcal{N}(\mu_{0},\Sigma_0), \mathcal{N}(\mu_{1},\Sigma_1)\right) = & \frac{1}{2}\left\{\operatorname{tr}\left(\boldsymbol{\Sigma}_{1}^{-1} \boldsymbol{\Sigma}_{0}\right)+\left(\boldsymbol{\mu}_{1}-\boldsymbol{\mu}_{0}\right)^{\mathrm{T}} \boldsymbol{\Sigma}_{1}^{-1}\left(\boldsymbol{\mu}_{1}-\boldsymbol{\mu}_{0}\right)-k+\ln \frac{\operatorname{det}(\boldsymbol{\Sigma}_{1})}{\operatorname{det}(\boldsymbol{\Sigma}_{0})} + \ln \operatorname{det}(2 \pi e \Sigma_0) \right\}\\ = & \frac{1}{2}\left\{\operatorname{tr}\left(\boldsymbol{\Sigma}_{1}^{-1} \boldsymbol{\Sigma}_{0}\right)+\left(\boldsymbol{\mu}_{1}-\boldsymbol{\mu}_{0}\right)^{\mathrm{T}} \boldsymbol{\Sigma}_{1}^{-1}\left(\boldsymbol{\mu}_{1}-\boldsymbol{\mu}_{0}\right) + \ln \operatorname{det}(2\pi \boldsymbol{\Sigma}_{1}) \right\} \end{aligned}$$

Expected log likelihood of $P(X_{1,\dots,T},Z_{1,\dots,T})$ over posterior distribution $P(X_{1,\dots,T} | Z_{1,\dots,T})$ is

$$\begin{aligned} & \mathbf{E}_{p(X_{1,\dots,T}|Z_{1,\dots,T})}\log P(X_{1,\dots,T},Z_{1,\dots,T}) = \log P(X_{1}) + \sum_{t=2}^{T} \log P(X_{t}|X_{t-1}) + \sum_{t=1}^{T} \log P(Z_{t}|X_{t})\\ = & {-\frac{1}{2}}\log\det(2\pi\Sigma_{0})-{1\over 2}\mathbf E_{p(x_{1}|z_{1,\dots,T})}\left(X_{1}-\mu_{0}\right)^{\text{T}}\Sigma_{0}^{-1}\left(X_{1}-\mu_{0}\right)+\\ & \sum_{t=2}^{T}{-\frac{1}{2}}\log\det(2\pi Q_t)-{1\over 2}\mathbf E_{p(x_{t-1,t}|z_{1,\dots,T})}\left(X_{t}-F_t\cdot X_{t-1}\right)^{\text{T}}Q^{-1}_t\left(X_{t}-F_t\cdot X_{t-1}\right)+\\ & \sum_{t=1}^{T}{-\frac{1}{2}}\log\det(2\pi R_t)-{1\over 2}\mathbf E_{p(x_{t}|z_{1,\dots,T})}\left(Z_{t}-H_t\cdot X_{t}\right)^{\text{T}}R^{-1}_t\left(Z_{t}-H_t\cdot X_{t}\right) \\ = & - H\left(\mathcal N(\hat x_{1|T}, P_{1|T}), \mathcal N(\mu_0, \Sigma_0)\right) + \\ & - H\left(\mathcal N\left(\hat x_{t|T} - F_t \hat x_{t-1|T}, \left(\begin{matrix} -F_t & I\end{matrix}\right) P_{t-1,t|T} \left(\begin{matrix}-F_t\\I\end{matrix}\right)\right), \mathcal N(X_t - F_t X_{t-1}; 0, Q_t)\right) + \\ & - H\left(\mathcal N(Z_t - H_t \hat x_{t|T}, H_t P_{t|T} H_t^\top), \mathcal N(Z_t - H_t X_t; 0, R_t)\right), \end{aligned}$$

where $P_{t,t+1|T} = \left(\begin{array}{c|c} P_{t|t}+P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}\left(P_{t+1|T}-P_{t+1|t}\right)P_{t+1|t}^{-1}F_{t+1}P_{t|t} & P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}P_{t+1|T}\\ P_{t+1|T}P_{t+1|t}^{-1}F_{t+1}P_{t|t} & P_{t+1|T} \end{array}\right) = \left(\begin{array}{c|c} P_{t|T} & G_{t}P_{t+1|T}\\ P_{t+1|T}G_{t}^{\top} & P_{t+1|T} \end{array}\right)$

and hence

$\left(\begin{matrix}-F_{t+1} & I\end{matrix}\right)P_{t,t+1|T}\left(\begin{matrix}-F_{t+1}^{\top}\\ I \end{matrix}\right)=\left(\begin{matrix}-F_{t+1} & I\end{matrix}\right)\left(\begin{array}{c|c} P_{t|T} & G_{t}P_{t+1|T}\\ P_{t+1|T}G_{t}^{\top} & P_{t+1|T} \end{array}\right)\left(\begin{matrix}-F_{t+1}^{\top}\\ I \end{matrix}\right) = F_{t+1}P_{t|T}F_{t+1}^{\top} -F_{t+1}G_{t}P_{t+1|T} -P_{t+1|T}G_{t}^{\top}F_{t+1}^{\top} + P_{t+1|T}$

Derivation: Kalman smoother $X_{t|T}=\hat{x}_{t|t}+P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}\left(x_{t+1|T}-F_{t+1}\hat{x}_{t|t}\right)+\epsilon$ where $\left(x_{t+1|T}-F_{t+1}\hat{x}_{t|t}\right)$ is the innovation brought forth by $x_{t+1|T}$, $P_{t+1|t}=F_{t+1}P_{t|t}F_{t+1}^{\top}+Q_{t+1}$ is the covariance of innovation, $P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}$ is the gain in smoothing, $\epsilon\sim\mathcal{N}(0,P_{t|t}-P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}F_{t+1}P_{t|t})$ is the noise term to sample $X_{t|T}$ conditional on $X_{t+1|T}$ as $in p(X_t|x_{t+1},y_1,\dots,y_T)$, and $X_{t+1|T}\sim\mathcal{N}(\hat{x}_{t+1|T},P_{t+1|T})$. The joint distribution is

$$\begin{aligned} \left(\begin{array}{c} \mathbf{x}_{t|T}\\ \mathbf{x}_{t+1|T} \end{array}\right) & \sim\mathcal{N}\left(\left(\begin{array}{c} \hat{\mathbf{x}}_{t|t}+P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}\left(\hat{\mathbf{x}}_{t+1|T}-\hat{\mathbf{x}}_{t+1|t}\right)\\ \hat{\mathbf{x}}_{t+1|T} \end{array}\right),\left(\begin{array}{c|c} P_{t|t}-P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}\left(P_{t+1|T}-P_{t+1|t}\right)P_{t+1|t}^{-1}F_{t+1}P_{t|t} & P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}P_{t+1|T}\\ P_{t+1|T}P_{t+1|t}^{-1}F_{t+1}P_{t|t} & P_{t+1|T} \end{array}\right)\right),\\ & \text{where }P_{t+1|t}=F_{t+1}P_{t|t}F_{t+1}^{\top}+Q_{t+1},\hat{\mathbf{x}}_{t+1|t}=F_{t+1}\hat{x}_{t|t}. \end{aligned}$$

In [ ]:
modelBayesian2 = keras_model(inputs = inputsBayesian,
  outputs = c(list(inputsBayesian), lapply(modelBayesian$layers[-2:-1], function(n) n$output)))

with(tf$GradientTape() %as% tape, {
  #tape$watch(inputsBayesian$)
  y_predBayesian = modelBayesian2$call(list(x,P0))
  tape$watch(y_predBayesian[1:4])
  lossBayesian = tf$reduce_mean(mylossBayesian(y, y_predBayesian[[5]]))
})
sensitivityBayesian = tape$gradient(lossBayesian, y_predBayesian[1:4])

modelBayesian_autodiff2 = keras_model(inputs = inputsBayesian,
  outputs = c(list(inputsBayesian), lapply(modelBayesian_autodiff$layers[-2:-1], function(n) n$output)))

with(tf$GradientTape() %as% tape, {
  tape$watch(P0)
  y_predBayesian_autodiff = modelBayesian_autodiff2$call(list(x,P0))
  tape$watch(y_predBayesian_autodiff[1:4])
  lossBayesian_autodiff = tf$reduce_mean(mylossBayesian(y, y_predBayesian_autodiff[[5]]))
})
sensitivityBayesian_autodiff = tape$gradient(lossBayesian_autodiff, y_predBayesian_autodiff[1:4])

x.update = lapply(1:4, function(n) y_predBayesian_autodiff[[n]][[1]] + tf$einsum('bij,bi->bj', y_predBayesian_autodiff[[n]][[2]], sensitivityBayesian_autodiff[[n]][[1]]))

P.update = lapply(1:4, function(n) y_predBayesian_autodiff[[n]][[2]] -
  tf$einsum('bij,bjk,blk->bil', y_predBayesian_autodiff[[n]][[2]], sensitivityBayesian_autodiff[[n]][[2]], y_predBayesian_autodiff[[n]][[2]]) -
  tf$einsum('bi,bj->bij', x.update[[n]] - y_predBayesian_autodiff[[n]][[1]], x.update[[n]] - y_predBayesian_autodiff[[n]][[1]]))
tracking layer_dense_Bayesian forward ....

tracking layer_dense_Bayesian forward ....

tracking layer_dense_Bayesian forward ....

trace CategoricalObsModel forward ....

trace CategoricalObsModel backwards ....

tracking layer_dense_Bayesian backward ....

tracking layer_dense_Bayesian backward ....

tracking layer_dense_Bayesian backward ....

tracking layer_dense_Bayesian forward ....

tracking layer_dense_Bayesian forward ....

tracking layer_dense_Bayesian forward ....

trace CategoricalObsModel forward ....

The following shows that no matter whether we use Tensorflow's autodiff or we back propagate smoothing densities, we get the same latent features smoothing densities. Because there is a 1-1 correspondence between smoothing densities and sensitivity of latent feature filter density parameters, both message passing and back propagation give the same sensitivities.

In [ ]:
require(repr)
repr.plot.opt = options(repr.plot.width=16, repr.plot.height=4)
par.opt = par(mfcol=c(1,4))

for(n in 1:4){
    plot(x.update[[n]]$numpy(), sensitivityBayesian[[n]][[1]]$numpy(), xlab='autodiff', ylab='custom_gradient', main='smoothing mean',asp=1)
    abline(coef=c(0,1),col='red')
}
for(n in 1:4){
    plot(P.update[[n]]$numpy(), sensitivityBayesian[[n]][[2]]$numpy(), xlab='autodiff', ylab='custom_gradient', main='smoothing var',asp=1)
    abline(coef=c(0,1),col='red')
}
par(par.opt)
options(repr.plot.opt)
Loading required package: repr

Image
Image
In [ ]:
rm(list=ls())

Duality between Langevin dynamics and Fokker-Planck dynamics¶

In this section, we show that we can forward propagate the filter densities of post-activations and backpropagate the smoothing densities in a Bayesian neural network, and we can also forward propagate tensor post-activation "particles" in an ensembled of vanilla (non-Bayesian) neural networks sampled from the BNN and the sensitivities of these tensor particles. These two approximate methods should ultimately agree on the gradient of loss with respect to weight parameters, if variational bias and MC variance goes to 0.

We also show that the smoothing densities contain information about the sensitivity of tensor post-activation particles --- how they should move to decrease the loss. In other words, we can estimate the sensitivity of the post-activations of a vanilla NN from the smooth densities of a Bayesian NN.

We use this example for two purposes:

  • to sanity-check Theorem 2 in the paper, and
  • to illustrate that the power of a deep neural network lies in the randomness in the synaptic weights. The information between input and output rides on the random synaptic weights.
In [ ]:
require(tensorflow)
require(keras)
require(reticulate)
require(tfprobability)
require(coro)

mnist = dataset_mnist()

n = 32L
indices = 1:n
x = tf$constant(matrix(mnist$train$x[indices,,]/255.0, nrow = n), dtype=tf$float32)
y = tf$one_hot(mnist$train$y[indices], depth = 10L)
P0 = tf$einsum('b,ij->bij', tf$ones_like(x[,1]), diag(rep(1e-4, 784)))
#P0 = tf$einsum('b,ij->bij', tf$ones_like(x[,1]), diag(rep(0, 784)))
Loading required package: tensorflow

Loading required package: keras

Loading required package: reticulate

Loading required package: tfprobability

Loading required package: coro


Attaching package: ‘coro’


The following object is masked from ‘package:reticulate’:

    as_iterator


In [ ]:
layer_dense_Bayesian = Layer(
  classname = 'BayesianDense',
  inherit = tf$keras$layers$Dense,
  initialize = function(units = 7L, activation, name=NULL) {
    super$initialize(units=units, activation=activation, use_bias=TRUE, name=name)
    self$input_spec = list(tf$keras$layers$InputSpec(min_ndim = 2L),tf$keras$layers$InputSpec(min_ndim = 3L))
  },

  build = function(input_shape) {
    c(shp.mean, shp.var) %<-% input_shape
    last_dim = shp.mean[[-1]]
    self$A = self$add_weight(name='A',shape=list(last_dim, self$units), initializer='normal', trainable=TRUE)
    self$B = self$add_weight(name='B',shape=list(last_dim, self$units), initializer=tf$initializers$RandomUniform(minval = -.01, maxval = .01), trainable=TRUE)
    self$W = tf$zeros_like(self$B[,1])
    self$built=TRUE
    NULL
  },

  call = function(inputs) {
    c(x, P0) %<-% inputs
    with(tf$GradientTape(persistent=TRUE) %as% g, {
      self$bias = tf$zeros_like(self$B[1,])
      g$watch(list(x, self$W, self$bias))
      self$kernel = self$A + tf$einsum('ij,i->ij',self$B, self$W) # mvnorm with mean and variance
      h = super$call(x)
      `__h__a` = super$call(x)
      `__h__b` = super$call(tf$stop_gradient(x))
    })
    # propagate variance (accept a variance, first order approximation of state transition, output a variance)
    dhdW = g$jacobian(`__h__b`, self$W)
    dhdx = tf$stop_gradient(g$batch_jacobian(`__h__a`, x))
    dhdB = tf$stop_gradient(g$jacobian(`__h__a`, self$bias))
    P = tf$matmul(tf$matmul(dhdx, P0), dhdx, transpose_b=TRUE) + tf$matmul(dhdW, dhdW, transpose_b=TRUE) + tf$linalg$diag(tf$einsum('bij,j->bi',dhdB^2, 1e-4*tf$ones_like(self$bias)))
    list(h, P)
  }
)

CategoricalObsModel_autodiff = PyClass(classname = 'CategoricalObs', defs = list(
  inputs = NULL,
  shape = NULL,
  dtype=tf$float32,
  `__log.prob__` = NULL,
  `__init__` = function(self, inputs) {
    self$inputs = inputs
    self$shape = inputs[[1]]$shape
    self$`__log.prob__` = function(h, P, y){
      message('trace CategoricalObsModel forward ....')
      inv.Z = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), tf$einsum('bi->b',tf$math$exp(h)))
      p = tf$einsum('bi,b->bi', tf$math$exp(h), inv.Z) # prediction y
      dydx = tf$stop_gradient(tf$linalg$diag(p) - tf$einsum('bi,bj->bij', p, p))
      S = tf$einsum('bip,bpq,bjq->bij',dydx, P, dydx) + tf$einsum('b,ij->bij',tf$ones_like(inv.Z), diag(rep(.14,10))) #+ tf$linalg$diag( (tf$sign(.2 - tf$abs(y-p))+1)/2*.14 ) + diag(rep(.01, 10))# + diag(rep(.14,10))
      log.prob =   tf_probability()$distributions$MultivariateNormalFullCovariance( loc=p, covariance_matrix= S )$log_prob( y ) #

      log.prob
    }
    NULL
  },
  `value` = function(self){
    h = tf_probability()$distributions$MultivariateNormalFullCovariance(loc=self$inputs[[1]], covariance_matrix=self$inputs[[2]])$sample()
    inv.Z = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), tf$einsum('bi->b',tf$math$exp(h)))
    tf$einsum('bi,b->bi', tf$math$exp(h), inv.Z) # prediction y
  },
  loss = function(self, y){
    - self$`__log.prob__`(self$inputs[[1]], self$inputs[[2]], y)
  }
))

tf$register_tensor_conversion_function(CategoricalObsModel_autodiff, function(x,...)x$`value`())
tf$keras$`__internal__`$utils$register_symbolic_tensor_type(CategoricalObsModel_autodiff)
In [ ]:
inputsBayesian <- list(layer_input(shape = 784L, name='x0'), layer_input(shape=tuple(784L, 784L), name='P0'))
predictionsBayesian <- inputsBayesian %>%
  layer_dense_Bayesian(units=7L, activation = tf$nn$sigmoid, name='bayesian_dense_1') %>%
  layer_dense_Bayesian(units=7L, activation = tf$nn$sigmoid, name='bayesian_dense_2') %>%
  layer_dense_Bayesian(units=10L, activation = tf$identity, name='bayesian_dense_3') %>%
  layer_lambda(function(inputs) CategoricalObsModel_autodiff(inputs), name='y_pred')
predictionsBayesian$`_type_spec` = tf$TensorSpec(shape=predictionsBayesian$type_spec$shape, dtype=predictionsBayesian$type_spec$dtype)
modelBayesian = keras_model(inputs = inputsBayesian, outputs = predictionsBayesian)

mylossBayesian = function(y, y_hat){
  y_hat$loss(y)
}
modelBayesian %>% compile(loss=mylossBayesian, optimizer='adam', metrics='accuracy')

#invisible(lapply(grep("bayesian_dense_[0-9]+/B:[0-9]+", sapply(modelBayesian$weights, function(n) n$name)), function(n){
#  modelBayesian$weights[[n]]$assign( modelBayesian$weights[[n]]/tf$reduce_max(tf$abs(modelBayesian$weights[[n]]))*.05 )
#}))
In [ ]:
modelBayesian2 = keras_model(inputs = inputsBayesian, outputs = c(list(inputsBayesian), lapply(modelBayesian$layers[-2:-1], function(n) n$output)))
modelBayesian2 %>% compile(loss=mylossBayesian, optimizer='adam', metrics='accuracy')

with(tf$GradientTape() %as% tape, {
  tape$watch(list(x, P0))
  y_predBayesian = modelBayesian2$call(list(x,P0))
  tape$watch(y_predBayesian[1:4])
  lossBayesian = y_predBayesian[[5]]$loss(y)
})
sensitivityBayesian = tape$gradient(lossBayesian, y_predBayesian[1:4])

str(y_predBayesian)
str(sensitivityBayesian)
modelBayesian2
trace CategoricalObsModel forward ....

List of 5
 $ :List of 2
  ..$ :<tf.Tensor: shape=(32, 784), dtype=float32, numpy=…>
  ..$ :<tf.Tensor: shape=(32, 784, 784), dtype=float32, numpy=…>
 $ :List of 2
  ..$ :<tf.Tensor: shape=(32, 7), dtype=float32, numpy=…>
  ..$ :<tf.Tensor: shape=(32, 7, 7), dtype=float32, numpy=…>
 $ :List of 2
  ..$ :<tf.Tensor: shape=(32, 7), dtype=float32, numpy=…>
  ..$ :<tf.Tensor: shape=(32, 7, 7), dtype=float32, numpy=…>
 $ :List of 2
  ..$ :<tf.Tensor: shape=(32, 10), dtype=float32, numpy=…>
  ..$ :<tf.Tensor: shape=(32, 10, 10), dtype=float32, numpy=…>
 $ :<CategoricalObs object at 0x7f37d04a1580>
List of 4
 $ :List of 2
  ..$ :<tf.Tensor: shape=(32, 784), dtype=float32, numpy=…>
  ..$ :<tf.Tensor: shape=(32, 784, 784), dtype=float32, numpy=…>
 $ :List of 2
  ..$ :<tf.Tensor: shape=(32, 7), dtype=float32, numpy=…>
  ..$ :<tf.Tensor: shape=(32, 7, 7), dtype=float32, numpy=…>
 $ :List of 2
  ..$ :<tf.Tensor: shape=(32, 7), dtype=float32, numpy=…>
  ..$ :<tf.Tensor: shape=(32, 7, 7), dtype=float32, numpy=…>
 $ :List of 2
  ..$ :<tf.Tensor: shape=(32, 10), dtype=float32, numpy=…>
  ..$ :<tf.Tensor: shape=(32, 10, 10), dtype=float32, numpy=…>
Model: "model_1"
________________________________________________________________________________
 Layer (type)             Output Shape      Param #  Connected to               
================================================================================
 x0 (InputLayer)          [(None, 784)]     0        []                         
 P0 (InputLayer)          [(None, 784, 784  0        []                         
                          )]                                                    
 bayesian_dense_1 (Bayesi  [(None, 7),      10976    ['x0[0][0]',               
 anDense)                  (None, 7, 7)]              'P0[0][0]']               
 bayesian_dense_2 (Bayesi  [(None, 7),      98       ['bayesian_dense_1[0][0]', 
 anDense)                  (None, 7, 7)]              'bayesian_dense_1[0][1]'] 
 bayesian_dense_3 (Bayesi  [(None, 10),     140      ['bayesian_dense_2[0][0]', 
 anDense)                  (None, 10, 10)]            'bayesian_dense_2[0][1]'] 
 y_pred (Lambda)          (None, 10)        0        ['bayesian_dense_3[0][0]', 
                                                      'bayesian_dense_3[0][1]'] 
================================================================================
Total params: 11,214
Trainable params: 11,214
Non-trainable params: 0
________________________________________________________________________________

In the following, we sample an ensemble of 64 (non-Bayesian) nerual networks from the Bayesian neural network.

In [ ]:
layer_dense_linearization = Layer(
  classname = 'DenseLinearization',
  inherit = tf$keras$layers$Layer,
  initialize = function(activation=NULL, ...) {
    super$initialize()
    self$dense = tf$keras$layers$Dense(activation=NULL, ...)
    self$activation = keras$activations$get(activation)
    py_set_attr(self, '_name', self$dense$name)
  },

  build = function(input_shape) {
    self$dense$build(input_shape)
    self$built=TRUE
  },

  call = function(inputs) {
    with(tf$GradientTape(persistent = TRUE) %as% tape, {
      tape$watch(inputs)
      h = self$dense$call(inputs)
      y0 = self$activation(h)
    })
    y = tf$stop_gradient(y0) + tf$stop_gradient(tape$gradient(y0, h)) * (h - tf$stop_gradient(h))
  }
)
In [ ]:
### sample vanilla NN from Bayesian NN, assuming weights are independent
inputs = layer_input(shape = 784L, name='input')
predictions = lapply(1:64, function(i){
  inputs %>%
    layer_lambda(function(x) tfp$distributions$MultivariateNormalDiag(loc = x, scale_identity_multiplier = 1e-2 )$sample(), name = sprintf('sample_%01d_0',i) ) %>%
    layer_dense_linearization(units=7, activation = 'sigmoid', use_bias = TRUE, name = sprintf('dense_%01d_1',i)) %>%
    layer_dense_linearization(units=7, activation = 'sigmoid', use_bias = TRUE, name = sprintf('dense_%01d_2',i)) %>%
    layer_dense_linearization(units=10, activation = NULL, use_bias = TRUE, name = sprintf('dense_%01d_3',i)) #%>%
  #layer_activation(activation = 'softmax', name=sprintf('softmax_%01d_3',i))
})
model = keras_model(inputs = inputs, outputs = predictions)

myloss = function(y, y_hat){
  p = tf$math$softmax(y_hat, axis=-1L)
  dydx = tf$linalg$diag(p) - tf$einsum('bi,bj->bij', p, p)
  y_pred = tf$stop_gradient(p) + tf$einsum('bij,bi->bj', tf$stop_gradient(dydx), y_hat - tf$stop_gradient(y_hat))

  - tf_probability()$distributions$MultivariateNormalDiag(loc=y_pred, scale_diag = tf$ones_like(y)*.14^.5)$log_prob(y)
}
model %>% compile(loss=myloss, optimizer=modelBayesian$optimizer, metrics='accuracy')

invisible(lapply(1:3, function(j){
  lapply(1:length(predictions), function(i){
    A = modelBayesian$get_layer(name = sprintf('bayesian_dense_%d',j))$weights[[1]]
    B = modelBayesian$get_layer(name = sprintf('bayesian_dense_%d',j))$weights[[2]]
    W = tf_probability()$distributions$MultivariateNormalDiag(loc=tf$zeros_like(B[,1]), scale_diag=tf$ones_like(B[,1]))$sample()
    model$get_layer(name = sprintf('dense_%01d_%01d',i,j))$weights[[1]]$assign(A + tf$einsum('ij,i->ij',B, W))
    # do not forget noise added to the latent layer, my bias is applied after activation, i.e., at the input side
    bias = tf_probability()$distributions$MultivariateNormalDiag(loc=tf$zeros_like(B[1,]), scale_diag=tf$ones_like(B[1,]))$sample() * 1e-2
    model$get_layer(name = sprintf('dense_%01d_%01d',i,j))$weights[[2]]$assign(bias)
  })
}))
In [ ]:
### find latent state and sensitivities
model2 = keras_model(
  inputs = inputs,
  outputs = c(x = model$get_layer('input')$output, split(lapply(model$layers[-1], function(l) l$output), gsub('(dense|sample)_[0-9]+_([0-9])', '\\2', sapply(model$layers[-1], function(n) n$name)))))

# `tf.keras.Model.output` gives structured model output, while `tf.keras.Model.outputs` gives a list of output tensors
# model2$output

with(tf$GradientTape(persistent = TRUE) %as% t2, {
  with(tf$GradientTape() %as% tape, {
    tape$watch(x)
    ret = model2$call(x)
    t2$watch(ret)
    tape$watch(ret)
    `__loss__` = sapply(1:length(ret[[5]]), function(i) {
      myloss(y, ret[[5]][[i]])
    })
  })
  sensitivity = tape$gradient(`__loss__`, ret)
})

#plot(tf$reduce_mean(`__loss__`, axis=0L), lossBayesian)
hessian = outer(2:length(ret), 1:length(ret[[2]]), Vectorize(function(m,n) t2$batch_jacobian(sensitivity[[m]][[n]], ret[[m]][[n]]) ))
In [ ]:
x.filter.vi = sapply(1L:4L, function(i){ y_predBayesian[[i]][[1]] })
P.filter.vi = sapply(1L:4L, function(i){ y_predBayesian[[i]][[2]] })
x.sensitivity.vi = sapply(1L:4L, function(i){ sensitivityBayesian[[i]][[1]] })
P.sensitivity.vi = sapply(1L:4L, function(i){ sensitivityBayesian[[i]][[2]] })

x.smoothing.vi = sapply(1:4, function(i){
  y_predBayesian[[i]][[1]] +tf$einsum('bim,bm->bi', y_predBayesian[[i]][[2]], sensitivityBayesian[[i]][[1]])
})

P.smoothing.vi = sapply(1:4, function(i){
  y_predBayesian[[i]][[2]] +
    tf$einsum('bij,bjk,blk->bil', y_predBayesian[[i]][[2]], sensitivityBayesian[[i]][[2]], y_predBayesian[[i]][[2]])*2 -
    tf$einsum('bim,bm,bn,bjn->bij', y_predBayesian[[i]][[2]], sensitivityBayesian[[i]][[1]], sensitivityBayesian[[i]][[1]], y_predBayesian[[i]][[2]])
})

# Hessian
H.vi = lapply(1:4, function(i){
  - tf$linalg$pinv(P.filter.vi[[i]] + tf$linalg$pinv(- sensitivityBayesian[[i]][[2]]*2 - tf$einsum('bm,bn->bmn', sensitivityBayesian[[i]][[1]], sensitivityBayesian[[i]][[1]]) ))
})

# Fisher divergence, modulus a data dependent term
FD.vi = -tf$linalg$diag_part(sensitivityBayesian[[1]][[2]]) - .5 * tf$linalg$diag_part(H.vi[[1]])
In [ ]:
x.filter.mc = sapply(1:4, function(i){
  tf$reduce_mean(tf$stack(ret[[i+1]]), axis=0L)
})

P.filter.mc = sapply(1:4, function(i){
  `__x__` = tf$stack(ret[[i+1]])
  `__2nd_moment__` = tf$einsum('lbi,lbj->bij', `__x__`, `__x__`)/dim(`__x__`)[1]
  `__1st_moment__` = tf$reduce_mean(`__x__`, axis=0L)
  `__2nd_moment__` - tf$einsum('bi,bj->bij',`__1st_moment__`,`__1st_moment__`)
})

posterior = tf$math$divide_no_nan(tf$einsum('lbi,bi->lb', tf$nn$softmax(tf$stack(ret[[5]]), axis=-1L), y), tf$einsum('lbi,bi->b', tf$nn$softmax(tf$stack(ret[[5]]), axis=-1L), y))

## it seems that the posterior estimated in this way has higher variance than unweighted average
x.sensitivity.mc = sapply(1:4, function(i){
  tf$einsum('lbi,lb->bi',tf$stack(sensitivity[[i+1]]), posterior)
})

# Hessian Bayesian
H.mc = lapply(1:4, function(i){
  tf$reduce_mean(tf$stack(hessian[i,]), axis=0L)
})

P.sensitivity.mc = sapply(1:4, function(i){
  A = tf$einsum('bsi,bsj,sb->bij',tf$stack(sensitivity[[i+1]],axis=1L), tf$stack(sensitivity[[i+1]],axis=1L), posterior)
  H = H.mc[[i]]
  (H - A) * .5
})

# Fisher divergence, modulus a data dependent term
FD.mc = .5 * tf$reduce_mean(tf$stack(sensitivity[[2]], axis=0L)^2, axis=0L) - tf$reduce_mean(tf$stack(lapply(hessian[1,], tf$linalg$diag_part), axis=0L), axis=0L)

The following shows that the smoothing densities of post-activations (in the Bayesian NN) can be estimated from the sensitivities of the post-activations (in the ensemble of 64 non-Bayesian/ vanilla neural networks). Monte Carlo estimation is very noisy though, especially in the estimation of Hessian, even in this tiny toy example.

In [ ]:
require(repr)
repr.plot.opt = options(repr.plot.width=17, repr.plot.height=4.5)
par.opt = par(mfcol=c(1,4))

plot(x.filter.vi[[1]]$numpy(), x.filter.mc[[1]]$numpy(), xlab='VI', ylab='MC', main = 'forward x[1]')
abline(coef=c(0,1),col='red')
plot(x.filter.vi[[2]]$numpy(), x.filter.mc[[2]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'forward x[2]')
abline(coef=c(0,1),col='red')
plot(x.filter.vi[[3]]$numpy(), x.filter.mc[[3]]$numpy(), xlab='VI', ylab='MC', main = 'forward x[3]')
abline(coef=c(0,1),col='red')
plot(x.filter.vi[[4]]$numpy(), x.filter.mc[[4]]$numpy(), xlab='VI', ylab='MC', main = 'forward x[4]')
abline(coef=c(0,1),col='red')

plot(P.filter.vi[[1]][1:8,,]$numpy(), P.filter.mc[[1]][1:8,,]$numpy(), xlab='VI', ylab='MC', main = 'forward P[1]')
points(tf$linalg$diag_part(P.filter.vi[[1]])$numpy(), tf$linalg$diag_part(P.filter.mc[[1]])$numpy(),col='red')
abline(coef=c(0,1),col='red')
plot(P.filter.vi[[2]]$numpy(), P.filter.mc[[2]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'forward P[2]')
points(tf$linalg$diag_part(P.filter.vi[[2]])$numpy(), tf$linalg$diag_part(P.filter.mc[[2]])$numpy(),col='red')
abline(coef=c(0,1),col='red')
plot(P.filter.vi[[3]]$numpy(), P.filter.mc[[3]]$numpy(), xlab='VI', ylab='MC', main = 'forward P[3]')
points(tf$linalg$diag_part(P.filter.vi[[3]])$numpy(), tf$linalg$diag_part(P.filter.mc[[3]])$numpy(),col='red')
abline(coef=c(0,1),col='red')
plot(P.filter.vi[[4]]$numpy(), P.filter.mc[[4]]$numpy(), xlab='VI', ylab='MC', main = 'forward P[4]')
points(tf$linalg$diag_part(P.filter.vi[[4]])$numpy(), tf$linalg$diag_part(P.filter.mc[[4]])$numpy(),col='red')
abline(coef=c(0,1),col='red')

plot(x.sensitivity.vi[[1]]$numpy(), x.sensitivity.mc[[1]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward x[1]')
abline(coef=c(0,1),col='red')
plot(x.sensitivity.vi[[2]]$numpy(), x.sensitivity.mc[[2]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward x[2]')
abline(coef=c(0,1),col='red')
plot(x.sensitivity.vi[[3]]$numpy(), x.sensitivity.mc[[3]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward x[3]')
abline(coef=c(0,1),col='red')
plot(x.sensitivity.vi[[4]]$numpy(), x.sensitivity.mc[[4]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward x[4]')
abline(coef=c(0,1),col='red')

local({
  ndx = sample(prod(dim(P.sensitivity.mc[[1]])), 4096)
  plot(P.sensitivity.vi[[1]]$numpy()[ndx], P.sensitivity.mc[[1]]$numpy()[ndx], asp=1, xlab='VI', ylab='MC', main = 'backward P[1]')
  points(tf$linalg$diag_part(P.sensitivity.vi[[1]])$numpy(), tf$linalg$diag_part(P.sensitivity.mc[[1]])$numpy(), col='red')
  abline(coef=c(0,1), col='red')
})
plot(P.sensitivity.vi[[2]]$numpy(), P.sensitivity.mc[[2]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward P[2]')
points(tf$linalg$diag_part(P.sensitivity.vi[[2]])$numpy(), tf$linalg$diag_part(P.sensitivity.mc[[2]])$numpy(), col='red')
abline(coef=c(0,1), col='red')
plot(P.sensitivity.vi[[3]]$numpy(), P.sensitivity.mc[[3]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward P[3]')
points(tf$linalg$diag_part(P.sensitivity.vi[[3]])$numpy(), tf$linalg$diag_part(P.sensitivity.mc[[3]])$numpy(), col='red')
abline(coef=c(0,1), col='red')
plot(P.sensitivity.vi[[4]]$numpy(), P.sensitivity.mc[[4]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward P[4]')
points(tf$linalg$diag_part(P.sensitivity.vi[[4]])$numpy(), tf$linalg$diag_part(P.sensitivity.mc[[4]])$numpy(), col='red')
abline(coef=c(0,1), col='red')

local({
  ndx = sample(prod(dim(H.mc[[1]])), 4096)
  plot(H.vi[[1]]$numpy()[ndx], H.mc[[1]]$numpy()[ndx], asp=1, xlab='VI', ylab='MC', main = 'Hessian [1]')
  points(tf$linalg$diag_part(H.vi[[1]])$numpy(), tf$linalg$diag_part(H.mc[[1]])$numpy(), col='red')
  abline(coef=c(0,1), col='red')
})

plot(H.vi[[2]]$numpy(), H.mc[[2]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'Hessian [2]')
abline(coef=c(0,1),col='red')
plot(H.vi[[3]]$numpy(), H.mc[[3]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'Hessian [3]')
abline(coef=c(0,1),col='red')
plot(H.vi[[4]]$numpy(), H.mc[[4]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'Hessian [4]')
abline(coef=c(0,1),col='red')

par(par.opt)
options(repr.plot.opt)
Loading required package: repr

Image
Image
Image
Image
Image

The following shows that we can compute particle sensitivity through (Bayesian) sensitivity over particle distribution parameters in the forward pass. The computation of particle sensitivity from through the inverse of smoothing densities and filter densities calculated from Bayesian sensitivity is not numerically stable. A more numerically stable way to compute sensitivity of loss over particles from sensitivity of loss over forward distribution parameter of particles.

$\require{cancel}$

\begin{align*} \hat{\mathbf{x}}_{l|L}= & \hat{\mathbf{x}}_{l}+P_{l}\left(\nabla_{\hat{\mathbf{x}}_{l}}p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)\\ P_{l|L}= & P_{l}+P_{l}\left[2\nabla_{P_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})-\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)^{\top}\right]P_{l}\\ P_{l|L}^{-1}= & P_{l}^{-1}-\cancel{P_{l}^{-1}P_{l}}\left[\left(2\nabla_{P_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})-\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)^{\top}\right)^{-1}+P_{l}P_{l}^{-1}P_{l}\right]^{-1}\cancel{P_{l}P_{l}^{-1}}\\ = & p_{l}^{-1}-\left[\left(2\nabla_{P_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})-\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)^{\top}\right)^{-1}+P_{l}P_{l}^{-1}P_{l}\right]^{-1}\\ \nabla_{\mathbf{x}_{l}}\log p(\mathbf{y}|\mathbf{x}_{l},\mathbf{x}_{0}=\mathbf{x})= & P_{l}^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l}\right)-P_{l|L}^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l|L}\right)\\ = & P_{l}^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l}\right)-\left\{ p_{l}^{-1}-\left[\left(2\nabla_{P_{l}}-\nabla_{\hat{\mathbf{x}}_{l}}\nabla_{\hat{\mathbf{x}}_{l}}^{\top}\right)^{-1}+P_{l}\right]^{-1}\right\} \left(\mathbf{x}_{l}-P_{l}\nabla_{\hat{\mathbf{x}}_{l}}-\hat{\mathbf{x}}_{l}\right)\\ = & \cancel{P_{l}^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l}\right)}\cancel{-p_{l}^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l}\right)}+\bcancel{p_{l}^{-1}}\bcancel{P_{l}}\nabla_{\hat{\mathbf{x}}_{l}}+\left[\left(2\nabla_{P_{l}}-\nabla_{\hat{\mathbf{x}}_{l}}\nabla_{\hat{\mathbf{x}}_{l}}^{\top}\right)^{-1}+P_{l}\right]^{-1}\left(\mathbf{x}_{l}-P_{l}\nabla_{\hat{\mathbf{x}}_{l}}-\hat{\mathbf{x}}_{l}\right)\\ = & \left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)+\left[\left(2\nabla_{P_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})-\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)^{\top}\right)^{-1}+P_{l}\right]^{-1}\left(\mathbf{x}_{l}-P_{l}\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)-\hat{\mathbf{x}}_{l}\right) \end{align*} The loss is $J=-\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})$, so if we formulate particle sensitivity using (Bayesian) sensitivity over forward distribution parameters, the corresponding signs should be changed.

\begin{align*} - \nabla_{\mathbf{x}_{l}}J = & -\nabla_{\hat{\mathbf{x}}_{l}}J+\left[\left(-2\nabla_{P_{l}}J-\left(\nabla_{\hat{\mathbf{x}}_{l}}J\right)\left(\nabla_{\hat{\mathbf{x}}_{l}}J\right)^{\top}\right)^{-1}+P_{l}\right]^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l}+P_{l}\nabla_{\hat{\mathbf{x}}_{l}}J\right)\text{ or }\\ \nabla_{\mathbf{x}_{l}}J = & \nabla_{\hat{\mathbf{x}}_{l}}J - \left[P_{l} - \left(2\nabla_{P_{l}}J+\left(\nabla_{\hat{\mathbf{x}}_{l}}J\right)\left(\nabla_{\hat{\mathbf{x}}_{l}}J\right)^{\top}\right)^{-1}\right]^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l}+P_{l}\nabla_{\hat{\mathbf{x}}_{l}}J\right). \end{align*}

The relationship is based on the "state transition" of particles at layers is linear Gaussian. In particular, in the linear Gaussian model, the activation function is linearized using first order Taylor expansion at the mean pre-activation. If this linear Gaussian assumption is held in the vanilla NN samples, the relationship between particle sensitivity and Bayesian sensitivity should be exact. However, the vanilla NN samples are not exactly from the linear Gaussian distribution. In particular, the gradient of sensitivity functions over pre-activation is not constant. That is why we see the scattering of particle sensitivity around the prediction (based on linear Gaussian assumption).

In [ ]:
dldx = sapply(1:4, function(i){
  A = tf$linalg$pinv(P.filter.vi[[i]] + tf$linalg$pinv(- sensitivityBayesian[[i]][[2]]*2 - tf$einsum('bm,bn->bmn', sensitivityBayesian[[i]][[1]], sensitivityBayesian[[i]][[1]]) ))
  sensitivityBayesian[[i]][[1]] - tf$einsum('bij,lbj->lbi', A, tf$stack(ret[[i+1]]) - x.filter.vi[[i]] + tf$einsum('bij,bj->bi',P.filter.vi[[i]], sensitivityBayesian[[i]][[1]]))
})

require(repr)
repr.plot.opt = options(repr.plot.width=17, repr.plot.height=4.5)
par.opt = par(mfcol=c(1,4))

require(MASS)
image(kde2d(dldx[[1]]$numpy(), tf$stack(sensitivity[[2]])$numpy(), n=100, lims = rep(range(tf$stack(sensitivity[[2]])$numpy()), times=2)))
#points(dldx[[1]]$numpy(), tf$stack(sensitivity[[2]])$numpy(), pch='.')
abline(coef=c(0,1))

image(kde2d(dldx[[2]]$numpy(), tf$stack(sensitivity[[3]])$numpy(), n=100, lims = rep(range(tf$stack(sensitivity[[3]])$numpy()), times=2)))
points(dldx[[2]]$numpy(), tf$stack(sensitivity[[3]])$numpy(), pch='.')
abline(coef=c(0,1))

image(kde2d(dldx[[3]]$numpy(), tf$stack(sensitivity[[4]])$numpy(), n=100, lims = rep(range(tf$stack(sensitivity[[4]])$numpy()), times=2)))
points(dldx[[3]]$numpy(), tf$stack(sensitivity[[4]])$numpy(), pch='.')
abline(coef=c(0,1))

image(kde2d(dldx[[4]]$numpy(), tf$stack(sensitivity[[5]])$numpy(), n=100, lims = rep(range(tf$stack(sensitivity[[5]])$numpy()), times=2)))
points(dldx[[4]]$numpy(), tf$stack(sensitivity[[5]])$numpy(), pch='.')
abline(coef=c(0,1))

par(par.opt)
options(repr.plot.opt)
Loading required package: MASS

Image

Fisher divergence with Bayesian neural network

Since the Hessian of loss over input is related to input filter and smoothing densities in a Gaussian Bayeisan neural network, we can effectively identify Fisher divergence with a forward propagation and a backward propagation. The following compares Fisher divergence through Monte Carlo integration over multiple vanilla NN and a Bayesian NN.

  • formulation of an energy-based model $p_\theta(\mathbf{x})=\exp \left(f_\theta(\mathbf{x}) - \log Z(\theta)\right)$.
  • (Stein) score function $s_\theta(\mathbf{x}):=\nabla_{\mathbf{x}} \log p_\theta(\mathbf{x})=\nabla_{\mathbf{x}} f_\theta(\mathbf{x})-\nabla_{\mathbf{x}} \log Z(\theta)=\nabla_{\mathbf{x}} f_\theta(\mathbf{x})$.
  • Fisher divergence between 2 distributions $p_{\text{data}}(\mathbf{x})$ and $p_\theta(\mathbf{x})$ is \begin{align*} D_F(p_{\text{data}}, p_\theta):= & \frac{1}{2}\mathbf{E}_{\mathbf{x}\sim p_{\text{data }}}\left[\left\Vert \nabla_{\mathbf{x}}\log p_{\text{data }}(\mathbf{x})-\nabla_{\mathbf{x}}\log p_{\theta}(\mathbf{x})\right\Vert _{2}^{2}\right]\\ = & ....\\ = & \mathbf{E}_{\mathbf{x}\sim p_{\text{data }}}\left[\frac{1}{2}\left\Vert \nabla_{\mathbf{x}}\log p_{\boldsymbol{\theta}}\left(\mathbf{x}\right)\right\Vert _{2}^{2}+\Delta_{\mathbf{x}}f_{\boldsymbol{\theta}}\left(\mathbf{x}\right)\right] + \text{const }\\ \approx & \frac{1}{n}\sum_{i=1}^{n}\left[\frac{1}{2}\left\Vert \nabla_{\mathbf{x}}\log p_{\boldsymbol{\theta}}\left(\mathbf{x}_{i}\right)\right\Vert _{2}^{2}+\Delta_{\mathbf{x}}f_{\boldsymbol{\theta}}\left(\mathbf{x}_{i}\right)\right]+\text{const }\\ = & \frac{1}{n}\sum_{i=1}^{n}\left[\frac{1}{2}\left\Vert \nabla_{\mathbf{x}}f_{\boldsymbol{\theta}}\left(\mathbf{x}_{i}\right)\right\Vert _{2}^{2}+\Delta_{\mathbf{x}}f_{\boldsymbol{\theta}}\left(\mathbf{x}_{i}\right)\right]+\text{const }. \end{align*}
  • Score matching algorithm minimizes the Fisher divergence between data distribution $p_{\text{data}}(\mathbf x)$ and model distribution $p_{\boldsymbol\theta}(\mathbf x)$, here formulated as a energy-based model. pro: no sampling; con: Hessian is hard to estimate.
    • Sample a mini-batch of datapoints $\left\{\mathbf{x}_1, \mathbf{x}_2, \cdots, \mathbf{x}_n\right\} \sim p_{\text {data }}(\mathbf{x})$.
    • Estimate the score matching loss with the empirical mean \begin{aligned} & \frac{1}{n} \sum_{i=1}^n\left[\frac{1}{2}\left\|\nabla_{\mathbf{x}} \log p_\theta\left(\mathbf{x}_i\right)\right\|_2^2+\operatorname{tr}\left(\nabla_{\mathbf{x}}^2 \log p_\theta\left(\mathbf{x}_i\right)\right)\right] \\ = & \frac{1}{n} \sum_{i=1}^n\left[\frac{1}{2}\operatorname{tr}\left[\left(\nabla_{\mathbf{x}} f_\theta\left(\mathbf{x}_i\right)\right) \left(\nabla_{\mathbf{x}} f_\theta\left(\mathbf{x}_i\right)\right)^\top\right]+\operatorname{tr}\left(\nabla_{\mathbf{x}}^2 f_\theta\left(\mathbf{x}_i\right)\right)\right] \end{aligned}
  • Score matching with Bayesian neural network.

\begin{align*} & \mathbf{E}_{\gamma(\mathbf{x}_{l})}\left(\nabla_{\mathbf{x}_{l}}\log p(\mathbf{y}|\mathbf{x}_{l},\mathbf{x}_{0}=\mathbf{x})\right)\left(\nabla_{\mathbf{x}_{l}}\log p(\mathbf{y}|\mathbf{x}_{l},\mathbf{x}_{0}=\mathbf{x})\right)^{\top}\\ = & P_{l}^{-1}\left[P_{l|L}-P_{l}+\left(\hat{\mathbf{x}}_{l|L}-\hat{\mathbf{x}}_{l}\right)\left(\hat{\mathbf{x}}_{l|L}-\hat{\mathbf{x}}_{l}\right)^{\top}\right]P_{l}^{-1}+\left(P_{l|L}^{-1}-P_{l}^{-1}\right)\\ = & 2\nabla_{P_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})-\left[\left(2 \nabla_{P_l} \log p\left(\mathbf{y} \mid \hat{\mathbf{x}}_l, P_l\right)-\left(\nabla_{\hat{\mathbf{x}}_l} p\left(\mathbf{y} \mid \hat{\mathbf{x}}_l, P_l\right)\right)\left(\nabla_{\hat{\mathbf{x}}_l} p\left(\mathbf{y} \mid \hat{\mathbf{x}}_l, P_l\right)\right)^{\top}\right)^{-1} + P_l\right]^{-1}\\ & \nabla_{\mathbf{x}_l \mathbf{x}_l^{\top}} \log p\left(\mathbf{y} \mid \mathbf{x}_l, \mathbf{x}_0=\mathbf{x}\right) = P_l^{-1} - P_{l \mid L}^{-1}\\ = & \left[\left(2 \nabla_{P_l} \log p\left(\mathbf{y} \mid \hat{\mathbf{x}}_l, P_l\right)-\left(\nabla_{\hat{\mathbf{x}}_l} p\left(\mathbf{y} \mid \hat{\mathbf{x}}_l, P_l\right)\right)\left(\nabla_{\hat{\mathbf{x}}_l} p\left(\mathbf{y} \mid \hat{\mathbf{x}}_l, P_l\right)\right)^{\top}\right)^{-1} + P_l\right]^{-1}\\ \\ & \frac{1}{n} \sum_{i=1}^n\left[\frac{1}{2}\left\|\nabla_{\mathbf{x}} \log p_\theta\left(\mathbf{x}_i\right)\right\|_2^2+\operatorname{tr}\left(\nabla_{\mathbf{x}}^2 \log p_\theta\left(\mathbf{x}_i\right)\right)\right] \\ = & \frac{1}{n} \sum_{i=1}^n \operatorname{tr}\left[\frac{1}{2}\left(\nabla_{\mathbf{x}} f_\theta\left(\mathbf{x}_i\right)\right) \left(\nabla_{\mathbf{x}} f_\theta\left(\mathbf{x}_i\right)\right)^\top + \nabla_{\mathbf{x}}^2 f_\theta\left(\mathbf{x}_i\right)\right]\\ = & \frac{1}{n} \sum_{i=1}^n \operatorname{tr}\left\lbrace \nabla_{P_{l}}\log p(\mathbf{y}_i|\hat{\mathbf{x}}_{i},P_{i}) + .5 \left[P_i + \left(2 \nabla_{P_i} \log p\left(\mathbf{y} \mid \hat{\mathbf{x}}_i, P_i\right)-\left(\nabla_{\hat{\mathbf{x}}_i} p\left(\mathbf{y} \mid \hat{\mathbf{x}}_i, P_i\right)\right)\left(\nabla_{\hat{\mathbf{x}}_i} p\left(\mathbf{y} \mid \hat{\mathbf{x}}_i, P_i\right)\right)^{\top}\right)^{-1}\right]^{-1} \right\rbrace. \end{align*}

In [ ]:
FD.mc = .5 * tf$reduce_mean(tf$stack(sensitivity[[2]], axis=0L)^2, axis=0L) - tf$reduce_mean(tf$stack(lapply(hessian[1,], tf$linalg$diag_part), axis=0L), axis=0L)
FD.vi = -tf$linalg$diag_part(sensitivityBayesian[[1]][[2]]) - .5 * tf$linalg$diag_part(H.vi[[1]])
plot(tf$reduce_sum(FD.vi, axis=-1L), tf$reduce_sum(FD.mc, axis=-1L),, xlab='VI', ylab='MC', main = 'Fisher divergence')
#abline(coef=c(0,1),col='red')
Image

In the following, we show linearization of densely connected layer doesn't affect the learning, which is based on gradient and assumes that the loss surface is locally linear.

In [ ]:
mnist_train_ds = tf$data$Dataset$from_tensor_slices(mnist$train)$shuffle(2048L)$batch(32L)$map(function(xy){
  c(x,y) %<-% xy
  x <- tf$reshape(tf$cast(x/255.0, dtype=tf$float32), shape=tuple(-1L,784L))
  y.onehot = tf$one_hot(y, depth = 10L, dtype=tf$float32)
  tuple(x,y.onehot)
})

mnist_test_ds = tf$data$Dataset$from_tensor_slices(mnist$test)$batch(1024L)$map(function(xy){
  c(x,y) %<-% xy
  x <- tf$reshape(tf$cast(x/255.0, dtype=tf$float32), shape=tuple(-1L,784L))
  y.onehot = tf$one_hot(y, depth = 10L, dtype=tf$float32)
  tuple(x,y.onehot)
})

inputs = layer_input(shape = 784L, name='input')
predictions = inputs %>%
    layer_lambda(function(x) tfp$distributions$MultivariateNormalDiag(loc = x, scale_identity_multiplier = 1e-2 )$sample() ) %>%
    layer_dense_linearization(units=7, activation = 'sigmoid', use_bias = TRUE) %>%
    layer_dense_linearization(units=7, activation = 'sigmoid', use_bias = TRUE) %>%
    layer_dense_linearization(units=10, activation = NULL, use_bias = TRUE)
model <- keras_model(inputs = inputs, outputs = predictions)

myloss = function(y, y_hat){
  p = tf$math$softmax(y_hat, axis=-1L)
  dydx = tf$linalg$diag(p) - tf$einsum('bi,bj->bij', p, p)
  y_pred = tf$stop_gradient(p) + tf$einsum('bij,bi->bj', tf$stop_gradient(dydx), y_hat - tf$stop_gradient(y_hat))

  #- tf_probability()$distributions$MultivariateNormalDiag(loc=y_pred, scale_diag = tf$ones_like(y)*.14^.5)$log_prob(y)
  - tf_probability()$distributions$MultivariateNormalDiag(loc=y_pred, scale_diag = tf$ones_like(y)*.01^.5)$log_prob(y)
}
model %>% compile(loss=myloss, optimizer='adam', metrics='accuracy')


history = model %>% fit(mnist_train_ds, epochs=20, validation_data=mnist_test_ds)
model %>% evaluate(mnist_test_ds)
model %>% evaluate(mnist_train_ds)
loss
-6.58690929412842
accuracy
0.904500007629395
loss
-7.37333059310913
accuracy
0.917349994182587
In [ ]:
rm(list=ls())

Bayesian layers¶

In this section, we provide the Bayesian layers with mean field posterior according to Algorithm 1, and with natural gradient. We hope to show two things:

  • The Bayesian layers generally wrap/decorate around the non-Bayesian layers. So we believe that adding a stochastic model to a non-Bayesian computational graph should involve minimal effort.
  • Using this layers, we can bring a very deep neural network to converge with competitive performance even without any normalization layers. This hopefully points to the potential of Bayesian modeling to hand vanishing/exploding gradient problem long plaggued deep learning.

We implement natural gradient by decorating Tensorflow's autodiff, as in the following code snippet.

grad_fn = function(dh, dP, variables){
        c(dx, dP0, dv) %<-% g2$gradient(list(h, P, divergence), list(x, P0, variables), list(dh, dP, tf$constant(1.0, dtype=self$compute_dtype)))
        dv2 = lapply(1:length(variables), function(n){
          switch(gsub('[_a-zA-Z0-9]+/(A|B|bias):[0-9]+','\\1',variables[[n]]$name),
                 A = dv[[n]] * tf$math$softplus(self$kernel_scale)^2 ,
                 B = dv[[n]] * tf$math$softplus(self$kernel_scale)^4 / tf$math$sigmoid(self$kernel_scale)^2 * 2,
                 bias = dv[[n]] )
        })
        return(list(list(dx, dP0), dv2))
      }
In [ ]:
my_divergence_fn = function(q, p, ignore=NULL){
  tfp$distributions$kl_divergence(q, p)  *0
}

# layer_conv_2d_Bayesian = Layer(
#   classname = 'BayesianConv2D',
#   inherit = tf$keras$layers$Layer,
#   initialize = function(activation = NULL, divergence_fn = my_divergence_fn, use_bias = TRUE, ...) {
#     super$initialize()
#     self$activation = keras$activations$get(activation)
#     self$divergence_fn = divergence_fn
#     self$use_bias = use_bias
#     self$conv2d = tf$keras$layers$Conv2D(activation = NULL, use_bias=FALSE, ...)
#   },

#   build = function(input_shape) {
#     self$conv2d$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
#     self$kernel_loc = self$add_weight(name='A',shape=self$conv2d$kernel$shape, initializer=self$conv2d$kernel_initializer, trainable=TRUE)
#     self$kernel_scale = self$add_weight(name='B',shape=self$conv2d$kernel$shape, initializer=tf$random_normal_initializer(mean = -3, stddev = .1), trainable=TRUE)
#     #'uniform', default for mean_field_posterior is -3
#     if(self$use_bias) self$bias = self$add_weight(name='bias', shape = list(self$conv2d$filters), initializer = 'zeros') else self$bias = NULL
#     self$built=TRUE
#   },

#   call = function(inputs) {
#     c(x, P0) %<-% inputs
#     with(tf$GradientTape(persistent=TRUE) %as% g, {
#       g$watch(list(x))
#       self$conv2d$kernel = self$kernel_loc
#       a = self$conv2d$call(x)
#       if(self$use_bias) a = tf$nn$bias_add(a, self$bias)
#       h = self$activation(a)
#       `__h__` = self$activation(a)
#     })
#     # propagate variance (accept a variance, first order approximation of state transition, output a variance)
#     #dhdx.sq.sum_x <- tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
#     self$conv2d$kernel = tf$stop_gradient(self$kernel_loc^2)
#     dhdx.sq.sum_x = tf$stop_gradient(g$gradient(`__h__`, a)^2) * self$conv2d$call(P0)
#     #dhdw.sq.sum_w <- tf$einsum('bi,ji, bj->bi',g$gradient(h, a)^2, self$B^2, tf$stop_gradient(x^2))
#     self$conv2d$kernel = tf$nn$softplus(self$kernel_scale+1e-9)^2
#     dhdw.sq.sum_w = g$gradient(h, a)^2 * self$conv2d$call(tf$stop_gradient(x^2))
#     P =  dhdx.sq.sum_x + dhdw.sq.sum_w + 1e-9

#     # reinterpreted_batch_ndims is the number of event dimensions, so every dimensions before that `:-reinterpreted_batch_ndims` is event dimension
#     q_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=self$kernel_loc, scale = tf$nn$softplus(self$kernel_scale+1e-9)), reinterpreted_batch_ndims = length(dim(self$kernel_loc)))
#     p_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc), scale = 1.0/prod(dim(self$kernel_loc)[1:3])^.5), reinterpreted_batch_ndims = length(dim(self$kernel_loc)))
#     self$add_loss(self$divergence_fn(q_, p_))

#     list(h, P)
#   }
# )

layer_conv_2d_Bayesian = Layer(
  classname = 'BayesianConv2D',
  inherit = tf$keras$layers$Layer,
  initialize = function(activation = NULL, divergence_fn = my_divergence_fn, use_bias = TRUE, ...) {
    super$initialize()
    self$activation = keras$activations$get(activation)
    self$divergence_fn = divergence_fn
    self$use_bias = use_bias
    self$conv2d = tf$keras$layers$Conv2D(activation = NULL, use_bias=FALSE, ...)
  },

  build = function(input_shape) { #tf.keras.initializers.VarianceScaling
    self$conv2d$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
    self$kernel_loc = self$add_weight(name='A',shape=self$conv2d$kernel$shape, initializer=self$conv2d$kernel_initializer, trainable=TRUE)
    #self$kernel_loc = self$add_weight(name='A',shape=self$conv2d$kernel$shape, initializer=initializer_variance_scaling(), trainable=TRUE)
    self$kernel_scale = self$add_weight(name='B',shape=self$conv2d$kernel$shape, initializer=tf$random_normal_initializer(mean = -3, stddev = .1), trainable=TRUE)
    #'uniform', default for mean_field_posterior is -3
    if(self$use_bias) self$bias = self$add_weight(name='bias', shape = list(self$conv2d$filters), initializer = 'zeros') else self$bias = NULL

    self$`__forward__` = tf$custom_gradient(function(inputs){
      c(x, P0) %<-% inputs
      with(tf$GradientTape(persistent=TRUE) %as% g2, {
        g2$watch(list(x, P0))
        with(tf$GradientTape(persistent=TRUE) %as% g, {
          g$watch(list(x, P0))
          self$conv2d$kernel = self$kernel_loc
          a = self$conv2d$call(x)
          if(self$use_bias) a = tf$nn$bias_add(a, self$bias)
          h = self$activation(a)
          `__h__` = self$activation(a)
        })
        # propagate variance (accept a variance, first order approximation of state transition, output a variance)
        #dhdx.sq.sum_x <- tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
        self$conv2d$kernel = tf$stop_gradient(self$kernel_loc^2)
        dhdx.sq.sum_x = tf$stop_gradient(g$gradient(`__h__`, a)^2) * self$conv2d$call(P0)
        #dhdw.sq.sum_w <- tf$einsum('bi,ji, bj->bi',g$gradient(h, a)^2, self$B^2, tf$stop_gradient(x^2))
        self$conv2d$kernel = tf$math$softplus(self$kernel_scale)^2
        dhdw.sq.sum_w = tf$stop_gradient(g$gradient(h, a)^2) * self$conv2d$call(tf$stop_gradient(x^2))
        P =  dhdx.sq.sum_x + dhdw.sq.sum_w + 1e-6

        q_ = tfp$distributions$Normal(loc=tf$cast(self$kernel_loc, dtype=self$compute_dtype), scale = tf$nn$softplus(self$kernel_scale))
        p_ = tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc, dtype=self$compute_dtype), scale = 1.0/prod(dim(self$kernel_loc)[1:3])^.5)
        divergence = tf$reduce_sum(self$divergence_fn(q_, p_))
        #q_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$cast(self$kernel_loc, dtype=self$compute_dtype), scale = tf$nn$softplus(self$kernel_scale)), reinterpreted_batch_ndims = length(dim(self$kernel_loc)))
        #p_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc, dtype=self$compute_dtype), scale = 1.0/prod(dim(self$kernel_loc)[1:3])^.5), reinterpreted_batch_ndims = length(dim(self$kernel_loc)))
        #divergence = self$divergence_fn(q_, p_)
        self$add_loss(tf$stop_gradient(divergence))
      })

      grad_fn = function(dh, dP, variables){
        c(dx, dP0, dv) %<-% g2$gradient(list(h, P, divergence), list(x, P0, variables), list(dh, dP, tf$constant(1.0, dtype=self$compute_dtype)))
        dv2 = lapply(1:length(variables), function(n){
          switch(gsub('[_a-zA-Z0-9]+/(A|B|bias):[0-9]+','\\1',variables[[n]]$name),
                 A = dv[[n]] * tf$math$softplus(self$kernel_scale)^2 ,
                 B = dv[[n]] * tf$math$softplus(self$kernel_scale)^4 / tf$math$sigmoid(self$kernel_scale)^2 * 2,
                 bias = dv[[n]] )
        })
        return(list(list(dx, dP0), dv2))
      }
      list(list(h, P), grad_fn)
    })

    self$built=TRUE
  },

  call = function(inputs) {
    return(self$`__forward__`(inputs))
  }
)

layer_average_pooling_2d_Bayesian = Layer(
  classname = 'BayesianAvgPool2D',
  inherit = tf$keras$layers$Layer,
  initialize = function(pool_size = c(2L, 2L), strides = NULL, ...) {
    super$initialize()
    if(is.null(strides)) strides = pool_size
    self$depthwise_conv_2d = tf$keras$layers$DepthwiseConv2D(kernel_size = pool_size, strides = strides, use_bias = FALSE, activation = NULL, ...)
  },

  build = function(input_shape) {
    self$depthwise_conv_2d$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
    self$depthwise_conv_2d$depthwise_kernel = self$depthwise_conv_2d$depthwise_kernel * 0 + 1
    self$built=TRUE
  },

  call = function(inputs) {
    c(x, P0) %<-% inputs
    h = a = self$depthwise_conv_2d$call(x) / self$depthwise_conv_2d$call( tf$ones_like(x))
    # propagate variance (accept a variance, first order approximation of state transition, output a variance)
    #dhdx.sq.sum_x = tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
    dhdx.sq.sum_x = self$depthwise_conv_2d$call(P0) / self$depthwise_conv_2d$call(tf$ones_like(P0))^2
    P =  dhdx.sq.sum_x + 1e-9
    list(h, P)
  }
)

layer_average_pooling_2d_Bayesian = Layer(
  classname = 'BayesianAvgPool2D',
  inherit = tf$keras$layers$Layer,
  initialize = function(...) {
    super$initialize()
    self$avg_pooling_2d = tf$keras$layers$AveragePooling2D(...)
  },

  build = function(input_shape) {
    self$avg_pooling_2d$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
    self$built=TRUE
  },

  call = function(inputs) {
    c(x, P0) %<-% inputs
    h = a = self$avg_pooling_2d$call(x)
    # propagate variance (accept a variance, first order approximation of state transition, output a variance)
    #dhdx.sq.sum_x <- tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
    dhdx.sq.sum_x = self$avg_pooling_2d$call(P0) / tf$cast(tf$reduce_prod(self$avg_pooling_2d$pool_size), P0$dtype)

    P =  dhdx.sq.sum_x + 1e-9
    list(h, P)
  }
)

layer_smooth_max_pooling_2d_Bayesian = Layer(
  classname = 'BayesianAvgPool2D',
  inherit = tf$keras$layers$Layer,
  initialize = function(pool_size = c(2L, 2L), strides = NULL, alpha = 50.0, ...) {
    super$initialize()
    self$alpha = alpha
    if(is.null(strides)) strides = pool_size
    self$depthwise_conv_2d = tf$keras$layers$DepthwiseConv2D(kernel_size = pool_size, strides = strides, use_bias = FALSE, activation = NULL, ...)
    self$conv_2d_transpose = tf$keras$layers$Conv2DTranspose(filters = NULL, kernel_size = pool_size, strides = strides, use_bias = FALSE, activation = NULL)
  },

  build = function(input_shape) {
    self$depthwise_conv_2d$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
    self$depthwise_conv_2d$depthwise_kernel = self$depthwise_conv_2d$depthwise_kernel * 0 + 1 # replace with a tensor

    channel_axis = self$conv_2d_transpose$`_get_channel_axis`()
    if(channel_axis < 0) channel_axis = length(input_shape[[1]]) + 1 + channel_axis # python index => R index
    self$conv_2d_transpose$filters = input_shape[[1]][[channel_axis]]
    self$conv_2d_transpose$build(input_shape[[1]])
    self$conv_2d_transpose$kernel$assign(self$conv_2d_transpose$kernel * 0)  # replace with a tensor
    for(i in 1L:self$conv_2d_transpose$filters)
      self$conv_2d_transpose$kernel[,,i,i]$assign(self$conv_2d_transpose$kernel[,,i,i] + 1.0)
    #self$conv_2d_transpose$kernel[,,i,i] = self$conv_2d_transpose$kernel[,,i,i] + tf$constant(1.0, dtype=tf$float32)
    self$conv_2d_transpose$kernel = self$conv_2d_transpose$kernel + 0
    self$built=TRUE
  },

  call = function(inputs) {
    c(x, P0) %<-% inputs
    my_max = tf$nn$max_pool2d(x, self$depthwise_conv_2d$kernel_size, self$depthwise_conv_2d$strides, 'VALID')
    my_max_in = self$conv_2d_transpose$call(my_max)
    b = tf$math$exp(self$alpha * (x - tf$stop_gradient(my_max_in)))
    a = self$depthwise_conv_2d$call(b)
    h = tf$math$log(a)/self$alpha + tf$stop_gradient(my_max)

    # propagate variance
    #dhdx.sq.sum_x = tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
    #self$depthwise_conv_2d$kernel are all 1's
    dhdx.sq.sum_x = tf$math$divide_no_nan(self$depthwise_conv_2d$call(b^2 * P0), tf$stop_gradient(a^2))
    #dhdw.sq.sum_w = g$gradient(h, a)^2 * self$conv2d$call(tf$stop_gradient(x^2)) #???
    P =  dhdx.sq.sum_x + 1e-9
    list(h, P)
  }
)

# layer_dense_Bayesian = Layer(
#   classname = 'BayesianDense',
#   inherit = tf$keras$layers$Layer,
#   initialize = function(activation=NULL, use_bias = TRUE, divergence_fn = my_divergence_fn, ...) {
#     super$initialize()
#     self$divergence_fn = divergence_fn
#     self$activation = keras$activations$get(activation)
#     self$use_bias = use_bias
#     self$dense = tf$keras$layers$Dense(activation=NULL, use_bias=FALSE, dtype = self$compute_dtype, ...)
#   },

#   build = function(input_shape) {
#     self$dense$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
#     # redefine BNN weights as Gaussian random variables in terms of trainable tf variables
#     self$kernel_loc = self$add_weight(name='A',shape=self$dense$kernel$shape, initializer=self$dense$kernel_initializer, trainable=TRUE)
#     self$kernel_scale = self$add_weight(name='B',shape=self$dense$kernel$shape, initializer=tf$random_normal_initializer(mean = -3, stddev = .1), trainable=TRUE)#'uniform'
#     if(self$use_bias) self$bias = self$add_weight(name='bias', shape = list(self$dense$units), initializer='zeros') else self$bias = NULL
#     self$built=TRUE
#   },

#   call = function(inputs) {
#     c(x, P0) %<-% inputs
#     with(tf$GradientTape(persistent=TRUE) %as% g, {
#       g$watch(list(x)) #g$watch(list(x, self$W))
#       self$dense$kernel = self$kernel_loc #+ tf$einsum('ij,i->ij',self$B, self$W) # mvnorm with mean and variance
#       a = self$dense$call(x) # tf$matmul(a=x, b=self$kernel) + self$bias
#       if(self$use_bias) a = tf$nn$bias_add(a, self$bias)
#       h = self$activation(a)
#       `__h__` = self$activation(a)
#     })
#     # propagate variance (accept a variance, first order approximation of state transition, output a variance)
#     #dhdx.sq.sum_x <- tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
#     self$dense$kernel = tf$stop_gradient(self$kernel_loc^2)
#       print('here')
#     dhdx.sq.sum_x = tf$stop_gradient(g$gradient(`__h__`, a)^2) * self$dense$call(P0)
#     #dhdw.sq.sum_w <- tf$einsum('bi,ji, bj->bi',g$gradient(h, a)^2, self$B^2, tf$stop_gradient(x^2))
#       print('here')
#     self$dense$kernel = tf$nn$softplus(self$kernel_scale)^2
#       print('here')
#     dhdw.sq.sum_w = g$gradient(h, a)^2 * self$dense$call(tf$stop_gradient(x^2))
#     P =  dhdx.sq.sum_x + dhdw.sq.sum_w + 1e-9
#       print('here')

#     ## add regularization loss
#     #q_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=self$kernel_loc, scale = tf$nn$softplus(self$kernel_scale+1e-9)), reinterpreted_batch_ndims = 2L)
#     #p_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc), scale = 1.0/dim(x)[2]^.5), reinterpreted_batch_ndims = 2L)
#     #self$add_loss(self$divergence_fn(q_, p_))

#     list(h, P)
#   }
# )

layer_dense_Bayesian = Layer(
  classname = 'BayesianDense',
  inherit = tf$keras$layers$Layer,
  initialize = function(activation=NULL, use_bias = TRUE, divergence_fn = my_divergence_fn, ...) {
    super$initialize()
    self$divergence_fn = divergence_fn
    self$activation = keras$activations$get(activation)
    self$use_bias = use_bias
    self$dense = tf$keras$layers$Dense(activation=NULL, use_bias=FALSE, ...)
  },

  build = function(input_shape) {
    self$dense$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
    # redefine BNN weights as Gaussian random variables in terms of trainable tf variables
    self$kernel_loc = self$add_weight(name='A',shape=self$dense$kernel$shape, initializer=self$dense$kernel_initializer, trainable=TRUE)
    self$kernel_scale = self$add_weight(name='B',shape=self$dense$kernel$shape, initializer=tf$random_normal_initializer(mean = -3, stddev = .1), trainable=TRUE)#'uniform'
    if(self$use_bias) self$bias = self$add_weight(name='bias', shape = list(self$dense$units), initializer='zeros') else self$bias = NULL

    self$`__forward__` = tf$custom_gradient(function(inputs){
      #message('tracing layer_dense_BayesianEP forward ....')
      c(x, P0) %<-% inputs

      with(tf$GradientTape(persistent=TRUE) %as% g2, {
        g2$watch(list(x, P0))
        with(tf$GradientTape(persistent=TRUE) %as% g, {
          g$watch(list(x, P0))
          self$dense$kernel = self$kernel_loc
          a = self$dense$call(x)
          if(self$use_bias) a = tf$nn$bias_add(a, self$bias)
          h = self$activation(a)
          `__h__` = self$activation(a)
        })
        # propagate variance (accept a variance, first order approximation of state transition, output a variance)
        #dhdx.sq.sum_x <- tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
        self$dense$kernel = tf$stop_gradient(self$kernel_loc^2)
        dhdx.sq.sum_x = tf$stop_gradient(g$gradient(`__h__`, a)^2) * self$dense$call(P0)
        #dhdw.sq.sum_w <- tf$einsum('bi,ji, bj->bi',g$gradient(h, a)^2, self$B^2, tf$stop_gradient(x^2))
        self$dense$kernel = tf$math$softplus(self$kernel_scale)^2
        dhdw.sq.sum_w = tf$stop_gradient(g$gradient(h, a)^2) * self$dense$call(tf$stop_gradient(x^2))
        P =  dhdx.sq.sum_x + dhdw.sq.sum_w + 1e-6

        q_ = tfp$distributions$Normal(loc=tf$cast(self$kernel_loc, dtype=self$compute_dtype), scale = tf$nn$softplus(self$kernel_scale))
        p_ = tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc, dtype=self$compute_dtype), scale = 1.0/dim(x)[2]^.5)
        divergence = tf$reduce_sum(self$divergence_fn(q_, p_))
        #q_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$cast(self$kernel_loc, dtype=self$compute_dtype), scale = tf$nn$softplus(self$kernel_scale)), reinterpreted_batch_ndims = 2L)
        #p_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc, dtype=self$compute_dtype), scale = 1.0/dim(x)[2]^.5), reinterpreted_batch_ndims = 2L)
        #divergence = self$divergence_fn(q_, p_)
        self$add_loss(tf$stop_gradient(divergence))
      })

      grad_fn = function(dh, dP, variables){
        c(dx, dP0, dv) %<-% g2$gradient(list(h, P, divergence), list(x, P0, variables), list(dh, dP, tf$constant(1.0, dtype=self$compute_dtype)))
        dv2 = lapply(1:length(variables), function(n){
         switch(gsub('[_a-zA-Z0-9]+/(A|B|bias):[0-9]+','\\1',variables[[n]]$name),
                A = dv[[n]] * tf$math$softplus(self$kernel_scale)^2 ,
                B = dv[[n]] * tf$math$softplus(self$kernel_scale)^4 / tf$math$sigmoid(self$kernel_scale)^2  * 2.0,
                bias = dv[[n]] )
        })
        return(list(list(dx, dP0), dv2))
      }
      list(list(h, P), grad_fn)
    })
    self$built=TRUE
  },

  call = function(inputs) {
    self$`__forward__`(inputs)
  }
)

layer_activation_Bayesian = Layer(
  classname = 'BayesianActivation',
  inherit = tf$keras$layers$Layer,
  initialize = function(activation){
    super$initialize()
    self$activation = keras$activations$get(activation)
    #self$activation = tf$keras$layers$tf$keras$layers$Activation(...)
  },

  build = function(input_shape) {
    self$built=TRUE
  },

  call = function(inputs, training) {
    c(x, P0) %<-% inputs

    with(tf$GradientTape(persistent=TRUE) %as% g, {
      g$watch(list(x))
      a = x
      h = self$activation(a)
      `__h__` = self$activation(a)
    })
    # propagate variance (accept a variance, first order approximation of state transition, output a variance)
    dhdx.sq.sum_x = tf$stop_gradient(g$gradient(`__h__`, a)^2) * P0
    #dhdw.sq.sum_w = g$gradient(h, a)^2 * tf$stop_gradient(x^2)
    dhdw.sq.sum_w = 0 #g$gradient(h, a)^2 * x^2 # need to back propagate to the weight scales from last layer
    P =  dhdx.sq.sum_x + dhdw.sq.sum_w + 1e-9
    list(h, P)
  }
)

While there is no need of a "Bayesian" normalization layer to achieve competitive convergence rate and generalization, we can definitely write one, as in the following.

layer_batch_normalization_Bayesian = Layer(
  classname = 'BayesianBatchNorm',
  inherit = tf$keras$layers$Layer,
  initialize = function(...) {
    super$initialize()
    self$batch_normalization = tf$keras$layers$BatchNormalization(...)
  },

  build = function(input_shape) {
    self$batch_normalization$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
    self$built=TRUE
  },

  call = function(inputs, training) {
    c(x, P0) %<-% inputs
    h = a = self$batch_normalization$call(x, training)
    # propagate variance: if is training, scale with batch variance (+P0), else scale with moving average variance, assume last dimension is BN dimension
    #dhdx.sq.sum_x = tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
    if(training) var = (tf$reduce_mean(x^2, axis=0L:(length(dim(x))-2)) - tf$reduce_mean(x, axis=0L:(length(dim(x))-2))^2 + 1e-3)
    else var = (self$batch_normalization$moving_variance + 1e-3)
    dhdx.sq.sum_x = P0 * tf$stop_gradient(tf$math$divide_no_nan(self$batch_normalization$gamma^2, var))
    P =  dhdx.sq.sum_x + 1e-9
    list(h, P)
  }
)
In [ ]:
rm(list=ls())

NN architectures¶

In [ ]:
cifar10 = dataset_cifar10()

cifar10.datagen = image_data_generator(
  featurewise_center=TRUE,
  featurewise_std_normalization=TRUE,
  ##rotation_range=15,
  zoom_range=0.1,
  width_shift_range=0.1,
  height_shift_range=0.1,
  horizontal_flip=TRUE,
)
cifar10.datagen$fit(cifar10$train$x)
In [ ]:
cifar100 = dataset_cifar100()

cifar100.datagen = image_data_generator(
  featurewise_center=TRUE,  # set input mean to 0 over the dataset
  featurewise_std_normalization=TRUE,  # divide inputs by std of the dataset
  #rotation_range=15,  # randomly rotate images in the range (degrees, 0 to 180)
  zoom_range=0.1, # Range for random zoom. If a float, [lower, upper] = [1-zoom_range, 1+zoom_range].
  width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
  height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)
  horizontal_flip=TRUE,  # randomly flip images
)
cifar100.datagen$fit(cifar100$train$x)

DenseNet-BC [L=100, k=12]¶

On all datasets except ImageNet, the DenseNet used in our experiments has three dense blocks that each has an equal number of layers. Before entering the first dense block, a convolution with 16 (or twice the growth rate for DenseNet-BC [note: convolution with $2\times 16 = 32$ features for DenseNet-Bottleneck+Compression]) output channels is performed on the input images. For convolutional layers with kernel size $3\times3$, each side of the inputs is zero-padded by one pixel to keep the feature-map size fixed. We use $1\times1$ convolution followed by $2\times2$ average pooling as transition layers between two contiguous dense blocks. At the end of the last dense block, a global average pooling is performed and then a softmax classifier is attached. The feature-map sizes in the three dense blocks are $32\times 32$, $16\times 16$, and $8\times 8$, respectively. We experiment with the basic DenseNet structure with configurations {L = 40,k = 12}, {L = 100, k = 12} [note: growth rate = 12. among the 100 convolutional layers, 4 are between input, 3 dense blocks, and outputs, and the rest $100-4=96$ blocks are distributed in 3 dense blocks. since each dense block is composed of a number of convolution blocks with one $1\times1$ convolution layer and one $3\times 3$ convolution layer, each dense block has $96/3/2=16$ convolution blocks] and {L = 100, k = 24}. For DenseNet- BC, the networks with configurations {L = 100, k = 12}, {L=250 , k=24} and {L=190,k=40} are evaluated.

In our experiments on ImageNet, we use a DenseNet-BC structure with 4 dense blocks on $224\times 224$ input images. The initial convolution layer comprises 2k convolutions of size $7\times 7$ with stride 2; the number of feature-maps in all other layers also follow from setting k. The exact network configurations we used on ImageNet are shown in Table 1.

All the networks are trained using stochastic gradient descent (SGD). On CIFAR and SVHN we train using batch size 64 for 300 and 40 epochs, respectively. The initial learning rate is set to 0.1, and is divided by 10 at 50% and 75% of the total number of training epochs. On ImageNet, we train models for 90 epochs with a batch size of 256. The learning rate is set to 0.1 initially, and is lowered by 10 times at epoch 30 and 60.

Following [8], we use a weight decay of $10^{−4}$ and a Nesterov momentum [35] of 0.9 without dampening. We adopt the weight initialization introduced by [10]. ... we add a dropout layer [33] after each convolu- tional layer (except the first one) and set the dropout rate to 0.2. The test errors were only evaluated once for each task and model setting.

Densenet architecture for ImageNet can also be confirmed from tf.keras.applications.DenseNet121. The code is here: https://github.com/keras-team/keras/blob/v2.12.0/keras/applications/densenet.py#L63

DenseNet121 = tf$keras$applications$DenseNet121(input_tensor = inputs %>%  layer_resizing(128L, 128L), include_top=TRUE, weights=NULL, classes=10L)

Huang, G., Liu, Z., Van Der Maaten, L., & Weinberger, K. Q. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 4700-4708).

In [ ]:
## NN architectures
# output width and height are the same, output channels = input channels + growth rate
conv_bottleneck = function(x, growth_rate){
  conv = x %>%
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>% # bn1
    layer_activation_relu() %>%
    layer_conv_2d(filters = 4 * growth_rate, kernel_size = c(1L,1L), use_bias = FALSE) %>% # conv1
    layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn2
    layer_activation_relu() %>%
    layer_conv_2d(filters = growth_rate, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE)  # conv2
  list(conv, x) %>% layer_concatenate(axis=-1L) # channel dimension 4 for channel last and 2 for channel first, keras default is channel last
}

# output width and height are half of the input, output channels = reduction * (input channels + growth rate * number of blocks)
dense_layer = function(x, block, growth_rate, num_blocks){
  layers = x
  for(i in 1:num_blocks) layers = layers %>% block(., growth_rate)
  print(sprintf('dense_layer: growth_rate = %d, num_blocks = %d', growth_rate, num_blocks))
  print(layers)
  layers
}

transition = function(x, reduction = .5){
  x %>%
    layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
    layer_activation_relu() %>%
    layer_conv_2d(filters = floor(reduction*dim(x)[4]), kernel_size = c(1L,1L), use_bias = FALSE) %>%
    layer_average_pooling_2d(pool_size = 2L)
}

CIFAR-10¶

In [ ]:
inputs = layer_input(shape = dim(cifar10$train$x)[-1])

g.r. = 12L
predictions = inputs %>%
  layer_conv_2d(filters = 2*g.r., kernel_size = 3L, padding = 'same', use_bias = FALSE) %>%
  dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>% # layer 1
  dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>%  # layer 2
  dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>%  # layer 3
  layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
  layer_activation_relu() %>%
  layer_average_pooling_2d(pool_size = 4L) %>%
  layer_flatten() %>%
  layer_dense(activation = 'softmax', units = 10L)

densenet.cifar10 = keras_model(inputs = inputs, outputs = predictions)
densenet.cifar10$compile(loss='categorical_crossentropy',optimizer= 'adam',metrics=list('accuracy'))
In [ ]:
tf$keras$utils$plot_model(densenet.cifar10, to_file='DenseNetBC_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'DenseNetBC_CIFAR10.png')
Image
PiecewiseConstantDecay, 900 epochs, best accuracy = 94.83%¶
In [ ]:
lr = tf$keras$optimizers$schedules$PiecewiseConstantDecay(
  boundaries = floor(nrow(cifar10$train$x)/64*3) * c(150, 225), values = c(0.1, 0.01, 0.001))
opt = tf$keras$optimizers$experimental$SGD(learning_rate = lr, momentum = .9, nesterov=TRUE, weight_decay=1e-4)

densenet.cifar10$compile(loss='categorical_crossentropy',optimizer= opt,metrics=list('accuracy'))

mc = callback_model_checkpoint('vanilla_DenseNet_CIFAR10-{epoch:03d}_{val_loss:.2f}_{val_accuracy:.4f}.ckpt', monitor = 'val_accuracy', save_best_only = FALSE, mode = 'auto', save_weights_only = FALSE, save_freq = 'epoch')

py_itertools = import('itertools')
history = densenet.cifar10 %>% fit(
  py_itertools$chain(cifar10.datagen$flow(cifar10$train$x, to_categorical(cifar10$train$y, num_classes=10L), batch_size = 64L,  shuffle=TRUE)),
  epochs = 300,
  steps_per_epoch = floor(nrow(cifar10$train$x)/64*3),
  validation_data = cifar10.datagen$flow(cifar10$test$x, to_categorical(cifar10$test$y, num_classes=10L), batch_size = 1000L),
  callbacks = list(mc, callback_terminate_on_naan(), callback_csv_logger(filename = 'vanilla_DenseNet_CIFAR10.txt')))
In [ ]:
saveRDS(history, file='history.RDS')
print(data.frame(history$metrics))
plot(history)
            loss  accuracy  val_loss val_accuracy
1   1.0606576204 0.6190676 0.7112364       0.7490
2   0.5555183291 0.8077606 0.5330710       0.8142
3   0.4249432683 0.8514031 0.4879132       0.8343
4   0.3554627299 0.8766621 0.4141560       0.8583
5   0.3093782067 0.8921219 0.3998110       0.8680
6   0.2732833624 0.9048115 0.3694597       0.8735
7   0.2472873926 0.9132957 0.3417197       0.8851
8   0.2257734388 0.9217665 0.3297557       0.8902
9   0.2033435851 0.9290759 0.3143096       0.8966
10  0.1888653487 0.9334148 0.3545958       0.8887
11  0.1780259758 0.9376669 0.3140873       0.8992
12  0.1672797501 0.9402568 0.3510486       0.8932
13  0.1580464393 0.9439082 0.3051358       0.9025
14  0.1496922374 0.9472258 0.3340426       0.8959
15  0.1391135305 0.9510974 0.3314134       0.8961
16  0.1332269609 0.9528530 0.3541025       0.8939
17  0.1284293830 0.9545685 0.2997865       0.9058
18  0.1225150228 0.9569783 0.3118749       0.9023
19  0.1199598089 0.9578861 0.3187947       0.9031
20  0.1113056317 0.9602625 0.3066177       0.9067
21  0.1080974266 0.9614640 0.3223408       0.9013
22  0.1062026396 0.9624453 0.3428392       0.8985
23  0.1007550508 0.9648617 0.3262803       0.9004
24  0.0995790511 0.9649819 0.3304870       0.9079
25  0.0982940122 0.9653556 0.3201887       0.9074
26  0.0932092071 0.9670311 0.3043353       0.9100
27  0.0926348567 0.9670244 0.3192368       0.9093
28  0.0876843706 0.9692606 0.3416500       0.9072
29  0.0887350067 0.9681259 0.3270891       0.9078
30  0.0857352316 0.9697213 0.3130928       0.9112
31  0.0838128403 0.9703287 0.3545492       0.9023
32  0.0831116959 0.9709495 0.3275500       0.9082
33  0.0814140663 0.9713567 0.3276057       0.9093
34  0.0796294361 0.9718306 0.3247251       0.9110
35  0.0807739720 0.9714568 0.3234583       0.9057
36  0.0767577216 0.9729187 0.3139792       0.9117
37  0.0788810626 0.9722311 0.3314813       0.9056
38  0.0763527676 0.9732457 0.3112955       0.9153
39  0.0712465569 0.9749079 0.3344874       0.9094
40  0.0740116462 0.9743472 0.3217798       0.9112
41  0.0710906610 0.9751949 0.3079209       0.9119
42  0.0723003969 0.9744673 0.3389498       0.9061
43  0.0692465156 0.9755086 0.3561204       0.9073
44  0.0726492554 0.9745741 0.3510042       0.9046
45  0.0719058067 0.9747410 0.3259953       0.9110
46  0.0686678439 0.9759826 0.3109468       0.9171
47  0.0664239749 0.9768637 0.3433867       0.9107
48  0.0686304122 0.9761761 0.3223711       0.9145
49  0.0669132695 0.9765500 0.3606417       0.9015
50  0.0686195195 0.9760694 0.3030332       0.9153
51  0.0645372346 0.9771507 0.3438092       0.9100
52  0.0669341534 0.9765099 0.3082600       0.9133
53  0.0657915920 0.9769038 0.3295853       0.9173
54  0.0636986792 0.9774178 0.3267203       0.9146
55  0.0643099099 0.9770239 0.3500418       0.9075
56  0.0650324449 0.9771841 0.3297406       0.9143
57  0.0633029789 0.9776581 0.3390211       0.9084
58  0.0617425144 0.9779384 0.3268529       0.9110
59  0.0605978072 0.9788930 0.3014122       0.9175
60  0.0631497800 0.9781454 0.3264746       0.9137
61  0.0600434393 0.9789664 0.3173001       0.9152
62  0.0627141073 0.9777381 0.3339829       0.9115
63  0.0591538474 0.9791200 0.3134220       0.9138
64  0.0617894158 0.9784591 0.3290031       0.9134
65  0.0598159730 0.9789130 0.3570973       0.9098
66  0.0582546704 0.9792134 0.3111114       0.9158
67  0.0583410747 0.9796206 0.3444687       0.9133
68  0.0592083447 0.9788529 0.3139260       0.9190
69  0.0597190559 0.9786994 0.3309248       0.9066
70  0.0590096340 0.9790198 0.3088430       0.9138
71  0.0586256087 0.9787795 0.3090871       0.9161
72  0.0589672364 0.9793936 0.3169852       0.9131
73  0.0589291938 0.9795872 0.3083241       0.9168
74  0.0580609776 0.9795138 0.3298902       0.9128
75  0.0563900061 0.9804483 0.3594465       0.9098
76  0.0553537831 0.9810090 0.3260561       0.9108
77  0.0573993884 0.9797274 0.3440355       0.9063
78  0.0573506542 0.9799610 0.3495877       0.9102
79  0.0571017750 0.9803014 0.3246139       0.9155
80  0.0532786362 0.9817767 0.3332381       0.9126
81  0.0563676469 0.9803548 0.3305483       0.9149
82  0.0565234423 0.9803148 0.3168681       0.9181
83  0.0537937097 0.9813895 0.3186008       0.9184
84  0.0549138300 0.9807287 0.3212758       0.9139
85  0.0543248877 0.9809356 0.3000783       0.9196
86  0.0546503663 0.9809222 0.3095301       0.9201
87  0.0537732653 0.9814963 0.3222055       0.9185
88  0.0537350178 0.9813161 0.3283524       0.9146
89  0.0554774553 0.9809623 0.3122342       0.9175
90  0.0536041111 0.9811025 0.3306598       0.9130
91  0.0515636429 0.9818234 0.3411500       0.9130
92  0.0540375970 0.9808888 0.3678151       0.9074
93  0.0560667403 0.9800478 0.3255465       0.9190
94  0.0506795608 0.9824575 0.3091723       0.9186
95  0.0523284823 0.9814963 0.3185725       0.9156
96  0.0532100461 0.9817032 0.3255993       0.9164
97  0.0549728312 0.9810624 0.3057711       0.9211
98  0.0507088639 0.9825376 0.3040094       0.9218
99  0.0525476150 0.9816298 0.3115408       0.9164
100 0.0527624451 0.9817299 0.3068883       0.9213
101 0.0512051545 0.9819636 0.3352647       0.9149
102 0.0559204891 0.9807354 0.2855245       0.9230
103 0.0523553006 0.9821905 0.3204745       0.9150
104 0.0523235574 0.9819369 0.3082117       0.9144
105 0.0506386310 0.9824508 0.3159414       0.9202
106 0.0512908772 0.9821104 0.3225025       0.9170
107 0.0500498302 0.9826578 0.3163248       0.9189
108 0.0506265685 0.9823975 0.3259247       0.9155
109 0.0488624610 0.9829115 0.3409326       0.9168
110 0.0496078096 0.9827312 0.3256379       0.9154
111 0.0514903776 0.9821038 0.2905671       0.9240
112 0.0511920117 0.9822306 0.3217449       0.9166
113 0.0516947322 0.9821371 0.2977969       0.9248
114 0.0493852347 0.9831251 0.3106343       0.9193
115 0.0510071889 0.9822840 0.2865785       0.9221
116 0.0478954501 0.9830583 0.3111472       0.9196
117 0.0513347872 0.9820237 0.3196720       0.9183
118 0.0507817641 0.9822506 0.3100148       0.9223
119 0.0458829440 0.9842932 0.3168711       0.9175
120 0.0501914546 0.9826845 0.2909254       0.9223
121 0.0481719598 0.9832986 0.3518380       0.9077
122 0.0498914644 0.9825310 0.3043787       0.9172
123 0.0477419458 0.9834121 0.3157377       0.9185
124 0.0490820520 0.9828314 0.3377511       0.9147
125 0.0475868434 0.9834121 0.3097544       0.9197
126 0.0468654074 0.9836591 0.3282976       0.9149
127 0.0484195203 0.9831651 0.3279636       0.9112
128 0.0494253896 0.9830984 0.3284343       0.9181
129 0.0495714210 0.9828714 0.3118525       0.9222
130 0.0454745181 0.9842732 0.3151410       0.9216
131 0.0508676767 0.9825376 0.3171544       0.9178
132 0.0490486734 0.9831251 0.3347439       0.9142
133 0.0478783101 0.9834455 0.2989904       0.9217
134 0.0480618253 0.9832519 0.3214881       0.9210
135 0.0496036075 0.9829782 0.3240486       0.9179
136 0.0484068766 0.9834521 0.3222163       0.9155
137 0.0500256419 0.9828247 0.3285780       0.9164
138 0.0463866331 0.9845936 0.3160755       0.9167
139 0.0486600287 0.9830717 0.3094704       0.9216
140 0.0469037145 0.9835523 0.3097007       0.9196
141 0.0452570617 0.9843199 0.3071093       0.9208
142 0.0470497832 0.9839928 0.3157251       0.9212
143 0.0474198461 0.9835256 0.3316968       0.9133
144 0.0462016203 0.9837726 0.3216465       0.9186
145 0.0471414886 0.9837258 0.2952453       0.9255
146 0.0463583656 0.9841464 0.3129661       0.9184
147 0.0453013144 0.9841664 0.3125087       0.9189
148 0.0466537140 0.9841797 0.3025919       0.9186
149 0.0462883748 0.9837525 0.3085271       0.9207
150 0.0465994962 0.9838393 0.3129956       0.9184
151 0.0179355461 0.9942593 0.2434872       0.9358
152 0.0099607548 0.9971030 0.2347471       0.9396
153 0.0083332034 0.9977238 0.2319643       0.9409
154 0.0069651329 0.9981376 0.2407704       0.9403
155 0.0062687914 0.9983579 0.2419545       0.9404
156 0.0053172470 0.9986049 0.2472901       0.9415
157 0.0050161425 0.9986783 0.2494557       0.9381
158 0.0048122229 0.9986984 0.2474525       0.9404
159 0.0040299026 0.9990788 0.2543326       0.9379
160 0.0039685452 0.9990521 0.2421567       0.9432
161 0.0039107334 0.9990655 0.2487621       0.9418
162 0.0039215549 0.9989653 0.2472721       0.9404
163 0.0034787229 0.9991322 0.2542889       0.9414
164 0.0032531412 0.9991589 0.2506921       0.9433
165 0.0032355059 0.9992324 0.2688347       0.9403
166 0.0031834957 0.9991723 0.2480445       0.9431
167 0.0029187598 0.9992991 0.2372066       0.9431
168 0.0029702932 0.9993325 0.2661545       0.9404
169 0.0029564695 0.9992123 0.2500715       0.9411
170 0.0028517786 0.9993325 0.2564101       0.9418
171 0.0027707925 0.9993859 0.2541938       0.9425
172 0.0025904675 0.9993125 0.2590739       0.9435
173 0.0027950117 0.9993392 0.2597617       0.9420
174 0.0024462701 0.9994593 0.2701470       0.9409
175 0.0028818778 0.9992123 0.2515804       0.9409
176 0.0025572372 0.9993926 0.2587079       0.9443
177 0.0024516101 0.9994059 0.2657771       0.9409
178 0.0023870892 0.9994259 0.2603669       0.9413
179 0.0022910915 0.9994593 0.2654268       0.9415
180 0.0021746829 0.9995528 0.2575617       0.9433
181 0.0020988819 0.9995328 0.2583516       0.9424
182 0.0020656686 0.9995795 0.2605426       0.9409
183 0.0022016745 0.9994727 0.2626130       0.9426
184 0.0018577089 0.9995862 0.2597829       0.9427
185 0.0019622957 0.9995194 0.2550812       0.9454
186 0.0021711886 0.9995061 0.2598144       0.9425
187 0.0018792929 0.9995394 0.2525495       0.9456
188 0.0019036195 0.9995661 0.2515860       0.9439
189 0.0018809845 0.9995461 0.2624856       0.9402
190 0.0017438964 0.9996395 0.2562388       0.9444
191 0.0017094957 0.9996262 0.2676136       0.9437
192 0.0017950477 0.9996195 0.2710856       0.9431
193 0.0019127132 0.9994794 0.2703643       0.9445
194 0.0018835242 0.9995328 0.2650857       0.9395
195 0.0019588454 0.9994927 0.2652496       0.9424
196 0.0017334609 0.9995595 0.2621994       0.9431
197 0.0017132707 0.9995995 0.2660967       0.9426
198 0.0018062600 0.9995528 0.2693620       0.9455
199 0.0017370149 0.9995595 0.2682706       0.9420
200 0.0016068980 0.9996128 0.2722749       0.9427
201 0.0016427378 0.9995795 0.2598840       0.9410
202 0.0016552433 0.9995795 0.2665069       0.9426
203 0.0015986745 0.9996062 0.2735996       0.9438
204 0.0015606694 0.9996796 0.2536149       0.9439
205 0.0015719079 0.9996395 0.2672495       0.9429
206 0.0016013680 0.9996462 0.2578813       0.9427
207 0.0014818421 0.9996328 0.2665404       0.9408
208 0.0015539966 0.9996796 0.2606453       0.9418
209 0.0017199826 0.9995928 0.2720141       0.9436
210 0.0016483091 0.9996462 0.2589284       0.9423
211 0.0016058885 0.9996128 0.2605408       0.9423
212 0.0015012980 0.9996395 0.2762633       0.9425
213 0.0015125026 0.9996328 0.2594804       0.9445
214 0.0014314755 0.9996862 0.2473979       0.9450
215 0.0014756023 0.9997463 0.2599065       0.9441
216 0.0016070793 0.9995862 0.2627328       0.9437
217 0.0013638893 0.9997063 0.2660143       0.9419
218 0.0014398288 0.9997129 0.2545859       0.9441
219 0.0015194830 0.9996395 0.2630380       0.9410
220 0.0015468349 0.9996595 0.2656153       0.9451
221 0.0013944136 0.9996529 0.2523253       0.9437
222 0.0013571376 0.9996929 0.2552068       0.9452
223 0.0014020227 0.9996929 0.2580487       0.9435
224 0.0013498960 0.9997463 0.2624156       0.9427
225 0.0016269906 0.9995661 0.2587347       0.9438
226 0.0013636855 0.9996796 0.2672617       0.9450
227 0.0013746357 0.9996996 0.2526856       0.9470
228 0.0012725247 0.9997063 0.2670771       0.9455
229 0.0013421818 0.9996595 0.2594174       0.9433
230 0.0013267826 0.9996996 0.2623502       0.9416
231 0.0013328877 0.9997196 0.2666925       0.9427
232 0.0012907290 0.9997463 0.2553433       0.9447
233 0.0012470031 0.9997396 0.2454526       0.9480
234 0.0012071745 0.9997396 0.2559259       0.9436
235 0.0012451368 0.9996996 0.2539373       0.9447
236 0.0011727633 0.9997663 0.2663018       0.9452
237 0.0012135674 0.9997730 0.2606533       0.9426
238 0.0012924110 0.9997263 0.2532763       0.9457
239 0.0012165137 0.9997597 0.2509798       0.9453
240 0.0011595050 0.9997931 0.2571831       0.9435
241 0.0012686322 0.9997129 0.2585961       0.9420
242 0.0012253302 0.9997931 0.2700651       0.9405
243 0.0011738786 0.9996929 0.2546351       0.9446
244 0.0012180576 0.9997196 0.2649637       0.9431
245 0.0012297052 0.9996996 0.2534186       0.9427
246 0.0010129946 0.9998064 0.2585182       0.9448
247 0.0010431785 0.9997931 0.2635866       0.9427
248 0.0011360900 0.9997663 0.2624811       0.9443
249 0.0011431404 0.9997530 0.2592306       0.9427
250 0.0010917124 0.9997730 0.2638656       0.9420
251 0.0011282272 0.9997931 0.2685623       0.9436
252 0.0012394985 0.9997196 0.2590341       0.9466
253 0.0010789371 0.9997997 0.2541146       0.9452
254 0.0011112305 0.9997530 0.2587816       0.9448
255 0.0011482992 0.9997931 0.2540508       0.9453
256 0.0010297050 0.9997864 0.2542107       0.9453
257 0.0010894474 0.9998131 0.2502379       0.9464
258 0.0012215456 0.9996996 0.2591000       0.9431
259 0.0010120156 0.9998131 0.2605835       0.9433
260 0.0010174772 0.9998331 0.2492044       0.9461
261 0.0011001317 0.9998065 0.2490786       0.9462
262 0.0010827924 0.9998064 0.2564962       0.9444
263 0.0011595695 0.9997597 0.2577664       0.9429
264 0.0011081379 0.9998198 0.2430533       0.9461
265 0.0012434288 0.9996796 0.2599364       0.9418
266 0.0010980372 0.9997997 0.2565093       0.9456
267 0.0010212519 0.9997997 0.2591139       0.9443
268 0.0010855130 0.9997864 0.2760078       0.9420
269 0.0011163345 0.9997797 0.2608415       0.9423
270 0.0009602397 0.9997663 0.2555375       0.9441
271 0.0011822636 0.9997396 0.2572889       0.9452
272 0.0010497085 0.9997730 0.2497728       0.9483
273 0.0009901815 0.9998398 0.2649685       0.9436
274 0.0010544929 0.9997797 0.2532931       0.9447
275 0.0011266209 0.9997396 0.2632809       0.9436
276 0.0010210563 0.9998131 0.2508913       0.9437
277 0.0011818097 0.9997530 0.2542528       0.9436
278 0.0011202770 0.9997530 0.2587285       0.9421
279 0.0009996522 0.9997730 0.2505912       0.9449
280 0.0010250015 0.9998331 0.2539868       0.9469
281 0.0012771676 0.9996328 0.2584452       0.9432
282 0.0010434379 0.9997663 0.2590295       0.9438
283 0.0010412202 0.9997663 0.2557743       0.9444
284 0.0010772038 0.9998064 0.2634786       0.9432
285 0.0009531339 0.9997931 0.2565470       0.9437
286 0.0009395551 0.9998465 0.2603623       0.9419
287 0.0011374332 0.9997530 0.2583754       0.9430
288 0.0011194443 0.9997730 0.2653035       0.9413
289 0.0009858023 0.9998264 0.2620707       0.9449
290 0.0012037378 0.9997463 0.2633227       0.9414
291 0.0010569250 0.9998198 0.2656918       0.9436
292 0.0009989782 0.9997797 0.2484338       0.9446
293 0.0010349407 0.9997864 0.2612786       0.9420
294 0.0012185022 0.9997063 0.2617649       0.9430
295 0.0009881820 0.9998531 0.2632387       0.9452
296 0.0011108380 0.9997931 0.2669857       0.9438
297 0.0011600659 0.9997663 0.2645464       0.9420
298 0.0009890068 0.9997931 0.2572861       0.9447
299 0.0010017925 0.9997997 0.2663090       0.9431
300 0.0010720694 0.9997463 0.2676879       0.9416
Image

CIFAR-100¶

In [ ]:
inputs = layer_input(shape = dim(cifar100$train$x)[-1])

g.r. = 12L
predictions = inputs %>%
  layer_conv_2d(filters = 2*g.r., kernel_size = 3L, padding = 'same', use_bias = FALSE) %>%
  dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>% # layer 1
  dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>%  # layer 2
  dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>%  # layer 3
  layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
  layer_activation_relu() %>%
  layer_average_pooling_2d(pool_size = 4L) %>%
  layer_flatten() %>%
  layer_dense(activation = 'softmax', units = 100L)

densenet.cifar100 = keras_model(inputs = inputs, outputs = predictions)
[1] "dense_layer: growth_rate = 12, num_blocks = 16"
KerasTensor(type_spec=TensorSpec(shape=(None, 32, 32, 216), dtype=tf.float32, name=None), name='concatenate_15/concat:0', description="created by layer 'concatenate_15'")
[1] "dense_layer: growth_rate = 12, num_blocks = 16"
KerasTensor(type_spec=TensorSpec(shape=(None, 16, 16, 300), dtype=tf.float32, name=None), name='concatenate_31/concat:0', description="created by layer 'concatenate_31'")
[1] "dense_layer: growth_rate = 12, num_blocks = 16"
KerasTensor(type_spec=TensorSpec(shape=(None, 8, 8, 342), dtype=tf.float32, name=None), name='concatenate_47/concat:0', description="created by layer 'concatenate_47'")
PiecewiseConstantDecay, 900 epochs, best accuracy = 75.61%¶
In [ ]:
lr = tf$keras$optimizers$schedules$PiecewiseConstantDecay(
  boundaries = nrow(cifar100$train$x) / 64 * c(150, 225) * 3, values = c(0.1, 0.01, 0.001))
opt = tf$keras$optimizers$experimental$SGD(learning_rate = lr, momentum = .9, nesterov=TRUE, weight_decay=1e-4)

densenet.cifar100$compile(loss='categorical_crossentropy',optimizer= opt,metrics=list('accuracy'))

mc = callback_model_checkpoint('vanilla_DenseNet_CIFAR100-{epoch:03d}_{val_loss:.2f}_{val_accuracy:.4f}.ckpt', monitor = 'val_accuracy', save_best_only = FALSE, mode = 'auto', save_weights_only = TRUE, save_freq = 'epoch')

history = densenet.cifar100 %>% fit(
  cifar100.datagen$flow(cifar100$train$x, to_categorical(cifar100$train$y, num_classes=100L), batch_size = 64L,  shuffle=TRUE),
  epochs = 300 * 3,
  validation_data = cifar100.datagen$flow(cifar100$test$x, to_categorical(cifar100$test$y, num_classes=100L), batch_size = 1000L),
  callbacks = list(mc, callback_terminate_on_naan(), callback_csv_logger(filename = 'vanilla_DenseNet_CIFAR100.txt')))
In [ ]:
saveRDS(history, file='history.RDS')
print(data.frame(history$metrics))
plot(history)
           loss accuracy val_loss val_accuracy
1   3.853438854  0.10366 3.595496       0.1520
2   3.029376268  0.24250 2.706999       0.2965
3   2.406250715  0.36314 2.238456       0.3950
4   2.041969538  0.44308 2.022791       0.4536
5   1.804827213  0.49968 1.995402       0.4626
6   1.627967000  0.54150 1.696434       0.5296
7   1.504715323  0.57346 1.596601       0.5555
8   1.398904324  0.59946 1.539676       0.5723
9   1.312076092  0.62070 1.454368       0.5912
10  1.241668940  0.63882 1.450769       0.5950
11  1.187236190  0.65216 1.403780       0.6011
12  1.121083736  0.66956 1.372716       0.6166
13  1.075478315  0.68436 1.381446       0.6170
14  1.027493477  0.69464 1.370398       0.6239
15  0.992836535  0.70366 1.318132       0.6346
16  0.960465789  0.71146 1.352921       0.6263
17  0.929455519  0.72070 1.330776       0.6296
18  0.887544990  0.73242 1.327415       0.6342
19  0.859275520  0.73898 1.317398       0.6452
20  0.832880676  0.74668 1.291956       0.6491
21  0.813728988  0.74904 1.279784       0.6491
22  0.784946859  0.75992 1.270259       0.6560
23  0.759003758  0.76634 1.327843       0.6396
24  0.740605950  0.77122 1.311594       0.6503
25  0.719220102  0.77670 1.300073       0.6566
26  0.703240573  0.78366 1.263511       0.6599
27  0.681695402  0.78788 1.326917       0.6568
28  0.671582043  0.79046 1.311492       0.6535
29  0.655797660  0.79498 1.302583       0.6557
30  0.638647318  0.79950 1.278390       0.6635
31  0.622537255  0.80260 1.276709       0.6617
32  0.603244066  0.80974 1.301411       0.6696
33  0.585440874  0.81524 1.317993       0.6663
34  0.569371819  0.81776 1.310695       0.6614
35  0.565074861  0.81964 1.309143       0.6687
36  0.544965208  0.82728 1.300821       0.6699
37  0.538428247  0.82800 1.363141       0.6611
38  0.525849342  0.83112 1.362929       0.6646
39  0.513292432  0.83350 1.427855       0.6473
40  0.514054894  0.83336 1.303089       0.6781
41  0.495509148  0.84118 1.368043       0.6728
42  0.492607355  0.84066 1.345099       0.6750
43  0.482471704  0.84248 1.329873       0.6764
44  0.461801440  0.84874 1.370691       0.6646
45  0.461309999  0.85152 1.355694       0.6702
46  0.453512102  0.85210 1.316995       0.6801
47  0.444712967  0.85450 1.371333       0.6672
48  0.437105894  0.85720 1.399071       0.6682
49  0.435878217  0.85702 1.372445       0.6745
50  0.433400244  0.86004 1.370352       0.6759
51  0.418163389  0.86434 1.347827       0.6846
52  0.409113050  0.86630 1.413429       0.6696
53  0.398824692  0.87068 1.383677       0.6736
54  0.404064953  0.86696 1.384419       0.6721
55  0.388572395  0.87130 1.379344       0.6750
56  0.378892452  0.87394 1.400450       0.6806
57  0.381080717  0.87472 1.444072       0.6745
58  0.368275046  0.87922 1.439468       0.6765
59  0.370565772  0.87692 1.391768       0.6778
60  0.368072003  0.87986 1.381385       0.6829
61  0.354157895  0.88380 1.441241       0.6717
62  0.370540589  0.87834 1.468642       0.6741
63  0.350557864  0.88708 1.443068       0.6729
64  0.349365741  0.88296 1.397306       0.6798
65  0.347443432  0.88392 1.423342       0.6827
66  0.335909754  0.89014 1.431867       0.6835
67  0.333221674  0.89018 1.402137       0.6838
68  0.337600589  0.88768 1.445183       0.6755
69  0.325612396  0.89220 1.441339       0.6776
70  0.322922081  0.89226 1.434373       0.6778
71  0.309412003  0.89700 1.463794       0.6741
72  0.310336828  0.89748 1.454021       0.6803
73  0.311839283  0.89774 1.466136       0.6867
74  0.295288116  0.90196 1.461647       0.6788
75  0.307723254  0.89814 1.445068       0.6802
76  0.298129112  0.90130 1.439320       0.6830
77  0.299325556  0.89994 1.481609       0.6778
78  0.286619693  0.90466 1.488152       0.6731
79  0.291644156  0.90222 1.440696       0.6866
80  0.290452898  0.90256 1.511275       0.6800
81  0.278444290  0.90764 1.521133       0.6798
82  0.283030033  0.90574 1.486773       0.6824
83  0.277228087  0.90898 1.436233       0.6879
84  0.284176886  0.90476 1.506604       0.6832
85  0.273584932  0.90962 1.472148       0.6843
86  0.273817390  0.90904 1.479098       0.6840
87  0.276113123  0.90852 1.510844       0.6794
88  0.270873100  0.91066 1.479823       0.6882
89  0.265368849  0.91170 1.573855       0.6753
90  0.268944442  0.90872 1.531904       0.6741
91  0.272095144  0.90840 1.520514       0.6858
92  0.261273175  0.91266 1.508646       0.6852
93  0.261625320  0.91286 1.483993       0.6825
94  0.255627960  0.91422 1.487268       0.6847
95  0.252025664  0.91458 1.540516       0.6813
96  0.253052950  0.91530 1.496504       0.6842
97  0.247655407  0.91686 1.505142       0.6844
98  0.243123919  0.91808 1.525993       0.6862
99  0.245150745  0.91878 1.565540       0.6875
100 0.252356112  0.91550 1.506533       0.6849
101 0.243642926  0.91832 1.485817       0.6919
102 0.240009710  0.92040 1.563167       0.6807
103 0.249794781  0.91674 1.463683       0.6956
104 0.239165962  0.91850 1.533352       0.6851
105 0.238270521  0.92112 1.497235       0.6876
106 0.230006620  0.92312 1.540964       0.6885
107 0.236917302  0.92100 1.589840       0.6782
108 0.234731048  0.92226 1.500657       0.6916
109 0.234176680  0.92178 1.523285       0.6881
110 0.221878007  0.92660 1.504256       0.6855
111 0.229166731  0.92318 1.528715       0.6882
112 0.230130494  0.92286 1.531066       0.6859
113 0.225189477  0.92336 1.554032       0.6840
114 0.227609709  0.92446 1.518346       0.6904
115 0.224551991  0.92466 1.530482       0.6905
116 0.226433456  0.92340 1.533297       0.6849
117 0.222123355  0.92596 1.544646       0.6864
118 0.218439654  0.92652 1.570457       0.6912
119 0.217200741  0.92624 1.563900       0.6925
120 0.211740687  0.92922 1.635813       0.6758
121 0.222971067  0.92610 1.492358       0.6970
122 0.225700825  0.92490 1.519286       0.6965
123 0.207204282  0.93174 1.584369       0.6891
124 0.214719176  0.92776 1.581977       0.6869
125 0.211682513  0.92874 1.540503       0.6862
126 0.220860347  0.92572 1.551330       0.6923
127 0.206330746  0.93236 1.546438       0.6875
128 0.212723926  0.92988 1.560123       0.6866
129 0.204282403  0.93184 1.542713       0.6881
130 0.216955975  0.92758 1.540074       0.6851
131 0.212414458  0.92972 1.536068       0.6943
132 0.204815716  0.93142 1.630420       0.6792
133 0.216589123  0.92848 1.527262       0.6950
134 0.201428577  0.93354 1.526483       0.6879
135 0.203964844  0.93304 1.587180       0.6892
136 0.206053287  0.93204 1.585572       0.6870
137 0.196083024  0.93474 1.541427       0.6888
138 0.208957046  0.92942 1.536776       0.6901
139 0.196804479  0.93438 1.547233       0.6900
140 0.210610121  0.93096 1.614300       0.6811
141 0.196455941  0.93412 1.557264       0.6887
142 0.199247748  0.93400 1.557708       0.6933
143 0.198018119  0.93450 1.653597       0.6843
144 0.214937598  0.92836 1.580083       0.6847
145 0.192141831  0.93472 1.576070       0.6890
146 0.187078848  0.93702 1.547123       0.6909
147 0.198283032  0.93414 1.580042       0.6848
148 0.201220945  0.93268 1.540498       0.6917
149 0.193608880  0.93552 1.556743       0.6901
150 0.200757384  0.93296 1.542952       0.6927
151 0.196128398  0.93386 1.581223       0.6874
152 0.181882381  0.93954 1.638187       0.6843
153 0.197095975  0.93556 1.549024       0.6953
154 0.193902224  0.93452 1.598127       0.6855
155 0.200428516  0.93266 1.570941       0.6849
156 0.198103458  0.93220 1.533105       0.6952
157 0.186783805  0.93764 1.598982       0.6920
158 0.184159100  0.93850 1.564770       0.6840
159 0.183524430  0.93958 1.641119       0.6850
160 0.196081296  0.93478 1.586900       0.6884
161 0.190530345  0.93574 1.542786       0.6929
162 0.182412386  0.93898 1.549326       0.6956
163 0.174402520  0.94140 1.624365       0.6878
164 0.186636686  0.93786 1.563943       0.7005
165 0.187461212  0.93800 1.572465       0.6941
166 0.179597780  0.94010 1.514007       0.6906
167 0.193388849  0.93614 1.577327       0.6916
168 0.186349198  0.93704 1.590873       0.6865
169 0.190994591  0.93686 1.576589       0.6971
170 0.186066061  0.93866 1.544907       0.6966
171 0.181229964  0.94046 1.590444       0.6885
172 0.188194081  0.93654 1.578899       0.6972
173 0.174478278  0.94192 1.588653       0.6908
174 0.188555017  0.93720 1.643187       0.6799
175 0.182899877  0.93932 1.624101       0.6867
176 0.182647854  0.93832 1.582345       0.6905
177 0.178410992  0.94128 1.593744       0.6818
178 0.179302543  0.94044 1.627043       0.6882
179 0.187712759  0.93692 1.614830       0.6843
180 0.178961322  0.94038 1.549768       0.6979
181 0.176136956  0.94136 1.621494       0.6854
182 0.170544937  0.94378 1.559111       0.6981
183 0.176616132  0.94202 1.652917       0.6890
184 0.180924907  0.93990 1.560009       0.6966
185 0.175623924  0.94090 1.562559       0.6968
186 0.176832005  0.94082 1.616949       0.6858
187 0.175889254  0.94228 1.558364       0.6925
188 0.171840131  0.94210 1.604219       0.6956
189 0.179239318  0.94250 1.674956       0.6826
190 0.187601402  0.93660 1.663619       0.6814
191 0.176178679  0.94002 1.555668       0.6957
192 0.172151938  0.94220 1.603836       0.6876
193 0.170420021  0.94332 1.583708       0.6868
194 0.177237272  0.94082 1.631758       0.6867
195 0.177409858  0.94110 1.549829       0.7010
196 0.168837130  0.94432 1.561375       0.6931
197 0.174638569  0.94186 1.635010       0.6855
198 0.175447032  0.94142 1.624403       0.6898
199 0.173982367  0.94232 1.584686       0.6910
200 0.169081569  0.94310 1.625709       0.6959
201 0.171336114  0.94238 1.567278       0.6968
202 0.167572185  0.94490 1.709782       0.6830
203 0.179914773  0.93986 1.657342       0.6878
204 0.176740378  0.94206 1.571865       0.6891
205 0.171712026  0.94266 1.576882       0.6987
206 0.173517361  0.94174 1.551179       0.6974
207 0.168849468  0.94506 1.609468       0.6973
208 0.175623998  0.94068 1.623109       0.6915
209 0.168745100  0.94408 1.581399       0.6916
210 0.172160462  0.94276 1.571459       0.6888
211 0.171243861  0.94204 1.630494       0.6943
212 0.169804722  0.94292 1.632143       0.6860
213 0.164358437  0.94636 1.598293       0.6939
214 0.168748274  0.94380 1.577195       0.6895
215 0.168900877  0.94362 1.568724       0.6916
216 0.174118951  0.94240 1.610893       0.6873
217 0.158761218  0.94704 1.596302       0.6941
218 0.156621978  0.94784 1.609557       0.6911
219 0.168153256  0.94380 1.618056       0.6933
220 0.173418626  0.94240 1.612581       0.6874
221 0.165395319  0.94466 1.590066       0.6978
222 0.164805934  0.94572 1.600944       0.6980
223 0.160869703  0.94696 1.621505       0.6912
224 0.160189614  0.94674 1.636526       0.6874
225 0.170317620  0.94280 1.582918       0.6946
226 0.172883928  0.94234 1.635332       0.6840
227 0.165740788  0.94484 1.551298       0.6940
228 0.170838267  0.94292 1.625545       0.6873
229 0.156748801  0.94824 1.588758       0.6976
230 0.169500157  0.94354 1.730745       0.6899
231 0.166624740  0.94470 1.580095       0.6948
232 0.156552985  0.94780 1.634737       0.6894
233 0.155373946  0.94882 1.594615       0.6902
234 0.160767585  0.94574 1.604833       0.6961
235 0.162376836  0.94630 1.618500       0.6950
236 0.163342729  0.94514 1.598217       0.6923
237 0.164869830  0.94504 1.589932       0.6901
238 0.167083070  0.94388 1.623780       0.6922
239 0.166366309  0.94386 1.686258       0.6806
240 0.163939670  0.94596 1.672970       0.6881
241 0.163656726  0.94590 1.584843       0.6938
242 0.160308957  0.94566 1.559196       0.6938
243 0.158870742  0.94752 1.782576       0.6799
244 0.153315708  0.94776 1.677821       0.6868
245 0.160054103  0.94704 1.633507       0.6918
246 0.166394845  0.94496 1.650819       0.6902
247 0.152245060  0.95044 1.574513       0.6950
248 0.153835684  0.94816 1.601563       0.6918
249 0.162520513  0.94680 1.587545       0.6883
250 0.158546731  0.94636 1.600158       0.6994
251 0.148602009  0.95026 1.591422       0.6957
252 0.169357643  0.94304 1.573461       0.6911
253 0.160394460  0.94652 1.615093       0.6948
254 0.155707747  0.94874 1.635089       0.6903
255 0.162835807  0.94572 1.613227       0.6941
256 0.158834085  0.94674 1.635267       0.6871
257 0.160693660  0.94676 1.644222       0.6879
258 0.158949137  0.94728 1.588645       0.6935
259 0.150822088  0.94974 1.624476       0.6933
260 0.156115428  0.94852 1.551984       0.6989
261 0.161673680  0.94612 1.680822       0.6930
262 0.156695291  0.94822 1.572773       0.7005
263 0.159275517  0.94632 1.582463       0.6926
264 0.160095990  0.94624 1.586808       0.6980
265 0.153429657  0.94806 1.652047       0.6873
266 0.161438271  0.94698 1.645907       0.6932
267 0.162474379  0.94528 1.612501       0.6924
268 0.156747237  0.94746 1.638494       0.6885
269 0.159331560  0.94750 1.645733       0.6891
270 0.161978960  0.94576 1.551418       0.6932
271 0.151905313  0.95026 1.642252       0.6929
272 0.154146463  0.94852 1.566631       0.6956
273 0.163320765  0.94596 1.600148       0.6946
274 0.163466081  0.94606 1.647755       0.6906
275 0.160146400  0.94612 1.621664       0.7021
276 0.156542495  0.94740 1.616311       0.6963
277 0.155010879  0.94790 1.628595       0.6961
278 0.164308131  0.94536 1.621129       0.6925
279 0.150064841  0.95040 1.577397       0.6974
280 0.151650503  0.95006 1.632705       0.6927
281 0.156152159  0.94858 1.573657       0.6960
282 0.158748269  0.94712 1.555497       0.6990
283 0.145519659  0.95174 1.585111       0.6957
284 0.153746843  0.95004 1.628406       0.6940
285 0.160314709  0.94686 1.614646       0.6980
286 0.156300768  0.94788 1.693703       0.6868
287 0.155918702  0.94786 1.598323       0.6956
288 0.152074277  0.95042 1.609434       0.6991
289 0.149731889  0.94964 1.672662       0.6942
290 0.149375886  0.95082 1.583402       0.7012
291 0.155525327  0.94850 1.657447       0.6954
292 0.154608265  0.94880 1.610406       0.6922
293 0.147619769  0.95094 1.619599       0.6947
294 0.159224764  0.94780 1.602643       0.6971
295 0.150616378  0.95076 1.598443       0.6945
296 0.149812341  0.95028 1.569883       0.6978
297 0.142961130  0.95236 1.606472       0.6973
298 0.161538079  0.94624 1.649357       0.6905
299 0.155460984  0.94824 1.585353       0.6947
300 0.157084525  0.94720 1.586110       0.6967
301 0.150351375  0.94986 1.628255       0.6902
302 0.148078367  0.95054 1.655204       0.6911
303 0.155564860  0.94816 1.588208       0.7007
304 0.156112134  0.94816 1.586227       0.6901
305 0.139473066  0.95328 1.666981       0.6885
306 0.145163044  0.95116 1.641278       0.6960
307 0.156359464  0.94784 1.604285       0.6973
308 0.158834875  0.94600 1.699039       0.6882
309 0.148029536  0.95076 1.571804       0.7037
310 0.141125873  0.95274 1.604127       0.6931
311 0.134430289  0.95552 1.587296       0.6973
312 0.153468445  0.94936 1.668628       0.6904
313 0.156663150  0.94830 1.611753       0.6944
314 0.156011999  0.94772 1.665336       0.6905
315 0.150991023  0.94944 1.627367       0.6882
316 0.151845798  0.94974 1.601071       0.6991
317 0.153294876  0.94994 1.600884       0.6964
318 0.139552265  0.95402 1.703377       0.6829
319 0.153586730  0.94932 1.647605       0.6914
320 0.165302575  0.94488 1.616824       0.6912
321 0.147717029  0.95126 1.593752       0.7007
322 0.135137230  0.95562 1.601599       0.7013
323 0.151293650  0.94984 1.650213       0.6959
324 0.154071644  0.94896 1.620587       0.6985
325 0.145795420  0.95104 1.605487       0.6982
326 0.143773049  0.95344 1.628502       0.6913
327 0.152538463  0.94996 1.578949       0.6943
328 0.148834243  0.94968 1.589458       0.7036
329 0.155491024  0.94816 1.605622       0.6982
330 0.155179396  0.94782 1.613582       0.6952
331 0.146019667  0.95184 1.597742       0.6996
332 0.139302239  0.95458 1.605506       0.6921
333 0.146719351  0.95150 1.560899       0.7039
334 0.144795492  0.95200 1.668529       0.6879
335 0.147545025  0.95144 1.609496       0.6988
336 0.145810306  0.95078 1.571884       0.6969
337 0.145682499  0.95078 1.659166       0.6902
338 0.151341334  0.94866 1.643639       0.6923
339 0.154521286  0.94870 1.623893       0.6939
340 0.148744017  0.95090 1.603120       0.7023
341 0.155419812  0.94862 1.594821       0.6927
342 0.145073339  0.95260 1.639588       0.6942
343 0.141883448  0.95342 1.596513       0.6979
344 0.148682743  0.95096 1.613608       0.6917
345 0.145902410  0.95102 1.620131       0.7011
346 0.138717175  0.95380 1.610826       0.6949
347 0.144152358  0.95216 1.554307       0.7042
348 0.139813170  0.95396 1.625822       0.6982
349 0.148287818  0.95108 1.635874       0.6904
350 0.141597345  0.95280 1.589076       0.7012
351 0.143589795  0.95270 1.641688       0.6880
352 0.150514320  0.94920 1.612702       0.6932
353 0.134994134  0.95512 1.594752       0.6996
354 0.146415249  0.95174 1.618544       0.6921
355 0.147946894  0.95128 1.595780       0.6941
356 0.141507193  0.95332 1.587531       0.6987
357 0.143222481  0.95218 1.642227       0.6977
358 0.136874095  0.95442 1.611352       0.7001
359 0.134477109  0.95490 1.587188       0.7052
360 0.147681177  0.95076 1.618784       0.7007
361 0.151681140  0.94910 1.624903       0.7002
362 0.148775756  0.94986 1.625516       0.6916
363 0.140271589  0.95264 1.695393       0.6960
364 0.152510107  0.94912 1.614601       0.6947
365 0.143475816  0.95292 1.631860       0.6932
366 0.146330699  0.95134 1.674337       0.6927
367 0.145930842  0.95130 1.583898       0.7017
368 0.141588286  0.95262 1.618281       0.6956
369 0.137425229  0.95444 1.687356       0.6918
370 0.145295680  0.95290 1.647281       0.6901
371 0.152795196  0.94978 1.579288       0.7001
372 0.136128083  0.95522 1.563160       0.7031
373 0.137540638  0.95492 1.597466       0.6983
374 0.139822215  0.95386 1.554187       0.6991
375 0.139861718  0.95402 1.588858       0.6996
376 0.143360466  0.95258 1.653458       0.6931
377 0.140761986  0.95272 1.612385       0.6944
378 0.136317670  0.95414 1.629950       0.6937
379 0.137556970  0.95476 1.632674       0.6931
380 0.147856683  0.95210 1.682567       0.6927
381 0.146144703  0.95162 1.601905       0.6996
382 0.142928511  0.95296 1.604671       0.6957
383 0.132159099  0.95612 1.595814       0.7000
384 0.135413557  0.95550 1.583656       0.7009
385 0.142993405  0.95282 1.614029       0.6969
386 0.154201001  0.94866 1.644896       0.6927
387 0.135699034  0.95504 1.590318       0.6913
388 0.142919347  0.95224 1.613929       0.6941
389 0.139648736  0.95394 1.612259       0.6923
390 0.142419606  0.95416 1.542146       0.7025
391 0.144736722  0.95178 1.573453       0.6961
392 0.136304215  0.95480 1.585395       0.7015
393 0.151811987  0.94834 1.635323       0.6952
394 0.145787388  0.95106 1.621509       0.6951
395 0.132778943  0.95624 1.573607       0.7008
396 0.131337613  0.95538 1.629827       0.6917
397 0.145763516  0.95292 1.601867       0.6992
398 0.140917927  0.95296 1.635482       0.7001
399 0.139591441  0.95354 1.623037       0.6961
400 0.130842701  0.95670 1.623863       0.6911
401 0.137896061  0.95490 1.583350       0.6995
402 0.143798321  0.95318 1.701605       0.6870
403 0.147250786  0.95230 1.662524       0.6929
404 0.134439930  0.95474 1.655495       0.6968
405 0.134252042  0.95570 1.677080       0.6918
406 0.150872976  0.95048 1.594711       0.6981
407 0.144244492  0.95178 1.633507       0.6970
408 0.133209661  0.95506 1.668014       0.6935
409 0.135291576  0.95468 1.608640       0.6964
410 0.139853045  0.95466 1.642455       0.6914
411 0.137768671  0.95490 1.600708       0.6923
412 0.137735337  0.95466 1.657257       0.6959
413 0.146236598  0.95182 1.577553       0.7015
414 0.136942446  0.95502 1.586368       0.6981
415 0.137865007  0.95344 1.600882       0.7012
416 0.133690596  0.95608 1.625185       0.6973
417 0.140155673  0.95338 1.634989       0.6879
418 0.147813395  0.95078 1.671531       0.6919
419 0.136237681  0.95542 1.617158       0.7006
420 0.138827220  0.95378 1.563253       0.6991
421 0.139598206  0.95344 1.549036       0.7017
422 0.141333848  0.95292 1.610616       0.7003
423 0.136539862  0.95516 1.617414       0.7006
424 0.134891361  0.95586 1.606521       0.6937
425 0.131976843  0.95620 1.615617       0.6993
426 0.129547238  0.95720 1.579768       0.7001
427 0.141107827  0.95330 1.645673       0.6948
428 0.137960166  0.95500 1.568360       0.7003
429 0.136556715  0.95540 1.613927       0.6992
430 0.137796089  0.95438 1.606424       0.6946
431 0.139588729  0.95360 1.651905       0.6933
432 0.141046390  0.95436 1.593631       0.7004
433 0.135906413  0.95390 1.641060       0.6918
434 0.128570244  0.95612 1.609333       0.7014
435 0.139806166  0.95424 1.659620       0.6971
436 0.134976178  0.95486 1.628250       0.6965
437 0.140654042  0.95310 1.587914       0.6961
438 0.141185418  0.95328 1.636467       0.6909
439 0.137370929  0.95416 1.662184       0.6898
440 0.133600637  0.95536 1.589532       0.7033
441 0.134294435  0.95578 1.646176       0.6974
442 0.147875652  0.95060 1.608860       0.6972
443 0.139603287  0.95392 1.591139       0.7025
444 0.136673346  0.95556 1.612237       0.7019
445 0.132182434  0.95656 1.595953       0.7034
446 0.130647138  0.95746 1.615778       0.6988
447 0.130669087  0.95620 1.639447       0.6942
448 0.130038485  0.95750 1.616603       0.6973
449 0.135308728  0.95486 1.659323       0.6933
450 0.124344736  0.95886 1.441587       0.7229
451 0.050656807  0.98490 1.407676       0.7322
452 0.035060871  0.98980 1.401264       0.7297
453 0.029878315  0.99262 1.383097       0.7313
454 0.026154773  0.99322 1.375369       0.7358
455 0.024405673  0.99418 1.393669       0.7343
456 0.021527730  0.99504 1.378236       0.7383
457 0.020021332  0.99536 1.387158       0.7409
458 0.020443695  0.99490 1.372585       0.7412
459 0.018110137  0.99588 1.374436       0.7407
460 0.017333362  0.99628 1.382423       0.7392
461 0.016203463  0.99682 1.382789       0.7418
462 0.015902476  0.99652 1.392298       0.7403
463 0.015864998  0.99646 1.396140       0.7430
464 0.015451864  0.99640 1.384595       0.7408
465 0.014641048  0.99688 1.395020       0.7395
466 0.013376734  0.99722 1.386763       0.7403
467 0.013458423  0.99720 1.391173       0.7430
468 0.013263275  0.99708 1.395430       0.7436
469 0.012452913  0.99736 1.380908       0.7424
470 0.012391060  0.99736 1.384527       0.7422
471 0.012044979  0.99720 1.386881       0.7453
472 0.011625202  0.99764 1.405479       0.7406
473 0.012104890  0.99742 1.400176       0.7426
474 0.010641610  0.99794 1.398543       0.7415
475 0.010204412  0.99808 1.410941       0.7391
476 0.010913683  0.99772 1.393620       0.7408
477 0.010438056  0.99792 1.416264       0.7423
478 0.010024033  0.99814 1.385298       0.7472
479 0.010866049  0.99752 1.383568       0.7429
480 0.009872962  0.99818 1.418330       0.7423
481 0.009655324  0.99802 1.416580       0.7397
482 0.010174521  0.99770 1.425981       0.7409
483 0.009646485  0.99806 1.397506       0.7432
484 0.010067051  0.99794 1.412241       0.7420
485 0.009147012  0.99840 1.434500       0.7390
486 0.009346766  0.99814 1.407985       0.7428
487 0.009006250  0.99816 1.430214       0.7409
488 0.008698191  0.99822 1.413028       0.7459
489 0.008830434  0.99836 1.407192       0.7406
490 0.008609996  0.99834 1.413055       0.7460
491 0.008674718  0.99834 1.415996       0.7461
492 0.008479036  0.99846 1.419722       0.7431
493 0.008542957  0.99822 1.404585       0.7415
494 0.008159215  0.99854 1.402936       0.7443
495 0.007854393  0.99840 1.385249       0.7453
496 0.008298021  0.99854 1.410055       0.7466
497 0.008261654  0.99844 1.426474       0.7433
498 0.007983575  0.99838 1.407051       0.7437
499 0.007840581  0.99846 1.403839       0.7461
500 0.008229845  0.99834 1.403572       0.7431
501 0.008912349  0.99796 1.414895       0.7471
502 0.007427551  0.99874 1.412331       0.7452
503 0.007547786  0.99842 1.419775       0.7473
504 0.008169206  0.99832 1.401624       0.7436
505 0.007995957  0.99834 1.426146       0.7447
506 0.007517054  0.99860 1.420662       0.7449
507 0.007447371  0.99858 1.416089       0.7463
508 0.007361646  0.99852 1.421918       0.7470
509 0.007002162  0.99874 1.412154       0.7458
510 0.006709813  0.99870 1.423001       0.7422
511 0.007253907  0.99874 1.431654       0.7461
512 0.007252783  0.99858 1.412890       0.7466
513 0.007185285  0.99848 1.416620       0.7460
514 0.007244453  0.99872 1.395788       0.7450
515 0.006935450  0.99872 1.416527       0.7457
516 0.006749317  0.99874 1.405227       0.7472
517 0.007246925  0.99846 1.420532       0.7474
518 0.006517833  0.99878 1.417111       0.7485
519 0.006669237  0.99886 1.441695       0.7451
520 0.006493097  0.99890 1.439675       0.7439
521 0.006711027  0.99888 1.410877       0.7452
522 0.006741743  0.99876 1.420701       0.7459
523 0.006743353  0.99866 1.407158       0.7435
524 0.006978492  0.99870 1.441018       0.7440
525 0.006896419  0.99856 1.432550       0.7452
526 0.006289320  0.99904 1.438000       0.7411
527 0.006299040  0.99884 1.423004       0.7430
528 0.006359863  0.99880 1.429271       0.7439
529 0.006170395  0.99874 1.422184       0.7483
530 0.006173828  0.99888 1.426472       0.7440
531 0.006638598  0.99884 1.405762       0.7456
532 0.006081163  0.99882 1.409936       0.7443
533 0.006174278  0.99884 1.425069       0.7455
534 0.006584657  0.99864 1.401203       0.7492
535 0.006540495  0.99880 1.427776       0.7443
536 0.006161072  0.99894 1.398188       0.7470
537 0.006527155  0.99876 1.416759       0.7476
538 0.006122320  0.99896 1.433329       0.7419
539 0.006534172  0.99880 1.425503       0.7480
540 0.006129253  0.99876 1.417909       0.7467
541 0.006209362  0.99874 1.427651       0.7424
542 0.006163097  0.99906 1.424983       0.7455
543 0.006365490  0.99874 1.411736       0.7498
544 0.005908345  0.99892 1.436467       0.7478
545 0.005721463  0.99904 1.417147       0.7476
546 0.006163025  0.99888 1.404277       0.7475
547 0.005539848  0.99892 1.412556       0.7478
548 0.005734612  0.99894 1.396025       0.7440
549 0.005844006  0.99896 1.433410       0.7478
550 0.005525191  0.99906 1.404340       0.7515
551 0.005642379  0.99910 1.428759       0.7503
552 0.005488868  0.99916 1.408633       0.7469
553 0.006140531  0.99888 1.455970       0.7466
554 0.005134188  0.99920 1.434279       0.7448
555 0.006134112  0.99886 1.426066       0.7467
556 0.005669209  0.99902 1.415759       0.7488
557 0.005865015  0.99870 1.432459       0.7460
558 0.005482425  0.99908 1.402396       0.7495
559 0.005404666  0.99912 1.450093       0.7455
560 0.005720284  0.99886 1.421450       0.7486
561 0.005665937  0.99914 1.424225       0.7473
562 0.005583452  0.99898 1.399133       0.7466
563 0.005681403  0.99896 1.432307       0.7474
564 0.005543076  0.99890 1.406861       0.7475
565 0.005151812  0.99922 1.408505       0.7499
566 0.004985921  0.99928 1.419432       0.7476
567 0.005783788  0.99892 1.413324       0.7448
568 0.005192699  0.99914 1.428858       0.7476
569 0.005511692  0.99906 1.425979       0.7454
570 0.005181483  0.99912 1.404181       0.7457
571 0.005310256  0.99900 1.426186       0.7419
572 0.005155613  0.99908 1.420886       0.7458
573 0.005286245  0.99910 1.410772       0.7471
574 0.005153177  0.99916 1.411543       0.7542
575 0.005183875  0.99900 1.426109       0.7454
576 0.005280633  0.99910 1.395915       0.7494
577 0.004710726  0.99934 1.414770       0.7470
578 0.005187221  0.99894 1.433996       0.7435
579 0.005299417  0.99894 1.408595       0.7457
580 0.005011478  0.99924 1.423391       0.7460
581 0.005053359  0.99924 1.421120       0.7446
582 0.005513133  0.99900 1.398511       0.7466
583 0.005261915  0.99906 1.429351       0.7443
584 0.005426925  0.99902 1.413547       0.7450
585 0.004856053  0.99922 1.419502       0.7482
586 0.004881298  0.99926 1.418158       0.7448
587 0.004796502  0.99940 1.397063       0.7470
588 0.005019376  0.99914 1.395555       0.7448
589 0.005098244  0.99920 1.394117       0.7472
590 0.004879857  0.99912 1.402306       0.7497
591 0.005088651  0.99906 1.420829       0.7463
592 0.004997355  0.99922 1.414816       0.7474
593 0.004906875  0.99918 1.395988       0.7484
594 0.004503739  0.99942 1.394794       0.7464
595 0.004700697  0.99926 1.403609       0.7457
596 0.005166062  0.99902 1.432539       0.7445
597 0.005204814  0.99890 1.406911       0.7462
598 0.004721060  0.99928 1.413863       0.7477
599 0.004791856  0.99932 1.394592       0.7531
600 0.004706315  0.99920 1.403072       0.7470
601 0.004697399  0.99930 1.410332       0.7513
602 0.005670072  0.99886 1.401724       0.7440
603 0.004857319  0.99916 1.393421       0.7516
604 0.004929340  0.99916 1.420501       0.7433
605 0.004938248  0.99920 1.406967       0.7473
606 0.004946361  0.99908 1.411371       0.7483
607 0.004767010  0.99920 1.393423       0.7486
608 0.005068546  0.99906 1.391346       0.7506
609 0.004958915  0.99920 1.445292       0.7450
610 0.004783743  0.99924 1.413274       0.7449
611 0.005048754  0.99906 1.414828       0.7461
612 0.004865469  0.99916 1.391849       0.7481
613 0.005005597  0.99908 1.406553       0.7457
614 0.005072235  0.99920 1.420675       0.7466
615 0.004741145  0.99908 1.416713       0.7429
616 0.004723425  0.99918 1.427763       0.7442
617 0.004865183  0.99926 1.406725       0.7450
618 0.005269920  0.99890 1.392820       0.7402
619 0.004278865  0.99942 1.392555       0.7501
620 0.004817104  0.99924 1.413443       0.7455
621 0.004909006  0.99914 1.396229       0.7517
622 0.004438785  0.99930 1.424255       0.7427
623 0.004656434  0.99914 1.409922       0.7487
624 0.004672626  0.99918 1.387659       0.7445
625 0.004300877  0.99930 1.408186       0.7465
626 0.004252086  0.99952 1.380214       0.7499
627 0.004458995  0.99936 1.391417       0.7508
628 0.004539465  0.99930 1.376831       0.7517
629 0.004472214  0.99936 1.408462       0.7493
630 0.004452997  0.99926 1.395432       0.7463
631 0.004546248  0.99922 1.398501       0.7442
632 0.004568384  0.99924 1.390968       0.7493
633 0.004459475  0.99930 1.387382       0.7483
634 0.004710467  0.99918 1.392331       0.7484
635 0.004839281  0.99906 1.396761       0.7471
636 0.004367354  0.99932 1.372210       0.7478
637 0.004415509  0.99924 1.400341       0.7481
638 0.004599945  0.99926 1.390681       0.7480
639 0.004971456  0.99904 1.391984       0.7458
640 0.004943959  0.99914 1.404574       0.7441
641 0.004728653  0.99926 1.384480       0.7535
642 0.004177971  0.99942 1.359593       0.7519
643 0.004834421  0.99908 1.422060       0.7468
644 0.004372767  0.99942 1.386489       0.7482
645 0.004461181  0.99920 1.412726       0.7462
646 0.004836538  0.99912 1.412337       0.7530
647 0.004467005  0.99922 1.409206       0.7456
648 0.004232721  0.99934 1.394285       0.7444
649 0.004597748  0.99928 1.382640       0.7469
650 0.004212113  0.99946 1.396048       0.7439
651 0.004277823  0.99940 1.374542       0.7491
652 0.004639094  0.99926 1.394459       0.7490
653 0.004335594  0.99928 1.402994       0.7485
654 0.004215192  0.99926 1.398258       0.7511
655 0.004126162  0.99944 1.401571       0.7453
656 0.004210052  0.99932 1.407287       0.7461
657 0.004066105  0.99948 1.385116       0.7486
658 0.004494595  0.99946 1.399248       0.7464
659 0.004397073  0.99920 1.365847       0.7506
660 0.004421169  0.99920 1.383437       0.7484
661 0.004021584  0.99944 1.372819       0.7491
662 0.004440803  0.99918 1.405566       0.7469
663 0.004477837  0.99920 1.397726       0.7472
664 0.004594364  0.99918 1.392519       0.7490
665 0.004321391  0.99928 1.400039       0.7467
666 0.003884707  0.99936 1.388085       0.7477
667 0.004417255  0.99940 1.409185       0.7443
668 0.004120239  0.99942 1.378207       0.7489
669 0.004492601  0.99934 1.371402       0.7470
670 0.004287312  0.99924 1.382235       0.7508
671 0.004465098  0.99918 1.373153       0.7480
672 0.003855312  0.99942 1.401969       0.7481
673 0.004302789  0.99926 1.390971       0.7502
674 0.004210928  0.99932 1.388709       0.7457
675 0.004166720  0.99926 1.399589       0.7466
676 0.004043411  0.99952 1.395130       0.7484
677 0.004054969  0.99942 1.394243       0.7471
678 0.004202967  0.99940 1.386384       0.7494
679 0.004249067  0.99930 1.398514       0.7464
680 0.004161104  0.99932 1.402154       0.7478
681 0.004315372  0.99934 1.401574       0.7471
682 0.004056135  0.99942 1.395551       0.7503
683 0.004266312  0.99934 1.378255       0.7537
684 0.004363045  0.99912 1.384869       0.7490
685 0.004116495  0.99944 1.398063       0.7503
686 0.004099772  0.99950 1.377018       0.7525
687 0.003875785  0.99944 1.376819       0.7491
688 0.004003887  0.99928 1.386256       0.7510
689 0.003963573  0.99942 1.384135       0.7505
690 0.003895512  0.99956 1.377205       0.7516
691 0.003829194  0.99942 1.380257       0.7509
692 0.003802414  0.99940 1.395562       0.7496
693 0.003770972  0.99946 1.373237       0.7514
694 0.003860362  0.99946 1.387800       0.7529
695 0.003928786  0.99940 1.404927       0.7452
696 0.003780503  0.99950 1.383660       0.7499
697 0.003906305  0.99950 1.397846       0.7510
698 0.003901723  0.99950 1.373569       0.7507
699 0.003651343  0.99944 1.371993       0.7515
700 0.003898600  0.99936 1.384891       0.7506
701 0.003819313  0.99948 1.386870       0.7492
702 0.004223661  0.99936 1.384581       0.7481
703 0.003855399  0.99950 1.384949       0.7520
704 0.003507904  0.99966 1.383653       0.7475
705 0.004018778  0.99936 1.366909       0.7540
706 0.003871137  0.99938 1.381888       0.7482
707 0.004155347  0.99932 1.400326       0.7461
708 0.004133416  0.99918 1.397508       0.7484
709 0.004068155  0.99936 1.373228       0.7492
710 0.003763632  0.99942 1.374093       0.7488
711 0.003687674  0.99944 1.393669       0.7473
712 0.003559326  0.99952 1.398147       0.7492
713 0.003739260  0.99956 1.370705       0.7518
714 0.003812938  0.99944 1.386667       0.7515
715 0.003516178  0.99950 1.368693       0.7487
716 0.003728548  0.99938 1.379789       0.7508
717 0.003725690  0.99960 1.379278       0.7489
718 0.003798536  0.99954 1.400110       0.7521
719 0.003879965  0.99942 1.403550       0.7457
720 0.004001467  0.99928 1.389689       0.7494
721 0.003742206  0.99946 1.374500       0.7531
722 0.003800071  0.99942 1.380778       0.7513
723 0.003963613  0.99938 1.396288       0.7499
724 0.003654225  0.99942 1.400507       0.7476
725 0.003773930  0.99944 1.391011       0.7496
726 0.003594571  0.99954 1.387267       0.7499
727 0.003741001  0.99944 1.392733       0.7506
728 0.003641232  0.99952 1.392243       0.7506
729 0.003903184  0.99936 1.396375       0.7469
730 0.003855030  0.99934 1.383650       0.7504
731 0.003584256  0.99948 1.380648       0.7477
732 0.003754498  0.99950 1.382176       0.7498
733 0.003618537  0.99954 1.392480       0.7500
734 0.003623980  0.99952 1.381816       0.7509
735 0.003822622  0.99948 1.391498       0.7515
736 0.003614904  0.99940 1.384733       0.7522
737 0.003754279  0.99932 1.382403       0.7496
738 0.003695032  0.99942 1.402883       0.7452
739 0.003441158  0.99960 1.405548       0.7480
740 0.003841226  0.99950 1.398224       0.7452
741 0.003648196  0.99958 1.398626       0.7470
742 0.003622203  0.99954 1.387685       0.7487
743 0.003837872  0.99944 1.394218       0.7487
744 0.003723893  0.99946 1.393869       0.7480
745 0.003575010  0.99956 1.379912       0.7528
746 0.003627359  0.99956 1.376933       0.7506
747 0.003672693  0.99946 1.383514       0.7488
748 0.003955462  0.99942 1.403958       0.7485
749 0.004004475  0.99946 1.391278       0.7477
750 0.003516122  0.99958 1.379871       0.7500
751 0.003871972  0.99934 1.377972       0.7504
752 0.003995358  0.99930 1.403870       0.7493
753 0.003779306  0.99952 1.381453       0.7508
754 0.003834113  0.99932 1.385096       0.7464
755 0.003847099  0.99940 1.380614       0.7519
756 0.003640553  0.99948 1.382856       0.7503
757 0.003545614  0.99950 1.375053       0.7482
758 0.004023343  0.99934 1.380218       0.7507
759 0.003765266  0.99944 1.377648       0.7493
760 0.003566456  0.99966 1.382731       0.7499
761 0.003774222  0.99950 1.373917       0.7497
762 0.003949451  0.99944 1.381418       0.7525
763 0.004019988  0.99944 1.387758       0.7480
764 0.003585106  0.99946 1.381206       0.7493
765 0.003709744  0.99946 1.382514       0.7488
766 0.003641587  0.99948 1.362586       0.7513
767 0.003836775  0.99940 1.391095       0.7490
768 0.003816722  0.99946 1.394251       0.7469
769 0.003510774  0.99968 1.387712       0.7499
770 0.003702947  0.99954 1.393254       0.7478
771 0.003654839  0.99948 1.371176       0.7450
772 0.003893751  0.99942 1.407273       0.7507
773 0.003593972  0.99950 1.385557       0.7469
774 0.003570807  0.99952 1.383940       0.7504
775 0.003660334  0.99956 1.360500       0.7490
776 0.003710129  0.99936 1.373093       0.7526
777 0.003749571  0.99940 1.394454       0.7482
778 0.003621595  0.99948 1.391579       0.7485
779 0.003916530  0.99946 1.404675       0.7462
780 0.003965732  0.99928 1.407158       0.7472
781 0.003507888  0.99954 1.362682       0.7557
782 0.003921149  0.99920 1.397395       0.7483
783 0.003888350  0.99944 1.388364       0.7484
784 0.004002109  0.99946 1.377400       0.7456
785 0.003597769  0.99958 1.386281       0.7507
786 0.003968587  0.99946 1.372886       0.7479
787 0.003746550  0.99942 1.389609       0.7485
788 0.003471009  0.99950 1.376045       0.7561
789 0.003515029  0.99950 1.391333       0.7492
790 0.003740238  0.99946 1.395389       0.7474
791 0.003662769  0.99952 1.381256       0.7506
792 0.003735510  0.99950 1.385208       0.7501
793 0.003777692  0.99936 1.377771       0.7509
794 0.003426865  0.99962 1.377166       0.7532
795 0.003850501  0.99950 1.381726       0.7488
796 0.003854274  0.99942 1.357744       0.7516
797 0.003896037  0.99942 1.376736       0.7502
798 0.003623167  0.99950 1.373237       0.7514
799 0.003336385  0.99958 1.380837       0.7527
800 0.003677030  0.99936 1.367090       0.7544
801 0.003757585  0.99938 1.381894       0.7539
802 0.003691996  0.99946 1.400380       0.7497
803 0.003496195  0.99952 1.386381       0.7531
804 0.003449135  0.99958 1.383694       0.7471
805 0.003695381  0.99956 1.388623       0.7473
806 0.003624609  0.99942 1.380503       0.7489
807 0.003855928  0.99934 1.399426       0.7477
808 0.003624713  0.99946 1.392185       0.7488
809 0.003710056  0.99944 1.386121       0.7504
810 0.003599518  0.99956 1.388235       0.7457
811 0.003652663  0.99936 1.416115       0.7488
812 0.003694006  0.99932 1.376989       0.7496
813 0.003279331  0.99958 1.370095       0.7499
814 0.003546879  0.99950 1.394469       0.7490
815 0.003811984  0.99938 1.390461       0.7494
816 0.003956810  0.99932 1.386067       0.7486
817 0.003529836  0.99950 1.389819       0.7472
818 0.003880393  0.99950 1.384268       0.7481
819 0.003608852  0.99948 1.376074       0.7542
820 0.003665391  0.99956 1.380117       0.7528
821 0.003703953  0.99942 1.388243       0.7485
822 0.003461364  0.99946 1.401620       0.7469
823 0.003734413  0.99936 1.391865       0.7481
824 0.003790044  0.99940 1.367412       0.7498
825 0.003895878  0.99934 1.378636       0.7477
826 0.003676614  0.99948 1.394193       0.7468
827 0.003633945  0.99954 1.376814       0.7473
828 0.003587150  0.99938 1.366035       0.7526
829 0.003597163  0.99948 1.380416       0.7495
830 0.003318495  0.99962 1.413152       0.7470
831 0.003712012  0.99942 1.371905       0.7471
832 0.003428303  0.99972 1.393037       0.7473
833 0.003723841  0.99938 1.385642       0.7472
834 0.003771480  0.99944 1.398134       0.7466
835 0.003604361  0.99958 1.384968       0.7474
836 0.003694733  0.99952 1.357234       0.7533
837 0.003581399  0.99950 1.394210       0.7483
838 0.003818823  0.99928 1.382098       0.7482
839 0.003807909  0.99934 1.392830       0.7486
840 0.003629315  0.99942 1.379845       0.7494
841 0.003871643  0.99932 1.381863       0.7483
842 0.003787301  0.99952 1.388470       0.7514
843 0.003391408  0.99948 1.392044       0.7521
844 0.003425142  0.99972 1.382607       0.7468
845 0.003776078  0.99942 1.385040       0.7475
846 0.003456869  0.99960 1.382704       0.7475
847 0.003438824  0.99948 1.400075       0.7479
848 0.003668452  0.99948 1.380061       0.7503
849 0.003408735  0.99968 1.403111       0.7484
850 0.003459629  0.99960 1.367557       0.7461
851 0.003659104  0.99944 1.394181       0.7476
852 0.003756752  0.99940 1.380636       0.7505
853 0.003638913  0.99954 1.387339       0.7490
854 0.003841528  0.99934 1.379344       0.7491
855 0.003615711  0.99946 1.382432       0.7498
856 0.003638959  0.99940 1.385617       0.7489
857 0.003685101  0.99940 1.414099       0.7474
858 0.003408466  0.99950 1.380612       0.7508
859 0.003256983  0.99956 1.379402       0.7494
860 0.003420987  0.99956 1.393150       0.7504
861 0.003469174  0.99958 1.391155       0.7528
862 0.003435688  0.99946 1.370455       0.7478
863 0.003549380  0.99956 1.365100       0.7502
864 0.003598281  0.99946 1.376953       0.7494
865 0.003442684  0.99962 1.394220       0.7496
866 0.003403823  0.99956 1.405198       0.7521
867 0.003663774  0.99946 1.386756       0.7482
868 0.003658666  0.99938 1.383487       0.7474
869 0.003466758  0.99950 1.378211       0.7530
870 0.003663639  0.99936 1.400924       0.7450
871 0.003621953  0.99948 1.393332       0.7487
872 0.003522192  0.99950 1.393829       0.7473
Image

ResNet¶

The network inputs are $32 \times 32$ images, with the per-pixel mean subtracted. The first layer is $3 \times 3$ convolutions. Then we use a stack of $6 n$ layers with $3 \times 3$ convolutions on the feature maps of sizes $\{32,16,8\}$ respectively, with $2 n$ layers for each feature map size. The numbers of filters are $\{16,32,64\}$ respectively. The subsampling is performed by convolutions with a stride of 2 . The network ends with a global average pooling, a 10-way fully-connected layer, and softmax. There are totally $6 n+2$ stacked weighted layers. The following table summarizes the architecture:

output map size $32 \times 32$ $16 \times 16$ $8 \times 8$
# layers $1+2 n$ $2 n$ $2 n$
# filters 16 32 64

When shortcut connections are used, they are connected to the pairs of $3 \times 3$ layers (totally $3 n$ shortcuts). On this dataset we use identity shortcuts in all cases (i.e., option A), so our residual models have exactly the same depth, width, and number of parameters as the plain counterparts.

We use a weight decay of 0.0001 and momentum of 0.9 , and adopt the weight initialization in [13] and $\mathrm{BN}$ [16] but with no dropout. These models are trained with a minibatch size of 128 on two GPUs. We start with a learning rate of 0.1 , divide it by 10 at $32 \mathrm{k}$ and $48 \mathrm{k}$ iterations, and terminate training at $64 \mathrm{k}$ iterations, which is determined on a $45 \mathrm{k} / 5 \mathrm{k}$ train/val split. We follow the simple data augmentation in [24] for training: 4 pixels are padded on each side, and a $32 \times 32$ crop is randomly sampled from the padded image or its horizontal flip. For testing, we only evaluate the single view of the original $32 \times 32$ image.

Our implementation for ImageNet follows the practice in $[21,41]$. The image is resized with its shorter side randomly sampled in $[256,480]$ for scale augmentation [41]. A $224 \times 224$ crop is randomly sampled from an image or its horizontal flip, with the per-pixel mean subtracted [21]. The standard color augmentation in [21] is used. We adopt batch normalization (BN) [16] right after each convolution and before activation, following [16]. We initialize the weights as in [13] and train all plain/residual nets from scratch. We use SGD with a mini-batch size of 256. The learning rate starts from 0.1 and is divided by 10 when the error plateaus, and the models are trained for up to $60 \times 10^4$ iterations. We use a weight decay of 0.0001 and a momentum of 0.9 . We do not use dropout [14], following the practice in [16]. In testing, for comparison studies we adopt the standard 10-crop testing [21]. For best results, we adopt the fullyconvolutional form as in $[41,13]$, and average the scores at multiple scales (images are resized such that the shorter side is in $\{224,256,384,480,640\}$ ). .... See Table 1 for detailed architectures. ....

He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).

In [ ]:
myconv2d = function(x, filters, strides){
  if(strides[1]>1)
    x %>%
    layer_zero_padding_2d(padding=list(list(1L, 0L),list(1L, 0L))) %>%
    layer_conv_2d(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'valid', use_bias = FALSE)
  else
    x %>%
    layer_conv_2d(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE)
}

# there are 2 conv2d-bn-relu layers in the bottleneck,
# the 1st layer has a specified `strides`, and the other layer has `strides=1` - so it is pixel-wise dnn)
# there is a `filters` parameters to this bottleneck, specifying thenumber of filters in the 2 layers
# input is added to the output, os if input and output has different size, conv-bn-relu layer is added to macth input and output
conv_block = function(x, filters, strides = 1L){
  #print('x')
  #print(x)
  out = x %>%
    layer_conv_2d(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>%
    #myconv2d(filters = filters, strides = strides) %>%  # conv1
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>% # bn1, by default channel is last dimension
    layer_activation_relu() %>%
    layer_conv_2d(filters = filters, strides = 1L, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>% # conv2
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  # bn2
  #print('out')
  #print(out)
  if(dim(x)[4]==filters & strides==1L) shortcut = x # channel dimension 4 for channel last and 2 for channel first
  else shortcut =  x %>%
      layer_conv_2d(filters = filters, strides = strides, kernel_size = 1L, use_bias = FALSE, name = sprintf('shortcut/conv2d_%d',keras$backend$get_uid(prefix='s/c'))) %>%
      layer_batch_normalization(momentum = .9, epsilon = 1e-5, name=sprintf('shortcut/bn_%d',keras$backend$get_uid(prefix='s/b')))
  #print('shortcut')
  #print(shortcut)
  (out + shortcut) %>% layer_activation_relu()
}

# there are 4 resnet-layers, with number of filters = 64, 128, 256, and 512
# the strides of the 4 resnet-layers are: 1, 2, 2, and 2, so from 2nd resnet-layer image width/depth decrease by 2
# each resnet-layer has a number of blocks/bottlenecks, and that leads to different total depth of resnet
# it starts from a conv2d-bn-relu layer - that outputs the same number of features (64) as resnet-layer 1, uses 3x3 kernel, keeps width and height (padding='same', stride=1)
# and ends with flatten-dense

conv_bottleneck = function(x, filters, strides = 1L){
  expansion = 4L
  #print('x')
  #print(x)
  out = x %>%
    layer_conv_2d(filters = filters, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>%
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>% # bn1
    layer_activation_relu() %>%
    layer_conv_2d(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>% # conv2
    layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn2
    layer_activation_relu() %>%
    layer_conv_2d(filters = expansion * filters, kernel_size = c(1L,1L), padding = 'same', use_bias = FALSE) %>% # conv3
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  # bn3
  #print('out')
  #print(out)
  if(dim(x)[4]== expansion * filters & strides==1L) shortcut = x # channel dimension 4 for channel last and 2 for channel first
  else shortcut =  x %>%
      layer_conv_2d(filters = expansion * filters, strides = strides, kernel_size = 1L, use_bias = FALSE, name = sprintf('shortcut/conv2d_%d',keras$backend$get_uid(prefix='s/c'))) %>%
      layer_batch_normalization(momentum = .9, epsilon = 1e-5, name=sprintf('shortcut/bn_%d',keras$backend$get_uid(prefix='s/b')))
  #print('shortcut')
  #print(shortcut)
  (out + shortcut) %>% layer_activation_relu()
}

conv_layer = function(x, block, filters, num_blocks, strides){
  layers = x
  print(sprintf('new conv_layer:filters = %d, num_blocks = %d, strides = %d', filters, num_blocks, strides))
  for(i in 1:num_blocks) {
    layers = layers %>% block(., filters, strides)
    strides = 1 # only shrink image once at the first block, if strides > 1
  }
  layers
}
In [ ]:
inputs = layer_input(shape = dim(cifar10$train$x)[-1])

predictions = inputs %>%
  layer_conv_2d(filters = 16, strides = 1L, kernel_size = 3L, padding = 'same', use_bias = FALSE) %>%
  layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
  layer_activation_relu %>%
  conv_layer(., block = conv_block, filters = 16, num_blocks = 18, strides = 1) %>%  # layer 1
  conv_layer(., block = conv_block, filters = 32, num_blocks = 18, strides = 2) %>%  # layer 2
  conv_layer(., block = conv_block, filters = 64, num_blocks = 18, strides = 2) %>%  # layer 3
  layer_average_pooling_2d(pool_size = 8L) %>%
  layer_flatten() %>%
  layer_dense(activation = 'softmax', units = 10L)

resnet.cifar10 = keras_model(inputs = inputs, outputs = predictions)
[1] "new conv_layer:filters = 16, num_blocks = 18, strides = 1"
[1] "new conv_layer:filters = 32, num_blocks = 18, strides = 2"
[1] "new conv_layer:filters = 64, num_blocks = 18, strides = 2"
In [ ]:
tf$keras$utils$plot_model(resnet.cifar10, to_file='ResNet_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'ResNet_CIFAR10.png')
Image

EfficientNet B0¶

The implementation follows keras reference implementation. We scaled input to $(224, 224)$.

inputs = layer_input(shape = dim(cifar10$train$x)[-1])
efficientNetB0.cifar10 = tf$keras$applications$EfficientNetB0(input_tensor = inputs %>%
  layer_lambda(function(img) tf$image$resize(img, c(224L, 224L))), include_top=TRUE, weights=NULL, classes=10L)
efficientNetB0.cifar10$compile(loss='categorical_crossentropy',optimizer= 'adam', metrics=list('accuracy'))
mc = callback_model_checkpoint('efficientNetB0_CIFAR10.ckpt', monitor = 'val_accuracy', save_best_only = TRUE, mode = 'auto', save_weights_only = TRUE)

tf$keras$utils$plot_model(efficientNetB0.cifar10, to_file='efficientNetB0_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'efficientNetB0_CIFAR10.png')

Tan, M., & Le, Q. (2019, May). Efficientnet: Rethinking model scaling for convolutional neural networks. In International conference on machine learning (pp. 6105-6114). PMLR.

In [ ]:
c_k_i = initializer_variance_scaling(scale=2.0, mode='fan_out')
convBnAct = function(x, filters, kernel_size=3L, stride=1L, padding='valid', bn=TRUE, act='swish', use_bias=FALSE, groups=1L){
  #if(stride==2L & kernel_size == 3L) x = x %>%
  #  layer_zero_padding_2d(padding=list(list(0L, 1L),list(0L, 1L))) %>%
  #  layer_conv_2d(filters=filters, kernel_size=kernel_size, stride=stride, padding='valid', use_bias=use_bias, groups = groups, kernel_initializer = c_k_i)
  #else
  x = layer_conv_2d(x, filters=filters, kernel_size=kernel_size, stride=stride, padding=padding, use_bias=use_bias, groups = groups, kernel_initializer = c_k_i)
  if(bn) x = layer_batch_normalization(x, momentum = .9, epsilon = 1e-5)
  if(!is.null(act)) x %>% layer_activation(activation = act) else x
}

squeezeExcitation = function(x, filters){
  y = x %>%
    layer_global_average_pooling_2d(keepdims=TRUE) %>% # is this adaptive average pool 2d with output dim 1x2?
    layer_conv_2d(filters=filters, kernel_size=1L, activation='swish',kernel_initializer = c_k_i) %>%
    layer_conv_2d(filters=dim(x)[4], kernel_size=1L, activation='sigmoid',kernel_initializer = c_k_i)
  x * y
}

stochasticDepth = function(x, survival_prob = .8){
  tf$where(tf$random$uniform(list()) < survival_prob, x / survival_prob, 0)
}

#Mobile-net conv Block with expansion factor N
MBConvN = function(x, filters, kernel_size=3L, stride=1L, expansion_factor = 6L, reduction = 4L, survival_prob = .8){
  filters_in = dim(x)[4]
  residual = x
  intermediate_channels = filters_in * expansion_factor
  if(expansion_factor!=1L) x = convBnAct(x, filters = intermediate_channels, kernel_size=1L)

  x = x %>%
    convBnAct(filters = intermediate_channels, kernel_size = kernel_size, stride = stride, padding = 'same', groups = intermediate_channels) %>%
    squeezeExcitation(filters = floor(filters_in/reduction)) %>%
    convBnAct(filters = filters, kernel_size = 1L, act=NULL, use_bias=FALSE)

  if(stride==1L & filters_in==filters){
    residual + layer_dropout(x, rate = .2, noise_shape=c(py_none(), 1L, 1L, 1L))
  } else x
}

featureExtractor = function(x, width_mult, depth_mult, last_channel){
  kernels = c(3, 3, 5, 3, 5, 5, 3)
  expansions = c(1, 6, 6, 6, 6, 6, 6)
  scaled_num_channels = 4 * ceiling( c(16, 24, 40, 80, 112, 192, 320) * width_mult/4 )
  scaled_num_layers = c(1, 2, 2, 3, 3, 4, 1) * depth_mult
  strides = c(1, 2, 2, 2, 1, 2, 1)

  x = x %>% convBnAct(filters = 4 * ceiling(32 * width_mult/4), kernel_size = 3L, stride = 2L, padding = 'same')
  for(i in 1:length(scaled_num_layers)){
    for(j in 1:scaled_num_layers[i]){
      x = x %>% MBConvN(filters = scaled_num_channels[i], kernel_size = kernels[i], stride = if(j==1) strides[i] else 1L, expansion_factor = expansions[i])
    }
  }
  x %>% convBnAct(filters = last_channel, kernel_size = 1L, stride = 1L, padding='valid')
}
In [ ]:
inputs = layer_input(shape = dim(cifar10$train$x)[-1])

predictions = inputs %>%
  featureExtractor(width_mult = 1, depth_mult = 1, last_channel = 1280) %>%
  layer_global_average_pooling_2d() %>%
  layer_dropout(.2) %>%
  layer_dense(units = 10L, activation='softmax')
efficientNetB0.cifar10 = keras_model(inputs = inputs, outputs = predictions)
#efficientNetB0.cifar10
In [ ]:
tf$keras$utils$plot_model(efficientNetB0.cifar10, to_file='efficientNetB0_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'efficientNetB0_CIFAR10.png')
Image

WideResNet-28-10¶

Compared to the original architecture $[11]$ in $[13]$ the order of batch normalization, activation and convolution in residual block was changed from conv-BN-ReLU to BN-ReLU-conv. As the latter was shown to train faster and achieve better results we don't consider the original version.

The general structure of our residual networks is illustrated in table 1: it consists of an initial convolutional layer conv1 that is followed by 3 groups (each of size $N$ ) of residual blocks conv2, conv3 and conv4, followed by average pooling and final classification layer. The size of conv1 is fixed in all of our experiments, while the introduced widening factor $k$ scales the width of the residual blocks in the three groups conv2-4 (e.g., the original «basic» architecture is equivalent to $k=1$ ).

group name output size block type $=B(3,3)$
conv1 $32 \times 32$ $[3 \times 3,16]$
conv2 $32 \times 32$ $\left[\begin{array}{c}3 \times 3,16 \times \mathrm{k} \\3 \times 3,16 \times \mathrm{k}\end{array}\right] \times \mathrm{N}$
conv3 $16 \times 16$ $\left[\begin{array}{c}3 \times 3,32 \times \mathrm{k} \\3 \times 3,32 \times \mathrm{k}\end{array}\right] \times \mathrm{N}$
conv4 $8 \times 8$ $\left[\begin{array}{c}3 \times 3,64 \times \mathrm{k} \\3 \times 3,64 \times \mathrm{k}\end{array}\right] \times \mathrm{N}$
avg-pool $1 \times 1$ $[8 \times 8]$

As widening increases the number of parameters we would like to study ways of regularization. Residual networks already have batch normalization that provides a regularization effect, however it requires heavy data augmentation, which we would like to avoid, and it's not always possible. We add a dropout layer into each residual block between convolutions as shown in fig. 1(d) and after ReLU to perturb batch normalization in the next residual block and prevent it from overfitting. In very deep residual networks that should help deal with diminishing feature reuse problem enforcing learning in different residual blocks. ... We trained networks with dropout inserted into residual block between convolutions on all datasets. We used cross-validation to determine dropout probability values, 0.3 on CIFAR and 0.4 on SVHN. Also, we didn't have to increase number of training epochs compared to baseline networks without dropout.

In all our experiments we use SGD with Nesterov momentum and cross-entropy loss. The initial learning rate is set to 0.1 , weight decay to 0.0005 , dampening to 0 , momentum to 0.9 and minibatch size to 128. On CIFAR learning rate dropped by 0.2 at 60, 120 and 160 epochs and we train for total 200 epochs. On SVHN initial learning rate is set to 0.01 and we drop it at 80 and 120 epochs by 0.1 , training for total 160 epochs. Our implementation is based on Torch [6]. We use [21] to reduce memory footprints of all our networks. For ImageNet experiments we used fb. resnet. torch implementation [10]. Our code and models are available at https://github.com/szagoruyko/wide-residual-networks.

some implementations at github by the authors and others

  • https://github.com/szagoruyko/wide-residual-networks
  • https://github.com/titu1994/Wide-Residual-Networks/
  • https://github.com/paradoxysm/wideresnet

Zagoruyko, S., & Komodakis, N. (2016, January). Wide Residual Networks. In British Machine Vision Conference 2016. British Machine Vision Association.

In [ ]:
conv_block = function(x, base, k, dropout = 0, strides = c(1L, 1L), shortcut = FALSE){
  h = x %>%
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>%
    layer_activation_relu()
  out = h %>%
    layer_conv_2d(base*k, c(3L,3L), padding = 'same', strides=strides, kernel_initializer='he_normal', use_bias = FALSE) %>%
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>%
    layer_activation_relu() %>%
    layer_dropout(rate = dropout) %>%  # conv1
    layer_conv_2d(base*k, c(3L,3L), padding = 'same', kernel_initializer='he_normal', use_bias = FALSE)

  if(shortcut) skip = h %>%
    layer_conv_2d(base*k, c(1L, 1L), strides = strides, kernel_initializer='he_normal', use_bias = FALSE)
  else skip = x
  layer_add(list(out, skip))
}
In [ ]:
inputs = layer_input(shape = dim(cifar10$train$x)[-1])

k = 10

predictions = inputs %>%
  layer_conv_2d(filters = 16, strides = 1L, kernel_size = 3L, padding = 'same', use_bias = FALSE) %>%
  conv_block(., 16, k, shortcut=TRUE) %>% conv_block(., 16, k, dropout=.4) %>% conv_block(., 16, k, dropout=.3) %>%
  conv_block(., 32, k, strides = c(2L, 2L), shortcut=TRUE) %>% conv_block(., 32, k, dropout=.4) %>% conv_block(., 32, k, dropout=.4) %>%
  conv_block(., 64, k, strides = c(2L, 2L), shortcut=TRUE) %>% conv_block(., 64, k, dropout=.4) %>% conv_block(., 64, k, dropout=.4) %>%
  layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>% layer_activation_relu() %>% layer_average_pooling_2d(pool_size = 8L) %>% layer_flatten() %>%
  layer_dense(activation = 'softmax', units = 10L)

wrn_28_10.cifar10 = keras_model(inputs = inputs, outputs = predictions)
In [ ]:
tf$keras$utils$plot_model(wrn_28_10.cifar10, to_file='WRN_28_10_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'WRN_28_10_CIFAR10.png')
Image

MLP-Mixer¶

Tolstikhin, Ilya O., Neil Houlsby, Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Thomas Unterthiner, Jessica Yung et al. "Mlp-mixer: An all-mlp architecture for vision." Advances in neural information processing systems 34 (2021): 24261-24272.

In [ ]:
layer_patch = Layer(
  classname = 'Patch',
  inherit = tf$keras$layers$Layer,
  initialize = function(patch_size, projection_dim, ...) {
    super$initialize()
    #print('initialize')
    self$patch_size = patch_size
    self$projection_dim = projection_dim
  },

  build = function(input_shape) {
    # assert: image is square, patch is square
    #print('build')
    self$num_patches =  as.integer(floor(input_shape[[2]]/self$patch_size)^2)
    self$projection = tf$keras$layers$Dense(units = self$projection_dim)
    self$position_embedding = tf$keras$layers$Embedding(input_dim = self$num_patches, output_dim = self$projection_dim)
    self$built=TRUE
    #print('end of build')
    return()
  },

  call = function(images) {
    #print('call')
    patches = tf$image$extract_patches(
      images = images,
      sizes = c(1L, self$patch_size, self$patch_size, 1L),
      strides = c(1L, self$patch_size, self$patch_size, 1L),
      rates = c(1L, 1L, 1L, 1L),
      padding = 'VALID'
    )
    patches = tf$reshape(
      patches,
      shape=c(tf$shape(images)[1L], -1L, tf$reduce_prod(tf$shape(patches)[4L]))
    )
    positions = tf$range(start=0, limit = self$num_patches, delta = 1)
    self$projection(patches) + self$position_embedding(positions)
  }
)

mlp_mixer_block = function(x){
  num_patches = dim(x)[2]
  projection_dim = dim(x)[3]
  mlp1 = x %>%
    layer_layer_normalization(epsilon = 1e-6) %>%
    tf$linalg$matrix_transpose(.) %>%
    layer_dense(units = num_patches, activation='gelu') %>%
    layer_dense(units = num_patches) %>%
    layer_dropout(rate = .2) %>%
    tf$linalg$matrix_transpose(.)
  x = layer_add(list(x, mlp1))
  mlp2 = x %>%
    layer_layer_normalization(epsilon = 1e-6) %>%
    layer_dense(units = projection_dim, activation='gelu') %>%
    layer_dense(units = projection_dim) %>%
    layer_dropout(rate = .2)
  layer_add(list(x, mlp2))
}
In [ ]:
inputs = layer_input(shape = dim(cifar10$train$x)[-1])

predictions = inputs %>%
  layer_resizing(64L, 64L) %>%
  layer_patch(patch_size = 8L, projection_dim = 256L) %>%
  mlp_mixer_block %>%
  mlp_mixer_block %>%
  mlp_mixer_block %>%
  mlp_mixer_block %>%
  layer_global_average_pooling_1d %>%
  layer_dropout(rate = .2) %>%
  layer_dense(units=10L, activation = 'softmax')

mlp_mixer.cifar10 = keras_model(inputs = inputs, outputs = predictions)
mlp_mixer.cifar10 %>% compile(loss='categorical_crossentropy',optimizer= 'adam',metrics=list('accuracy'))
In [ ]:
tf$keras$utils$plot_model(mlp_mixer.cifar10, to_file='MLPMixer_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'MLPMixer_CIFAR10.png')
Image

Bayesian NN architectures for Message Passing¶

Bayesian DenseNet-BC¶

In [ ]:
my_divergence_fn = function(q, p, ignore=NULL){
  tfp$distributions$kl_divergence(q, p) / 50000 /500
}

conv_bottleneckEP = function(x, growth_rate){
  conv = x %>%
    #layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>% # bn1
    layer_activation_Bayesian('relu') %>%
    layer_conv_2d_Bayesian(filters = 4 * growth_rate, kernel_size = c(1L,1L), use_bias = FALSE, activation=NULL) %>% # conv1
    #layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn2
    layer_activation_Bayesian('relu') %>%
    layer_conv_2d_Bayesian(filters = growth_rate, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE, activation=NULL)  # conv2
    # channel dimension 4 for channel last and 2 for channel first, keras default is channel last
  layer_lambda(list(x, conv), function(x_conv){
    c(x, conv) %<-% x_conv
    c(c(conv[[1]], x[[1]]) %>% layer_concatenate(axis=-1L),
      c(conv[[2]], x[[2]]) %>% layer_concatenate(axis=-1L) )
  })
}

# output width and height are half of the input, output channels = reduction * (input channels + growth rate * number of blocks)
dense_layerEP = function(x, block, growth_rate, num_blocks){
  layers = x
  for(i in 1:num_blocks) layers = layers %>% block(., growth_rate)
  print(sprintf('dense_layer: growth_rate = %d, num_blocks = %d', growth_rate, num_blocks))
  print(layers)
  layers
}

transitionEP = function(x, reduction = .5){
  x %>%
    #layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
    layer_activation_Bayesian('relu') %>%
    layer_conv_2d_Bayesian(filters = floor(reduction*dim(x[[1]])[4]), kernel_size = c(1L,1L), use_bias = FALSE, activation=NULL) %>%
    layer_average_pooling_2d_Bayesian(pool_size = 2L)
}
In [ ]:
inputsEP = list(layer_input(shape = dim(cifar10$train$x)[-1]),layer_input(shape = dim(cifar10$train$x)[-1]))

g.r. = 12L
predictionsEP = inputsEP %>%
  layer_conv_2d_Bayesian(filters = 2*g.r., kernel_size = 3L, padding = 'same', use_bias = FALSE, activation = NULL) %>%
  dense_layerEP(., block = conv_bottleneckEP, growth_rate = g.r., num_blocks = 16) %>% transitionEP(.) %>% # layer 1
  dense_layerEP(., block = conv_bottleneckEP, growth_rate = g.r., num_blocks = 16) %>% transitionEP(.) %>%  # layer 2
  dense_layerEP(., block = conv_bottleneckEP, growth_rate = g.r., num_blocks = 16) %>% transitionEP(.) %>%  # layer 3
  #layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
  layer_activation_Bayesian('relu') %>%
  layer_average_pooling_2d_Bayesian(pool_size = 4L) %>%
  layer_lambda(function(xP){list(tf$reshape(xP[[1]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))), tf$reshape(xP[[2]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))))}) %>% #layer_flatten() %>%
  layer_dense_Bayesian(activation = NULL, units = 10L) %>%
  layer_lambda(f = function(mv){
    tf$reduce_mean(tf$nn$softmax(tfp$distributions$MultivariateNormalDiag(loc = mv[[1]], scale_diag=mv[[2]]^.5)$sample(50L), axis=-1L), axis=0L)
  })

densenet.cifar10EP = keras_model(inputs = inputsEP, outputs = predictionsEP)
[1] "dense_layer: growth_rate = 12, num_blocks = 16"
[[1]]
KerasTensor(type_spec=TensorSpec(shape=(None, 32, 32, 216), dtype=tf.float32, name=None), name='lambda_92/concatenate/concat:0', description="created by layer 'lambda_92'")

[[2]]
KerasTensor(type_spec=TensorSpec(shape=(None, 32, 32, 216), dtype=tf.float32, name=None), name='lambda_92/concatenate_1/concat:0', description="created by layer 'lambda_92'")

[1] "dense_layer: growth_rate = 12, num_blocks = 16"
[[1]]
KerasTensor(type_spec=TensorSpec(shape=(None, 16, 16, 300), dtype=tf.float32, name=None), name='lambda_108/concatenate/concat:0', description="created by layer 'lambda_108'")

[[2]]
KerasTensor(type_spec=TensorSpec(shape=(None, 16, 16, 300), dtype=tf.float32, name=None), name='lambda_108/concatenate_1/concat:0', description="created by layer 'lambda_108'")

[1] "dense_layer: growth_rate = 12, num_blocks = 16"
[[1]]
KerasTensor(type_spec=TensorSpec(shape=(None, 8, 8, 342), dtype=tf.float32, name=None), name='lambda_124/concatenate/concat:0', description="created by layer 'lambda_124'")

[[2]]
KerasTensor(type_spec=TensorSpec(shape=(None, 8, 8, 342), dtype=tf.float32, name=None), name='lambda_124/concatenate_1/concat:0', description="created by layer 'lambda_124'")

In [ ]:
tf$keras$utils$plot_model(densenet.cifar10EP, to_file='Bayesian_DenseNetBC_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'Bayesian_DenseNetBC_CIFAR10.png')
Image
In [ ]:
densenet.cifar10EP = keras_model(inputs = inputsEP, outputs = predictionsEP)
lr = tf$keras$optimizers$schedules$CosineDecayRestarts(initial_learning_rate = 1, first_decay_steps = 50000/100*100)
opt = tf$keras$mixed_precision$LossScaleOptimizer(tf$keras$optimizers$experimental$SGD(learning_rate = 1, momentum=1-100/50000, nesterov=TRUE, global_clipnorm = 20))
#opt = tf$keras$optimizers$experimental$SGD(learning_rate = lr, momentum= 1-200/50000, nesterov=TRUE, global_clipnorm = 20) # if use `tf$keras$mixed_precision$set_global_policy('mixed_float16')`
densenet.cifar10EP$compile(loss='categorical_crossentropy',optimizer= opt, metrics=list('accuracy'))
mc = callback_model_checkpoint('Bayesian_DenseNet_CIFAR10-{epoch:03d}_{val_loss:.2f}_{val_accuracy:.4f}.ckpt', monitor = 'val_accuracy', save_best_only = FALSE, mode = 'auto', save_weights_only = TRUE, save_freq = 'epoch')
historyEP = densenet.cifar10EP %>% fit(
  cifar10.datagen$flow(list(cifar10$train$x, cifar10$train$x*0 + 1e-6), to_categorical(cifar10$train$y, num_classes=10L), batch_size = 100L,  shuffle=TRUE),
  epochs = 100 , #* (1 + 2 +4),
  validation_data = cifar10.datagen$flow(list(cifar10$test$x, cifar10$test$x*0 + 1e-6), to_categorical(cifar10$test$y, num_classes=10L), batch_size = 200L),
  callbacks = list(mc, callback_terminate_on_naan(), callback_csv_logger(filename = 'Bayesian_DenseNet_CIFAR10.txt')))
In [ ]:
print(data.frame(historyEP$metrics))
plot(historyEP)
           loss accuracy  val_loss val_accuracy
1   1.507544398  0.44768 1.2507046       0.5612
2   0.991522789  0.64886 0.9443011       0.6739
3   0.740942478  0.74084 0.7780184       0.7346
4   0.612260640  0.78602 0.6886450       0.7636
5   0.536301196  0.81394 0.6576619       0.7827
6   0.482742935  0.83092 0.5292346       0.8158
7   0.436886132  0.84730 0.5240674       0.8228
8   0.400461048  0.86060 0.5130056       0.8269
9   0.376282543  0.86894 0.4926983       0.8323
10  0.350261360  0.87790 0.4437283       0.8458
11  0.325476199  0.88632 0.4471522       0.8512
12  0.308074325  0.89242 0.3922808       0.8646
13  0.289072394  0.89988 0.4368181       0.8588
14  0.268653154  0.90690 0.4473700       0.8478
15  0.264465302  0.90738 0.3601842       0.8793
16  0.241475150  0.91528 0.4337616       0.8576
17  0.240748286  0.91622 0.3829522       0.8738
18  0.226144493  0.91936 0.3759227       0.8750
19  0.211735114  0.92444 0.3631086       0.8839
20  0.203297466  0.92838 0.3736733       0.8778
21  0.195500612  0.93112 0.3687617       0.8828
22  0.188374653  0.93408 0.3488044       0.8862
23  0.176909372  0.93760 0.3591793       0.8847
24  0.165337682  0.94158 0.3468764       0.8890
25  0.163461506  0.94152 0.4085495       0.8701
26  0.154732496  0.94490 0.3538958       0.8904
27  0.146540150  0.94788 0.3708380       0.8867
28  0.137216091  0.95126 0.3410468       0.8954
29  0.136211067  0.95266 0.3348385       0.8994
30  0.130257204  0.95428 0.3426736       0.8979
31  0.123988517  0.95580 0.3463230       0.8968
32  0.124184981  0.95624 0.3239390       0.8974
33  0.112449415  0.96014 0.3321945       0.8988
34  0.106419988  0.96166 0.3460067       0.8952
35  0.102975868  0.96316 0.3502261       0.8966
36  0.096118785  0.96630 0.3444697       0.9002
37  0.096012004  0.96578 0.3461250       0.9011
38  0.092155643  0.96708 0.3436781       0.9029
39  0.089766242  0.96746 0.3207698       0.9059
40  0.080002733  0.97040 0.3379705       0.9060
41  0.077242859  0.97284 0.3454384       0.9043
42  0.075790472  0.97320 0.3285445       0.9069
43  0.072758012  0.97384 0.3520682       0.9058
44  0.066192716  0.97618 0.3450360       0.9033
45  0.066513471  0.97668 0.3412762       0.9071
46  0.062482961  0.97754 0.3346150       0.9068
47  0.056999233  0.97952 0.3461328       0.9122
48  0.053404722  0.98166 0.3443820       0.9072
49  0.054178670  0.98082 0.3488195       0.9090
50  0.048612613  0.98312 0.3341542       0.9093
51  0.045235343  0.98398 0.3151713       0.9159
52  0.042834073  0.98466 0.3431740       0.9116
53  0.040008824  0.98602 0.3264747       0.9134
54  0.037579529  0.98694 0.3391601       0.9130
55  0.037912995  0.98660 0.3395627       0.9142
56  0.033098727  0.98802 0.3570186       0.9143
57  0.033382103  0.98832 0.3471235       0.9137
58  0.031156521  0.98944 0.3390150       0.9171
59  0.027277537  0.99058 0.3605517       0.9136
60  0.024240302  0.99174 0.3551328       0.9167
61  0.025279714  0.99062 0.3415895       0.9172
62  0.021866318  0.99310 0.3632956       0.9138
63  0.017950036  0.99420 0.3465253       0.9197
64  0.016946517  0.99478 0.3460971       0.9181
65  0.017682793  0.99402 0.3401641       0.9201
66  0.016459063  0.99428 0.3403973       0.9194
67  0.015079840  0.99514 0.3495630       0.9186
68  0.013104274  0.99582 0.3509070       0.9170
69  0.013004387  0.99574 0.3425159       0.9215
70  0.012999629  0.99578 0.3469994       0.9202
71  0.012204343  0.99622 0.3681124       0.9188
72  0.010561449  0.99686 0.3632663       0.9216
73  0.010643934  0.99692 0.3487050       0.9224
74  0.009710396  0.99686 0.3552284       0.9206
75  0.009160192  0.99718 0.3517626       0.9245
76  0.008669264  0.99732 0.3502497       0.9210
77  0.008502587  0.99752 0.3495335       0.9249
78  0.007657965  0.99748 0.3379188       0.9254
79  0.006842029  0.99808 0.3321439       0.9261
80  0.007568012  0.99776 0.3456600       0.9210
81  0.006723642  0.99810 0.3463190       0.9230
82  0.006581388  0.99806 0.3491128       0.9238
83  0.006428984  0.99808 0.3509857       0.9234
84  0.006708242  0.99788 0.3510391       0.9259
85  0.006418953  0.99806 0.3576565       0.9235
86  0.005776848  0.99836 0.3454177       0.9255
87  0.005551881  0.99860 0.3389569       0.9263
88  0.005839402  0.99826 0.3447225       0.9235
89  0.005101785  0.99872 0.3424715       0.9250
90  0.005926417  0.99826 0.3483006       0.9235
91  0.005643965  0.99836 0.3434032       0.9282
92  0.004899295  0.99872 0.3490171       0.9240
93  0.005357474  0.99846 0.3457305       0.9259
94  0.004880808  0.99860 0.3442807       0.9244
95  0.005065570  0.99850 0.3454005       0.9250
96  0.005079482  0.99864 0.3439325       0.9253
97  0.005028059  0.99868 0.3427972       0.9254
98  0.004800315  0.99864 0.3581291       0.9262
99  0.005389702  0.99838 0.3636788       0.9205
100 0.005091061  0.99848 0.3469904       0.9242
Image
In [ ]:
# cosine decay restarts schedule
lr = tf$keras$optimizers$schedules$CosineDecayRestarts(initial_learning_rate = 1, first_decay_steps = 50000/100*100)
plot( (1:(50000/100*100*(1+2+4)))/(50000/100), lr(1:(50000/100*100*(1+2+4))), type='l', xlab='epoch', ylab='lr schedule', main='cosine decay annealing with warm restart')
abline(v=100*cumsum(c(1,2,4)), h = 0, col='red')
Image

Bayesian ResNet¶

In [ ]:
conv_blockEP = function(xP, filters, strides = 1L){
  #print('x')
  #print(xP)
  out = xP %>%
    layer_conv_2d_Bayesian(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>%
    #layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>% # bn1, by default channel is last dimension
    layer_activation_Bayesian('relu') %>%
    layer_conv_2d_Bayesian(filters = filters, strides = 1L, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) #%>% # conv2
    #layer_batch_normalization(momentum = .9, epsilon = 1e-5)  # bn2
  #print('out')
  #print(out)
  if(dim(xP[[1]])[4]== filters & strides==1L) shortcut = xP # channel dimension 4 for channel last and 2 for channel first
  else shortcut =  xP %>%
      layer_conv_2d_Bayesian(filters = filters, strides = rep(strides,2), kernel_size = 1L, use_bias = FALSE, name = sprintf('shortcut/conv2d_%d',keras$backend$get_uid(prefix='s/c'))) #%>%
      #layer_batch_normalization(momentum = .9, epsilon = 1e-5, name=sprintf('shortcut/bn_%d',keras$backend$get_uid(prefix='s/b')))
  #print('shortcut')
  #print(shortcut)
  layer_lambda(list(out, shortcut), function(out_shortcut){
    c(out, shortcut) %<-% out_shortcut
    list(out[[1]] + shortcut[[1]], out[[2]]+shortcut[[2]]) %>% layer_activation_Bayesian('relu')
  })
}

conv_bottleneckEP = function(xP, filters, strides = 1L){
  expansion = 4L
  #print('x')
  #print(xP)
  out = xP %>%
    layer_conv_2d_Bayesian(filters = filters, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>%
    #layer_batch_normalization(axis = 1L, momentum = .9, epsilon = 1e-5)  %>% # bn1
    layer_activation_Bayesian('relu') %>%
    layer_conv_2d_Bayesian(filters = filters, strides = rep(strides,2), kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>% # conv2
    #layer_batch_normalization(axis = 1L, momentum = .9, epsilon = 1e-5) %>% # bn2
    layer_activation_Bayesian('relu') %>%
    layer_conv_2d_Bayesian(filters = expansion * filters, kernel_size = c(1L,1L), padding = 'same', use_bias = FALSE, activation = NULL) #%>% # conv3
    #layer_batch_normalization(axis = 1L, momentum = .9, epsilon = 1e-5)  # bn3
  #print("out")
  #print(out)
  if(dim(xP[[1]])[4]== expansion * filters & strides==1L) shortcut = xP # channel dimension 4 for channel last and 2 for channel first
  else shortcut =  xP %>%
    layer_conv_2d_Bayesian(filters = expansion * filters, strides = rep(strides,2), kernel_size = 1L, use_bias = FALSE, name = sprintf('shortcut/conv2d_%d',keras$backend$get_uid(prefix='s/c'))) #%>%
    #layer_batch_normalization(axis = 1L, momentum = .9, epsilon = 1e-5, name=sprintf('shortcut/bn_%d',keras$backend$get_uid(prefix='s/b')))
  #print('shortcut')
  #print(shortcut)
  layer_lambda(list(out, shortcut), function(out_shortcut){
    c(out, shortcut) %<-% out_shortcut
    list(out[[1]] + shortcut[[1]], out[[2]]+shortcut[[2]]) %>% layer_activation_Bayesian('relu')
  })
}

conv_layerEP = function(xP, block, filters, num_blocks, strides){
  layers = xP
  print(sprintf('new conv_layer:filters = %d, num_blocks = %d, strides = %d', filters, num_blocks, strides))
  for(i in 1:num_blocks) {
    layers = layers %>% block(., filters, strides)
    strides = as.integer(1L) # only shrink image once at the first block, if strides > 1
  }
  layers
}
In [ ]:
inputsEP = list(layer_input(shape = dim(cifar10$train$x)[-1]),layer_input(shape = dim(cifar10$train$x)[-1]))

predictionsEP = inputsEP %>%
  layer_conv_2d_Bayesian(filters = 64, strides = 1L, kernel_size = 3L, padding = 'same', use_bias = FALSE) %>%
  #layer_batch_normalization(axis = 1L, momentum = .9, epsilon = 1e-5) %>%
  layer_activation_Bayesian('relu') %>%
  conv_layerEP(., block = conv_blockEP, filters = 16L, num_blocks = 18, strides = 1L) %>% # layer 1
  conv_layerEP(., block = conv_blockEP, filters = 32L, num_blocks = 18, strides = 2L) %>%  # layer 2
  conv_layerEP(., block = conv_blockEP, filters = 64L, num_blocks = 18, strides = 2L) %>%  # layer 3
  layer_average_pooling_2d_Bayesian(pool_size = 8L) %>%
  layer_lambda(function(xP){list(tf$reshape(xP[[1]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))), tf$reshape(xP[[2]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))))}) %>% #layer_flatten() %>%
  layer_dense_Bayesian(units = 10L, activation = NULL) %>%
  layer_lambda(f = function(mv){
    tf$reduce_mean(tf$nn$softmax(tfp$distributions$MultivariateNormalDiag(loc = mv[[1]], scale_diag=mv[[2]]^.5)$sample(50L), axis=-1L), axis=0L)
  })

resnet.cifar10EP = keras_model(inputs = inputsEP, outputs = predictionsEP)
resnet.cifar10EP$compile(loss='categorical_crossentropy',optimizer= 'adam',metrics=list('accuracy'))
[1] "new conv_layer:filters = 16, num_blocks = 18, strides = 1"
[1] "new conv_layer:filters = 32, num_blocks = 18, strides = 2"
[1] "new conv_layer:filters = 64, num_blocks = 18, strides = 2"
In [ ]:
tf$keras$utils$plot_model(resnet.cifar10EP, to_file='Bayesian_ResNet_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'Bayesian_ResNet_CIFAR10.png')
Image

Bayesian EfficientNet B0¶

In [ ]:
c_k_i = initializer_variance_scaling(scale=2.0, mode='fan_out')
convBnAct = function(xP, filters, kernel_size=3L, strides=1L, padding='valid', bn=TRUE, act='swish', use_bias=FALSE, groups=1L){
  #if(stride==2L & kernel_size == 3L) x = x %>%
  #  layer_zero_padding_2d(padding=list(list(0L, 1L),list(0L, 1L))) %>%
  #  layer_conv_2d(filters=filters, kernel_size=kernel_size, stride=stride, padding='valid', use_bias=use_bias, groups = groups, kernel_initializer = c_k_i)
  #else
  xP = layer_conv_2d_Bayesian(xP, filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, use_bias=use_bias, groups = groups, kernel_initializer = c_k_i)
  #if(bn) x = layer_batch_normalization(x, momentum = .9, epsilon = 1e-5)
  if(!is.null(act)) xP %>% layer_activation_Bayesian(activation = act) else xP
}

squeezeExcitation = function(x, filters){
  y = x %>%
    layer_average_pooling_2d_Bayesian(pool_size = dim(x[[1]])[2:3]) %>%  ## ??? layer_global_average_pooling_2d(keepdims=TRUE) %>% # is this adaptive average pool 2d with output dim 1x2?
    layer_conv_2d_Bayesian(filters=filters, kernel_size=1L, activation='swish', kernel_initializer = c_k_i) %>%
    layer_conv_2d_Bayesian(filters=dim(x[[1]])[4], kernel_size=1L, activation='sigmoid', kernel_initializer = c_k_i)
  #list(x[[1]]*y[[1]], x[[2]]*y[[1]]^2 + x[[1]]^2 * y[[2]])#x * y
  layer_lambda(list(x,y), function(xy){
    c(x,y) %<-% xy
    list(x[[1]]*y[[1]], x[[2]]*y[[1]]^2 + x[[1]]^2 * y[[2]])
  })#x * y
}

stochasticDepth = function(x, survival_prob = .8){
  tf$where(tf$random$uniform(list()) < survival_prob, x / survival_prob, 0)
}

#Mobile-net conv Block with expansion factor N
MBConvN = function(xP, filters, kernel_size=3L, strides=1L, expansion_factor = 6L, reduction = 4L, survival_prob = .8){
  filters_in = dim(xP[[1]])[4]
  residual = xP
  intermediate_channels = filters_in * expansion_factor
  if(expansion_factor!=1L) xP = convBnAct(xP, filters = intermediate_channels, kernel_size=1L)
  xP = xP %>%
    convBnAct(filters = intermediate_channels, kernel_size = kernel_size, strides = strides, padding = 'same', groups = intermediate_channels) %>%
    squeezeExcitation(filters = floor(filters_in/reduction)) %>%
    convBnAct(filters = filters, kernel_size = 1L, act=NULL, use_bias=FALSE)
  if(strides==1L & filters_in==filters)
    layer_lambda(list(xP, residual), function(xP_residual){
      c(xP, residual) %<-% xP_residual
      list(residual[[1]] + xP[[1]], residual[[2]] + xP[[2]]) #layer_dropout(xP, rate = .2, noise_shape=c(py_none(), 1L, 1L, 1L))
    })
  else xP
}

featureExtractor = function(x, width_mult, depth_mult, last_channel){
  kernels = c(3L, 3L, 5L, 3L, 5L, 5L, 3L)
  expansions = c(1L, 6L, 6L, 6L, 6L, 6L, 6L)
  scaled_num_channels = as.integer(4 * ceiling( c(16, 24, 40, 80, 112, 192, 320) * width_mult/4 ))
  scaled_num_layers = as.integer(c(1, 2, 2, 3, 3, 4, 1) * depth_mult)
  strides = c(1L, 2L, 2L, 2L, 1L, 2L, 1L)

  x = x %>% convBnAct(filters = 4 * ceiling(32 * width_mult/4), kernel_size = 3L, stride = 2L, padding = 'same')
  for(i in 1:length(scaled_num_layers)){
    for(j in 1:scaled_num_layers[i]){
      x = x %>% MBConvN(filters = scaled_num_channels[i], kernel_size = kernels[i], strides = if(j==1) strides[i] else 1L, expansion_factor = expansions[i])
    }
  }
  x %>% convBnAct(filters = last_channel, kernel_size = 1L, stride = 1L, padding='valid')
}
In [ ]:
inputsEP = list(layer_input(shape = dim(cifar10$train$x)[-1]),layer_input(shape = dim(cifar10$train$x)[-1]))

predictionsEP = inputsEP %>%
  featureExtractor(width_mult = 1, depth_mult = 1, last_channel = 1280) %>%
  layer_lambda(function(xP){list(tf$reshape(xP[[1]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))), tf$reshape(xP[[2]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))))}) %>% #layer_flatten() %>%
  layer_dense_Bayesian(units = 10L, activation = NULL) %>%
  layer_lambda(f = function(mv){
    tf$reduce_mean(tf$nn$softmax(tfp$distributions$MultivariateNormalDiag(loc = mv[[1]], scale_diag=mv[[2]]^.5)$sample(50L), axis=-1L), axis=0L)
  })
efficientNetB0.cifar10EP = keras_model(inputs = inputsEP, outputs = predictionsEP)
In [ ]:
tf$keras$utils$plot_model(efficientNetB0.cifar10EP, to_file='Bayesian_EfficientNetB0_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'Bayesian_EfficientNetB0_CIFAR10.png')
Image

Bayesian WideResNet-28-10¶

In [ ]:
conv_blockEP = function(xP, base, k, strides = c(1L, 1L), shortcut = FALSE){
  h = xP %>%
    #layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>%
    layer_activation_Bayesian('relu')
  out = h %>%
    layer_conv_2d_Bayesian(filters = base*k, kernel_size = c(3L,3L), padding = 'same', strides=strides, kernel_initializer='he_normal', use_bias = FALSE) %>%
    #layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>%
    layer_activation_Bayesian('relu') %>%
    #layer_dropout(rate = dropout) %>%  # conv1
    layer_conv_2d_Bayesian(filters = base*k, kernel_size = c(3L,3L), padding = 'same', kernel_initializer='he_normal', use_bias = FALSE)

  if(shortcut) skip = h %>%
    layer_conv_2d_Bayesian(filters = base*k, kernel_size = c(1L, 1L), strides = strides, kernel_initializer='he_normal', use_bias = FALSE)
  else skip = xP
  layer_lambda(list(out, skip), function(out_skip){
    c(out, skip) %<-% out_skip
    list(out[[1]] + skip[[1]], out[[2]]+skip[[2]]) #layer_add(list(out, skip))
  })
}
In [ ]:
inputsEP = list(layer_input(shape = dim(cifar10$train$x)[-1]),layer_input(shape = dim(cifar10$train$x)[-1]))

k = 10L

predictionsEP = inputsEP %>%
  layer_conv_2d_Bayesian(filters = 16, strides = 1L, kernel_size = 3L, padding = 'same', use_bias = FALSE) %>%
  conv_blockEP(., 16, k, shortcut=TRUE) %>% conv_blockEP(., 16, k) %>% conv_blockEP(., 16, k) %>%
  conv_blockEP(., 32, k, strides = c(2L, 2L), shortcut=TRUE) %>% conv_blockEP(., 32, k) %>% conv_blockEP(., 32, k) %>%
  conv_blockEP(., 64, k, strides = c(2L, 2L), shortcut=TRUE) %>% conv_blockEP(., 64, k) %>% conv_blockEP(., 64, k) %>%
  #layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>% layer_activation_relu() %>% layer_average_pooling_2d(pool_size = 8L) %>% layer_flatten() %>%
  layer_activation_Bayesian('relu') %>%
  layer_average_pooling_2d_Bayesian(pool_size = 4L) %>%
  layer_lambda(function(xP){list(tf$reshape(xP[[1]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))), tf$reshape(xP[[2]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))))}) %>% #layer_flatten() %>%
  layer_dense_Bayesian(units = 10L, activation = NULL) %>%
  layer_lambda(f = function(mv){
    tf$reduce_mean(tf$nn$softmax(tfp$distributions$MultivariateNormalDiag(loc = mv[[1]], scale_diag=mv[[2]]^.5)$sample(50L), axis=-1L), axis=0L)
  })

wrn_28_10.cifar10EP = keras_model(inputs = inputsEP, outputs = predictionsEP)
In [ ]:
tf$keras$utils$plot_model(wrn_28_10.cifar10EP, to_file='Bayesian_WRN-28-10_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'Bayesian_WRN-28-10_CIFAR10.png')
Image

MLP-Mixer¶

Bayesian NN constructed from tfp.layers¶

We bring tfp.layers to competitive performance by

  • sampling multiple non-Bayesian nerual networks (AverageCase),
  • wraping around call function to implement natural gradient, and
  • considering adversarial training ('WorseCase`).
In [ ]:
AverageCase(tf$keras$Model) %py_class% {

  test_step = function(self, data){
    x = data[[1]]
    y = data[[2]]
    if(length(data)==3) sample_weight = data[[3]] else sample_weight = NULL

    #y_pred = self$call(x)
    y_pred = tf$reduce_mean(tf$stack(lapply(1:8,function(n)self$call(x)), axis=0L), axis=0L)
    self$compiled_loss(y, y_pred, regularization_losses=self$losses)
    self$compiled_metrics$update_state(y, y_pred, sample_weight=sample_weight)
    mapply(assign, sapply(self$metrics, function(m) m$name), lapply(self$metrics, function(m) m$result()))
  }

  train_step = function(self, data) {
    x = data[[1]]
    y = data[[2]]
    if(length(data)==3) sample_weight = data[[3]] else sample_weight = NULL

    with(tf$GradientTape() %as% tape, {
      #y_pred = tf$reduce_logsumexp(tf$stack(lapply(1:4,function(n)self$call(x)), axis=0L), axis=0L)
      y_pred = tf$reduce_mean(tf$stack(lapply(1:8,function(n)self$call(x)), axis=0L), axis=0L)
      loss = self$compiled_loss(y, y_pred, regularization_losses=self$losses)
    })
    gradients = tape$gradient(loss, self$trainable_variables)
    self$optimizer$apply_gradients(mapply(list, gradients, self$trainable_variables, SIMPLIFY =FALSE))
    self$compiled_metrics$update_state(y, y_pred, sample_weight=sample_weight)
    mapply(assign, sapply(self$metrics, function(m) m$name), lapply(self$metrics, function(m) m$result()))
  }
}

WorstCase(tf$keras$Model) %py_class% {

  test_step = function(self, data){
    x = data[[1]]
    y = data[[2]]
    if(length(data)==3) sample_weight = data[[3]] else sample_weight = NULL

    #y_pred = self$call(x)
    y_pred = tf$reduce_mean(tf$stack(lapply(1:8,function(n)self$call(x)), axis=0L), axis=0L)
    self$compiled_loss(y, y_pred, regularization_losses=self$losses)
    self$compiled_metrics$update_state(y, y_pred, sample_weight=sample_weight)
    mapply(assign, sapply(self$metrics, function(m) m$name), lapply(self$metrics, function(m) m$result()))
  }

  train_step = function(self, data) {
    x = data[[1]]
    y = data[[2]]
    if(length(data)==3) sample_weight = data[[3]] else sample_weight = NULL

    with(tf$GradientTape() %as% tape, {
      y_pred = tf$stack(lapply(1:8,function(n)self$call(x)), axis=0L)
      self$compiled_loss(y, tf$reduce_mean(y_pred, axis=0L), regularization_losses=self$losses) # loss to display

      loss_ = tf$losses$categorical_crossentropy(tf$einsum('b,ij->bij',tf$ones_like(y_pred[,1,1]), y), y_pred)
      loss_ = tf$reduce_max(tf$reduce_mean(loss_, axis=1L))
      loss = loss_ + self$losses # loss to optimize
    })
    gradients = tape$gradient(loss, self$trainable_variables)
    self$optimizer$apply_gradients(mapply(list, gradients, self$trainable_variables, SIMPLIFY =FALSE))
    self$compiled_metrics$update_state(y, tf$reduce_mean(y_pred, axis=0L), sample_weight=sample_weight)
    mapply(assign, sapply(self$metrics, function(m) m$name), lapply(self$metrics, function(m) m$result()))

  }
}
In [ ]:
my.posterior = tfp$layers$default_mean_field_normal_fn(untransformed_scale_constraint = function(w){
  tf$clip_by_value(w, tfp$math$softplus_inverse(tf$experimental$numpy$finfo('float32')$resolution*10.0), 1.0)
})

layer_dense_reparameterization_NG = Layer(
  classname = 'ReparameterizationNGDense',
  inherit = tf$keras$layers$Layer,
  initialize = function(...) {
    super$initialize()
    self$denseReparameterization = tfp$layers$DenseReparameterization(...)
  },

  build = function(input_shape) {
    self$denseReparameterization$build(input_shape)

    self$`__forward__` = tf$custom_gradient(function(inputs){

      with(tf$GradientTape(persistent=TRUE) %as% tape, {
        tape$watch(inputs)
        outputs = self$denseReparameterization$call(inputs)
      })

      grad_fn = function(upstream, variables){
        c(dx, dv) %<-% tape$gradient(outputs, list(inputs, variables), upstream)
        dv2 = lapply(1:length(variables), function(n){
          switch(
            gsub('.+/(bias_posterior_loc|kernel_posterior_loc|kernel_posterior_untransformed_scale):[0-9]+','\\1', variables[[n]]$name),
            bias_posterior_loc = dv[[n]],
            kernel_posterior_loc = dv[[n]] * tf$math$softplus(self$denseReparameterization$variables[[2]])^2,
            kernel_posterior_untransformed_scale = dv[[n]] * tf$math$softplus(self$denseReparameterization$variables[[2]])^4 / tf$math$sigmoid(self$denseReparameterization$variables[[2]])^2  * 2)
        })
        list(dx, dv2)
      }
      list(outputs, grad_fn)
    })
    # self$`__forward__` = function(inputs){
    #   self$denseReparameterization$call(inputs)
    # }

    self$built=TRUE
  },

  call = function(inputs) {
    self$`__forward__`(inputs)
    # self$denseReparameterization$call(inputs)
  }
)

layer_conv2d_reparameterization_NG = Layer(
  classname = 'ReparameterizationNGConv2D',
  inherit = tf$keras$layers$Layer,
  initialize = function(...) {
    super$initialize()
    self$conv2dReparameterization = tfp$layers$Convolution2DReparameterization(...)
  },

  build = function(input_shape) {
    self$conv2dReparameterization$build(input_shape)

    self$`__forward__` = tf$custom_gradient(function(inputs){

      with(tf$GradientTape(persistent=TRUE) %as% tape, {
        tape$watch(inputs)
        outputs = self$conv2dReparameterization$call(inputs)
      })

      grad_fn = function(upstream, variables){
        c(dx, dv) %<-% tape$gradient(outputs, list(inputs, variables), upstream)
        dv2 = lapply(1:length(variables), function(n){
          switch(
            gsub('.+/(bias_posterior_loc|kernel_posterior_loc|kernel_posterior_untransformed_scale):[0-9]+','\\1', variables[[n]]$name),
            bias_posterior_loc = dv[[n]],
            kernel_posterior_loc = dv[[n]] * tf$math$softplus(self$conv2dReparameterization$variables[[2]])^2,
            kernel_posterior_untransformed_scale = dv[[n]] * tf$math$softplus(self$conv2dReparameterization$variables[[2]])^4 / tf$math$sigmoid(self$conv2dReparameterization$variables[[2]])^2  * 2)
        })
        list(dx, dv2)
      }
      list(outputs, grad_fn)
    })

    self$built=TRUE
  },

  call = function(inputs) {
    self$`__forward__`(inputs)
  }
)

DenseNet-BC from ftp.layers¶

In [ ]:
divergence_fn = function(q, p, ignore){
  0.0 #tfp$distributions$kl_divergence(q, p) /50000
}

#my.posterior = tfp$layers$default_mean_field_normal_fn(untransformed_scale_initializer=tf$random_normal_initializer(mean=-9.0, stddev=0.05))
my.posterior = tfp$layers$default_mean_field_normal_fn()

# output width and height are the same, output channels = input channels + growth rate
conv_bottleneck = function(x, growth_rate){
  conv = x %>%
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>% # bn1
    layer_activation_relu() %>%
    layer_conv_2d_reparameterization(filters = 4 * growth_rate, kernel_size = c(1L,1L), bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>% # conv1
    layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn2
    layer_activation_relu() %>%
    layer_conv_2d_reparameterization(filters = growth_rate, kernel_size = c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)  # conv2
  list(conv, x) %>% layer_concatenate(axis=-1L) # channel dimension 4 for channel last and 2 for channel first, keras default is channel last
}

# output width and height are half of the input, output channels = reduction * (input channels + growth rate * number of blocks)
dense_layer = function(x, block, growth_rate, num_blocks){
  layers = x
  for(i in 1:num_blocks) layers = layers %>% block(., growth_rate)
  print(sprintf('dense_layer: growth_rate = %d, num_blocks = %d', growth_rate, num_blocks))
  print(layers)
  layers
}

transition = function(x, reduction = .5){
  x %>%
    layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
    layer_activation_relu() %>%
    layer_conv_2d_reparameterization(filters = floor(reduction*dim(x)[4]), kernel_size = c(1L,1L), bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
    layer_average_pooling_2d(pool_size = 2L)
}
In [ ]:
g.r. = 12L
predictions = inputs %>%
  #layer_lambda(function(img) tf$image$resize(img, c(224L, 224L))) %>%
  layer_conv_2d_reparameterization(filters = 2*g.r., kernel_size = 3L, padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
  dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>% # layer 1
  dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>%  # layer 2
  dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>%  # layer 3
  layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
  layer_activation_relu() %>%
  layer_average_pooling_2d(pool_size = 4L) %>%
  layer_flatten() %>%
  layer_dense(activation = 'softmax', units = 10L)

densenet.cifar10 = keras_model(inputs = inputs, outputs = predictions)
[1] "dense_layer: growth_rate = 12, num_blocks = 16"
KerasTensor(type_spec=TensorSpec(shape=(None, 32, 32, 216), dtype=tf.float32, name=None), name='concatenate_207/concat:0', description="created by layer 'concatenate_207'")
[1] "dense_layer: growth_rate = 12, num_blocks = 16"
KerasTensor(type_spec=TensorSpec(shape=(None, 16, 16, 300), dtype=tf.float32, name=None), name='concatenate_223/concat:0', description="created by layer 'concatenate_223'")
[1] "dense_layer: growth_rate = 12, num_blocks = 16"
KerasTensor(type_spec=TensorSpec(shape=(None, 8, 8, 342), dtype=tf.float32, name=None), name='concatenate_239/concat:0', description="created by layer 'concatenate_239'")
In [ ]:
tf$keras$utils$plot_model(densenet.cifar10, to_file='tfp_DenseNetBC_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'tfp_DenseNetBC_CIFAR10.png')
Image

ResNet from tfp.layers¶

In [ ]:
divergence_fn = function(q, p, ignore){
  0.0 #tfp$distributions$kl_divergence(q, p) /50000
}

#my.posterior = tfp$layers$default_mean_field_normal_fn(untransformed_scale_initializer=tf$random_normal_initializer(mean=-9.0, stddev=0.05))
my.posterior = tfp$layers$default_mean_field_normal_fn()


myconv2d = function(x, filters, strides){
  if(strides[1]>1)
    x %>%
    layer_zero_padding_2d(padding=list(list(1L, 0L),list(1L, 0L))) %>%
    layer_conv_2d_reparameterization(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'valid', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
  else
    x %>%
    layer_conv_2d_reparameterization(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
}

conv_block = function(x, filters, strides = 1L){
  #print('x')
  #print(x)
  out = x %>%
    layer_conv_2d_reparameterization(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
    #myconv2d(filters = filters, strides = strides) %>%  # conv1
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>% # bn1, by default channel is last dimension
    layer_activation_relu() %>%
    layer_conv_2d_reparameterization(filters = filters, strides = 1L, kernel_size = c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>% # conv2
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  # bn2
  #print('out')
  #print(out)
  if(dim(x)[4]==filters & strides==1L) shortcut = x # channel dimension 4 for channel last and 2 for channel first
  else shortcut =  x %>%
      layer_conv_2d_reparameterization(filters = filters, strides = strides, kernel_size = 1L, name = sprintf('shortcut/conv2d_%d',keras$backend$get_uid(prefix='s/c')), bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
      layer_batch_normalization(momentum = .9, epsilon = 1e-5, name=sprintf('shortcut/bn_%d',keras$backend$get_uid(prefix='s/b')))
  #print('shortcut')
  #print(shortcut)
  (out + shortcut) %>% layer_activation_relu()
}

conv_bottleneck = function(x, filters, strides = 1L){
  expansion = 4L
  #print('x')
  #print(x)
  out = x %>%
    layer_conv_2d_reparameterization(filters = filters, kernel_size = c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>% # bn1
    layer_activation_relu() %>%
    layer_conv_2d_reparameterization(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>% # conv2
    layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn2
    layer_activation_relu() %>%
    layer_conv_2d_reparameterization(filters = expansion * filters, kernel_size = c(1L,1L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>% # conv3
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  # bn3
  #print('out')
  #print(out)
  if(dim(x)[4]== expansion * filters & strides==1L) shortcut = x # channel dimension 4 for channel last and 2 for channel first
  else shortcut =  x %>%
      layer_conv_2d_reparameterization(filters = expansion * filters, strides = strides, kernel_size = 1L, name = sprintf('shortcut/conv2d_%d',keras$backend$get_uid(prefix='s/c')), bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
      layer_batch_normalization(momentum = .9, epsilon = 1e-5, name=sprintf('shortcut/bn_%d',keras$backend$get_uid(prefix='s/b')))
  #print('shortcut')
  #print(shortcut)
  (out + shortcut) %>% layer_activation_relu()
}

conv_layer = function(x, block, filters, num_blocks, strides){
  layers = x
  print(sprintf('new conv_layer:filters = %d, num_blocks = %d, strides = %d', filters, num_blocks, strides))
  for(i in 1:num_blocks) {
    layers = layers %>% block(., filters, strides)
    strides = 1 # only shrink image once at the first block, if strides > 1
  }
  layers
}
In [ ]:
inputs = layer_input(shape = dim(cifar10$train$x)[-1])

predictions = inputs %>%
  layer_conv_2d_reparameterization(filters = 16, strides = 1L, kernel_size = 3L, padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
  layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
  layer_activation_relu %>%
  conv_layer(., block = conv_block, filters = 16, num_blocks = 18, strides = 1) %>%  # layer 1
  conv_layer(., block = conv_block, filters = 32, num_blocks = 18, strides = 2) %>%  # layer 2
  conv_layer(., block = conv_block, filters = 64, num_blocks = 18, strides = 2) %>%  # layer 3
  layer_average_pooling_2d(pool_size = 8L) %>%
  layer_flatten() %>%
  layer_dense_reparameterization(activation = 'softmax', units = 10L, bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)

resnet.cifar10 = keras_model(inputs = inputs, outputs = predictions)
[1] "new conv_layer:filters = 16, num_blocks = 18, strides = 1"
[1] "new conv_layer:filters = 32, num_blocks = 18, strides = 2"
[1] "new conv_layer:filters = 64, num_blocks = 18, strides = 2"
In [ ]:
tf$keras$utils$plot_model(densenet.cifar10, to_file='tfp_ResNet_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'tfp_ResNet_CIFAR10.png')
Image

EfficientNet B0 from tfp.layers¶

In [ ]:
divergence_fn = function(q, p, ignore){
  0.0 #tfp$distributions$kl_divergence(q, p) /50000
}

my.posterior = tfp$layers$default_mean_field_normal_fn(loc_initializer = initializer_variance_scaling(scale=2.0, mode='fan_out'))

convBnAct = function(x, filters, kernel_size=3L, stride=1L, padding='valid', bn=TRUE, act='swish', use_bias=FALSE, groups=1L){
  #if(stride==2L & kernel_size == 3L) x = x %>%
  #  layer_zero_padding_2d(padding=list(list(0L, 1L),list(0L, 1L))) %>%
  #  layer_conv_2d(filters=filters, kernel_size=kernel_size, stride=stride, padding='valid', use_bias=use_bias, groups = groups, kernel_initializer = c_k_i)
  #else
  x = layer_conv_2d_reparameterization(x, filters=filters, kernel_size=kernel_size, stride=stride, padding=padding, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
  if(bn) x = layer_batch_normalization(x, momentum = .9, epsilon = 1e-5)
  if(!is.null(act)) x %>% layer_activation(activation = act) else x
}

squeezeExcitation = function(x, filters){
  y = x %>%
    layer_global_average_pooling_2d(keepdims=TRUE) %>% # is this adaptive average pool 2d with output dim 1x2?
    layer_conv_2d_reparameterization(filters=filters, kernel_size=1L, activation='swish', kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
    layer_conv_2d_reparameterization(filters=dim(x)[4], kernel_size=1L, activation='sigmoid', kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
  x * y
}

stochasticDepth = function(x, survival_prob = .8){
  tf$where(tf$random$uniform(list()) < survival_prob, x / survival_prob, 0)
}

#Mobile-net conv Block with expansion factor N
MBConvN = function(x, filters, kernel_size=3L, stride=1L, expansion_factor = 6L, reduction = 4L, survival_prob = .8){
  filters_in = dim(x)[4]
  residual = x
  intermediate_channels = filters_in * expansion_factor
  if(expansion_factor!=1L) x = convBnAct(x, filters = intermediate_channels, kernel_size=1L)

  x = x %>%
    convBnAct(filters = intermediate_channels, kernel_size = kernel_size, stride = stride, padding = 'same', groups = intermediate_channels) %>%
    squeezeExcitation(filters = floor(filters_in/reduction)) %>%
    convBnAct(filters = filters, kernel_size = 1L, act=NULL, use_bias=FALSE)

  if(stride==1L & filters_in==filters){
    residual + layer_dropout(x, rate = .2, noise_shape=c(py_none(), 1L, 1L, 1L))
  } else x
}

featureExtractor = function(x, width_mult, depth_mult, last_channel){
  kernels = c(3, 3, 5, 3, 5, 5, 3)
  expansions = c(1, 6, 6, 6, 6, 6, 6)
  scaled_num_channels = 4 * ceiling( c(16, 24, 40, 80, 112, 192, 320) * width_mult/4 )
  scaled_num_layers = c(1, 2, 2, 3, 3, 4, 1) * depth_mult
  strides = c(1, 2, 2, 2, 1, 2, 1)

  x = x %>% convBnAct(filters = 4 * ceiling(32 * width_mult/4), kernel_size = 3L, stride = 2L, padding = 'same')
  for(i in 1:length(scaled_num_layers)){
    for(j in 1:scaled_num_layers[i]){
      x = x %>% MBConvN(filters = scaled_num_channels[i], kernel_size = kernels[i], stride = if(j==1) strides[i] else 1L, expansion_factor = expansions[i])
    }
  }
  x %>% convBnAct(filters = last_channel, kernel_size = 1L, stride = 1L, padding='valid')
}
In [ ]:
inputs = layer_input(shape = dim(cifar10$train$x)[-1])

predictions = inputs %>%
  featureExtractor(width_mult = 1, depth_mult = 1, last_channel = 1280) %>%
  layer_global_average_pooling_2d() %>%
  layer_dropout(.2) %>%
  layer_dense_reparameterization(units = 10L, activation='softmax', kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
efficientNetB0.cifar10 = keras_model(inputs = inputs, outputs = predictions)
#efficientNetB0.cifar10
In [ ]:
tf$keras$utils$plot_model(efficientNetB0.cifar10, to_file='tfp_EfficientNetB0_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'tfp_EfficientNetB0_CIFAR10.png')
Image

Wide Resnet 28-10 from tfp.layers¶

In [ ]:
divergence_fn = function(q, p, ignore){
  0.0 #tfp$distributions$kl_divergence(q, p) /50000
}

my.posterior = tfp$layers$default_mean_field_normal_fn(loc_initializer='he_normal')

conv_block = function(x, base, k, dropout = 0, strides = c(1L, 1L), shortcut = FALSE){
  h = x %>%
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>%
    layer_activation_relu()
  out = h %>%
    layer_conv_2d_reparameterization(base*k, c(3L,3L), padding = 'same', strides=strides, bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
    layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>%
    layer_activation_relu() %>%
    layer_dropout(rate = dropout) %>%  # conv1
    layer_conv_2d_reparameterization(base*k, c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)

  if(shortcut) skip = h %>%
    layer_conv_2d_reparameterization(base*k, c(1L, 1L), strides = strides, bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
  else skip = x
  layer_add(list(out, skip))
}
In [ ]:
inputs = layer_input(shape = dim(cifar10$train$x)[-1])

k = 10

predictions = inputs %>%
  layer_conv_2d_reparameterization(filters = 16, strides = 1L, kernel_size = 3L, padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
  conv_block(., 16, k, shortcut=TRUE) %>% conv_block(., 16, k, dropout=.4) %>% conv_block(., 16, k, dropout=.3) %>%
  conv_block(., 32, k, strides = c(2L, 2L), shortcut=TRUE) %>% conv_block(., 32, k, dropout=.4) %>% conv_block(., 32, k, dropout=.4) %>%
  conv_block(., 64, k, strides = c(2L, 2L), shortcut=TRUE) %>% conv_block(., 64, k, dropout=.4) %>% conv_block(., 64, k, dropout=.4) %>%
  layer_batch_normalization(momentum = .9, epsilon = 1e-5)  %>% layer_activation_relu() %>% layer_average_pooling_2d(pool_size = 8L) %>% layer_flatten() %>%
  layer_dense_reparameterization(activation = 'softmax', units = 10L, bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)

wrn_28_10.cifar10 = keras_model(inputs = inputs, outputs = predictions)
In [ ]:
tf$keras$utils$plot_model(wrn_28_10.cifar10, to_file='tfp_WRN-28-10_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'tfp_WRN-28-10_CIFAR10.png')
Image

MLP-Mixer¶

In [ ]:
layer_patch = Layer(
  classname = 'Patch',
  inherit = tf$keras$layers$Layer,
  initialize = function(patch_size, projection_dim, ...) {
    super$initialize()
    self$patch_size = patch_size
    self$projection_dim = projection_dim
  },

  build = function(input_shape) {
    # assert: image is square, patch is square
    self$num_patches =  as.integer(floor(input_shape[[2]]/self$patch_size)^2)
    self$projection = layer_dense_reparameterization(units = self$projection_dim)
    self$position_embedding = tf$keras$layers$Embedding(input_dim = self$num_patches, output_dim = self$projection_dim)
    self$built=TRUE
    return()
  },

  call = function(images) {
    patches = tf$image$extract_patches(
      images = images,
      sizes = c(1L, self$patch_size, self$patch_size, 1L),
      strides = c(1L, self$patch_size, self$patch_size, 1L),
      rates = c(1L, 1L, 1L, 1L),
      padding = 'VALID'
    )
    patches = tf$reshape(
      patches,
      shape=c(tf$shape(images)[1L], -1L, tf$reduce_prod(tf$shape(patches)[4L]))
    )
    positions = tf$range(start=0, limit = self$num_patches, delta = 1)
    self$projection(patches) + self$position_embedding(positions)
  }
)

mlp_mixer_block = function(x){
  num_patches = dim(x)[2]
  projection_dim = dim(x)[3]
  mlp1 = x %>%
    layer_layer_normalization(epsilon = 1e-6) %>%
    tf$linalg$matrix_transpose(.) %>%
    layer_dense_reparameterization(units = num_patches, activation='gelu') %>%
    layer_dense_reparameterization(units = num_patches) %>%
    layer_dropout(rate = .2) %>%
    tf$linalg$matrix_transpose(.)
  x = layer_add(list(x, mlp1))
  mlp2 = x %>%
    layer_layer_normalization(epsilon = 1e-6) %>%
    layer_dense_reparameterization(units = projection_dim, activation='gelu') %>%
    layer_dense_reparameterization(units = projection_dim) #%>%
    layer_dropout(rate = .2)
  layer_add(list(x, mlp2))
}
In [ ]:
inputs = layer_input(shape = dim(cifar10$train$x)[-1])

predictions = inputs %>%
  layer_resizing(64L, 64L) %>%
  layer_patch(patch_size = 8L, projection_dim = 256L) %>%
  mlp_mixer_block %>%
  mlp_mixer_block %>%
  mlp_mixer_block %>%
  mlp_mixer_block %>%
  layer_global_average_pooling_1d %>%
  layer_dropout(rate = .2) %>%
  layer_dense_reparameterization(units=10L, activation = 'softmax')

mlp_mixer.cifar10 = keras_model(inputs = inputs, outputs = predictions)
In [ ]:
tf$keras$utils$plot_model(mlp_mixer.cifar10, to_file='tfp_MLPMixer_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'tfp_MLPMixer_CIFAR10.png')
Image

Datasets involved¶

CIFAR 10¶

In [ ]:
cifar10 = dataset_cifar10()

cifar10.datagen = image_data_generator(
  featurewise_center=TRUE,
  featurewise_std_normalization=TRUE,
  ##rotation_range=15,
  zoom_range=0.1,
  width_shift_range=0.1,
  height_shift_range=0.1,
  horizontal_flip=TRUE,
)
cifar10.datagen$fit(cifar10$train$x)

# stream the train/test sets with `tf.keras.preprocessing.image.ImageDataGenerator.flow` in the following way
#cifar10.datagen$flow(cifar10$train$x, to_categorical(cifar10$train$y, num_classes=10L), batch_size = 64L,  shuffle=TRUE)
#cifar10.datagen$flow(cifar10$test$x, to_categorical(cifar10$test$y, num_classes=10L), batch_size = 1000L)

CIFAR 100¶

In [ ]:
cifar100 = dataset_cifar100()

cifar100.datagen = image_data_generator(
  featurewise_center=TRUE,
  featurewise_std_normalization=TRUE,
  ##rotation_range=15,
  zoom_range=0.1,
  width_shift_range=0.1,
  height_shift_range=0.1,
  horizontal_flip=TRUE,
)
cifar100.datagen$fit(cifar100$train$x)

# stream the train/test sets with `tf.keras.preprocessing.image.ImageDataGenerator.flow` in the following way
#cifar100.datagen$flow(cifar100$train$x, to_categorical(cifar100$train$y, num_classes=100L), batch_size = 64L,  shuffle=TRUE)
#cifar100.datagen$flow(cifar100$test$x, to_categorical(cifar100$test$y, num_classes=100L), batch_size = 1000L)

ImageNet resized/64x64¶

In [ ]:
httr::GET('https://image-net.org/login.php', authenticate("<username>", "<password>"))
tfds = reticulate::import('tensorflow_datasets')
In [ ]:
c(imagenet64x64, info) %<-% tfds$load('imagenet_resized/64x64', shuffle_files=FALSE, with_info=TRUE)
In [ ]:
info
tfds.core.DatasetInfo(
    name='imagenet_resized',
    full_name='imagenet_resized/64x64/0.1.0',
    description="""
    This dataset consists of the ImageNet dataset resized to fixed size. The images
    here are the ones provided by Chrabaszcz et. al. using the box resize method.
    
    For [downsampled ImageNet](http://image-net.org/download.php) for unsupervised
    learning see `downsampled_imagenet`.
    
    WARNING: The integer labels used are defined by the authors and do not match
    those from the other ImageNet datasets provided by Tensorflow datasets. See the
    original
    [label list](https://github.com/PatrykChrabaszcz/Imagenet32_Scripts/blob/master/map_clsloc.txt),
    and the
    [labels used by this dataset](https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/image_classification/imagenet_resized_labels.txt).
    Additionally, the original authors 1 index there labels which we convert to 0
    indexed by subtracting one.
    """,
    config_description="""
    Images resized to 64x64
    """,
    homepage='https://patrykchrabaszcz.github.io/Imagenet32/',
    data_path='/root/tensorflow_datasets/imagenet_resized/64x64/0.1.0',
    file_format=tfrecord,
    download_size=13.13 GiB,
    dataset_size=10.29 GiB,
    features=FeaturesDict({
        'image': Image(shape=(64, 64, 3), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=1000),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'train': <SplitInfo num_examples=1281167, num_shards=128>,
        'validation': <SplitInfo num_examples=50000, num_shards=4>,
    },
    citation="""@article{chrabaszcz2017downsampled,
      title={A downsampled variant of imagenet as an alternative to the cifar datasets},
      author={Chrabaszcz, Patryk and Loshchilov, Ilya and Hutter, Frank},
      journal={arXiv preprint arXiv:1707.08819},
      year={2017}
    }""",
)
In [ ]:
fig = tfds$show_examples(imagenet64x64$train, info)
fig$savefig('imagenet64x64.png')
IRdisplay::display_png(file = 'imagenet64x64.png')
Image
In [ ]:
# use keras preprocessing layer for normalization and basic augmentation
#inputs = layer_input(imagenet64x64$train$element_spec$image$shape)
normalization = layer_normalization()
normalization$adapt(imagenet64x64$validation$map(function(xy){
  c(x,y) %<-% xy
  x
})$batch(1000L))
normalization$mean
normalization$variance^.5
tf.Tensor([[[[120.70578  114.888855 102.38361 ]]]], shape=(1, 1, 1, 3), dtype=float32)
tf.Tensor([[[[67.8931  66.11974 69.63358]]]], shape=(1, 1, 1, 3), dtype=float32)

Tiny ImageNet¶

https://github.com/ksachdeva/tiny-imagenet-tfds/blob/master/tiny_imagenet/_imagenet.py

  • ~/.local/share/r-miniconda/bin/conda init
  • . ~/.bashrc
  • conda info --envs
  • conda activate r-reticulate
  • pip install git+https://github.com/ksachdeva/tiny-imagenet-tfds.git
  • python -c "import site; print(''.join(site.getsitepackages()))"
  • need to comment out
    • vi /root/.local/share/r-miniconda/envs/r-reticulate/lib/python3.9/site-packages/tiny_imagenet/_imagenet.py
    • #urls=["https://tiny-imagenet.herokuapp.com/"],
    • #num_shards=1,
import os
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds
from tiny_imagenet import TinyImagenetDataset

# optional
tf.compat.v1.enable_eager_execution()

tiny_imagenet_builder = TinyImagenetDataset()

# this call (download_and_prepare) will trigger the download of the dataset
# and preparation (conversion to tfrecords)
#
# This will be done only once and on next usage tfds will
# use the cached version on your host.
#
# You can pass optional argument to this method called
# DownloadConfig (https://www.tensorflow.org/datasets/api_docs/python/tfds/download/DownloadConfig)
# to customize the location where the dataset is downloaded, extracted and processed.
tiny_imagenet_builder.download_and_prepare()

train_dataset = tiny_imagenet_builder.as_dataset(split="train")
validation_dataset = tiny_imagenet_builder.as_dataset(split="validation")

assert(isinstance(train_dataset, tf.data.Dataset))
assert(isinstance(validation_dataset, tf.data.Dataset))

for a_train_example in train_dataset.take(5):
    image, label, id = a_train_example["image"], a_train_example["label"], a_train_example["id"]
    print(f"Image Shape - {image.shape}")
    print(f"Label - {label.numpy()}")
    print(f"Id - {id.numpy()}")

# print info about the data
print(tiny_imagenet_builder.info)
In [ ]:
tiny_imagenet = reticulate::import('tiny_imagenet')
tiny_imagenet_builder = tiny_imagenet$TinyImagenetDataset()
tiny_imagenet_builder$info
tiny_imagenet_builder$download_and_prepare()
train_dataset = tiny_imagenet_builder$as_dataset(split="train")
validation_dataset = tiny_imagenet_builder$as_dataset(split="validation")
tfds.core.DatasetInfo(
    name='tiny_imagenet_dataset',
    full_name='tiny_imagenet_dataset/0.1.0',
    description="""
    Tiny ImageNet Challenge is a similar challenge as ImageNet with a smaller dataset but
                             less image classes. It contains 200 image classes, a training
                             dataset of 100, 000 images, a validation dataset of 10, 000
                             images, and a test dataset of 10, 000 images. All images are
                             of size 64×64.
    """,
    homepage='https://www.tensorflow.org/datasets/catalog/tiny_imagenet_dataset',
    data_path='/root/tensorflow_datasets/tiny_imagenet_dataset/0.1.0',
    file_format=tfrecord,
    download_size=236.61 MiB,
    dataset_size=210.18 MiB,
    features=FeaturesDict({
        'id': Text(shape=(), dtype=string),
        'image': Image(shape=(64, 64, 3), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=200),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'train': <SplitInfo num_examples=100000, num_shards=2>,
        'validation': <SplitInfo num_examples=10000, num_shards=1>,
    },
    citation="""@article{tiny-imagenet,
                                  author = {Li,Fei-Fei}, {Karpathy,Andrej} and {Johnson,Justin}"}""",
)
In [ ]:
tfds = reticulate::import('tensorflow_datasets')
fig = tfds$show_examples(train_dataset, tiny_imagenet_builder$info)
fig$savefig('tiny_imagenet.png')
IRdisplay::display_png(file = 'tiny_imagenet.png')
Image
In [ ]:
str(train_dataset$take(1L)$get_single_element())
List of 3
 $ id   :<tf.Tensor: shape=(), dtype=string, numpy=b'n02279972'>
 $ image:<tf.Tensor: shape=(64, 64, 3), dtype=uint8, numpy=…>
 $ label:<tf.Tensor: shape=(), dtype=int64, numpy=13>

Characterizing BNN loss and noise¶

In [ ]:
require(tensorflow)
require(keras)
require(reticulate)
require(tfprobability)
require(coro)
Sys.setenv("TF_XLA_FLAGS" = "--tf_xla_enable_xla_devices")
In [ ]:
mnist <- dataset_mnist()

mnist_train_ds = tf$data$Dataset$from_tensor_slices(mnist$train)$shuffle(2048L)$batch(256L)$map(function(xy){
  c(x,y) %<-% xy
  x <- tf$reshape(tf$cast(x/255.0, dtype=tf$float32), shape=tuple(-1L,784L))
  y.onehot = tf$one_hot(y, depth = 10L, dtype=tf$float32)
  tuple(x,y.onehot)
})

mnist_test_ds = tf$data$Dataset$from_tensor_slices(mnist$test)$batch(1024L)$map(function(xy){
  c(x,y) %<-% xy
  x <- tf$reshape(tf$cast(x/255.0, dtype=tf$float32), shape=tuple(-1L,784L))
  y.onehot = tf$one_hot(y, depth = 10L, dtype=tf$float32)
  tuple(x,y.onehot)
})

Implementation of tfp.layers.DenseReparameterization.¶

According to documentation, it implements the following reparameterization estimator.

  • kernel, bias ~ posterior
  • outputs = activation(matmul(inputs, kernel) + bias)

"You can access the kernel and/or bias posterior and prior distributions after the layer is built via the kernel_posterior, kernel_prior, bias_posterior and bias_prior properties."

"Upon being built, this layer adds losses (accessible via the losses property) representing the divergences of kernel and/or bias surrogate posteriors and their respective priors. When doing minibatch stochastic optimization, make sure to scale this loss such that it is applied just once per epoch."

  • tfp.layers.DenseReparameterization is a subclass of tfp.layers._DenseVariational, the forward propagation is defined in call method, which in turn calls tfp.layers.DenseReparameterization._apply_variational_kernel, tfp.layers._DenseVariational._apply_variational_bias, tfp.layers._DenseVariational._apply_divergence with the provided kernel_divergence_fn/kernel_posterior/kernel_prior, bias_divergence_fn/bias_posterior/bias_prior.
def _apply_variational_kernel(self, inputs):
    self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn(
        self.kernel_posterior)
    self.kernel_posterior_affine = None
    self.kernel_posterior_affine_tensor = None
    return tf.matmul(inputs, self.kernel_posterior_tensor)
  • _apply_variational_kernel calls tfp.layers._DenseVariational.kernel_posterior_tensor_fn, which by default sample() from the tfp.layers._DenseVariational.kernel_posterior attribute, which in turn, is constructed in tfp.layers._DenseVariational.build by tfp.layers.DenseReparameterization.kernel_posterior_fn, which by default is tfp.layers.default_mean_field_normal_fn, which in turn calls tfp.layers.default_loc_scale_fn to create Variables representing mean and scale. The scale is a DeferredTensor.
scale = tfp_util.DeferredTensor(
    untransformed_scale,
    lambda x: (np.finfo(dtype.as_numpy_dtype).eps + tf.nn.softplus(x)))

Implementation of tfp.layers.DenseLocalReparameterization¶

def _apply_variational_kernel(self, inputs):
    # ... omitted for brevity
    self.kernel_posterior_affine = normal_lib.Normal(
        loc=tf.matmul(inputs, self.kernel_posterior.distribution.loc),
        scale=tf.sqrt(tf.matmul(
            tf.square(inputs),
            tf.square(self.kernel_posterior.distribution.scale))))
    self.kernel_posterior_affine_tensor = (
        self.kernel_posterior_tensor_fn(self.kernel_posterior_affine))
    self.kernel_posterior_tensor = None
    return self.kernel_posterior_affine_tensor

Implementation of tfp.layers.DenseFlipout¶

def _apply_variational_kernel(self, inputs):
    # ... omitted for brevity
    self.kernel_posterior_affine = normal_lib.Normal(
        loc=tf.zeros_like(self.kernel_posterior.distribution.loc),
        scale=self.kernel_posterior.distribution.scale)
    self.kernel_posterior_affine_tensor = (
        self.kernel_posterior_tensor_fn(self.kernel_posterior_affine))
    self.kernel_posterior_tensor = None

    input_shape = tf.shape(inputs)
    batch_shape = input_shape[:-1]

    seed_stream = SeedStream(self.seed, salt='DenseFlipout')

    sign_input = tfp_random.rademacher(
        input_shape,
        dtype=inputs.dtype,
        seed=seed_stream())
    sign_output = tfp_random.rademacher(
        tf.concat([batch_shape,
                   tf.expand_dims(self.units, 0)], 0),
        dtype=inputs.dtype,
        seed=seed_stream())
    perturbed_inputs = tf.matmul(
        inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output

    outputs = tf.matmul(inputs, self.kernel_posterior.distribution.loc)
    outputs += perturbed_inputs
    return outputs
In [ ]:
divergence_fn = function(q, p, ignore){
  tfp$distributions$kl_divergence(q, p) / 60000
}

inputs = layer_input(shape = 784L, name='input')
predictionsR = inputs %>%
  layer_dense_reparameterization(units = 40L, activation = tf$nn$sigmoid, kernel_divergence_fn = divergence_fn, bias_posterior_fn = NULL) %>%
  layer_dense_reparameterization(units = 40L, activation = tf$nn$sigmoid, kernel_divergence_fn = divergence_fn, bias_posterior_fn = NULL) %>%
  layer_dense_reparameterization(units = 10L, activation = 'softmax', kernel_divergence_fn = divergence_fn, , bias_posterior_fn = NULL)

modelR = keras_model(inputs = inputs, outputs = predictionsR)
modelR %>% compile(loss='categorical_crossentropy', optimizer='adam', metrics='accuracy')
modelR$save_weights('model0.ckpt')
modelR$load_weights('model0.ckpt')

modelR %>% fit(mnist_train_ds, epochs=100, validation_data=mnist_test_ds, verbose=0L)

modelR %>% evaluate(mnist_test_ds)
modelR$save_weights('modelR.ckpt')
modelR$load_weights('modelR.ckpt')
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f128c19ef70>
loss
0.402956306934357
accuracy
0.947799980640411
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f128a940d30>
In [ ]:
c(x,y) %<-% coro::collect(mnist_test_ds, n = 1)[[1]]
gradR = lapply(1:16, function(n){
  with(tf$GradientTape(persistent=TRUE) %as% tape, {
    y_pred = modelR$call(x)
    loss = modelR$compiled_loss(y_true = y, y_pred = y_pred, regularization_losses = modelR$losses)
  })
  grad = tape$gradient(loss, modelR$weights)
})

b = lapply(1:length(gradR[[1]]), function(m){
   tf$reduce_mean(tf$stack(lapply(gradR, function(n) n[[m]]), axis=0L)^2, axis=0L) - tf$reduce_mean(tf$stack(lapply(gradR, function(n) n[[m]]), axis=0L), axis=0L)^2
 })

for(m in 1:length(a)-1){
  if(m %%2 ==0){
    print(sprintf('mean %d', m/2))
    print(summary(c(b[[m+1]]$numpy())))
  }
  else {
    print(sprintf('scale %d', (m-1)/2))
    print(summary(c(b[[m+1]]$numpy())))

  }
}


b = lapply(1:length(gradR[[1]]), function(m){
   tf$reduce_mean(tf$stack(lapply(gradR, function(n) n[[m]]), axis=0L), axis=0L)
 })

for(m in 1:length(a)-1){
  if(m %%2 ==0){
    print(sprintf('mean %d', m/2))
    print(summary(c(b[[m+1]]$numpy())))
  }
  else {
    print(sprintf('scale %d', (m-1)/2))
    print(summary(c(b[[m+1]]$numpy()))) #tf$math$softplus(

  }
}
In [ ]:
my_divergence_fn = function(q, p, ignore=NULL){
  # regularization loss will dominate the total loss, 15-approx 0, so I scale it by 1e-3
  tfp$distributions$kl_divergence(q, p) / 60000 /1e3
}

inputs = layer_input(shape = 784L, name='input')
In [ ]:
predictionsR = inputs %>%
  layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
  layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
  layer_dense_reparameterization(units = 10L, activation = 'softmax', kernel_divergence_fn = my_divergence_fn)

modelR = keras_model(inputs = inputs, outputs = predictionsR)
modelR %>% compile(loss='categorical_crossentropy', optimizer='adam', metrics='accuracy')
modelR$save_weights('model0.ckpt')

Below is how to set the Bayesian NN parameter initializers to match vanilla NN parameter initializers. This doesn't seem necessary for MNIST.

my_kernel_posterior_fn = function() tfp$layers$util$default_mean_field_normal_fn(loc_initializer = 'glorot_uniform')
my_bias_posterior_fn = function() tfp$layers$util$default_mean_field_normal_fn(loc_initializer = 'zeros',is_singular=TRUE)
predictionsR = inputs %>%
  layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn, kernel_posterior_fn = my_kernel_posterior_fn(), bias_posterior_fn = my_bias_posterior_fn()) %>%
  layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn, kernel_posterior_fn = my_kernel_posterior_fn(), bias_posterior_fn = my_bias_posterior_fn()) %>%
  layer_dense_reparameterization(units = 10L, activation = 'softmax', kernel_divergence_fn = my_divergence_fn, kernel_posterior_fn = my_kernel_posterior_fn(), bias_posterior_fn = my_bias_posterior_fn())
In [ ]:
modelR$load_weights('model0.ckpt')
# there are 60k training example, and I set the batch size to be 300. so 2000 steps is 10 epochs - 60000/300 * 10 = 2000
historyR = modelR %>% fit(mnist_train_ds, steps_per_epoch = 1, epochs=2000, validation_data=mnist_test_ds, verbose=0L, callbacks = list(callback_csv_logger(filename = 'R.txt')))
modelR %>% evaluate(mnist_test_ds)
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f6ed06b88e0>
loss
0.113996714353561
accuracy
0.974399983882904
In [ ]:
predictionsL = inputs %>%
  layer_dense_local_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
  layer_dense_local_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
  layer_dense_local_reparameterization(units = 10L, activation = 'softmax', kernel_divergence_fn = my_divergence_fn)

modelL = keras_model(inputs = inputs, outputs = predictionsL)
modelL %>% compile(loss='categorical_crossentropy', optimizer='adam', metrics='accuracy')

modelL$load_weights('model0.ckpt')

historyL = modelL %>% fit(mnist_train_ds, steps_per_epoch = 1, epochs=2000, validation_data=mnist_test_ds, verbose=0L, callbacks = list(callback_csv_logger(filename = 'L.txt')))
modelL %>% evaluate(mnist_test_ds)
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f6ec86c1060>
loss
0.111960597336292
accuracy
0.978500008583069
In [ ]:
predictionsF = inputs %>%
  layer_dense_flipout(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
  layer_dense_flipout(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
  layer_dense_flipout(units = 10L, activation = 'softmax', kernel_divergence_fn = my_divergence_fn)

modelF = keras_model(inputs = inputs, outputs = predictionsF)
modelF %>% compile(loss='categorical_crossentropy', optimizer='adam', metrics='accuracy')

modelF$load_weights('model0.ckpt')

historyF = modelF %>% fit(mnist_train_ds, steps_per_epoch = 1, epochs=2000, validation_data=mnist_test_ds, verbose=0L, callbacks = list(callback_csv_logger(filename = 'F.txt')))
modelF %>% evaluate(mnist_test_ds)
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f6ec858e410>
loss
0.116981789469719
accuracy
0.972800016403198
In [ ]:
predictions0 = inputs %>%
  layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
  layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
  layer_dense_reparameterization(units = 10L, activation = 'softmax', kernel_divergence_fn = my_divergence_fn)

model0 = keras_model(inputs = inputs, outputs = predictions0)
model0$load_weights('model0.ckpt')

predictions = inputs %>%
  layer_dense(units = 512L, activation = 'relu') %>%
  layer_dense(units = 512L, activation = 'relu') %>%
  layer_dense(units = 10L, activation = 'softmax')

model = keras_model(inputs = inputs, outputs = predictions)
model %>% compile(loss='categorical_crossentropy', optimizer='adam', metrics='accuracy')

invisible({
  model$weights[[1]]$assign(model0$weights[[1]]) # layer 1 - kernel
  model$weights[[2]]$assign(model0$weights[[3]]) # layer 1 - bias
  model$weights[[3]]$assign(model0$weights[[4]]) # layer 2 - kernel
  model$weights[[4]]$assign(model0$weights[[6]]) # layer 2 - bias
  model$weights[[5]]$assign(model0$weights[[7]]) # layer 3 - kernel
  model$weights[[6]]$assign(model0$weights[[9]]) # layer 3 - bias
})
rm(predictions0, model0)

history = model %>% fit(mnist_train_ds, steps_per_epoch = 1, epochs=2000, validation_data=mnist_test_ds, verbose=0L, callbacks = list(callback_csv_logger(filename = 'V.txt')))
model %>% evaluate(mnist_test_ds)
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f6e44783670>
loss
0.0837457105517387
accuracy
0.979200005531311
In [ ]:
CE.mvnorm.diag = function(m0, Q0, m1, Q1){
  pinv.Q1 = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), Q1)
  (tf$reduce_sum( pinv.Q1 * Q0 + (m1-m0)^2 * pinv.Q1 + tf$math$log(2 * pi * Q1), axis=c(1L))) * (-.5)
}

#my_divergence_fn = function(q, p, ignore=NULL){
#  tfp$distributions$kl_divergence(q, p) / 60000 / 1e3
#}

layer_dense_Bayesian = Layer(
  classname = 'BayesianDense',
  inherit = tf$keras$layers$Layer,
  initialize = function(activation=NULL, use_bias = TRUE, divergence_fn = my_divergence_fn, ...) {
    super$initialize()
    self$divergence_fn = divergence_fn
    self$activation = keras$activations$get(activation)
    self$use_bias = use_bias
    self$dense = tf$keras$layers$Dense(activation=NULL, use_bias=FALSE, ...)
  },

  build = function(input_shape) {
    self$dense$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
    # redefine BNN weights as Gaussian random variables in terms of trainable tf variables
    self$kernel_loc = self$add_weight(name='A',shape=self$dense$kernel$shape, initializer=self$dense$kernel_initializer, trainable=TRUE)
    self$kernel_scale = self$add_weight(name='B',shape=self$dense$kernel$shape, initializer=tf$random_normal_initializer(mean = -3, stddev = .1), trainable=TRUE)#'uniform'
    if(self$use_bias) self$bias = self$add_weight(name='bias', shape = list(self$dense$units), initializer='zeros') else self$bias = NULL
    self$built=TRUE
  },

  call = function(inputs) {
    c(x, P0) %<-% inputs
    with(tf$GradientTape(persistent=TRUE) %as% g, {
      g$watch(list(x)) #g$watch(list(x, self$W))
      self$dense$kernel = self$kernel_loc #+ tf$einsum('ij,i->ij',self$B, self$W) # mvnorm with mean and variance
      a = self$dense$call(x) # tf$matmul(a=x, b=self$kernel) + self$bias
      if(self$use_bias) a = tf$nn$bias_add(a, self$bias)
      h = self$activation(a)
      `__h__` = self$activation(a)
    })
    # propagate variance (accept a variance, first order approximation of state transition, output a variance)
    #dhdx.sq.sum_x <- tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
    self$dense$kernel = tf$stop_gradient(self$kernel_loc^2)
    dhdx.sq.sum_x = tf$stop_gradient(g$gradient(`__h__`, a)^2) * self$dense$call(P0)
    #dhdw.sq.sum_w <- tf$einsum('bi,ji, bj->bi',g$gradient(h, a)^2, self$B^2, tf$stop_gradient(x^2))
    self$dense$kernel = tf$nn$softplus(self$kernel_scale+1e-9)^2
    dhdw.sq.sum_w = g$gradient(h, a)^2 * self$dense$call(tf$stop_gradient(x^2))
    P =  dhdx.sq.sum_x + dhdw.sq.sum_w + 1e-9

    ## add regularization loss
    q_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=self$kernel_loc, scale = tf$nn$softplus(self$kernel_scale+1e-9)), reinterpreted_batch_ndims = 2L)
    p_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc), scale = 1.0), reinterpreted_batch_ndims = 2L)
    self$add_loss(self$divergence_fn(q_, p_))

    list(h, P)
  }
)
In [ ]:
mnist_train_dsEP = tf$data$Dataset$from_tensor_slices(mnist$train)$`repeat`(-1L)$shuffle(2048L)$batch(300L)$map(function(xy){
  c(x,y) %<-% xy
  x = tf$reshape(tf$cast(x/255.0, dtype=tf$float32), shape=tuple(-1L,784L))
  P = tf$zeros_like(x) + 1e-6
  y.onehot = tf$one_hot(y, depth = 10L, dtype=tf$float32)
  tuple(tuple(x,P),y.onehot)
})

mnist_test_dsEP = tf$data$Dataset$from_tensor_slices(mnist$test)$batch(1024L)$map(function(xy){
  c(x,y) %<-% xy
  x <- tf$reshape(tf$cast(x/255.0, dtype=tf$float32), shape=tuple(-1L,784L))
  P = tf$zeros_like(x) + 1e-6
  y.onehot = tf$one_hot(y, depth = 10L, dtype=tf$float32)
  tuple(tuple(x,P),y.onehot)
})

mnist_train_dsEP$element_spec
[[1]]
[[1]][[1]]
TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)

[[1]][[2]]
TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)


[[2]]
TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
In [ ]:
inputsEP = list(layer_input(shape = 784L), layer_input(shape=tuple(784L)))
predictionsEP = inputsEP %>%
  layer_dense_Bayesian(units=512, activation = 'relu') %>%
  layer_dense_Bayesian(units=512, activation = 'relu') %>%
  layer_dense_Bayesian(units=10, activation = tf$identity) %>%
  layer_lambda(f = function(mv){
    tf$nn$softmax(tf$reduce_mean(tfp$distributions$MultivariateNormalDiag(loc = mv[[1]], scale_diag=mv[[2]]^.5)$sample(20L), axis=0L))
  })
modelEP = keras_model(inputs = inputsEP, outputs = predictionsEP)
modelEP %>% compile(loss='categorical_crossentropy', optimizer='adam', metrics='accuracy')

modelEP$load_weights('model0.ckpt')

historyEP = modelEP %>% fit(mnist_train_dsEP, steps_per_epoch = 1, epochs=2000, validation_data=mnist_test_dsEP, verbose=0L, callbacks = list(callback_csv_logger(filename = 'MRF.txt')))
modelEP %>% evaluate(mnist_test_dsEP)
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f6e44515a80>
loss
0.109516456723213
accuracy
0.979200005531311
In [ ]:
Vadam(keras$optimizers$experimental$Optimizer) %py_class% {
  initialize <- function(train_set_size, learning_rate=1e-3, beta_1=.9, beta_2 = .999, prior_prec=1.0, prec_init=0.0, num_samples=1, name='Vadam') {
    super$initialize(name=name)
    self$`_learning_rate` = self$`_build_learning_rate`(learning_rate)
    self$beta_1 = beta_1
    self$beta_2 = beta_2
    self$gamma = prior_prec
    self$gamma_0 = prec_init
    self$N = train_set_size
    self$num_samples = num_samples
  }
  build = function(var_list){
    super$build(var_list)
    if(py_has_attr(self, "_built") && self$`_built`) return()
    self$m = lapply(var_list, function(var) self$add_variable_from_reference(model_variable=var, variable_name="m"))
    self$mu = lapply(var_list, function(var) self$add_variable_from_reference(model_variable=var, variable_name="mu", initial_value=var))
    self$s = lapply(var_list, function(var) self$add_variable_from_reference(model_variable=var, variable_name="s"))
    self$`_built` = TRUE
  }
  update_step = function(gradient, variable){
     lr = tf$cast(self$learning_rate, variable$dtype)
     var_key = self$`_var_key`(variable)
     s = self$s[[self$`_index_dict`[var_key]]] #+1L? python index or R index?
     m = self$m[[self$`_index_dict`[var_key]]]
     mu = self$mu[[self$`_index_dict`[var_key]]]
     #v.assign_add((tf.square(gradient) - v) * (1 - self.beta_2))
     s$assign_add((gradient^2-s)*(1 - self$beta_2))
     #m.assign_add((gradient - m) * (1 - self.beta_1))
     #m$assign_add((1-self$beta_1)*(gradient-m)) # Adam
     m$assign_add((1-self$beta_1)*(gradient-m+self$gamma/self$N*mu)) # Vadam

     #local_step = tf.cast(self.iterations + 1, variable.dtype)
     #beta_1_power = tf.pow(tf.cast(self.beta_1, variable.dtype), local_step)
     #beta_2_power = tf.pow(tf.cast(self.beta_2, variable.dtype), local_step)
     #alpha = lr * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power)
     #variable.assign_sub((m * alpha) / (tf.sqrt(v) + self.epsilon))
     local_step = tf$cast(self$iterations + 1, variable$dtype)
     beta_1_power = tf$pow(tf$cast(self$beta_1, variable$dtype), local_step)
     beta_2_power = tf$pow(tf$cast(self$beta_2, variable$dtype), local_step)
     alpha = lr * tf$sqrt(1 - beta_2_power) / (1 - beta_1_power)
     #mu$assign_sub(tf$math$divide_no_nan(m * alpha, tf$sqrt(s))) # Adam
     mu$assign_sub(tf$math$divide_no_nan(m * alpha, tf$sqrt(s)+self$gamma/self$N)) # Vadam

     #variable$assign(mu) # Adam
     variable$assign(tfp$distributions$Normal(loc=mu, scale=tf$math$divide_no_nan(1.0, (self$N*s+self$gamma)^.5))$sample()) #Vadam
  }
  get_config = function(){
    super$get_config()
  }
}
In [ ]:
predictionsVadam = inputs %>%
  layer_dense(units = 512L, activation = 'relu') %>%
  layer_dense(units = 512L, activation = 'relu') %>%
  layer_dense(units = 10L, activation = 'softmax')

modelVadam = keras_model(inputs = inputs, outputs = predictionsVadam)

predictions0 = inputs %>%
  layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
  layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
  layer_dense_reparameterization(units = 10L, activation = 'softmax', kernel_divergence_fn = my_divergence_fn)

model0 = keras_model(inputs = inputs, outputs = predictions0)
model0$load_weights('model0.ckpt')

invisible({
  modelVadam$weights[[1]]$assign(model0$weights[[1]]) # layer 1 - kernel
  modelVadam$weights[[2]]$assign(model0$weights[[3]]) # layer 1 - bias
  modelVadam$weights[[3]]$assign(model0$weights[[4]]) # layer 2 - kernel
  modelVadam$weights[[4]]$assign(model0$weights[[6]]) # layer 2 - bias
  modelVadam$weights[[5]]$assign(model0$weights[[7]]) # layer 3 - kernel
  modelVadam$weights[[6]]$assign(model0$weights[[9]]) # layer 3 - bias
})
rm(predictions0, model0)


modelVadam %>% compile(loss='categorical_crossentropy', optimizer=Vadam(train_set_size=60000 * 1e4, learning_rate=0.001, beta_1=.9, beta_2 = .999 ), metrics='accuracy')
historyVadam = modelVadam %>% fit(mnist_train_ds$unbatch()$batch(60L), steps_per_epoch = 5, epochs=2000, validation_data=mnist_test_ds, verbose=0L, callbacks = list(callback_csv_logger(filename = 'Vadam.txt')))
modelVadam %>% evaluate(mnist_test_ds)
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f6f4a78f490>
loss
0.124228939414024
accuracy
0.97189998626709
In [ ]:
matplot(filter(cbind(
  historyR$metrics$loss,
  historyR$metrics$val_loss,
  historyL$metrics$loss,
  historyL$metrics$val_loss,
  historyF$metrics$loss,
  historyF$metrics$val_loss,
  history$metrics$loss,
  history$metrics$val_loss,
  historyEP$metrics$loss,
  historyEP$metrics$val_loss
  ), method='recursive', filter=.1)/1.1,
  lty = rep(c(1,2), times=5), col = rep(1:5, each=2),
  type='l', ylab='loss',xlab='epoch', log='xy')
legend('bottomleft',
  lty=rep(1:2, times=5),
  col=rep(1:5, each=2),
  legend=c('reparam.','reparam. val.', 'local reparam.', 'local reparam. val', 'flipout', 'flipout val', 'vanilla', 'vanilla val', 'VBP', 'VBP val'))

matplot(cbind(
  historyR$metrics$accuracy,
  historyR$metrics$val_accuracy,
  historyL$metrics$accuracy,
  historyL$metrics$val_accuracy,
  historyF$metrics$accuracy,
  historyF$metrics$val_accuracy,
  history$metrics$accuracy,
  history$metrics$val_accuracy,
  historyEP$metrics$accuracy,
  historyEP$metrics$val_accuracy
  ), lty = rep(c(1,2), times=5), col = rep(1:5, each=2),
  type='l', ylab='accuracy',xlab='epoch', log='xy')
legend('bottomright',
  lty=rep(1:2, times=5),
  col=rep(1:5, each=2),
  legend=c('reparam.','reparam. val.', 'local reparam.', 'local reparam. val', 'flipout', 'flipout val', 'vanilla', 'vanilla val', 'EP', 'EP val'))
Image
Image