install.packages(c('tensorflow','keras','reticulate','coro','tfprobability','httr'))
require(keras)
install_keras(version='gpu', extra_packages=c('tensorflow_probability==0.16','tensorflow-datasets','matplotlib'))
We use R reticulate package to bring together R and Python APIs. Internally, install_keras
creates an r-miniconda
environment at ~/.local/share/
and puts packages there.
require(tensorflow)
require(keras)
require(reticulate)
Sys.setenv("TF_XLA_FLAGS" = "--tf_xla_enable_xla_devices")
# R is not supported in colab. install own R.
#use_python("/usr/local/bin/python")
require(tfprobability)
require(coro)
Loading required package: tensorflow Loading required package: keras Loading required package: reticulate Loading required package: tfprobability Loading required package: coro Attaching package: ‘coro’ The following object is masked from ‘package:reticulate’: as_iterator
Duality between message passing and back propagation¶
In this section, we show that we can back-propagate the smoothing densities of a Bayesian neural network through message passing, and we can also back propagate the gradient of loss with respect to filter density parameters. These two back propagation methods give the same gradient of loss with respect to weight parameters. We use this example for two purposes:
- to sanity-check Theorem 1 in the paper, and
- to illustrate that autodiff can compute not only gradients, but also probability distributions.
We use tf.custom_gradient
to pass smoothing probability distributions backward. We compare the smoothing densities of latent features from message passing (implemented in layer_dense_Bayesian
) and from the gradient of cross entropy loss with respect to the mean and variance of latent feature smoothing densities (implemented through tensorflow's autodiff in layer_dense_Bayesian_autodiff
). The conversion is in the following code snippet, where x.update
and P.update
are smoothing densities computed from gradients.
x.update = lapply(1:4, function(n) y_predBayesian_autodiff[[n]][[1]] + tf$einsum('bij,bi->bj', y_predBayesian_autodiff[[n]][[2]], sensitivityBayesian_autodiff[[n]][[1]]))
P.update = lapply(1:4, function(n) y_predBayesian_autodiff[[n]][[2]] -
tf$einsum('bij,bjk,blk->bil', y_predBayesian_autodiff[[n]][[2]], sensitivityBayesian_autodiff[[n]][[2]], y_predBayesian_autodiff[[n]][[2]]) -
tf$einsum('bi,bj->bij', x.update[[n]] - y_predBayesian_autodiff[[n]][[1]], x.update[[n]] - y_predBayesian_autodiff[[n]][[1]]))
mnist <- dataset_mnist()
CE.mvnorm = function(m0, Q0, m1, Q1){
pinv.Q1 = tf$linalg$pinv(Q1)
(tf$reduce_sum( pinv.Q1 * Q0, axis=c(1L, 2L)) + tf$einsum('bi,bij,bj->b', m1-m0, pinv.Q1, m1-m0) + tf$linalg$logdet(2 * pi * Q1)) * (-.5)
}
layer_dense_Bayesian <- Layer(
classname = "BayesDense",
inherit = keras$layers$Layer,
initialize = function(units = 7, f = tf$sigmoid, ...){
super$initialize(...)
self$units = as.integer(units)
self$f = f
},
build = function(input_shape){
#print(input_shape)
super$build(input_shape)
self$A = self$add_weight(name='A',shape=list(input_shape[[1]][[2]], self$units), initializer='normal', trainable=TRUE)
self$B = self$add_weight(name='B',shape=list(input_shape[[1]][[2]], self$units), initializer='uniform', trainable=TRUE)
self$W = tf$zeros_like(self$B[,1])
self$`__forward__` = tf$custom_gradient(function(inputs){
message('trace layer_dense_Bayesian forward ....')
x = inputs[[1]]
P0 = inputs[[2]]
with(tf$GradientTape(persistent=TRUE) %as% g2, {
with(tf$GradientTape(persistent=TRUE) %as% g, {
g$watch(list(x,self$W))
B.W = tf$einsum('ij,i->ij',self$B, self$W)
h = self$f(tf$matmul(x, (self$A + B.W)))
`__h__a` = self$f(tf$matmul(x, (self$A + B.W)))
`__h__b` = self$f(tf$matmul(tf$stop_gradient(x), (self$A + B.W)))
})
dhdW = g$jacobian(`__h__b`, self$W)
dhdx = tf$stop_gradient(g$batch_jacobian(`__h__a`, x))
P = tf$matmul(tf$matmul(dhdx, P0), dhdx, transpose_b=TRUE) + tf$matmul(dhdW, dhdW, transpose_b=TRUE)
P = P + tf$linalg$diag(1e-3*tf$ones_like(tf$linalg$diag_part(P[1,,])))
})
grad_fn = function(dh, dP, variables){
message('trace layer_dense_Bayesian backward ....')
with(g2, {
h.update = dh
P.update = dP
G0 = tf$stop_gradient(tf$einsum('bip,bqp,bqj->bij', P0, dhdx, tf$linalg$pinv(P)))
x.update = tf$stop_gradient(x + tf$einsum('bij,bj->bi',G0, h.update - h))
P0.update = tf$stop_gradient(P0 - tf$einsum('bip,bpq,bjq->bij', G0, P - P.update, G0))
nce = - tf$reduce_mean(CE.mvnorm(h.update - h - tf$einsum('bij,bj->bi',dhdx, x.update-x),
tf$einsum('bip,bpq,bjq->bij',dhdx, P0.update, dhdx) + P.update - tf$einsum('bim,bqm,bjq->bij', P.update,G0,dhdx) - tf$einsum('bim,bqm,bjq->bji', P.update,G0,dhdx),
tf$zeros_like(h),
tf$matmul(dhdW, dhdW, transpose_b=TRUE) + tf$linalg$diag(1e-3*tf$ones_like(P[1,1,]))))
})
list(list(x.update, P0.update), g2$gradient(nce, variables))
}
list(list(h, P), grad_fn)
})
},
call = function(inputs){
self$`__forward__`(inputs)
},
compute_output_shape = function(input_shape){
super()$compute_output_shape(input_shape)
}
)
CategoricalObsModel = PyClass(classname = 'CategoricalObs', defs = list(
inputs = NULL,
shape = NULL,
dtype=tf$float32,
`__log.prob__` = NULL,
`__init__` = function(self, inputs) {
self$inputs = inputs
self$shape = inputs[[1]]$shape
self$`__log.prob__` = tf$custom_gradient(f = function(h, P, y){
message('trace CategoricalObsModel forward ....')
inv.Z = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), tf$einsum('bi->b',tf$math$exp(h)))
p = tf$einsum('bi,b->bi', tf$math$exp(h), inv.Z) # prediction y
dydx = tf$stop_gradient(tf$linalg$diag(p) - tf$einsum('bi,bj->bij', p, p))
S = tf$einsum('bip,bpq,bjq->bij',dydx, P, dydx) + tf$einsum('b,ij->bij',tf$ones_like(inv.Z), diag(rep(.14,10))) #+ tf$linalg$diag( (tf$sign(.2 - tf$abs(y-p))+1)/2*.14 ) + diag(rep(.01, 10))# + diag(rep(.14,10))
log.prob = tf_probability()$distributions$MultivariateNormalFullCovariance( loc=p, covariance_matrix= S )$log_prob( y ) #
grad_fn = function(dy, variables=NULL){
message('trace CategoricalObsModel backward ....')
# tf$assert_less(tf$abs(dy + 1.0), tf$constant(1e-3, dtype=tf$float32))
K = tf$einsum('bip,bqp,bqj->bij', P, dydx, tf$linalg$pinv(S))
h.update = tf$stop_gradient(h + tf$einsum('bij,bj->bi',K, y - p))
P.update = tf$stop_gradient(P - tf$einsum('bip,bpq,bqj->bij', K, dydx, P))
list(h.update, P.update, tf$ones_like(y))
}
list(log.prob, grad_fn)
})
NULL
},
`value` = function(self){
h = tf_probability()$distributions$MultivariateNormalFullCovariance(loc=self$inputs[[1]], covariance_matrix=self$inputs[[2]])$sample()
inv.Z = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), tf$einsum('bi->b',tf$math$exp(h)))
tf$einsum('bi,b->bi', tf$math$exp(h), inv.Z) # prediction y
},
loss = function(self, y){
- self$`__log.prob__`(self$inputs[[1]], self$inputs[[2]], y)
}
))
tf$register_tensor_conversion_function(CategoricalObsModel, function(x,...)x$`value`())
tf$keras$`__internal__`$utils$register_symbolic_tensor_type(CategoricalObsModel)
layer_dense_Bayesian_autodiff <- Layer(
classname = "BayesDense",
inherit = keras$layers$Layer,
initialize = function(units = 7, f = tf$sigmoid, ...){
super$initialize(...)
self$units = as.integer(units)
self$f = f
},
build = function(input_shape){
#print(input_shape)
super$build(input_shape)
self$A = self$add_weight(name='A',shape=list(input_shape[[1]][[2]], self$units), initializer='normal', trainable=TRUE)
self$B = self$add_weight(name='B',shape=list(input_shape[[1]][[2]], self$units), initializer='uniform', trainable=TRUE)
self$W = tf$zeros_like(self$B[,1])
self$`__forward__` = function(inputs){
message('trace layer_dense_Bayesian_autodiff forward ....')
x = inputs[[1]]
P0 = inputs[[2]]
with(tf$GradientTape(persistent=TRUE) %as% g2, {
with(tf$GradientTape(persistent=TRUE) %as% g, {
g$watch(list(x,self$W))
B.W = tf$einsum('ij,i->ij',self$B, self$W)
h = self$f(tf$matmul(x, (self$A + B.W)))
`__h__a` = self$f(tf$matmul(x, (self$A + B.W)))
`__h__b` = self$f(tf$matmul(tf$stop_gradient(x), (self$A + B.W)))
})
dhdW = g$jacobian(`__h__b`, self$W)
dhdx = tf$stop_gradient(g$batch_jacobian(`__h__a`, x))
P = tf$matmul(tf$matmul(dhdx, P0), dhdx, transpose_b=TRUE) + tf$matmul(dhdW, dhdW, transpose_b=TRUE)
P = P + tf$linalg$diag(1e-3*tf$ones_like(tf$linalg$diag_part(P[1,,])))
})
list(h, P)
}
},
call = function(inputs){
self$`__forward__`(inputs)
},
compute_output_shape = function(input_shape){
super()$compute_output_shape(input_shape)
}
)
CategoricalObsModel_autodiff = PyClass(classname = 'CategoricalObs', defs = list(
inputs = NULL,
shape = NULL,
dtype=tf$float32,
`__log.prob__` = NULL,
`__init__` = function(self, inputs) {
self$inputs = inputs
self$shape = inputs[[1]]$shape
self$`__log.prob__` = function(h, P, y){
message('trace CategoricalObsModel_autodiff forward ....')
inv.Z = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), tf$einsum('bi->b',tf$math$exp(h)))
p = tf$einsum('bi,b->bi', tf$math$exp(h), inv.Z) # prediction y
dydx = tf$stop_gradient(tf$linalg$diag(p) - tf$einsum('bi,bj->bij', p, p))
S = tf$einsum('bip,bpq,bjq->bij',dydx, P, dydx) + tf$einsum('b,ij->bij',tf$ones_like(inv.Z), diag(rep(.14,10))) #+ tf$linalg$diag( (tf$sign(.2 - tf$abs(y-p))+1)/2*.14 ) + diag(rep(.01, 10))# + diag(rep(.14,10))
log.prob = tf_probability()$distributions$MultivariateNormalFullCovariance( loc=p, covariance_matrix= S )$log_prob( y ) #
log.prob
}
NULL
},
`value` = function(self){
h = tf_probability()$distributions$MultivariateNormalFullCovariance(loc=self$inputs[[1]], covariance_matrix=self$inputs[[2]])$sample()
inv.Z = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), tf$einsum('bi->b',tf$math$exp(h)))
tf$einsum('bi,b->bi', tf$math$exp(h), inv.Z) # prediction y
},
loss = function(self, y){
- self$`__log.prob__`(self$inputs[[1]], self$inputs[[2]], y)
}
))#, inherit = tf$Module)
#a = CategoricalObsModel(layer_dense_Bayesian(units=10)(inputs))
#a$value()
tf$register_tensor_conversion_function(CategoricalObsModel_autodiff, function(x,...)x$`value`())
tf$keras$`__internal__`$utils$register_symbolic_tensor_type(CategoricalObsModel_autodiff)
inputsBayesian <- list(layer_input(shape = 784L), layer_input(shape=tuple(784L)))
predictionsBayesian <- inputsBayesian %>%
layer_dense_Bayesian(units=7) %>%
layer_dense_Bayesian(units=7) %>%
layer_dense_Bayesian(units=10, f = tf$identity) %>%
layer_lambda(function(inputs) CategoricalObsModel(inputs))
modelBayesian <- keras_model(inputs = inputsBayesian, outputs = predictionsBayesian)
mylossBayesian = function(y, y_hat){
y_hat$loss(y)
}
modelBayesian %>% compile(loss=mylossBayesian, optimizer=tf$keras$optimizers$SGD(learning_rate=.03), metrics='accuracy')
predictionsBayesian_autodiff = inputsBayesian %>%
layer_dense_Bayesian_autodiff(units=7) %>%
layer_dense_Bayesian_autodiff(units=7) %>%
layer_dense_Bayesian_autodiff(units=10, f = tf$identity) %>%
layer_lambda(function(inputs) CategoricalObsModel_autodiff(inputs))
modelBayesian_autodiff = keras_model(inputs = inputsBayesian, outputs = predictionsBayesian_autodiff)
modelBayesian_autodiff %>% compile(loss=mylossBayesian, optimizer=tf$keras$optimizers$SGD(learning_rate=.03), metrics='accuracy')
invisible(lapply(1:6, function(n) modelBayesian_autodiff$trainable_variables[[n]]$assign(modelBayesian$trainable_variables[[n]]) ))
tracking layer_dense_Bayesian forward .... tracking layer_dense_Bayesian forward .... tracking layer_dense_Bayesian forward .... tracking layer_dense_Bayesian forward .... tracking layer_dense_Bayesian forward .... tracking layer_dense_Bayesian forward ....
n = 32L
indices = sample(dim(mnist$train$x)[1], size = n)
x = tf$constant(matrix(mnist$train$x[indices, , ] / 255.0, nrow = n))
y = tf$one_hot(mnist$train$y[indices], depth = 10L)
P0 = tf$einsum('b,ij->bij', tf$ones_like(x[,1]), tf$constant(diag(rep(1e-4, 784)), dtype=tf$float64) )
with(tf$GradientTape() %as% tape, {
y_predBayesian = modelBayesian$call(list(x,P0))
lossBayesian = tf$reduce_mean(mylossBayesian(y, y_predBayesian))
})
gradientsBayesian = tape$gradient(lossBayesian, modelBayesian$trainable_variables)
with(tf$GradientTape() %as% tape, {
y_predBayesian_autodiff = modelBayesian_autodiff$call(list(x,P0))
lossBayesian_autodiff = tf$reduce_mean(mylossBayesian(y, y_predBayesian_autodiff))
})
gradientsBayesian_autodiff = tape$gradient(lossBayesian_autodiff, modelBayesian_autodiff$trainable_variables)
tracking layer_dense_Bayesian forward .... tracking layer_dense_Bayesian forward .... tracking layer_dense_Bayesian forward .... trace CategoricalObsModel forward .... trace CategoricalObsModel backwards .... tracking layer_dense_Bayesian backward .... tracking layer_dense_Bayesian backward .... tracking layer_dense_Bayesian backward .... tracking layer_dense_Bayesian forward .... tracking layer_dense_Bayesian forward .... tracking layer_dense_Bayesian forward .... trace CategoricalObsModel forward ....
par('mfrow')
options('repr.plot.width')
options('repr.plot.height')
- 1
- 1
The following shows that no matter whether we use Tensorflow's autodiff or we back propagate smoothing densities, we get the same gradient of loss with respect to the variational mean and covariance of weight parameters.
#str(gradientsBayesian_autodiff)
#str(gradientsBayesian)
require(repr)
repr.plot.opt = options(repr.plot.width=11, repr.plot.height=16)
par.opt = par(mfrow=c(4,3))
for(n in 1:3){
plot(gradientsBayesian_autodiff[[n*2-1]]$numpy(), gradientsBayesian[[n*2-1]]$numpy(), xlab='autodiff', ylab='custom_gradient', main=sprintf('d W_loc (layer %d)', n),asp=1)
abline(coef=c(0,1),col='red')
}
for(n in 1:3){
plot(gradientsBayesian_autodiff[[n*2]]$numpy(), gradientsBayesian_autodiff[[n*2]]$numpy(), xlab='autodiff', ylab='custom_gradient', main=sprintf('d W_scale (layer %d)', n),asp=1)
abline(coef=c(0,1),col='red')
}
for(n in 1:3){
image(x = seq(dim(gradientsBayesian[[n*2-1]])[1]*2), y=seq(dim(gradientsBayesian[[n*2-1]])[2]), z = rbind(gradientsBayesian[[n*2-1]]$numpy(), gradientsBayesian_autodiff[[n*2-1]]$numpy()), main=sprintf('d W_loc (layer %d)',n), xlab='1:dim(input)', ylab='1:dim(output)', xaxt='n')
abline(v=dim(gradientsBayesian[[n*2-1]])[1]+.5,col='black')
axis(side=1, at = seq(dim(gradientsBayesian[[n*2-1]])[1]*2), label=rep(seq(dim(gradientsBayesian[[n*2-1]])[1]), times=2))
}
for(n in 1:3){
image(x = seq(dim(gradientsBayesian[[n*2]])[1]*2), y=seq(dim(gradientsBayesian[[n*2]])[2]), z = rbind(gradientsBayesian[[n*2]]$numpy(), gradientsBayesian_autodiff[[n*2]]$numpy()), main=sprintf('d W_scale (layer %d)',n), xlab='1:dim(input)', ylab='1:dim(output)', xaxt='n')
abline(v=dim(gradientsBayesian[[n*2]])[1]+.5,col='black')
axis(side=1, at = seq(dim(gradientsBayesian[[n*2]])[1]*2), label=rep(seq(dim(gradientsBayesian[[n*2]])[1]), times=2))
}
par(par.opt)
options(repr.plot.opt)
While we can find the gradient w.r.t. parameters of weight varioual posterior from the smoothing distribution of weight, treating weight as latent features, we compute the gradient directly from smoothing densities of latent features in our layer_dense_Bayesian
, with a hope to provide some deeper insights between back propagating sensitivity wrt functions vs. sensitivity wrst distributions.
nce = - tf$reduce_mean(CE.mvnorm(h.update - h - tf$einsum('bij,bj->bi',dhdx, x.update-x),
tf$einsum('bip,bpq,bjq->bij',dhdx, P0.update, dhdx) + P.update - tf$einsum('bim,bqm,bjq->bij', P.update,G0,dhdx) - tf$einsum('bim,bqm,bjq->bji', P.update,G0,dhdx),
tf$zeros_like(h),
tf$matmul(dhdW, dhdW, transpose_b=TRUE) + tf$linalg$diag(1e-3*tf$ones_like(P[1,1,]))))
})
# gradient wrt weight distributions parameters is `tape$gradient(nce, variables)`
The derivation is below, where we treat a deep neural network as a Gaussian linear process, and use Kalman filter/smoother to backpropagate sensitivity over filter distributions.
- $x_t = F_t x_{t-1} + w_t$, where $w_t\sim {\cal N}(\vec 0, Q_t)$
- $z_t = H_t x_t + v_t$, where $v_t\sim {\cal N}(\vec 0, R_t)$
Lower bound of (latent variable) log likelihood is negative cross entropy + entropy over some proposal distribution. $\log p(y) = \log \sum_x \log p(x,y) \ge \mathbf E_{q(x)} \log p(x,y) - \mathbf E_{q(x)} \log q(x)$. The bound is tight because equality holds when $q(x)=p(x|y)$.
The cross-entropy between two probability distributions $p$ and $q$ measures the average number of bits needed to identify an event drawn from the true distribution $p$ with the coding scheme optimized for an estimated probability distribution $q$. $H(p, q) = - \mathbf E_p(\log q) = H(p)+D_{\mathrm{KL}}(p \| q)$. The K-L divergence between two multivariate normal distributions is $D_{\mathrm{KL}}\left(\mathcal{N}(\mu_{0},\Sigma_0) \| \mathcal{N}(\mu_{1},\Sigma_1)\right)=\frac{1}{2}\left\{\operatorname{tr}\left(\boldsymbol{\Sigma}_{1}^{-1} \boldsymbol{\Sigma}_{0}\right)+\left(\boldsymbol{\mu}_{1}-\boldsymbol{\mu}_{0}\right)^{\mathrm{T}} \boldsymbol{\Sigma}_{1}^{-1}\left(\boldsymbol{\mu}_{1}-\boldsymbol{\mu}_{0}\right)-k+\ln \frac{\left|\boldsymbol{\Sigma}_{1}\right|}{\left|\boldsymbol{\Sigma}_{0}\right|}\right\}$. The entropy of a multivariate normal distribution is $\frac{1}{2} \ln \operatorname{det}(2 \pi e \Sigma)$. So
$$\begin{aligned} H\left(\mathcal{N}(\mu_{0},\Sigma_0), \mathcal{N}(\mu_{1},\Sigma_1)\right) = & \frac{1}{2}\left\{\operatorname{tr}\left(\boldsymbol{\Sigma}_{1}^{-1} \boldsymbol{\Sigma}_{0}\right)+\left(\boldsymbol{\mu}_{1}-\boldsymbol{\mu}_{0}\right)^{\mathrm{T}} \boldsymbol{\Sigma}_{1}^{-1}\left(\boldsymbol{\mu}_{1}-\boldsymbol{\mu}_{0}\right)-k+\ln \frac{\operatorname{det}(\boldsymbol{\Sigma}_{1})}{\operatorname{det}(\boldsymbol{\Sigma}_{0})} + \ln \operatorname{det}(2 \pi e \Sigma_0) \right\}\\ = & \frac{1}{2}\left\{\operatorname{tr}\left(\boldsymbol{\Sigma}_{1}^{-1} \boldsymbol{\Sigma}_{0}\right)+\left(\boldsymbol{\mu}_{1}-\boldsymbol{\mu}_{0}\right)^{\mathrm{T}} \boldsymbol{\Sigma}_{1}^{-1}\left(\boldsymbol{\mu}_{1}-\boldsymbol{\mu}_{0}\right) + \ln \operatorname{det}(2\pi \boldsymbol{\Sigma}_{1}) \right\} \end{aligned}$$
Expected log likelihood of $P(X_{1,\dots,T},Z_{1,\dots,T})$ over posterior distribution $P(X_{1,\dots,T} | Z_{1,\dots,T})$ is
$$\begin{aligned} & \mathbf{E}_{p(X_{1,\dots,T}|Z_{1,\dots,T})}\log P(X_{1,\dots,T},Z_{1,\dots,T}) = \log P(X_{1}) + \sum_{t=2}^{T} \log P(X_{t}|X_{t-1}) + \sum_{t=1}^{T} \log P(Z_{t}|X_{t})\\ = & {-\frac{1}{2}}\log\det(2\pi\Sigma_{0})-{1\over 2}\mathbf E_{p(x_{1}|z_{1,\dots,T})}\left(X_{1}-\mu_{0}\right)^{\text{T}}\Sigma_{0}^{-1}\left(X_{1}-\mu_{0}\right)+\\ & \sum_{t=2}^{T}{-\frac{1}{2}}\log\det(2\pi Q_t)-{1\over 2}\mathbf E_{p(x_{t-1,t}|z_{1,\dots,T})}\left(X_{t}-F_t\cdot X_{t-1}\right)^{\text{T}}Q^{-1}_t\left(X_{t}-F_t\cdot X_{t-1}\right)+\\ & \sum_{t=1}^{T}{-\frac{1}{2}}\log\det(2\pi R_t)-{1\over 2}\mathbf E_{p(x_{t}|z_{1,\dots,T})}\left(Z_{t}-H_t\cdot X_{t}\right)^{\text{T}}R^{-1}_t\left(Z_{t}-H_t\cdot X_{t}\right) \\ = & - H\left(\mathcal N(\hat x_{1|T}, P_{1|T}), \mathcal N(\mu_0, \Sigma_0)\right) + \\ & - H\left(\mathcal N\left(\hat x_{t|T} - F_t \hat x_{t-1|T}, \left(\begin{matrix} -F_t & I\end{matrix}\right) P_{t-1,t|T} \left(\begin{matrix}-F_t\\I\end{matrix}\right)\right), \mathcal N(X_t - F_t X_{t-1}; 0, Q_t)\right) + \\ & - H\left(\mathcal N(Z_t - H_t \hat x_{t|T}, H_t P_{t|T} H_t^\top), \mathcal N(Z_t - H_t X_t; 0, R_t)\right), \end{aligned}$$
where $P_{t,t+1|T} = \left(\begin{array}{c|c} P_{t|t}+P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}\left(P_{t+1|T}-P_{t+1|t}\right)P_{t+1|t}^{-1}F_{t+1}P_{t|t} & P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}P_{t+1|T}\\ P_{t+1|T}P_{t+1|t}^{-1}F_{t+1}P_{t|t} & P_{t+1|T} \end{array}\right) = \left(\begin{array}{c|c} P_{t|T} & G_{t}P_{t+1|T}\\ P_{t+1|T}G_{t}^{\top} & P_{t+1|T} \end{array}\right)$
and hence
$\left(\begin{matrix}-F_{t+1} & I\end{matrix}\right)P_{t,t+1|T}\left(\begin{matrix}-F_{t+1}^{\top}\\ I \end{matrix}\right)=\left(\begin{matrix}-F_{t+1} & I\end{matrix}\right)\left(\begin{array}{c|c} P_{t|T} & G_{t}P_{t+1|T}\\ P_{t+1|T}G_{t}^{\top} & P_{t+1|T} \end{array}\right)\left(\begin{matrix}-F_{t+1}^{\top}\\ I \end{matrix}\right) = F_{t+1}P_{t|T}F_{t+1}^{\top} -F_{t+1}G_{t}P_{t+1|T} -P_{t+1|T}G_{t}^{\top}F_{t+1}^{\top} + P_{t+1|T}$
Derivation: Kalman smoother $X_{t|T}=\hat{x}_{t|t}+P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}\left(x_{t+1|T}-F_{t+1}\hat{x}_{t|t}\right)+\epsilon$ where $\left(x_{t+1|T}-F_{t+1}\hat{x}_{t|t}\right)$ is the innovation brought forth by $x_{t+1|T}$, $P_{t+1|t}=F_{t+1}P_{t|t}F_{t+1}^{\top}+Q_{t+1}$ is the covariance of innovation, $P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}$ is the gain in smoothing, $\epsilon\sim\mathcal{N}(0,P_{t|t}-P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}F_{t+1}P_{t|t})$ is the noise term to sample $X_{t|T}$ conditional on $X_{t+1|T}$ as $in p(X_t|x_{t+1},y_1,\dots,y_T)$, and $X_{t+1|T}\sim\mathcal{N}(\hat{x}_{t+1|T},P_{t+1|T})$. The joint distribution is
$$\begin{aligned} \left(\begin{array}{c} \mathbf{x}_{t|T}\\ \mathbf{x}_{t+1|T} \end{array}\right) & \sim\mathcal{N}\left(\left(\begin{array}{c} \hat{\mathbf{x}}_{t|t}+P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}\left(\hat{\mathbf{x}}_{t+1|T}-\hat{\mathbf{x}}_{t+1|t}\right)\\ \hat{\mathbf{x}}_{t+1|T} \end{array}\right),\left(\begin{array}{c|c} P_{t|t}-P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}\left(P_{t+1|T}-P_{t+1|t}\right)P_{t+1|t}^{-1}F_{t+1}P_{t|t} & P_{t|t}F_{t+1}^{\top}P_{t+1|t}^{-1}P_{t+1|T}\\ P_{t+1|T}P_{t+1|t}^{-1}F_{t+1}P_{t|t} & P_{t+1|T} \end{array}\right)\right),\\ & \text{where }P_{t+1|t}=F_{t+1}P_{t|t}F_{t+1}^{\top}+Q_{t+1},\hat{\mathbf{x}}_{t+1|t}=F_{t+1}\hat{x}_{t|t}. \end{aligned}$$
modelBayesian2 = keras_model(inputs = inputsBayesian,
outputs = c(list(inputsBayesian), lapply(modelBayesian$layers[-2:-1], function(n) n$output)))
with(tf$GradientTape() %as% tape, {
#tape$watch(inputsBayesian$)
y_predBayesian = modelBayesian2$call(list(x,P0))
tape$watch(y_predBayesian[1:4])
lossBayesian = tf$reduce_mean(mylossBayesian(y, y_predBayesian[[5]]))
})
sensitivityBayesian = tape$gradient(lossBayesian, y_predBayesian[1:4])
modelBayesian_autodiff2 = keras_model(inputs = inputsBayesian,
outputs = c(list(inputsBayesian), lapply(modelBayesian_autodiff$layers[-2:-1], function(n) n$output)))
with(tf$GradientTape() %as% tape, {
tape$watch(P0)
y_predBayesian_autodiff = modelBayesian_autodiff2$call(list(x,P0))
tape$watch(y_predBayesian_autodiff[1:4])
lossBayesian_autodiff = tf$reduce_mean(mylossBayesian(y, y_predBayesian_autodiff[[5]]))
})
sensitivityBayesian_autodiff = tape$gradient(lossBayesian_autodiff, y_predBayesian_autodiff[1:4])
x.update = lapply(1:4, function(n) y_predBayesian_autodiff[[n]][[1]] + tf$einsum('bij,bi->bj', y_predBayesian_autodiff[[n]][[2]], sensitivityBayesian_autodiff[[n]][[1]]))
P.update = lapply(1:4, function(n) y_predBayesian_autodiff[[n]][[2]] -
tf$einsum('bij,bjk,blk->bil', y_predBayesian_autodiff[[n]][[2]], sensitivityBayesian_autodiff[[n]][[2]], y_predBayesian_autodiff[[n]][[2]]) -
tf$einsum('bi,bj->bij', x.update[[n]] - y_predBayesian_autodiff[[n]][[1]], x.update[[n]] - y_predBayesian_autodiff[[n]][[1]]))
tracking layer_dense_Bayesian forward .... tracking layer_dense_Bayesian forward .... tracking layer_dense_Bayesian forward .... trace CategoricalObsModel forward .... trace CategoricalObsModel backwards .... tracking layer_dense_Bayesian backward .... tracking layer_dense_Bayesian backward .... tracking layer_dense_Bayesian backward .... tracking layer_dense_Bayesian forward .... tracking layer_dense_Bayesian forward .... tracking layer_dense_Bayesian forward .... trace CategoricalObsModel forward ....
The following shows that no matter whether we use Tensorflow's autodiff or we back propagate smoothing densities, we get the same latent features smoothing densities. Because there is a 1-1 correspondence between smoothing densities and sensitivity of latent feature filter density parameters, both message passing and back propagation give the same sensitivities.
require(repr)
repr.plot.opt = options(repr.plot.width=16, repr.plot.height=4)
par.opt = par(mfcol=c(1,4))
for(n in 1:4){
plot(x.update[[n]]$numpy(), sensitivityBayesian[[n]][[1]]$numpy(), xlab='autodiff', ylab='custom_gradient', main='smoothing mean',asp=1)
abline(coef=c(0,1),col='red')
}
for(n in 1:4){
plot(P.update[[n]]$numpy(), sensitivityBayesian[[n]][[2]]$numpy(), xlab='autodiff', ylab='custom_gradient', main='smoothing var',asp=1)
abline(coef=c(0,1),col='red')
}
par(par.opt)
options(repr.plot.opt)
Loading required package: repr
rm(list=ls())
Duality between Langevin dynamics and Fokker-Planck dynamics¶
In this section, we show that we can forward propagate the filter densities of post-activations and backpropagate the smoothing densities in a Bayesian neural network, and we can also forward propagate tensor post-activation "particles" in an ensembled of vanilla (non-Bayesian) neural networks sampled from the BNN and the sensitivities of these tensor particles. These two approximate methods should ultimately agree on the gradient of loss with respect to weight parameters, if variational bias and MC variance goes to 0.
We also show that the smoothing densities contain information about the sensitivity of tensor post-activation particles --- how they should move to decrease the loss. In other words, we can estimate the sensitivity of the post-activations of a vanilla NN from the smooth densities of a Bayesian NN.
We use this example for two purposes:
- to sanity-check Theorem 2 in the paper, and
- to illustrate that the power of a deep neural network lies in the randomness in the synaptic weights. The information between input and output rides on the random synaptic weights.
require(tensorflow)
require(keras)
require(reticulate)
require(tfprobability)
require(coro)
mnist = dataset_mnist()
n = 32L
indices = 1:n
x = tf$constant(matrix(mnist$train$x[indices,,]/255.0, nrow = n), dtype=tf$float32)
y = tf$one_hot(mnist$train$y[indices], depth = 10L)
P0 = tf$einsum('b,ij->bij', tf$ones_like(x[,1]), diag(rep(1e-4, 784)))
#P0 = tf$einsum('b,ij->bij', tf$ones_like(x[,1]), diag(rep(0, 784)))
Loading required package: tensorflow Loading required package: keras Loading required package: reticulate Loading required package: tfprobability Loading required package: coro Attaching package: ‘coro’ The following object is masked from ‘package:reticulate’: as_iterator
layer_dense_Bayesian = Layer(
classname = 'BayesianDense',
inherit = tf$keras$layers$Dense,
initialize = function(units = 7L, activation, name=NULL) {
super$initialize(units=units, activation=activation, use_bias=TRUE, name=name)
self$input_spec = list(tf$keras$layers$InputSpec(min_ndim = 2L),tf$keras$layers$InputSpec(min_ndim = 3L))
},
build = function(input_shape) {
c(shp.mean, shp.var) %<-% input_shape
last_dim = shp.mean[[-1]]
self$A = self$add_weight(name='A',shape=list(last_dim, self$units), initializer='normal', trainable=TRUE)
self$B = self$add_weight(name='B',shape=list(last_dim, self$units), initializer=tf$initializers$RandomUniform(minval = -.01, maxval = .01), trainable=TRUE)
self$W = tf$zeros_like(self$B[,1])
self$built=TRUE
NULL
},
call = function(inputs) {
c(x, P0) %<-% inputs
with(tf$GradientTape(persistent=TRUE) %as% g, {
self$bias = tf$zeros_like(self$B[1,])
g$watch(list(x, self$W, self$bias))
self$kernel = self$A + tf$einsum('ij,i->ij',self$B, self$W) # mvnorm with mean and variance
h = super$call(x)
`__h__a` = super$call(x)
`__h__b` = super$call(tf$stop_gradient(x))
})
# propagate variance (accept a variance, first order approximation of state transition, output a variance)
dhdW = g$jacobian(`__h__b`, self$W)
dhdx = tf$stop_gradient(g$batch_jacobian(`__h__a`, x))
dhdB = tf$stop_gradient(g$jacobian(`__h__a`, self$bias))
P = tf$matmul(tf$matmul(dhdx, P0), dhdx, transpose_b=TRUE) + tf$matmul(dhdW, dhdW, transpose_b=TRUE) + tf$linalg$diag(tf$einsum('bij,j->bi',dhdB^2, 1e-4*tf$ones_like(self$bias)))
list(h, P)
}
)
CategoricalObsModel_autodiff = PyClass(classname = 'CategoricalObs', defs = list(
inputs = NULL,
shape = NULL,
dtype=tf$float32,
`__log.prob__` = NULL,
`__init__` = function(self, inputs) {
self$inputs = inputs
self$shape = inputs[[1]]$shape
self$`__log.prob__` = function(h, P, y){
message('trace CategoricalObsModel forward ....')
inv.Z = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), tf$einsum('bi->b',tf$math$exp(h)))
p = tf$einsum('bi,b->bi', tf$math$exp(h), inv.Z) # prediction y
dydx = tf$stop_gradient(tf$linalg$diag(p) - tf$einsum('bi,bj->bij', p, p))
S = tf$einsum('bip,bpq,bjq->bij',dydx, P, dydx) + tf$einsum('b,ij->bij',tf$ones_like(inv.Z), diag(rep(.14,10))) #+ tf$linalg$diag( (tf$sign(.2 - tf$abs(y-p))+1)/2*.14 ) + diag(rep(.01, 10))# + diag(rep(.14,10))
log.prob = tf_probability()$distributions$MultivariateNormalFullCovariance( loc=p, covariance_matrix= S )$log_prob( y ) #
log.prob
}
NULL
},
`value` = function(self){
h = tf_probability()$distributions$MultivariateNormalFullCovariance(loc=self$inputs[[1]], covariance_matrix=self$inputs[[2]])$sample()
inv.Z = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), tf$einsum('bi->b',tf$math$exp(h)))
tf$einsum('bi,b->bi', tf$math$exp(h), inv.Z) # prediction y
},
loss = function(self, y){
- self$`__log.prob__`(self$inputs[[1]], self$inputs[[2]], y)
}
))
tf$register_tensor_conversion_function(CategoricalObsModel_autodiff, function(x,...)x$`value`())
tf$keras$`__internal__`$utils$register_symbolic_tensor_type(CategoricalObsModel_autodiff)
inputsBayesian <- list(layer_input(shape = 784L, name='x0'), layer_input(shape=tuple(784L, 784L), name='P0'))
predictionsBayesian <- inputsBayesian %>%
layer_dense_Bayesian(units=7L, activation = tf$nn$sigmoid, name='bayesian_dense_1') %>%
layer_dense_Bayesian(units=7L, activation = tf$nn$sigmoid, name='bayesian_dense_2') %>%
layer_dense_Bayesian(units=10L, activation = tf$identity, name='bayesian_dense_3') %>%
layer_lambda(function(inputs) CategoricalObsModel_autodiff(inputs), name='y_pred')
predictionsBayesian$`_type_spec` = tf$TensorSpec(shape=predictionsBayesian$type_spec$shape, dtype=predictionsBayesian$type_spec$dtype)
modelBayesian = keras_model(inputs = inputsBayesian, outputs = predictionsBayesian)
mylossBayesian = function(y, y_hat){
y_hat$loss(y)
}
modelBayesian %>% compile(loss=mylossBayesian, optimizer='adam', metrics='accuracy')
#invisible(lapply(grep("bayesian_dense_[0-9]+/B:[0-9]+", sapply(modelBayesian$weights, function(n) n$name)), function(n){
# modelBayesian$weights[[n]]$assign( modelBayesian$weights[[n]]/tf$reduce_max(tf$abs(modelBayesian$weights[[n]]))*.05 )
#}))
modelBayesian2 = keras_model(inputs = inputsBayesian, outputs = c(list(inputsBayesian), lapply(modelBayesian$layers[-2:-1], function(n) n$output)))
modelBayesian2 %>% compile(loss=mylossBayesian, optimizer='adam', metrics='accuracy')
with(tf$GradientTape() %as% tape, {
tape$watch(list(x, P0))
y_predBayesian = modelBayesian2$call(list(x,P0))
tape$watch(y_predBayesian[1:4])
lossBayesian = y_predBayesian[[5]]$loss(y)
})
sensitivityBayesian = tape$gradient(lossBayesian, y_predBayesian[1:4])
str(y_predBayesian)
str(sensitivityBayesian)
modelBayesian2
trace CategoricalObsModel forward ....
List of 5 $ :List of 2 ..$ :<tf.Tensor: shape=(32, 784), dtype=float32, numpy=…> ..$ :<tf.Tensor: shape=(32, 784, 784), dtype=float32, numpy=…> $ :List of 2 ..$ :<tf.Tensor: shape=(32, 7), dtype=float32, numpy=…> ..$ :<tf.Tensor: shape=(32, 7, 7), dtype=float32, numpy=…> $ :List of 2 ..$ :<tf.Tensor: shape=(32, 7), dtype=float32, numpy=…> ..$ :<tf.Tensor: shape=(32, 7, 7), dtype=float32, numpy=…> $ :List of 2 ..$ :<tf.Tensor: shape=(32, 10), dtype=float32, numpy=…> ..$ :<tf.Tensor: shape=(32, 10, 10), dtype=float32, numpy=…> $ :<CategoricalObs object at 0x7f37d04a1580> List of 4 $ :List of 2 ..$ :<tf.Tensor: shape=(32, 784), dtype=float32, numpy=…> ..$ :<tf.Tensor: shape=(32, 784, 784), dtype=float32, numpy=…> $ :List of 2 ..$ :<tf.Tensor: shape=(32, 7), dtype=float32, numpy=…> ..$ :<tf.Tensor: shape=(32, 7, 7), dtype=float32, numpy=…> $ :List of 2 ..$ :<tf.Tensor: shape=(32, 7), dtype=float32, numpy=…> ..$ :<tf.Tensor: shape=(32, 7, 7), dtype=float32, numpy=…> $ :List of 2 ..$ :<tf.Tensor: shape=(32, 10), dtype=float32, numpy=…> ..$ :<tf.Tensor: shape=(32, 10, 10), dtype=float32, numpy=…>
Model: "model_1" ________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================ x0 (InputLayer) [(None, 784)] 0 [] P0 (InputLayer) [(None, 784, 784 0 [] )] bayesian_dense_1 (Bayesi [(None, 7), 10976 ['x0[0][0]', anDense) (None, 7, 7)] 'P0[0][0]'] bayesian_dense_2 (Bayesi [(None, 7), 98 ['bayesian_dense_1[0][0]', anDense) (None, 7, 7)] 'bayesian_dense_1[0][1]'] bayesian_dense_3 (Bayesi [(None, 10), 140 ['bayesian_dense_2[0][0]', anDense) (None, 10, 10)] 'bayesian_dense_2[0][1]'] y_pred (Lambda) (None, 10) 0 ['bayesian_dense_3[0][0]', 'bayesian_dense_3[0][1]'] ================================================================================ Total params: 11,214 Trainable params: 11,214 Non-trainable params: 0 ________________________________________________________________________________
In the following, we sample an ensemble of 64 (non-Bayesian) nerual networks from the Bayesian neural network.
layer_dense_linearization = Layer(
classname = 'DenseLinearization',
inherit = tf$keras$layers$Layer,
initialize = function(activation=NULL, ...) {
super$initialize()
self$dense = tf$keras$layers$Dense(activation=NULL, ...)
self$activation = keras$activations$get(activation)
py_set_attr(self, '_name', self$dense$name)
},
build = function(input_shape) {
self$dense$build(input_shape)
self$built=TRUE
},
call = function(inputs) {
with(tf$GradientTape(persistent = TRUE) %as% tape, {
tape$watch(inputs)
h = self$dense$call(inputs)
y0 = self$activation(h)
})
y = tf$stop_gradient(y0) + tf$stop_gradient(tape$gradient(y0, h)) * (h - tf$stop_gradient(h))
}
)
### sample vanilla NN from Bayesian NN, assuming weights are independent
inputs = layer_input(shape = 784L, name='input')
predictions = lapply(1:64, function(i){
inputs %>%
layer_lambda(function(x) tfp$distributions$MultivariateNormalDiag(loc = x, scale_identity_multiplier = 1e-2 )$sample(), name = sprintf('sample_%01d_0',i) ) %>%
layer_dense_linearization(units=7, activation = 'sigmoid', use_bias = TRUE, name = sprintf('dense_%01d_1',i)) %>%
layer_dense_linearization(units=7, activation = 'sigmoid', use_bias = TRUE, name = sprintf('dense_%01d_2',i)) %>%
layer_dense_linearization(units=10, activation = NULL, use_bias = TRUE, name = sprintf('dense_%01d_3',i)) #%>%
#layer_activation(activation = 'softmax', name=sprintf('softmax_%01d_3',i))
})
model = keras_model(inputs = inputs, outputs = predictions)
myloss = function(y, y_hat){
p = tf$math$softmax(y_hat, axis=-1L)
dydx = tf$linalg$diag(p) - tf$einsum('bi,bj->bij', p, p)
y_pred = tf$stop_gradient(p) + tf$einsum('bij,bi->bj', tf$stop_gradient(dydx), y_hat - tf$stop_gradient(y_hat))
- tf_probability()$distributions$MultivariateNormalDiag(loc=y_pred, scale_diag = tf$ones_like(y)*.14^.5)$log_prob(y)
}
model %>% compile(loss=myloss, optimizer=modelBayesian$optimizer, metrics='accuracy')
invisible(lapply(1:3, function(j){
lapply(1:length(predictions), function(i){
A = modelBayesian$get_layer(name = sprintf('bayesian_dense_%d',j))$weights[[1]]
B = modelBayesian$get_layer(name = sprintf('bayesian_dense_%d',j))$weights[[2]]
W = tf_probability()$distributions$MultivariateNormalDiag(loc=tf$zeros_like(B[,1]), scale_diag=tf$ones_like(B[,1]))$sample()
model$get_layer(name = sprintf('dense_%01d_%01d',i,j))$weights[[1]]$assign(A + tf$einsum('ij,i->ij',B, W))
# do not forget noise added to the latent layer, my bias is applied after activation, i.e., at the input side
bias = tf_probability()$distributions$MultivariateNormalDiag(loc=tf$zeros_like(B[1,]), scale_diag=tf$ones_like(B[1,]))$sample() * 1e-2
model$get_layer(name = sprintf('dense_%01d_%01d',i,j))$weights[[2]]$assign(bias)
})
}))
### find latent state and sensitivities
model2 = keras_model(
inputs = inputs,
outputs = c(x = model$get_layer('input')$output, split(lapply(model$layers[-1], function(l) l$output), gsub('(dense|sample)_[0-9]+_([0-9])', '\\2', sapply(model$layers[-1], function(n) n$name)))))
# `tf.keras.Model.output` gives structured model output, while `tf.keras.Model.outputs` gives a list of output tensors
# model2$output
with(tf$GradientTape(persistent = TRUE) %as% t2, {
with(tf$GradientTape() %as% tape, {
tape$watch(x)
ret = model2$call(x)
t2$watch(ret)
tape$watch(ret)
`__loss__` = sapply(1:length(ret[[5]]), function(i) {
myloss(y, ret[[5]][[i]])
})
})
sensitivity = tape$gradient(`__loss__`, ret)
})
#plot(tf$reduce_mean(`__loss__`, axis=0L), lossBayesian)
hessian = outer(2:length(ret), 1:length(ret[[2]]), Vectorize(function(m,n) t2$batch_jacobian(sensitivity[[m]][[n]], ret[[m]][[n]]) ))
x.filter.vi = sapply(1L:4L, function(i){ y_predBayesian[[i]][[1]] })
P.filter.vi = sapply(1L:4L, function(i){ y_predBayesian[[i]][[2]] })
x.sensitivity.vi = sapply(1L:4L, function(i){ sensitivityBayesian[[i]][[1]] })
P.sensitivity.vi = sapply(1L:4L, function(i){ sensitivityBayesian[[i]][[2]] })
x.smoothing.vi = sapply(1:4, function(i){
y_predBayesian[[i]][[1]] +tf$einsum('bim,bm->bi', y_predBayesian[[i]][[2]], sensitivityBayesian[[i]][[1]])
})
P.smoothing.vi = sapply(1:4, function(i){
y_predBayesian[[i]][[2]] +
tf$einsum('bij,bjk,blk->bil', y_predBayesian[[i]][[2]], sensitivityBayesian[[i]][[2]], y_predBayesian[[i]][[2]])*2 -
tf$einsum('bim,bm,bn,bjn->bij', y_predBayesian[[i]][[2]], sensitivityBayesian[[i]][[1]], sensitivityBayesian[[i]][[1]], y_predBayesian[[i]][[2]])
})
# Hessian
H.vi = lapply(1:4, function(i){
- tf$linalg$pinv(P.filter.vi[[i]] + tf$linalg$pinv(- sensitivityBayesian[[i]][[2]]*2 - tf$einsum('bm,bn->bmn', sensitivityBayesian[[i]][[1]], sensitivityBayesian[[i]][[1]]) ))
})
# Fisher divergence, modulus a data dependent term
FD.vi = -tf$linalg$diag_part(sensitivityBayesian[[1]][[2]]) - .5 * tf$linalg$diag_part(H.vi[[1]])
x.filter.mc = sapply(1:4, function(i){
tf$reduce_mean(tf$stack(ret[[i+1]]), axis=0L)
})
P.filter.mc = sapply(1:4, function(i){
`__x__` = tf$stack(ret[[i+1]])
`__2nd_moment__` = tf$einsum('lbi,lbj->bij', `__x__`, `__x__`)/dim(`__x__`)[1]
`__1st_moment__` = tf$reduce_mean(`__x__`, axis=0L)
`__2nd_moment__` - tf$einsum('bi,bj->bij',`__1st_moment__`,`__1st_moment__`)
})
posterior = tf$math$divide_no_nan(tf$einsum('lbi,bi->lb', tf$nn$softmax(tf$stack(ret[[5]]), axis=-1L), y), tf$einsum('lbi,bi->b', tf$nn$softmax(tf$stack(ret[[5]]), axis=-1L), y))
## it seems that the posterior estimated in this way has higher variance than unweighted average
x.sensitivity.mc = sapply(1:4, function(i){
tf$einsum('lbi,lb->bi',tf$stack(sensitivity[[i+1]]), posterior)
})
# Hessian Bayesian
H.mc = lapply(1:4, function(i){
tf$reduce_mean(tf$stack(hessian[i,]), axis=0L)
})
P.sensitivity.mc = sapply(1:4, function(i){
A = tf$einsum('bsi,bsj,sb->bij',tf$stack(sensitivity[[i+1]],axis=1L), tf$stack(sensitivity[[i+1]],axis=1L), posterior)
H = H.mc[[i]]
(H - A) * .5
})
# Fisher divergence, modulus a data dependent term
FD.mc = .5 * tf$reduce_mean(tf$stack(sensitivity[[2]], axis=0L)^2, axis=0L) - tf$reduce_mean(tf$stack(lapply(hessian[1,], tf$linalg$diag_part), axis=0L), axis=0L)
The following shows that the smoothing densities of post-activations (in the Bayesian NN) can be estimated from the sensitivities of the post-activations (in the ensemble of 64 non-Bayesian/ vanilla neural networks). Monte Carlo estimation is very noisy though, especially in the estimation of Hessian, even in this tiny toy example.
require(repr)
repr.plot.opt = options(repr.plot.width=17, repr.plot.height=4.5)
par.opt = par(mfcol=c(1,4))
plot(x.filter.vi[[1]]$numpy(), x.filter.mc[[1]]$numpy(), xlab='VI', ylab='MC', main = 'forward x[1]')
abline(coef=c(0,1),col='red')
plot(x.filter.vi[[2]]$numpy(), x.filter.mc[[2]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'forward x[2]')
abline(coef=c(0,1),col='red')
plot(x.filter.vi[[3]]$numpy(), x.filter.mc[[3]]$numpy(), xlab='VI', ylab='MC', main = 'forward x[3]')
abline(coef=c(0,1),col='red')
plot(x.filter.vi[[4]]$numpy(), x.filter.mc[[4]]$numpy(), xlab='VI', ylab='MC', main = 'forward x[4]')
abline(coef=c(0,1),col='red')
plot(P.filter.vi[[1]][1:8,,]$numpy(), P.filter.mc[[1]][1:8,,]$numpy(), xlab='VI', ylab='MC', main = 'forward P[1]')
points(tf$linalg$diag_part(P.filter.vi[[1]])$numpy(), tf$linalg$diag_part(P.filter.mc[[1]])$numpy(),col='red')
abline(coef=c(0,1),col='red')
plot(P.filter.vi[[2]]$numpy(), P.filter.mc[[2]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'forward P[2]')
points(tf$linalg$diag_part(P.filter.vi[[2]])$numpy(), tf$linalg$diag_part(P.filter.mc[[2]])$numpy(),col='red')
abline(coef=c(0,1),col='red')
plot(P.filter.vi[[3]]$numpy(), P.filter.mc[[3]]$numpy(), xlab='VI', ylab='MC', main = 'forward P[3]')
points(tf$linalg$diag_part(P.filter.vi[[3]])$numpy(), tf$linalg$diag_part(P.filter.mc[[3]])$numpy(),col='red')
abline(coef=c(0,1),col='red')
plot(P.filter.vi[[4]]$numpy(), P.filter.mc[[4]]$numpy(), xlab='VI', ylab='MC', main = 'forward P[4]')
points(tf$linalg$diag_part(P.filter.vi[[4]])$numpy(), tf$linalg$diag_part(P.filter.mc[[4]])$numpy(),col='red')
abline(coef=c(0,1),col='red')
plot(x.sensitivity.vi[[1]]$numpy(), x.sensitivity.mc[[1]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward x[1]')
abline(coef=c(0,1),col='red')
plot(x.sensitivity.vi[[2]]$numpy(), x.sensitivity.mc[[2]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward x[2]')
abline(coef=c(0,1),col='red')
plot(x.sensitivity.vi[[3]]$numpy(), x.sensitivity.mc[[3]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward x[3]')
abline(coef=c(0,1),col='red')
plot(x.sensitivity.vi[[4]]$numpy(), x.sensitivity.mc[[4]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward x[4]')
abline(coef=c(0,1),col='red')
local({
ndx = sample(prod(dim(P.sensitivity.mc[[1]])), 4096)
plot(P.sensitivity.vi[[1]]$numpy()[ndx], P.sensitivity.mc[[1]]$numpy()[ndx], asp=1, xlab='VI', ylab='MC', main = 'backward P[1]')
points(tf$linalg$diag_part(P.sensitivity.vi[[1]])$numpy(), tf$linalg$diag_part(P.sensitivity.mc[[1]])$numpy(), col='red')
abline(coef=c(0,1), col='red')
})
plot(P.sensitivity.vi[[2]]$numpy(), P.sensitivity.mc[[2]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward P[2]')
points(tf$linalg$diag_part(P.sensitivity.vi[[2]])$numpy(), tf$linalg$diag_part(P.sensitivity.mc[[2]])$numpy(), col='red')
abline(coef=c(0,1), col='red')
plot(P.sensitivity.vi[[3]]$numpy(), P.sensitivity.mc[[3]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward P[3]')
points(tf$linalg$diag_part(P.sensitivity.vi[[3]])$numpy(), tf$linalg$diag_part(P.sensitivity.mc[[3]])$numpy(), col='red')
abline(coef=c(0,1), col='red')
plot(P.sensitivity.vi[[4]]$numpy(), P.sensitivity.mc[[4]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'backward P[4]')
points(tf$linalg$diag_part(P.sensitivity.vi[[4]])$numpy(), tf$linalg$diag_part(P.sensitivity.mc[[4]])$numpy(), col='red')
abline(coef=c(0,1), col='red')
local({
ndx = sample(prod(dim(H.mc[[1]])), 4096)
plot(H.vi[[1]]$numpy()[ndx], H.mc[[1]]$numpy()[ndx], asp=1, xlab='VI', ylab='MC', main = 'Hessian [1]')
points(tf$linalg$diag_part(H.vi[[1]])$numpy(), tf$linalg$diag_part(H.mc[[1]])$numpy(), col='red')
abline(coef=c(0,1), col='red')
})
plot(H.vi[[2]]$numpy(), H.mc[[2]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'Hessian [2]')
abline(coef=c(0,1),col='red')
plot(H.vi[[3]]$numpy(), H.mc[[3]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'Hessian [3]')
abline(coef=c(0,1),col='red')
plot(H.vi[[4]]$numpy(), H.mc[[4]]$numpy(), asp=1, xlab='VI', ylab='MC', main = 'Hessian [4]')
abline(coef=c(0,1),col='red')
par(par.opt)
options(repr.plot.opt)
Loading required package: repr
The following shows that we can compute particle sensitivity through (Bayesian) sensitivity over particle distribution parameters in the forward pass. The computation of particle sensitivity from through the inverse of smoothing densities and filter densities calculated from Bayesian sensitivity is not numerically stable. A more numerically stable way to compute sensitivity of loss over particles from sensitivity of loss over forward distribution parameter of particles.
$\require{cancel}$
\begin{align*} \hat{\mathbf{x}}_{l|L}= & \hat{\mathbf{x}}_{l}+P_{l}\left(\nabla_{\hat{\mathbf{x}}_{l}}p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)\\ P_{l|L}= & P_{l}+P_{l}\left[2\nabla_{P_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})-\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)^{\top}\right]P_{l}\\ P_{l|L}^{-1}= & P_{l}^{-1}-\cancel{P_{l}^{-1}P_{l}}\left[\left(2\nabla_{P_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})-\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)^{\top}\right)^{-1}+P_{l}P_{l}^{-1}P_{l}\right]^{-1}\cancel{P_{l}P_{l}^{-1}}\\ = & p_{l}^{-1}-\left[\left(2\nabla_{P_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})-\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)^{\top}\right)^{-1}+P_{l}P_{l}^{-1}P_{l}\right]^{-1}\\ \nabla_{\mathbf{x}_{l}}\log p(\mathbf{y}|\mathbf{x}_{l},\mathbf{x}_{0}=\mathbf{x})= & P_{l}^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l}\right)-P_{l|L}^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l|L}\right)\\ = & P_{l}^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l}\right)-\left\{ p_{l}^{-1}-\left[\left(2\nabla_{P_{l}}-\nabla_{\hat{\mathbf{x}}_{l}}\nabla_{\hat{\mathbf{x}}_{l}}^{\top}\right)^{-1}+P_{l}\right]^{-1}\right\} \left(\mathbf{x}_{l}-P_{l}\nabla_{\hat{\mathbf{x}}_{l}}-\hat{\mathbf{x}}_{l}\right)\\ = & \cancel{P_{l}^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l}\right)}\cancel{-p_{l}^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l}\right)}+\bcancel{p_{l}^{-1}}\bcancel{P_{l}}\nabla_{\hat{\mathbf{x}}_{l}}+\left[\left(2\nabla_{P_{l}}-\nabla_{\hat{\mathbf{x}}_{l}}\nabla_{\hat{\mathbf{x}}_{l}}^{\top}\right)^{-1}+P_{l}\right]^{-1}\left(\mathbf{x}_{l}-P_{l}\nabla_{\hat{\mathbf{x}}_{l}}-\hat{\mathbf{x}}_{l}\right)\\ = & \left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)+\left[\left(2\nabla_{P_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})-\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)^{\top}\right)^{-1}+P_{l}\right]^{-1}\left(\mathbf{x}_{l}-P_{l}\left(\nabla_{\hat{\mathbf{x}}_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})\right)-\hat{\mathbf{x}}_{l}\right) \end{align*} The loss is $J=-\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})$, so if we formulate particle sensitivity using (Bayesian) sensitivity over forward distribution parameters, the corresponding signs should be changed.
\begin{align*} - \nabla_{\mathbf{x}_{l}}J = & -\nabla_{\hat{\mathbf{x}}_{l}}J+\left[\left(-2\nabla_{P_{l}}J-\left(\nabla_{\hat{\mathbf{x}}_{l}}J\right)\left(\nabla_{\hat{\mathbf{x}}_{l}}J\right)^{\top}\right)^{-1}+P_{l}\right]^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l}+P_{l}\nabla_{\hat{\mathbf{x}}_{l}}J\right)\text{ or }\\ \nabla_{\mathbf{x}_{l}}J = & \nabla_{\hat{\mathbf{x}}_{l}}J - \left[P_{l} - \left(2\nabla_{P_{l}}J+\left(\nabla_{\hat{\mathbf{x}}_{l}}J\right)\left(\nabla_{\hat{\mathbf{x}}_{l}}J\right)^{\top}\right)^{-1}\right]^{-1}\left(\mathbf{x}_{l}-\hat{\mathbf{x}}_{l}+P_{l}\nabla_{\hat{\mathbf{x}}_{l}}J\right). \end{align*}
The relationship is based on the "state transition" of particles at layers is linear Gaussian. In particular, in the linear Gaussian model, the activation function is linearized using first order Taylor expansion at the mean pre-activation. If this linear Gaussian assumption is held in the vanilla NN samples, the relationship between particle sensitivity and Bayesian sensitivity should be exact. However, the vanilla NN samples are not exactly from the linear Gaussian distribution. In particular, the gradient of sensitivity functions over pre-activation is not constant. That is why we see the scattering of particle sensitivity around the prediction (based on linear Gaussian assumption).
dldx = sapply(1:4, function(i){
A = tf$linalg$pinv(P.filter.vi[[i]] + tf$linalg$pinv(- sensitivityBayesian[[i]][[2]]*2 - tf$einsum('bm,bn->bmn', sensitivityBayesian[[i]][[1]], sensitivityBayesian[[i]][[1]]) ))
sensitivityBayesian[[i]][[1]] - tf$einsum('bij,lbj->lbi', A, tf$stack(ret[[i+1]]) - x.filter.vi[[i]] + tf$einsum('bij,bj->bi',P.filter.vi[[i]], sensitivityBayesian[[i]][[1]]))
})
require(repr)
repr.plot.opt = options(repr.plot.width=17, repr.plot.height=4.5)
par.opt = par(mfcol=c(1,4))
require(MASS)
image(kde2d(dldx[[1]]$numpy(), tf$stack(sensitivity[[2]])$numpy(), n=100, lims = rep(range(tf$stack(sensitivity[[2]])$numpy()), times=2)))
#points(dldx[[1]]$numpy(), tf$stack(sensitivity[[2]])$numpy(), pch='.')
abline(coef=c(0,1))
image(kde2d(dldx[[2]]$numpy(), tf$stack(sensitivity[[3]])$numpy(), n=100, lims = rep(range(tf$stack(sensitivity[[3]])$numpy()), times=2)))
points(dldx[[2]]$numpy(), tf$stack(sensitivity[[3]])$numpy(), pch='.')
abline(coef=c(0,1))
image(kde2d(dldx[[3]]$numpy(), tf$stack(sensitivity[[4]])$numpy(), n=100, lims = rep(range(tf$stack(sensitivity[[4]])$numpy()), times=2)))
points(dldx[[3]]$numpy(), tf$stack(sensitivity[[4]])$numpy(), pch='.')
abline(coef=c(0,1))
image(kde2d(dldx[[4]]$numpy(), tf$stack(sensitivity[[5]])$numpy(), n=100, lims = rep(range(tf$stack(sensitivity[[5]])$numpy()), times=2)))
points(dldx[[4]]$numpy(), tf$stack(sensitivity[[5]])$numpy(), pch='.')
abline(coef=c(0,1))
par(par.opt)
options(repr.plot.opt)
Loading required package: MASS
Fisher divergence with Bayesian neural network
Since the Hessian of loss over input is related to input filter and smoothing densities in a Gaussian Bayeisan neural network, we can effectively identify Fisher divergence with a forward propagation and a backward propagation. The following compares Fisher divergence through Monte Carlo integration over multiple vanilla NN and a Bayesian NN.
- formulation of an energy-based model $p_\theta(\mathbf{x})=\exp \left(f_\theta(\mathbf{x}) - \log Z(\theta)\right)$.
- (Stein) score function $s_\theta(\mathbf{x}):=\nabla_{\mathbf{x}} \log p_\theta(\mathbf{x})=\nabla_{\mathbf{x}} f_\theta(\mathbf{x})-\nabla_{\mathbf{x}} \log Z(\theta)=\nabla_{\mathbf{x}} f_\theta(\mathbf{x})$.
- Fisher divergence between 2 distributions $p_{\text{data}}(\mathbf{x})$ and $p_\theta(\mathbf{x})$ is \begin{align*} D_F(p_{\text{data}}, p_\theta):= & \frac{1}{2}\mathbf{E}_{\mathbf{x}\sim p_{\text{data }}}\left[\left\Vert \nabla_{\mathbf{x}}\log p_{\text{data }}(\mathbf{x})-\nabla_{\mathbf{x}}\log p_{\theta}(\mathbf{x})\right\Vert _{2}^{2}\right]\\ = & ....\\ = & \mathbf{E}_{\mathbf{x}\sim p_{\text{data }}}\left[\frac{1}{2}\left\Vert \nabla_{\mathbf{x}}\log p_{\boldsymbol{\theta}}\left(\mathbf{x}\right)\right\Vert _{2}^{2}+\Delta_{\mathbf{x}}f_{\boldsymbol{\theta}}\left(\mathbf{x}\right)\right] + \text{const }\\ \approx & \frac{1}{n}\sum_{i=1}^{n}\left[\frac{1}{2}\left\Vert \nabla_{\mathbf{x}}\log p_{\boldsymbol{\theta}}\left(\mathbf{x}_{i}\right)\right\Vert _{2}^{2}+\Delta_{\mathbf{x}}f_{\boldsymbol{\theta}}\left(\mathbf{x}_{i}\right)\right]+\text{const }\\ = & \frac{1}{n}\sum_{i=1}^{n}\left[\frac{1}{2}\left\Vert \nabla_{\mathbf{x}}f_{\boldsymbol{\theta}}\left(\mathbf{x}_{i}\right)\right\Vert _{2}^{2}+\Delta_{\mathbf{x}}f_{\boldsymbol{\theta}}\left(\mathbf{x}_{i}\right)\right]+\text{const }. \end{align*}
- Score matching algorithm minimizes the Fisher divergence between data distribution $p_{\text{data}}(\mathbf x)$ and model distribution $p_{\boldsymbol\theta}(\mathbf x)$, here formulated as a energy-based model. pro: no sampling; con: Hessian is hard to estimate.
- Sample a mini-batch of datapoints $\left\{\mathbf{x}_1, \mathbf{x}_2, \cdots, \mathbf{x}_n\right\} \sim p_{\text {data }}(\mathbf{x})$.
- Estimate the score matching loss with the empirical mean \begin{aligned} & \frac{1}{n} \sum_{i=1}^n\left[\frac{1}{2}\left\|\nabla_{\mathbf{x}} \log p_\theta\left(\mathbf{x}_i\right)\right\|_2^2+\operatorname{tr}\left(\nabla_{\mathbf{x}}^2 \log p_\theta\left(\mathbf{x}_i\right)\right)\right] \\ = & \frac{1}{n} \sum_{i=1}^n\left[\frac{1}{2}\operatorname{tr}\left[\left(\nabla_{\mathbf{x}} f_\theta\left(\mathbf{x}_i\right)\right) \left(\nabla_{\mathbf{x}} f_\theta\left(\mathbf{x}_i\right)\right)^\top\right]+\operatorname{tr}\left(\nabla_{\mathbf{x}}^2 f_\theta\left(\mathbf{x}_i\right)\right)\right] \end{aligned}
- Score matching with Bayesian neural network.
\begin{align*} & \mathbf{E}_{\gamma(\mathbf{x}_{l})}\left(\nabla_{\mathbf{x}_{l}}\log p(\mathbf{y}|\mathbf{x}_{l},\mathbf{x}_{0}=\mathbf{x})\right)\left(\nabla_{\mathbf{x}_{l}}\log p(\mathbf{y}|\mathbf{x}_{l},\mathbf{x}_{0}=\mathbf{x})\right)^{\top}\\ = & P_{l}^{-1}\left[P_{l|L}-P_{l}+\left(\hat{\mathbf{x}}_{l|L}-\hat{\mathbf{x}}_{l}\right)\left(\hat{\mathbf{x}}_{l|L}-\hat{\mathbf{x}}_{l}\right)^{\top}\right]P_{l}^{-1}+\left(P_{l|L}^{-1}-P_{l}^{-1}\right)\\ = & 2\nabla_{P_{l}}\log p(\mathbf{y}|\hat{\mathbf{x}}_{l},P_{l})-\left[\left(2 \nabla_{P_l} \log p\left(\mathbf{y} \mid \hat{\mathbf{x}}_l, P_l\right)-\left(\nabla_{\hat{\mathbf{x}}_l} p\left(\mathbf{y} \mid \hat{\mathbf{x}}_l, P_l\right)\right)\left(\nabla_{\hat{\mathbf{x}}_l} p\left(\mathbf{y} \mid \hat{\mathbf{x}}_l, P_l\right)\right)^{\top}\right)^{-1} + P_l\right]^{-1}\\ & \nabla_{\mathbf{x}_l \mathbf{x}_l^{\top}} \log p\left(\mathbf{y} \mid \mathbf{x}_l, \mathbf{x}_0=\mathbf{x}\right) = P_l^{-1} - P_{l \mid L}^{-1}\\ = & \left[\left(2 \nabla_{P_l} \log p\left(\mathbf{y} \mid \hat{\mathbf{x}}_l, P_l\right)-\left(\nabla_{\hat{\mathbf{x}}_l} p\left(\mathbf{y} \mid \hat{\mathbf{x}}_l, P_l\right)\right)\left(\nabla_{\hat{\mathbf{x}}_l} p\left(\mathbf{y} \mid \hat{\mathbf{x}}_l, P_l\right)\right)^{\top}\right)^{-1} + P_l\right]^{-1}\\ \\ & \frac{1}{n} \sum_{i=1}^n\left[\frac{1}{2}\left\|\nabla_{\mathbf{x}} \log p_\theta\left(\mathbf{x}_i\right)\right\|_2^2+\operatorname{tr}\left(\nabla_{\mathbf{x}}^2 \log p_\theta\left(\mathbf{x}_i\right)\right)\right] \\ = & \frac{1}{n} \sum_{i=1}^n \operatorname{tr}\left[\frac{1}{2}\left(\nabla_{\mathbf{x}} f_\theta\left(\mathbf{x}_i\right)\right) \left(\nabla_{\mathbf{x}} f_\theta\left(\mathbf{x}_i\right)\right)^\top + \nabla_{\mathbf{x}}^2 f_\theta\left(\mathbf{x}_i\right)\right]\\ = & \frac{1}{n} \sum_{i=1}^n \operatorname{tr}\left\lbrace \nabla_{P_{l}}\log p(\mathbf{y}_i|\hat{\mathbf{x}}_{i},P_{i}) + .5 \left[P_i + \left(2 \nabla_{P_i} \log p\left(\mathbf{y} \mid \hat{\mathbf{x}}_i, P_i\right)-\left(\nabla_{\hat{\mathbf{x}}_i} p\left(\mathbf{y} \mid \hat{\mathbf{x}}_i, P_i\right)\right)\left(\nabla_{\hat{\mathbf{x}}_i} p\left(\mathbf{y} \mid \hat{\mathbf{x}}_i, P_i\right)\right)^{\top}\right)^{-1}\right]^{-1} \right\rbrace. \end{align*}
FD.mc = .5 * tf$reduce_mean(tf$stack(sensitivity[[2]], axis=0L)^2, axis=0L) - tf$reduce_mean(tf$stack(lapply(hessian[1,], tf$linalg$diag_part), axis=0L), axis=0L)
FD.vi = -tf$linalg$diag_part(sensitivityBayesian[[1]][[2]]) - .5 * tf$linalg$diag_part(H.vi[[1]])
plot(tf$reduce_sum(FD.vi, axis=-1L), tf$reduce_sum(FD.mc, axis=-1L),, xlab='VI', ylab='MC', main = 'Fisher divergence')
#abline(coef=c(0,1),col='red')
In the following, we show linearization of densely connected layer doesn't affect the learning, which is based on gradient and assumes that the loss surface is locally linear.
mnist_train_ds = tf$data$Dataset$from_tensor_slices(mnist$train)$shuffle(2048L)$batch(32L)$map(function(xy){
c(x,y) %<-% xy
x <- tf$reshape(tf$cast(x/255.0, dtype=tf$float32), shape=tuple(-1L,784L))
y.onehot = tf$one_hot(y, depth = 10L, dtype=tf$float32)
tuple(x,y.onehot)
})
mnist_test_ds = tf$data$Dataset$from_tensor_slices(mnist$test)$batch(1024L)$map(function(xy){
c(x,y) %<-% xy
x <- tf$reshape(tf$cast(x/255.0, dtype=tf$float32), shape=tuple(-1L,784L))
y.onehot = tf$one_hot(y, depth = 10L, dtype=tf$float32)
tuple(x,y.onehot)
})
inputs = layer_input(shape = 784L, name='input')
predictions = inputs %>%
layer_lambda(function(x) tfp$distributions$MultivariateNormalDiag(loc = x, scale_identity_multiplier = 1e-2 )$sample() ) %>%
layer_dense_linearization(units=7, activation = 'sigmoid', use_bias = TRUE) %>%
layer_dense_linearization(units=7, activation = 'sigmoid', use_bias = TRUE) %>%
layer_dense_linearization(units=10, activation = NULL, use_bias = TRUE)
model <- keras_model(inputs = inputs, outputs = predictions)
myloss = function(y, y_hat){
p = tf$math$softmax(y_hat, axis=-1L)
dydx = tf$linalg$diag(p) - tf$einsum('bi,bj->bij', p, p)
y_pred = tf$stop_gradient(p) + tf$einsum('bij,bi->bj', tf$stop_gradient(dydx), y_hat - tf$stop_gradient(y_hat))
#- tf_probability()$distributions$MultivariateNormalDiag(loc=y_pred, scale_diag = tf$ones_like(y)*.14^.5)$log_prob(y)
- tf_probability()$distributions$MultivariateNormalDiag(loc=y_pred, scale_diag = tf$ones_like(y)*.01^.5)$log_prob(y)
}
model %>% compile(loss=myloss, optimizer='adam', metrics='accuracy')
history = model %>% fit(mnist_train_ds, epochs=20, validation_data=mnist_test_ds)
model %>% evaluate(mnist_test_ds)
model %>% evaluate(mnist_train_ds)
- loss
- -6.58690929412842
- accuracy
- 0.904500007629395
- loss
- -7.37333059310913
- accuracy
- 0.917349994182587
rm(list=ls())
Bayesian layers¶
In this section, we provide the Bayesian layers with mean field posterior according to Algorithm 1, and with natural gradient. We hope to show two things:
- The Bayesian layers generally wrap/
decorate
around the non-Bayesian layers. So we believe that adding a stochastic model to a non-Bayesian computational graph should involve minimal effort. - Using this layers, we can bring a very deep neural network to converge with competitive performance even without any normalization layers. This hopefully points to the potential of Bayesian modeling to hand vanishing/exploding gradient problem long plaggued deep learning.
We implement natural gradient by decorating Tensorflow's autodiff, as in the following code snippet.
grad_fn = function(dh, dP, variables){
c(dx, dP0, dv) %<-% g2$gradient(list(h, P, divergence), list(x, P0, variables), list(dh, dP, tf$constant(1.0, dtype=self$compute_dtype)))
dv2 = lapply(1:length(variables), function(n){
switch(gsub('[_a-zA-Z0-9]+/(A|B|bias):[0-9]+','\\1',variables[[n]]$name),
A = dv[[n]] * tf$math$softplus(self$kernel_scale)^2 ,
B = dv[[n]] * tf$math$softplus(self$kernel_scale)^4 / tf$math$sigmoid(self$kernel_scale)^2 * 2,
bias = dv[[n]] )
})
return(list(list(dx, dP0), dv2))
}
my_divergence_fn = function(q, p, ignore=NULL){
tfp$distributions$kl_divergence(q, p) *0
}
# layer_conv_2d_Bayesian = Layer(
# classname = 'BayesianConv2D',
# inherit = tf$keras$layers$Layer,
# initialize = function(activation = NULL, divergence_fn = my_divergence_fn, use_bias = TRUE, ...) {
# super$initialize()
# self$activation = keras$activations$get(activation)
# self$divergence_fn = divergence_fn
# self$use_bias = use_bias
# self$conv2d = tf$keras$layers$Conv2D(activation = NULL, use_bias=FALSE, ...)
# },
# build = function(input_shape) {
# self$conv2d$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
# self$kernel_loc = self$add_weight(name='A',shape=self$conv2d$kernel$shape, initializer=self$conv2d$kernel_initializer, trainable=TRUE)
# self$kernel_scale = self$add_weight(name='B',shape=self$conv2d$kernel$shape, initializer=tf$random_normal_initializer(mean = -3, stddev = .1), trainable=TRUE)
# #'uniform', default for mean_field_posterior is -3
# if(self$use_bias) self$bias = self$add_weight(name='bias', shape = list(self$conv2d$filters), initializer = 'zeros') else self$bias = NULL
# self$built=TRUE
# },
# call = function(inputs) {
# c(x, P0) %<-% inputs
# with(tf$GradientTape(persistent=TRUE) %as% g, {
# g$watch(list(x))
# self$conv2d$kernel = self$kernel_loc
# a = self$conv2d$call(x)
# if(self$use_bias) a = tf$nn$bias_add(a, self$bias)
# h = self$activation(a)
# `__h__` = self$activation(a)
# })
# # propagate variance (accept a variance, first order approximation of state transition, output a variance)
# #dhdx.sq.sum_x <- tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
# self$conv2d$kernel = tf$stop_gradient(self$kernel_loc^2)
# dhdx.sq.sum_x = tf$stop_gradient(g$gradient(`__h__`, a)^2) * self$conv2d$call(P0)
# #dhdw.sq.sum_w <- tf$einsum('bi,ji, bj->bi',g$gradient(h, a)^2, self$B^2, tf$stop_gradient(x^2))
# self$conv2d$kernel = tf$nn$softplus(self$kernel_scale+1e-9)^2
# dhdw.sq.sum_w = g$gradient(h, a)^2 * self$conv2d$call(tf$stop_gradient(x^2))
# P = dhdx.sq.sum_x + dhdw.sq.sum_w + 1e-9
# # reinterpreted_batch_ndims is the number of event dimensions, so every dimensions before that `:-reinterpreted_batch_ndims` is event dimension
# q_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=self$kernel_loc, scale = tf$nn$softplus(self$kernel_scale+1e-9)), reinterpreted_batch_ndims = length(dim(self$kernel_loc)))
# p_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc), scale = 1.0/prod(dim(self$kernel_loc)[1:3])^.5), reinterpreted_batch_ndims = length(dim(self$kernel_loc)))
# self$add_loss(self$divergence_fn(q_, p_))
# list(h, P)
# }
# )
layer_conv_2d_Bayesian = Layer(
classname = 'BayesianConv2D',
inherit = tf$keras$layers$Layer,
initialize = function(activation = NULL, divergence_fn = my_divergence_fn, use_bias = TRUE, ...) {
super$initialize()
self$activation = keras$activations$get(activation)
self$divergence_fn = divergence_fn
self$use_bias = use_bias
self$conv2d = tf$keras$layers$Conv2D(activation = NULL, use_bias=FALSE, ...)
},
build = function(input_shape) { #tf.keras.initializers.VarianceScaling
self$conv2d$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
self$kernel_loc = self$add_weight(name='A',shape=self$conv2d$kernel$shape, initializer=self$conv2d$kernel_initializer, trainable=TRUE)
#self$kernel_loc = self$add_weight(name='A',shape=self$conv2d$kernel$shape, initializer=initializer_variance_scaling(), trainable=TRUE)
self$kernel_scale = self$add_weight(name='B',shape=self$conv2d$kernel$shape, initializer=tf$random_normal_initializer(mean = -3, stddev = .1), trainable=TRUE)
#'uniform', default for mean_field_posterior is -3
if(self$use_bias) self$bias = self$add_weight(name='bias', shape = list(self$conv2d$filters), initializer = 'zeros') else self$bias = NULL
self$`__forward__` = tf$custom_gradient(function(inputs){
c(x, P0) %<-% inputs
with(tf$GradientTape(persistent=TRUE) %as% g2, {
g2$watch(list(x, P0))
with(tf$GradientTape(persistent=TRUE) %as% g, {
g$watch(list(x, P0))
self$conv2d$kernel = self$kernel_loc
a = self$conv2d$call(x)
if(self$use_bias) a = tf$nn$bias_add(a, self$bias)
h = self$activation(a)
`__h__` = self$activation(a)
})
# propagate variance (accept a variance, first order approximation of state transition, output a variance)
#dhdx.sq.sum_x <- tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
self$conv2d$kernel = tf$stop_gradient(self$kernel_loc^2)
dhdx.sq.sum_x = tf$stop_gradient(g$gradient(`__h__`, a)^2) * self$conv2d$call(P0)
#dhdw.sq.sum_w <- tf$einsum('bi,ji, bj->bi',g$gradient(h, a)^2, self$B^2, tf$stop_gradient(x^2))
self$conv2d$kernel = tf$math$softplus(self$kernel_scale)^2
dhdw.sq.sum_w = tf$stop_gradient(g$gradient(h, a)^2) * self$conv2d$call(tf$stop_gradient(x^2))
P = dhdx.sq.sum_x + dhdw.sq.sum_w + 1e-6
q_ = tfp$distributions$Normal(loc=tf$cast(self$kernel_loc, dtype=self$compute_dtype), scale = tf$nn$softplus(self$kernel_scale))
p_ = tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc, dtype=self$compute_dtype), scale = 1.0/prod(dim(self$kernel_loc)[1:3])^.5)
divergence = tf$reduce_sum(self$divergence_fn(q_, p_))
#q_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$cast(self$kernel_loc, dtype=self$compute_dtype), scale = tf$nn$softplus(self$kernel_scale)), reinterpreted_batch_ndims = length(dim(self$kernel_loc)))
#p_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc, dtype=self$compute_dtype), scale = 1.0/prod(dim(self$kernel_loc)[1:3])^.5), reinterpreted_batch_ndims = length(dim(self$kernel_loc)))
#divergence = self$divergence_fn(q_, p_)
self$add_loss(tf$stop_gradient(divergence))
})
grad_fn = function(dh, dP, variables){
c(dx, dP0, dv) %<-% g2$gradient(list(h, P, divergence), list(x, P0, variables), list(dh, dP, tf$constant(1.0, dtype=self$compute_dtype)))
dv2 = lapply(1:length(variables), function(n){
switch(gsub('[_a-zA-Z0-9]+/(A|B|bias):[0-9]+','\\1',variables[[n]]$name),
A = dv[[n]] * tf$math$softplus(self$kernel_scale)^2 ,
B = dv[[n]] * tf$math$softplus(self$kernel_scale)^4 / tf$math$sigmoid(self$kernel_scale)^2 * 2,
bias = dv[[n]] )
})
return(list(list(dx, dP0), dv2))
}
list(list(h, P), grad_fn)
})
self$built=TRUE
},
call = function(inputs) {
return(self$`__forward__`(inputs))
}
)
layer_average_pooling_2d_Bayesian = Layer(
classname = 'BayesianAvgPool2D',
inherit = tf$keras$layers$Layer,
initialize = function(pool_size = c(2L, 2L), strides = NULL, ...) {
super$initialize()
if(is.null(strides)) strides = pool_size
self$depthwise_conv_2d = tf$keras$layers$DepthwiseConv2D(kernel_size = pool_size, strides = strides, use_bias = FALSE, activation = NULL, ...)
},
build = function(input_shape) {
self$depthwise_conv_2d$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
self$depthwise_conv_2d$depthwise_kernel = self$depthwise_conv_2d$depthwise_kernel * 0 + 1
self$built=TRUE
},
call = function(inputs) {
c(x, P0) %<-% inputs
h = a = self$depthwise_conv_2d$call(x) / self$depthwise_conv_2d$call( tf$ones_like(x))
# propagate variance (accept a variance, first order approximation of state transition, output a variance)
#dhdx.sq.sum_x = tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
dhdx.sq.sum_x = self$depthwise_conv_2d$call(P0) / self$depthwise_conv_2d$call(tf$ones_like(P0))^2
P = dhdx.sq.sum_x + 1e-9
list(h, P)
}
)
layer_average_pooling_2d_Bayesian = Layer(
classname = 'BayesianAvgPool2D',
inherit = tf$keras$layers$Layer,
initialize = function(...) {
super$initialize()
self$avg_pooling_2d = tf$keras$layers$AveragePooling2D(...)
},
build = function(input_shape) {
self$avg_pooling_2d$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
self$built=TRUE
},
call = function(inputs) {
c(x, P0) %<-% inputs
h = a = self$avg_pooling_2d$call(x)
# propagate variance (accept a variance, first order approximation of state transition, output a variance)
#dhdx.sq.sum_x <- tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
dhdx.sq.sum_x = self$avg_pooling_2d$call(P0) / tf$cast(tf$reduce_prod(self$avg_pooling_2d$pool_size), P0$dtype)
P = dhdx.sq.sum_x + 1e-9
list(h, P)
}
)
layer_smooth_max_pooling_2d_Bayesian = Layer(
classname = 'BayesianAvgPool2D',
inherit = tf$keras$layers$Layer,
initialize = function(pool_size = c(2L, 2L), strides = NULL, alpha = 50.0, ...) {
super$initialize()
self$alpha = alpha
if(is.null(strides)) strides = pool_size
self$depthwise_conv_2d = tf$keras$layers$DepthwiseConv2D(kernel_size = pool_size, strides = strides, use_bias = FALSE, activation = NULL, ...)
self$conv_2d_transpose = tf$keras$layers$Conv2DTranspose(filters = NULL, kernel_size = pool_size, strides = strides, use_bias = FALSE, activation = NULL)
},
build = function(input_shape) {
self$depthwise_conv_2d$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
self$depthwise_conv_2d$depthwise_kernel = self$depthwise_conv_2d$depthwise_kernel * 0 + 1 # replace with a tensor
channel_axis = self$conv_2d_transpose$`_get_channel_axis`()
if(channel_axis < 0) channel_axis = length(input_shape[[1]]) + 1 + channel_axis # python index => R index
self$conv_2d_transpose$filters = input_shape[[1]][[channel_axis]]
self$conv_2d_transpose$build(input_shape[[1]])
self$conv_2d_transpose$kernel$assign(self$conv_2d_transpose$kernel * 0) # replace with a tensor
for(i in 1L:self$conv_2d_transpose$filters)
self$conv_2d_transpose$kernel[,,i,i]$assign(self$conv_2d_transpose$kernel[,,i,i] + 1.0)
#self$conv_2d_transpose$kernel[,,i,i] = self$conv_2d_transpose$kernel[,,i,i] + tf$constant(1.0, dtype=tf$float32)
self$conv_2d_transpose$kernel = self$conv_2d_transpose$kernel + 0
self$built=TRUE
},
call = function(inputs) {
c(x, P0) %<-% inputs
my_max = tf$nn$max_pool2d(x, self$depthwise_conv_2d$kernel_size, self$depthwise_conv_2d$strides, 'VALID')
my_max_in = self$conv_2d_transpose$call(my_max)
b = tf$math$exp(self$alpha * (x - tf$stop_gradient(my_max_in)))
a = self$depthwise_conv_2d$call(b)
h = tf$math$log(a)/self$alpha + tf$stop_gradient(my_max)
# propagate variance
#dhdx.sq.sum_x = tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
#self$depthwise_conv_2d$kernel are all 1's
dhdx.sq.sum_x = tf$math$divide_no_nan(self$depthwise_conv_2d$call(b^2 * P0), tf$stop_gradient(a^2))
#dhdw.sq.sum_w = g$gradient(h, a)^2 * self$conv2d$call(tf$stop_gradient(x^2)) #???
P = dhdx.sq.sum_x + 1e-9
list(h, P)
}
)
# layer_dense_Bayesian = Layer(
# classname = 'BayesianDense',
# inherit = tf$keras$layers$Layer,
# initialize = function(activation=NULL, use_bias = TRUE, divergence_fn = my_divergence_fn, ...) {
# super$initialize()
# self$divergence_fn = divergence_fn
# self$activation = keras$activations$get(activation)
# self$use_bias = use_bias
# self$dense = tf$keras$layers$Dense(activation=NULL, use_bias=FALSE, dtype = self$compute_dtype, ...)
# },
# build = function(input_shape) {
# self$dense$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
# # redefine BNN weights as Gaussian random variables in terms of trainable tf variables
# self$kernel_loc = self$add_weight(name='A',shape=self$dense$kernel$shape, initializer=self$dense$kernel_initializer, trainable=TRUE)
# self$kernel_scale = self$add_weight(name='B',shape=self$dense$kernel$shape, initializer=tf$random_normal_initializer(mean = -3, stddev = .1), trainable=TRUE)#'uniform'
# if(self$use_bias) self$bias = self$add_weight(name='bias', shape = list(self$dense$units), initializer='zeros') else self$bias = NULL
# self$built=TRUE
# },
# call = function(inputs) {
# c(x, P0) %<-% inputs
# with(tf$GradientTape(persistent=TRUE) %as% g, {
# g$watch(list(x)) #g$watch(list(x, self$W))
# self$dense$kernel = self$kernel_loc #+ tf$einsum('ij,i->ij',self$B, self$W) # mvnorm with mean and variance
# a = self$dense$call(x) # tf$matmul(a=x, b=self$kernel) + self$bias
# if(self$use_bias) a = tf$nn$bias_add(a, self$bias)
# h = self$activation(a)
# `__h__` = self$activation(a)
# })
# # propagate variance (accept a variance, first order approximation of state transition, output a variance)
# #dhdx.sq.sum_x <- tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
# self$dense$kernel = tf$stop_gradient(self$kernel_loc^2)
# print('here')
# dhdx.sq.sum_x = tf$stop_gradient(g$gradient(`__h__`, a)^2) * self$dense$call(P0)
# #dhdw.sq.sum_w <- tf$einsum('bi,ji, bj->bi',g$gradient(h, a)^2, self$B^2, tf$stop_gradient(x^2))
# print('here')
# self$dense$kernel = tf$nn$softplus(self$kernel_scale)^2
# print('here')
# dhdw.sq.sum_w = g$gradient(h, a)^2 * self$dense$call(tf$stop_gradient(x^2))
# P = dhdx.sq.sum_x + dhdw.sq.sum_w + 1e-9
# print('here')
# ## add regularization loss
# #q_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=self$kernel_loc, scale = tf$nn$softplus(self$kernel_scale+1e-9)), reinterpreted_batch_ndims = 2L)
# #p_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc), scale = 1.0/dim(x)[2]^.5), reinterpreted_batch_ndims = 2L)
# #self$add_loss(self$divergence_fn(q_, p_))
# list(h, P)
# }
# )
layer_dense_Bayesian = Layer(
classname = 'BayesianDense',
inherit = tf$keras$layers$Layer,
initialize = function(activation=NULL, use_bias = TRUE, divergence_fn = my_divergence_fn, ...) {
super$initialize()
self$divergence_fn = divergence_fn
self$activation = keras$activations$get(activation)
self$use_bias = use_bias
self$dense = tf$keras$layers$Dense(activation=NULL, use_bias=FALSE, ...)
},
build = function(input_shape) {
self$dense$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
# redefine BNN weights as Gaussian random variables in terms of trainable tf variables
self$kernel_loc = self$add_weight(name='A',shape=self$dense$kernel$shape, initializer=self$dense$kernel_initializer, trainable=TRUE)
self$kernel_scale = self$add_weight(name='B',shape=self$dense$kernel$shape, initializer=tf$random_normal_initializer(mean = -3, stddev = .1), trainable=TRUE)#'uniform'
if(self$use_bias) self$bias = self$add_weight(name='bias', shape = list(self$dense$units), initializer='zeros') else self$bias = NULL
self$`__forward__` = tf$custom_gradient(function(inputs){
#message('tracing layer_dense_BayesianEP forward ....')
c(x, P0) %<-% inputs
with(tf$GradientTape(persistent=TRUE) %as% g2, {
g2$watch(list(x, P0))
with(tf$GradientTape(persistent=TRUE) %as% g, {
g$watch(list(x, P0))
self$dense$kernel = self$kernel_loc
a = self$dense$call(x)
if(self$use_bias) a = tf$nn$bias_add(a, self$bias)
h = self$activation(a)
`__h__` = self$activation(a)
})
# propagate variance (accept a variance, first order approximation of state transition, output a variance)
#dhdx.sq.sum_x <- tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
self$dense$kernel = tf$stop_gradient(self$kernel_loc^2)
dhdx.sq.sum_x = tf$stop_gradient(g$gradient(`__h__`, a)^2) * self$dense$call(P0)
#dhdw.sq.sum_w <- tf$einsum('bi,ji, bj->bi',g$gradient(h, a)^2, self$B^2, tf$stop_gradient(x^2))
self$dense$kernel = tf$math$softplus(self$kernel_scale)^2
dhdw.sq.sum_w = tf$stop_gradient(g$gradient(h, a)^2) * self$dense$call(tf$stop_gradient(x^2))
P = dhdx.sq.sum_x + dhdw.sq.sum_w + 1e-6
q_ = tfp$distributions$Normal(loc=tf$cast(self$kernel_loc, dtype=self$compute_dtype), scale = tf$nn$softplus(self$kernel_scale))
p_ = tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc, dtype=self$compute_dtype), scale = 1.0/dim(x)[2]^.5)
divergence = tf$reduce_sum(self$divergence_fn(q_, p_))
#q_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$cast(self$kernel_loc, dtype=self$compute_dtype), scale = tf$nn$softplus(self$kernel_scale)), reinterpreted_batch_ndims = 2L)
#p_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc, dtype=self$compute_dtype), scale = 1.0/dim(x)[2]^.5), reinterpreted_batch_ndims = 2L)
#divergence = self$divergence_fn(q_, p_)
self$add_loss(tf$stop_gradient(divergence))
})
grad_fn = function(dh, dP, variables){
c(dx, dP0, dv) %<-% g2$gradient(list(h, P, divergence), list(x, P0, variables), list(dh, dP, tf$constant(1.0, dtype=self$compute_dtype)))
dv2 = lapply(1:length(variables), function(n){
switch(gsub('[_a-zA-Z0-9]+/(A|B|bias):[0-9]+','\\1',variables[[n]]$name),
A = dv[[n]] * tf$math$softplus(self$kernel_scale)^2 ,
B = dv[[n]] * tf$math$softplus(self$kernel_scale)^4 / tf$math$sigmoid(self$kernel_scale)^2 * 2.0,
bias = dv[[n]] )
})
return(list(list(dx, dP0), dv2))
}
list(list(h, P), grad_fn)
})
self$built=TRUE
},
call = function(inputs) {
self$`__forward__`(inputs)
}
)
layer_activation_Bayesian = Layer(
classname = 'BayesianActivation',
inherit = tf$keras$layers$Layer,
initialize = function(activation){
super$initialize()
self$activation = keras$activations$get(activation)
#self$activation = tf$keras$layers$tf$keras$layers$Activation(...)
},
build = function(input_shape) {
self$built=TRUE
},
call = function(inputs, training) {
c(x, P0) %<-% inputs
with(tf$GradientTape(persistent=TRUE) %as% g, {
g$watch(list(x))
a = x
h = self$activation(a)
`__h__` = self$activation(a)
})
# propagate variance (accept a variance, first order approximation of state transition, output a variance)
dhdx.sq.sum_x = tf$stop_gradient(g$gradient(`__h__`, a)^2) * P0
#dhdw.sq.sum_w = g$gradient(h, a)^2 * tf$stop_gradient(x^2)
dhdw.sq.sum_w = 0 #g$gradient(h, a)^2 * x^2 # need to back propagate to the weight scales from last layer
P = dhdx.sq.sum_x + dhdw.sq.sum_w + 1e-9
list(h, P)
}
)
While there is no need of a "Bayesian" normalization layer to achieve competitive convergence rate and generalization, we can definitely write one, as in the following.
layer_batch_normalization_Bayesian = Layer(
classname = 'BayesianBatchNorm',
inherit = tf$keras$layers$Layer,
initialize = function(...) {
super$initialize()
self$batch_normalization = tf$keras$layers$BatchNormalization(...)
},
build = function(input_shape) {
self$batch_normalization$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
self$built=TRUE
},
call = function(inputs, training) {
c(x, P0) %<-% inputs
h = a = self$batch_normalization$call(x, training)
# propagate variance: if is training, scale with batch variance (+P0), else scale with moving average variance, assume last dimension is BN dimension
#dhdx.sq.sum_x = tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
if(training) var = (tf$reduce_mean(x^2, axis=0L:(length(dim(x))-2)) - tf$reduce_mean(x, axis=0L:(length(dim(x))-2))^2 + 1e-3)
else var = (self$batch_normalization$moving_variance + 1e-3)
dhdx.sq.sum_x = P0 * tf$stop_gradient(tf$math$divide_no_nan(self$batch_normalization$gamma^2, var))
P = dhdx.sq.sum_x + 1e-9
list(h, P)
}
)
rm(list=ls())
NN architectures¶
cifar10 = dataset_cifar10()
cifar10.datagen = image_data_generator(
featurewise_center=TRUE,
featurewise_std_normalization=TRUE,
##rotation_range=15,
zoom_range=0.1,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=TRUE,
)
cifar10.datagen$fit(cifar10$train$x)
cifar100 = dataset_cifar100()
cifar100.datagen = image_data_generator(
featurewise_center=TRUE, # set input mean to 0 over the dataset
featurewise_std_normalization=TRUE, # divide inputs by std of the dataset
#rotation_range=15, # randomly rotate images in the range (degrees, 0 to 180)
zoom_range=0.1, # Range for random zoom. If a float, [lower, upper] = [1-zoom_range, 1+zoom_range].
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
horizontal_flip=TRUE, # randomly flip images
)
cifar100.datagen$fit(cifar100$train$x)
DenseNet-BC [L=100, k=12]¶
On all datasets except ImageNet, the DenseNet used in our experiments has three dense blocks that each has an equal number of layers. Before entering the first dense block, a convolution with 16 (or twice the growth rate for DenseNet-BC [note: convolution with $2\times 16 = 32$ features for DenseNet-Bottleneck+Compression]) output channels is performed on the input images. For convolutional layers with kernel size $3\times3$, each side of the inputs is zero-padded by one pixel to keep the feature-map size fixed. We use $1\times1$ convolution followed by $2\times2$ average pooling as transition layers between two contiguous dense blocks. At the end of the last dense block, a global average pooling is performed and then a softmax classifier is attached. The feature-map sizes in the three dense blocks are $32\times 32$, $16\times 16$, and $8\times 8$, respectively. We experiment with the basic DenseNet structure with configurations {L = 40,k = 12}, {L = 100, k = 12} [note: growth rate = 12. among the 100 convolutional layers, 4 are between input, 3 dense blocks, and outputs, and the rest $100-4=96$ blocks are distributed in 3 dense blocks. since each dense block is composed of a number of convolution blocks with one $1\times1$ convolution layer and one $3\times 3$ convolution layer, each dense block has $96/3/2=16$ convolution blocks] and {L = 100, k = 24}. For DenseNet- BC, the networks with configurations {L = 100, k = 12}, {L=250 , k=24} and {L=190,k=40} are evaluated.
In our experiments on ImageNet, we use a DenseNet-BC structure with 4 dense blocks on $224\times 224$ input images. The initial convolution layer comprises 2k convolutions of size $7\times 7$ with stride 2; the number of feature-maps in all other layers also follow from setting k. The exact network configurations we used on ImageNet are shown in Table 1.
All the networks are trained using stochastic gradient descent (SGD). On CIFAR and SVHN we train using batch size 64 for 300 and 40 epochs, respectively. The initial learning rate is set to 0.1, and is divided by 10 at 50% and 75% of the total number of training epochs. On ImageNet, we train models for 90 epochs with a batch size of 256. The learning rate is set to 0.1 initially, and is lowered by 10 times at epoch 30 and 60.
Following [8], we use a weight decay of $10^{−4}$ and a Nesterov momentum [35] of 0.9 without dampening. We adopt the weight initialization introduced by [10]. ... we add a dropout layer [33] after each convolu- tional layer (except the first one) and set the dropout rate to 0.2. The test errors were only evaluated once for each task and model setting.
Densenet architecture for ImageNet can also be confirmed from tf.keras.applications.DenseNet121
. The code is here: https://github.com/keras-team/keras/blob/v2.12.0/keras/applications/densenet.py#L63
DenseNet121 = tf$keras$applications$DenseNet121(input_tensor = inputs %>% layer_resizing(128L, 128L), include_top=TRUE, weights=NULL, classes=10L)
Huang, G., Liu, Z., Van Der Maaten, L., & Weinberger, K. Q. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 4700-4708).
## NN architectures
# output width and height are the same, output channels = input channels + growth rate
conv_bottleneck = function(x, growth_rate){
conv = x %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn1
layer_activation_relu() %>%
layer_conv_2d(filters = 4 * growth_rate, kernel_size = c(1L,1L), use_bias = FALSE) %>% # conv1
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn2
layer_activation_relu() %>%
layer_conv_2d(filters = growth_rate, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) # conv2
list(conv, x) %>% layer_concatenate(axis=-1L) # channel dimension 4 for channel last and 2 for channel first, keras default is channel last
}
# output width and height are half of the input, output channels = reduction * (input channels + growth rate * number of blocks)
dense_layer = function(x, block, growth_rate, num_blocks){
layers = x
for(i in 1:num_blocks) layers = layers %>% block(., growth_rate)
print(sprintf('dense_layer: growth_rate = %d, num_blocks = %d', growth_rate, num_blocks))
print(layers)
layers
}
transition = function(x, reduction = .5){
x %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_relu() %>%
layer_conv_2d(filters = floor(reduction*dim(x)[4]), kernel_size = c(1L,1L), use_bias = FALSE) %>%
layer_average_pooling_2d(pool_size = 2L)
}
CIFAR-10¶
inputs = layer_input(shape = dim(cifar10$train$x)[-1])
g.r. = 12L
predictions = inputs %>%
layer_conv_2d(filters = 2*g.r., kernel_size = 3L, padding = 'same', use_bias = FALSE) %>%
dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>% # layer 1
dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>% # layer 2
dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>% # layer 3
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_relu() %>%
layer_average_pooling_2d(pool_size = 4L) %>%
layer_flatten() %>%
layer_dense(activation = 'softmax', units = 10L)
densenet.cifar10 = keras_model(inputs = inputs, outputs = predictions)
densenet.cifar10$compile(loss='categorical_crossentropy',optimizer= 'adam',metrics=list('accuracy'))
tf$keras$utils$plot_model(densenet.cifar10, to_file='DenseNetBC_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'DenseNetBC_CIFAR10.png')
PiecewiseConstantDecay, 900 epochs, best accuracy = 94.83%¶
lr = tf$keras$optimizers$schedules$PiecewiseConstantDecay(
boundaries = floor(nrow(cifar10$train$x)/64*3) * c(150, 225), values = c(0.1, 0.01, 0.001))
opt = tf$keras$optimizers$experimental$SGD(learning_rate = lr, momentum = .9, nesterov=TRUE, weight_decay=1e-4)
densenet.cifar10$compile(loss='categorical_crossentropy',optimizer= opt,metrics=list('accuracy'))
mc = callback_model_checkpoint('vanilla_DenseNet_CIFAR10-{epoch:03d}_{val_loss:.2f}_{val_accuracy:.4f}.ckpt', monitor = 'val_accuracy', save_best_only = FALSE, mode = 'auto', save_weights_only = FALSE, save_freq = 'epoch')
py_itertools = import('itertools')
history = densenet.cifar10 %>% fit(
py_itertools$chain(cifar10.datagen$flow(cifar10$train$x, to_categorical(cifar10$train$y, num_classes=10L), batch_size = 64L, shuffle=TRUE)),
epochs = 300,
steps_per_epoch = floor(nrow(cifar10$train$x)/64*3),
validation_data = cifar10.datagen$flow(cifar10$test$x, to_categorical(cifar10$test$y, num_classes=10L), batch_size = 1000L),
callbacks = list(mc, callback_terminate_on_naan(), callback_csv_logger(filename = 'vanilla_DenseNet_CIFAR10.txt')))
saveRDS(history, file='history.RDS')
print(data.frame(history$metrics))
plot(history)
loss accuracy val_loss val_accuracy 1 1.0606576204 0.6190676 0.7112364 0.7490 2 0.5555183291 0.8077606 0.5330710 0.8142 3 0.4249432683 0.8514031 0.4879132 0.8343 4 0.3554627299 0.8766621 0.4141560 0.8583 5 0.3093782067 0.8921219 0.3998110 0.8680 6 0.2732833624 0.9048115 0.3694597 0.8735 7 0.2472873926 0.9132957 0.3417197 0.8851 8 0.2257734388 0.9217665 0.3297557 0.8902 9 0.2033435851 0.9290759 0.3143096 0.8966 10 0.1888653487 0.9334148 0.3545958 0.8887 11 0.1780259758 0.9376669 0.3140873 0.8992 12 0.1672797501 0.9402568 0.3510486 0.8932 13 0.1580464393 0.9439082 0.3051358 0.9025 14 0.1496922374 0.9472258 0.3340426 0.8959 15 0.1391135305 0.9510974 0.3314134 0.8961 16 0.1332269609 0.9528530 0.3541025 0.8939 17 0.1284293830 0.9545685 0.2997865 0.9058 18 0.1225150228 0.9569783 0.3118749 0.9023 19 0.1199598089 0.9578861 0.3187947 0.9031 20 0.1113056317 0.9602625 0.3066177 0.9067 21 0.1080974266 0.9614640 0.3223408 0.9013 22 0.1062026396 0.9624453 0.3428392 0.8985 23 0.1007550508 0.9648617 0.3262803 0.9004 24 0.0995790511 0.9649819 0.3304870 0.9079 25 0.0982940122 0.9653556 0.3201887 0.9074 26 0.0932092071 0.9670311 0.3043353 0.9100 27 0.0926348567 0.9670244 0.3192368 0.9093 28 0.0876843706 0.9692606 0.3416500 0.9072 29 0.0887350067 0.9681259 0.3270891 0.9078 30 0.0857352316 0.9697213 0.3130928 0.9112 31 0.0838128403 0.9703287 0.3545492 0.9023 32 0.0831116959 0.9709495 0.3275500 0.9082 33 0.0814140663 0.9713567 0.3276057 0.9093 34 0.0796294361 0.9718306 0.3247251 0.9110 35 0.0807739720 0.9714568 0.3234583 0.9057 36 0.0767577216 0.9729187 0.3139792 0.9117 37 0.0788810626 0.9722311 0.3314813 0.9056 38 0.0763527676 0.9732457 0.3112955 0.9153 39 0.0712465569 0.9749079 0.3344874 0.9094 40 0.0740116462 0.9743472 0.3217798 0.9112 41 0.0710906610 0.9751949 0.3079209 0.9119 42 0.0723003969 0.9744673 0.3389498 0.9061 43 0.0692465156 0.9755086 0.3561204 0.9073 44 0.0726492554 0.9745741 0.3510042 0.9046 45 0.0719058067 0.9747410 0.3259953 0.9110 46 0.0686678439 0.9759826 0.3109468 0.9171 47 0.0664239749 0.9768637 0.3433867 0.9107 48 0.0686304122 0.9761761 0.3223711 0.9145 49 0.0669132695 0.9765500 0.3606417 0.9015 50 0.0686195195 0.9760694 0.3030332 0.9153 51 0.0645372346 0.9771507 0.3438092 0.9100 52 0.0669341534 0.9765099 0.3082600 0.9133 53 0.0657915920 0.9769038 0.3295853 0.9173 54 0.0636986792 0.9774178 0.3267203 0.9146 55 0.0643099099 0.9770239 0.3500418 0.9075 56 0.0650324449 0.9771841 0.3297406 0.9143 57 0.0633029789 0.9776581 0.3390211 0.9084 58 0.0617425144 0.9779384 0.3268529 0.9110 59 0.0605978072 0.9788930 0.3014122 0.9175 60 0.0631497800 0.9781454 0.3264746 0.9137 61 0.0600434393 0.9789664 0.3173001 0.9152 62 0.0627141073 0.9777381 0.3339829 0.9115 63 0.0591538474 0.9791200 0.3134220 0.9138 64 0.0617894158 0.9784591 0.3290031 0.9134 65 0.0598159730 0.9789130 0.3570973 0.9098 66 0.0582546704 0.9792134 0.3111114 0.9158 67 0.0583410747 0.9796206 0.3444687 0.9133 68 0.0592083447 0.9788529 0.3139260 0.9190 69 0.0597190559 0.9786994 0.3309248 0.9066 70 0.0590096340 0.9790198 0.3088430 0.9138 71 0.0586256087 0.9787795 0.3090871 0.9161 72 0.0589672364 0.9793936 0.3169852 0.9131 73 0.0589291938 0.9795872 0.3083241 0.9168 74 0.0580609776 0.9795138 0.3298902 0.9128 75 0.0563900061 0.9804483 0.3594465 0.9098 76 0.0553537831 0.9810090 0.3260561 0.9108 77 0.0573993884 0.9797274 0.3440355 0.9063 78 0.0573506542 0.9799610 0.3495877 0.9102 79 0.0571017750 0.9803014 0.3246139 0.9155 80 0.0532786362 0.9817767 0.3332381 0.9126 81 0.0563676469 0.9803548 0.3305483 0.9149 82 0.0565234423 0.9803148 0.3168681 0.9181 83 0.0537937097 0.9813895 0.3186008 0.9184 84 0.0549138300 0.9807287 0.3212758 0.9139 85 0.0543248877 0.9809356 0.3000783 0.9196 86 0.0546503663 0.9809222 0.3095301 0.9201 87 0.0537732653 0.9814963 0.3222055 0.9185 88 0.0537350178 0.9813161 0.3283524 0.9146 89 0.0554774553 0.9809623 0.3122342 0.9175 90 0.0536041111 0.9811025 0.3306598 0.9130 91 0.0515636429 0.9818234 0.3411500 0.9130 92 0.0540375970 0.9808888 0.3678151 0.9074 93 0.0560667403 0.9800478 0.3255465 0.9190 94 0.0506795608 0.9824575 0.3091723 0.9186 95 0.0523284823 0.9814963 0.3185725 0.9156 96 0.0532100461 0.9817032 0.3255993 0.9164 97 0.0549728312 0.9810624 0.3057711 0.9211 98 0.0507088639 0.9825376 0.3040094 0.9218 99 0.0525476150 0.9816298 0.3115408 0.9164 100 0.0527624451 0.9817299 0.3068883 0.9213 101 0.0512051545 0.9819636 0.3352647 0.9149 102 0.0559204891 0.9807354 0.2855245 0.9230 103 0.0523553006 0.9821905 0.3204745 0.9150 104 0.0523235574 0.9819369 0.3082117 0.9144 105 0.0506386310 0.9824508 0.3159414 0.9202 106 0.0512908772 0.9821104 0.3225025 0.9170 107 0.0500498302 0.9826578 0.3163248 0.9189 108 0.0506265685 0.9823975 0.3259247 0.9155 109 0.0488624610 0.9829115 0.3409326 0.9168 110 0.0496078096 0.9827312 0.3256379 0.9154 111 0.0514903776 0.9821038 0.2905671 0.9240 112 0.0511920117 0.9822306 0.3217449 0.9166 113 0.0516947322 0.9821371 0.2977969 0.9248 114 0.0493852347 0.9831251 0.3106343 0.9193 115 0.0510071889 0.9822840 0.2865785 0.9221 116 0.0478954501 0.9830583 0.3111472 0.9196 117 0.0513347872 0.9820237 0.3196720 0.9183 118 0.0507817641 0.9822506 0.3100148 0.9223 119 0.0458829440 0.9842932 0.3168711 0.9175 120 0.0501914546 0.9826845 0.2909254 0.9223 121 0.0481719598 0.9832986 0.3518380 0.9077 122 0.0498914644 0.9825310 0.3043787 0.9172 123 0.0477419458 0.9834121 0.3157377 0.9185 124 0.0490820520 0.9828314 0.3377511 0.9147 125 0.0475868434 0.9834121 0.3097544 0.9197 126 0.0468654074 0.9836591 0.3282976 0.9149 127 0.0484195203 0.9831651 0.3279636 0.9112 128 0.0494253896 0.9830984 0.3284343 0.9181 129 0.0495714210 0.9828714 0.3118525 0.9222 130 0.0454745181 0.9842732 0.3151410 0.9216 131 0.0508676767 0.9825376 0.3171544 0.9178 132 0.0490486734 0.9831251 0.3347439 0.9142 133 0.0478783101 0.9834455 0.2989904 0.9217 134 0.0480618253 0.9832519 0.3214881 0.9210 135 0.0496036075 0.9829782 0.3240486 0.9179 136 0.0484068766 0.9834521 0.3222163 0.9155 137 0.0500256419 0.9828247 0.3285780 0.9164 138 0.0463866331 0.9845936 0.3160755 0.9167 139 0.0486600287 0.9830717 0.3094704 0.9216 140 0.0469037145 0.9835523 0.3097007 0.9196 141 0.0452570617 0.9843199 0.3071093 0.9208 142 0.0470497832 0.9839928 0.3157251 0.9212 143 0.0474198461 0.9835256 0.3316968 0.9133 144 0.0462016203 0.9837726 0.3216465 0.9186 145 0.0471414886 0.9837258 0.2952453 0.9255 146 0.0463583656 0.9841464 0.3129661 0.9184 147 0.0453013144 0.9841664 0.3125087 0.9189 148 0.0466537140 0.9841797 0.3025919 0.9186 149 0.0462883748 0.9837525 0.3085271 0.9207 150 0.0465994962 0.9838393 0.3129956 0.9184 151 0.0179355461 0.9942593 0.2434872 0.9358 152 0.0099607548 0.9971030 0.2347471 0.9396 153 0.0083332034 0.9977238 0.2319643 0.9409 154 0.0069651329 0.9981376 0.2407704 0.9403 155 0.0062687914 0.9983579 0.2419545 0.9404 156 0.0053172470 0.9986049 0.2472901 0.9415 157 0.0050161425 0.9986783 0.2494557 0.9381 158 0.0048122229 0.9986984 0.2474525 0.9404 159 0.0040299026 0.9990788 0.2543326 0.9379 160 0.0039685452 0.9990521 0.2421567 0.9432 161 0.0039107334 0.9990655 0.2487621 0.9418 162 0.0039215549 0.9989653 0.2472721 0.9404 163 0.0034787229 0.9991322 0.2542889 0.9414 164 0.0032531412 0.9991589 0.2506921 0.9433 165 0.0032355059 0.9992324 0.2688347 0.9403 166 0.0031834957 0.9991723 0.2480445 0.9431 167 0.0029187598 0.9992991 0.2372066 0.9431 168 0.0029702932 0.9993325 0.2661545 0.9404 169 0.0029564695 0.9992123 0.2500715 0.9411 170 0.0028517786 0.9993325 0.2564101 0.9418 171 0.0027707925 0.9993859 0.2541938 0.9425 172 0.0025904675 0.9993125 0.2590739 0.9435 173 0.0027950117 0.9993392 0.2597617 0.9420 174 0.0024462701 0.9994593 0.2701470 0.9409 175 0.0028818778 0.9992123 0.2515804 0.9409 176 0.0025572372 0.9993926 0.2587079 0.9443 177 0.0024516101 0.9994059 0.2657771 0.9409 178 0.0023870892 0.9994259 0.2603669 0.9413 179 0.0022910915 0.9994593 0.2654268 0.9415 180 0.0021746829 0.9995528 0.2575617 0.9433 181 0.0020988819 0.9995328 0.2583516 0.9424 182 0.0020656686 0.9995795 0.2605426 0.9409 183 0.0022016745 0.9994727 0.2626130 0.9426 184 0.0018577089 0.9995862 0.2597829 0.9427 185 0.0019622957 0.9995194 0.2550812 0.9454 186 0.0021711886 0.9995061 0.2598144 0.9425 187 0.0018792929 0.9995394 0.2525495 0.9456 188 0.0019036195 0.9995661 0.2515860 0.9439 189 0.0018809845 0.9995461 0.2624856 0.9402 190 0.0017438964 0.9996395 0.2562388 0.9444 191 0.0017094957 0.9996262 0.2676136 0.9437 192 0.0017950477 0.9996195 0.2710856 0.9431 193 0.0019127132 0.9994794 0.2703643 0.9445 194 0.0018835242 0.9995328 0.2650857 0.9395 195 0.0019588454 0.9994927 0.2652496 0.9424 196 0.0017334609 0.9995595 0.2621994 0.9431 197 0.0017132707 0.9995995 0.2660967 0.9426 198 0.0018062600 0.9995528 0.2693620 0.9455 199 0.0017370149 0.9995595 0.2682706 0.9420 200 0.0016068980 0.9996128 0.2722749 0.9427 201 0.0016427378 0.9995795 0.2598840 0.9410 202 0.0016552433 0.9995795 0.2665069 0.9426 203 0.0015986745 0.9996062 0.2735996 0.9438 204 0.0015606694 0.9996796 0.2536149 0.9439 205 0.0015719079 0.9996395 0.2672495 0.9429 206 0.0016013680 0.9996462 0.2578813 0.9427 207 0.0014818421 0.9996328 0.2665404 0.9408 208 0.0015539966 0.9996796 0.2606453 0.9418 209 0.0017199826 0.9995928 0.2720141 0.9436 210 0.0016483091 0.9996462 0.2589284 0.9423 211 0.0016058885 0.9996128 0.2605408 0.9423 212 0.0015012980 0.9996395 0.2762633 0.9425 213 0.0015125026 0.9996328 0.2594804 0.9445 214 0.0014314755 0.9996862 0.2473979 0.9450 215 0.0014756023 0.9997463 0.2599065 0.9441 216 0.0016070793 0.9995862 0.2627328 0.9437 217 0.0013638893 0.9997063 0.2660143 0.9419 218 0.0014398288 0.9997129 0.2545859 0.9441 219 0.0015194830 0.9996395 0.2630380 0.9410 220 0.0015468349 0.9996595 0.2656153 0.9451 221 0.0013944136 0.9996529 0.2523253 0.9437 222 0.0013571376 0.9996929 0.2552068 0.9452 223 0.0014020227 0.9996929 0.2580487 0.9435 224 0.0013498960 0.9997463 0.2624156 0.9427 225 0.0016269906 0.9995661 0.2587347 0.9438 226 0.0013636855 0.9996796 0.2672617 0.9450 227 0.0013746357 0.9996996 0.2526856 0.9470 228 0.0012725247 0.9997063 0.2670771 0.9455 229 0.0013421818 0.9996595 0.2594174 0.9433 230 0.0013267826 0.9996996 0.2623502 0.9416 231 0.0013328877 0.9997196 0.2666925 0.9427 232 0.0012907290 0.9997463 0.2553433 0.9447 233 0.0012470031 0.9997396 0.2454526 0.9480 234 0.0012071745 0.9997396 0.2559259 0.9436 235 0.0012451368 0.9996996 0.2539373 0.9447 236 0.0011727633 0.9997663 0.2663018 0.9452 237 0.0012135674 0.9997730 0.2606533 0.9426 238 0.0012924110 0.9997263 0.2532763 0.9457 239 0.0012165137 0.9997597 0.2509798 0.9453 240 0.0011595050 0.9997931 0.2571831 0.9435 241 0.0012686322 0.9997129 0.2585961 0.9420 242 0.0012253302 0.9997931 0.2700651 0.9405 243 0.0011738786 0.9996929 0.2546351 0.9446 244 0.0012180576 0.9997196 0.2649637 0.9431 245 0.0012297052 0.9996996 0.2534186 0.9427 246 0.0010129946 0.9998064 0.2585182 0.9448 247 0.0010431785 0.9997931 0.2635866 0.9427 248 0.0011360900 0.9997663 0.2624811 0.9443 249 0.0011431404 0.9997530 0.2592306 0.9427 250 0.0010917124 0.9997730 0.2638656 0.9420 251 0.0011282272 0.9997931 0.2685623 0.9436 252 0.0012394985 0.9997196 0.2590341 0.9466 253 0.0010789371 0.9997997 0.2541146 0.9452 254 0.0011112305 0.9997530 0.2587816 0.9448 255 0.0011482992 0.9997931 0.2540508 0.9453 256 0.0010297050 0.9997864 0.2542107 0.9453 257 0.0010894474 0.9998131 0.2502379 0.9464 258 0.0012215456 0.9996996 0.2591000 0.9431 259 0.0010120156 0.9998131 0.2605835 0.9433 260 0.0010174772 0.9998331 0.2492044 0.9461 261 0.0011001317 0.9998065 0.2490786 0.9462 262 0.0010827924 0.9998064 0.2564962 0.9444 263 0.0011595695 0.9997597 0.2577664 0.9429 264 0.0011081379 0.9998198 0.2430533 0.9461 265 0.0012434288 0.9996796 0.2599364 0.9418 266 0.0010980372 0.9997997 0.2565093 0.9456 267 0.0010212519 0.9997997 0.2591139 0.9443 268 0.0010855130 0.9997864 0.2760078 0.9420 269 0.0011163345 0.9997797 0.2608415 0.9423 270 0.0009602397 0.9997663 0.2555375 0.9441 271 0.0011822636 0.9997396 0.2572889 0.9452 272 0.0010497085 0.9997730 0.2497728 0.9483 273 0.0009901815 0.9998398 0.2649685 0.9436 274 0.0010544929 0.9997797 0.2532931 0.9447 275 0.0011266209 0.9997396 0.2632809 0.9436 276 0.0010210563 0.9998131 0.2508913 0.9437 277 0.0011818097 0.9997530 0.2542528 0.9436 278 0.0011202770 0.9997530 0.2587285 0.9421 279 0.0009996522 0.9997730 0.2505912 0.9449 280 0.0010250015 0.9998331 0.2539868 0.9469 281 0.0012771676 0.9996328 0.2584452 0.9432 282 0.0010434379 0.9997663 0.2590295 0.9438 283 0.0010412202 0.9997663 0.2557743 0.9444 284 0.0010772038 0.9998064 0.2634786 0.9432 285 0.0009531339 0.9997931 0.2565470 0.9437 286 0.0009395551 0.9998465 0.2603623 0.9419 287 0.0011374332 0.9997530 0.2583754 0.9430 288 0.0011194443 0.9997730 0.2653035 0.9413 289 0.0009858023 0.9998264 0.2620707 0.9449 290 0.0012037378 0.9997463 0.2633227 0.9414 291 0.0010569250 0.9998198 0.2656918 0.9436 292 0.0009989782 0.9997797 0.2484338 0.9446 293 0.0010349407 0.9997864 0.2612786 0.9420 294 0.0012185022 0.9997063 0.2617649 0.9430 295 0.0009881820 0.9998531 0.2632387 0.9452 296 0.0011108380 0.9997931 0.2669857 0.9438 297 0.0011600659 0.9997663 0.2645464 0.9420 298 0.0009890068 0.9997931 0.2572861 0.9447 299 0.0010017925 0.9997997 0.2663090 0.9431 300 0.0010720694 0.9997463 0.2676879 0.9416
CIFAR-100¶
inputs = layer_input(shape = dim(cifar100$train$x)[-1])
g.r. = 12L
predictions = inputs %>%
layer_conv_2d(filters = 2*g.r., kernel_size = 3L, padding = 'same', use_bias = FALSE) %>%
dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>% # layer 1
dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>% # layer 2
dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>% # layer 3
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_relu() %>%
layer_average_pooling_2d(pool_size = 4L) %>%
layer_flatten() %>%
layer_dense(activation = 'softmax', units = 100L)
densenet.cifar100 = keras_model(inputs = inputs, outputs = predictions)
[1] "dense_layer: growth_rate = 12, num_blocks = 16" KerasTensor(type_spec=TensorSpec(shape=(None, 32, 32, 216), dtype=tf.float32, name=None), name='concatenate_15/concat:0', description="created by layer 'concatenate_15'") [1] "dense_layer: growth_rate = 12, num_blocks = 16" KerasTensor(type_spec=TensorSpec(shape=(None, 16, 16, 300), dtype=tf.float32, name=None), name='concatenate_31/concat:0', description="created by layer 'concatenate_31'") [1] "dense_layer: growth_rate = 12, num_blocks = 16" KerasTensor(type_spec=TensorSpec(shape=(None, 8, 8, 342), dtype=tf.float32, name=None), name='concatenate_47/concat:0', description="created by layer 'concatenate_47'")
PiecewiseConstantDecay, 900 epochs, best accuracy = 75.61%¶
lr = tf$keras$optimizers$schedules$PiecewiseConstantDecay(
boundaries = nrow(cifar100$train$x) / 64 * c(150, 225) * 3, values = c(0.1, 0.01, 0.001))
opt = tf$keras$optimizers$experimental$SGD(learning_rate = lr, momentum = .9, nesterov=TRUE, weight_decay=1e-4)
densenet.cifar100$compile(loss='categorical_crossentropy',optimizer= opt,metrics=list('accuracy'))
mc = callback_model_checkpoint('vanilla_DenseNet_CIFAR100-{epoch:03d}_{val_loss:.2f}_{val_accuracy:.4f}.ckpt', monitor = 'val_accuracy', save_best_only = FALSE, mode = 'auto', save_weights_only = TRUE, save_freq = 'epoch')
history = densenet.cifar100 %>% fit(
cifar100.datagen$flow(cifar100$train$x, to_categorical(cifar100$train$y, num_classes=100L), batch_size = 64L, shuffle=TRUE),
epochs = 300 * 3,
validation_data = cifar100.datagen$flow(cifar100$test$x, to_categorical(cifar100$test$y, num_classes=100L), batch_size = 1000L),
callbacks = list(mc, callback_terminate_on_naan(), callback_csv_logger(filename = 'vanilla_DenseNet_CIFAR100.txt')))
saveRDS(history, file='history.RDS')
print(data.frame(history$metrics))
plot(history)
loss accuracy val_loss val_accuracy 1 3.853438854 0.10366 3.595496 0.1520 2 3.029376268 0.24250 2.706999 0.2965 3 2.406250715 0.36314 2.238456 0.3950 4 2.041969538 0.44308 2.022791 0.4536 5 1.804827213 0.49968 1.995402 0.4626 6 1.627967000 0.54150 1.696434 0.5296 7 1.504715323 0.57346 1.596601 0.5555 8 1.398904324 0.59946 1.539676 0.5723 9 1.312076092 0.62070 1.454368 0.5912 10 1.241668940 0.63882 1.450769 0.5950 11 1.187236190 0.65216 1.403780 0.6011 12 1.121083736 0.66956 1.372716 0.6166 13 1.075478315 0.68436 1.381446 0.6170 14 1.027493477 0.69464 1.370398 0.6239 15 0.992836535 0.70366 1.318132 0.6346 16 0.960465789 0.71146 1.352921 0.6263 17 0.929455519 0.72070 1.330776 0.6296 18 0.887544990 0.73242 1.327415 0.6342 19 0.859275520 0.73898 1.317398 0.6452 20 0.832880676 0.74668 1.291956 0.6491 21 0.813728988 0.74904 1.279784 0.6491 22 0.784946859 0.75992 1.270259 0.6560 23 0.759003758 0.76634 1.327843 0.6396 24 0.740605950 0.77122 1.311594 0.6503 25 0.719220102 0.77670 1.300073 0.6566 26 0.703240573 0.78366 1.263511 0.6599 27 0.681695402 0.78788 1.326917 0.6568 28 0.671582043 0.79046 1.311492 0.6535 29 0.655797660 0.79498 1.302583 0.6557 30 0.638647318 0.79950 1.278390 0.6635 31 0.622537255 0.80260 1.276709 0.6617 32 0.603244066 0.80974 1.301411 0.6696 33 0.585440874 0.81524 1.317993 0.6663 34 0.569371819 0.81776 1.310695 0.6614 35 0.565074861 0.81964 1.309143 0.6687 36 0.544965208 0.82728 1.300821 0.6699 37 0.538428247 0.82800 1.363141 0.6611 38 0.525849342 0.83112 1.362929 0.6646 39 0.513292432 0.83350 1.427855 0.6473 40 0.514054894 0.83336 1.303089 0.6781 41 0.495509148 0.84118 1.368043 0.6728 42 0.492607355 0.84066 1.345099 0.6750 43 0.482471704 0.84248 1.329873 0.6764 44 0.461801440 0.84874 1.370691 0.6646 45 0.461309999 0.85152 1.355694 0.6702 46 0.453512102 0.85210 1.316995 0.6801 47 0.444712967 0.85450 1.371333 0.6672 48 0.437105894 0.85720 1.399071 0.6682 49 0.435878217 0.85702 1.372445 0.6745 50 0.433400244 0.86004 1.370352 0.6759 51 0.418163389 0.86434 1.347827 0.6846 52 0.409113050 0.86630 1.413429 0.6696 53 0.398824692 0.87068 1.383677 0.6736 54 0.404064953 0.86696 1.384419 0.6721 55 0.388572395 0.87130 1.379344 0.6750 56 0.378892452 0.87394 1.400450 0.6806 57 0.381080717 0.87472 1.444072 0.6745 58 0.368275046 0.87922 1.439468 0.6765 59 0.370565772 0.87692 1.391768 0.6778 60 0.368072003 0.87986 1.381385 0.6829 61 0.354157895 0.88380 1.441241 0.6717 62 0.370540589 0.87834 1.468642 0.6741 63 0.350557864 0.88708 1.443068 0.6729 64 0.349365741 0.88296 1.397306 0.6798 65 0.347443432 0.88392 1.423342 0.6827 66 0.335909754 0.89014 1.431867 0.6835 67 0.333221674 0.89018 1.402137 0.6838 68 0.337600589 0.88768 1.445183 0.6755 69 0.325612396 0.89220 1.441339 0.6776 70 0.322922081 0.89226 1.434373 0.6778 71 0.309412003 0.89700 1.463794 0.6741 72 0.310336828 0.89748 1.454021 0.6803 73 0.311839283 0.89774 1.466136 0.6867 74 0.295288116 0.90196 1.461647 0.6788 75 0.307723254 0.89814 1.445068 0.6802 76 0.298129112 0.90130 1.439320 0.6830 77 0.299325556 0.89994 1.481609 0.6778 78 0.286619693 0.90466 1.488152 0.6731 79 0.291644156 0.90222 1.440696 0.6866 80 0.290452898 0.90256 1.511275 0.6800 81 0.278444290 0.90764 1.521133 0.6798 82 0.283030033 0.90574 1.486773 0.6824 83 0.277228087 0.90898 1.436233 0.6879 84 0.284176886 0.90476 1.506604 0.6832 85 0.273584932 0.90962 1.472148 0.6843 86 0.273817390 0.90904 1.479098 0.6840 87 0.276113123 0.90852 1.510844 0.6794 88 0.270873100 0.91066 1.479823 0.6882 89 0.265368849 0.91170 1.573855 0.6753 90 0.268944442 0.90872 1.531904 0.6741 91 0.272095144 0.90840 1.520514 0.6858 92 0.261273175 0.91266 1.508646 0.6852 93 0.261625320 0.91286 1.483993 0.6825 94 0.255627960 0.91422 1.487268 0.6847 95 0.252025664 0.91458 1.540516 0.6813 96 0.253052950 0.91530 1.496504 0.6842 97 0.247655407 0.91686 1.505142 0.6844 98 0.243123919 0.91808 1.525993 0.6862 99 0.245150745 0.91878 1.565540 0.6875 100 0.252356112 0.91550 1.506533 0.6849 101 0.243642926 0.91832 1.485817 0.6919 102 0.240009710 0.92040 1.563167 0.6807 103 0.249794781 0.91674 1.463683 0.6956 104 0.239165962 0.91850 1.533352 0.6851 105 0.238270521 0.92112 1.497235 0.6876 106 0.230006620 0.92312 1.540964 0.6885 107 0.236917302 0.92100 1.589840 0.6782 108 0.234731048 0.92226 1.500657 0.6916 109 0.234176680 0.92178 1.523285 0.6881 110 0.221878007 0.92660 1.504256 0.6855 111 0.229166731 0.92318 1.528715 0.6882 112 0.230130494 0.92286 1.531066 0.6859 113 0.225189477 0.92336 1.554032 0.6840 114 0.227609709 0.92446 1.518346 0.6904 115 0.224551991 0.92466 1.530482 0.6905 116 0.226433456 0.92340 1.533297 0.6849 117 0.222123355 0.92596 1.544646 0.6864 118 0.218439654 0.92652 1.570457 0.6912 119 0.217200741 0.92624 1.563900 0.6925 120 0.211740687 0.92922 1.635813 0.6758 121 0.222971067 0.92610 1.492358 0.6970 122 0.225700825 0.92490 1.519286 0.6965 123 0.207204282 0.93174 1.584369 0.6891 124 0.214719176 0.92776 1.581977 0.6869 125 0.211682513 0.92874 1.540503 0.6862 126 0.220860347 0.92572 1.551330 0.6923 127 0.206330746 0.93236 1.546438 0.6875 128 0.212723926 0.92988 1.560123 0.6866 129 0.204282403 0.93184 1.542713 0.6881 130 0.216955975 0.92758 1.540074 0.6851 131 0.212414458 0.92972 1.536068 0.6943 132 0.204815716 0.93142 1.630420 0.6792 133 0.216589123 0.92848 1.527262 0.6950 134 0.201428577 0.93354 1.526483 0.6879 135 0.203964844 0.93304 1.587180 0.6892 136 0.206053287 0.93204 1.585572 0.6870 137 0.196083024 0.93474 1.541427 0.6888 138 0.208957046 0.92942 1.536776 0.6901 139 0.196804479 0.93438 1.547233 0.6900 140 0.210610121 0.93096 1.614300 0.6811 141 0.196455941 0.93412 1.557264 0.6887 142 0.199247748 0.93400 1.557708 0.6933 143 0.198018119 0.93450 1.653597 0.6843 144 0.214937598 0.92836 1.580083 0.6847 145 0.192141831 0.93472 1.576070 0.6890 146 0.187078848 0.93702 1.547123 0.6909 147 0.198283032 0.93414 1.580042 0.6848 148 0.201220945 0.93268 1.540498 0.6917 149 0.193608880 0.93552 1.556743 0.6901 150 0.200757384 0.93296 1.542952 0.6927 151 0.196128398 0.93386 1.581223 0.6874 152 0.181882381 0.93954 1.638187 0.6843 153 0.197095975 0.93556 1.549024 0.6953 154 0.193902224 0.93452 1.598127 0.6855 155 0.200428516 0.93266 1.570941 0.6849 156 0.198103458 0.93220 1.533105 0.6952 157 0.186783805 0.93764 1.598982 0.6920 158 0.184159100 0.93850 1.564770 0.6840 159 0.183524430 0.93958 1.641119 0.6850 160 0.196081296 0.93478 1.586900 0.6884 161 0.190530345 0.93574 1.542786 0.6929 162 0.182412386 0.93898 1.549326 0.6956 163 0.174402520 0.94140 1.624365 0.6878 164 0.186636686 0.93786 1.563943 0.7005 165 0.187461212 0.93800 1.572465 0.6941 166 0.179597780 0.94010 1.514007 0.6906 167 0.193388849 0.93614 1.577327 0.6916 168 0.186349198 0.93704 1.590873 0.6865 169 0.190994591 0.93686 1.576589 0.6971 170 0.186066061 0.93866 1.544907 0.6966 171 0.181229964 0.94046 1.590444 0.6885 172 0.188194081 0.93654 1.578899 0.6972 173 0.174478278 0.94192 1.588653 0.6908 174 0.188555017 0.93720 1.643187 0.6799 175 0.182899877 0.93932 1.624101 0.6867 176 0.182647854 0.93832 1.582345 0.6905 177 0.178410992 0.94128 1.593744 0.6818 178 0.179302543 0.94044 1.627043 0.6882 179 0.187712759 0.93692 1.614830 0.6843 180 0.178961322 0.94038 1.549768 0.6979 181 0.176136956 0.94136 1.621494 0.6854 182 0.170544937 0.94378 1.559111 0.6981 183 0.176616132 0.94202 1.652917 0.6890 184 0.180924907 0.93990 1.560009 0.6966 185 0.175623924 0.94090 1.562559 0.6968 186 0.176832005 0.94082 1.616949 0.6858 187 0.175889254 0.94228 1.558364 0.6925 188 0.171840131 0.94210 1.604219 0.6956 189 0.179239318 0.94250 1.674956 0.6826 190 0.187601402 0.93660 1.663619 0.6814 191 0.176178679 0.94002 1.555668 0.6957 192 0.172151938 0.94220 1.603836 0.6876 193 0.170420021 0.94332 1.583708 0.6868 194 0.177237272 0.94082 1.631758 0.6867 195 0.177409858 0.94110 1.549829 0.7010 196 0.168837130 0.94432 1.561375 0.6931 197 0.174638569 0.94186 1.635010 0.6855 198 0.175447032 0.94142 1.624403 0.6898 199 0.173982367 0.94232 1.584686 0.6910 200 0.169081569 0.94310 1.625709 0.6959 201 0.171336114 0.94238 1.567278 0.6968 202 0.167572185 0.94490 1.709782 0.6830 203 0.179914773 0.93986 1.657342 0.6878 204 0.176740378 0.94206 1.571865 0.6891 205 0.171712026 0.94266 1.576882 0.6987 206 0.173517361 0.94174 1.551179 0.6974 207 0.168849468 0.94506 1.609468 0.6973 208 0.175623998 0.94068 1.623109 0.6915 209 0.168745100 0.94408 1.581399 0.6916 210 0.172160462 0.94276 1.571459 0.6888 211 0.171243861 0.94204 1.630494 0.6943 212 0.169804722 0.94292 1.632143 0.6860 213 0.164358437 0.94636 1.598293 0.6939 214 0.168748274 0.94380 1.577195 0.6895 215 0.168900877 0.94362 1.568724 0.6916 216 0.174118951 0.94240 1.610893 0.6873 217 0.158761218 0.94704 1.596302 0.6941 218 0.156621978 0.94784 1.609557 0.6911 219 0.168153256 0.94380 1.618056 0.6933 220 0.173418626 0.94240 1.612581 0.6874 221 0.165395319 0.94466 1.590066 0.6978 222 0.164805934 0.94572 1.600944 0.6980 223 0.160869703 0.94696 1.621505 0.6912 224 0.160189614 0.94674 1.636526 0.6874 225 0.170317620 0.94280 1.582918 0.6946 226 0.172883928 0.94234 1.635332 0.6840 227 0.165740788 0.94484 1.551298 0.6940 228 0.170838267 0.94292 1.625545 0.6873 229 0.156748801 0.94824 1.588758 0.6976 230 0.169500157 0.94354 1.730745 0.6899 231 0.166624740 0.94470 1.580095 0.6948 232 0.156552985 0.94780 1.634737 0.6894 233 0.155373946 0.94882 1.594615 0.6902 234 0.160767585 0.94574 1.604833 0.6961 235 0.162376836 0.94630 1.618500 0.6950 236 0.163342729 0.94514 1.598217 0.6923 237 0.164869830 0.94504 1.589932 0.6901 238 0.167083070 0.94388 1.623780 0.6922 239 0.166366309 0.94386 1.686258 0.6806 240 0.163939670 0.94596 1.672970 0.6881 241 0.163656726 0.94590 1.584843 0.6938 242 0.160308957 0.94566 1.559196 0.6938 243 0.158870742 0.94752 1.782576 0.6799 244 0.153315708 0.94776 1.677821 0.6868 245 0.160054103 0.94704 1.633507 0.6918 246 0.166394845 0.94496 1.650819 0.6902 247 0.152245060 0.95044 1.574513 0.6950 248 0.153835684 0.94816 1.601563 0.6918 249 0.162520513 0.94680 1.587545 0.6883 250 0.158546731 0.94636 1.600158 0.6994 251 0.148602009 0.95026 1.591422 0.6957 252 0.169357643 0.94304 1.573461 0.6911 253 0.160394460 0.94652 1.615093 0.6948 254 0.155707747 0.94874 1.635089 0.6903 255 0.162835807 0.94572 1.613227 0.6941 256 0.158834085 0.94674 1.635267 0.6871 257 0.160693660 0.94676 1.644222 0.6879 258 0.158949137 0.94728 1.588645 0.6935 259 0.150822088 0.94974 1.624476 0.6933 260 0.156115428 0.94852 1.551984 0.6989 261 0.161673680 0.94612 1.680822 0.6930 262 0.156695291 0.94822 1.572773 0.7005 263 0.159275517 0.94632 1.582463 0.6926 264 0.160095990 0.94624 1.586808 0.6980 265 0.153429657 0.94806 1.652047 0.6873 266 0.161438271 0.94698 1.645907 0.6932 267 0.162474379 0.94528 1.612501 0.6924 268 0.156747237 0.94746 1.638494 0.6885 269 0.159331560 0.94750 1.645733 0.6891 270 0.161978960 0.94576 1.551418 0.6932 271 0.151905313 0.95026 1.642252 0.6929 272 0.154146463 0.94852 1.566631 0.6956 273 0.163320765 0.94596 1.600148 0.6946 274 0.163466081 0.94606 1.647755 0.6906 275 0.160146400 0.94612 1.621664 0.7021 276 0.156542495 0.94740 1.616311 0.6963 277 0.155010879 0.94790 1.628595 0.6961 278 0.164308131 0.94536 1.621129 0.6925 279 0.150064841 0.95040 1.577397 0.6974 280 0.151650503 0.95006 1.632705 0.6927 281 0.156152159 0.94858 1.573657 0.6960 282 0.158748269 0.94712 1.555497 0.6990 283 0.145519659 0.95174 1.585111 0.6957 284 0.153746843 0.95004 1.628406 0.6940 285 0.160314709 0.94686 1.614646 0.6980 286 0.156300768 0.94788 1.693703 0.6868 287 0.155918702 0.94786 1.598323 0.6956 288 0.152074277 0.95042 1.609434 0.6991 289 0.149731889 0.94964 1.672662 0.6942 290 0.149375886 0.95082 1.583402 0.7012 291 0.155525327 0.94850 1.657447 0.6954 292 0.154608265 0.94880 1.610406 0.6922 293 0.147619769 0.95094 1.619599 0.6947 294 0.159224764 0.94780 1.602643 0.6971 295 0.150616378 0.95076 1.598443 0.6945 296 0.149812341 0.95028 1.569883 0.6978 297 0.142961130 0.95236 1.606472 0.6973 298 0.161538079 0.94624 1.649357 0.6905 299 0.155460984 0.94824 1.585353 0.6947 300 0.157084525 0.94720 1.586110 0.6967 301 0.150351375 0.94986 1.628255 0.6902 302 0.148078367 0.95054 1.655204 0.6911 303 0.155564860 0.94816 1.588208 0.7007 304 0.156112134 0.94816 1.586227 0.6901 305 0.139473066 0.95328 1.666981 0.6885 306 0.145163044 0.95116 1.641278 0.6960 307 0.156359464 0.94784 1.604285 0.6973 308 0.158834875 0.94600 1.699039 0.6882 309 0.148029536 0.95076 1.571804 0.7037 310 0.141125873 0.95274 1.604127 0.6931 311 0.134430289 0.95552 1.587296 0.6973 312 0.153468445 0.94936 1.668628 0.6904 313 0.156663150 0.94830 1.611753 0.6944 314 0.156011999 0.94772 1.665336 0.6905 315 0.150991023 0.94944 1.627367 0.6882 316 0.151845798 0.94974 1.601071 0.6991 317 0.153294876 0.94994 1.600884 0.6964 318 0.139552265 0.95402 1.703377 0.6829 319 0.153586730 0.94932 1.647605 0.6914 320 0.165302575 0.94488 1.616824 0.6912 321 0.147717029 0.95126 1.593752 0.7007 322 0.135137230 0.95562 1.601599 0.7013 323 0.151293650 0.94984 1.650213 0.6959 324 0.154071644 0.94896 1.620587 0.6985 325 0.145795420 0.95104 1.605487 0.6982 326 0.143773049 0.95344 1.628502 0.6913 327 0.152538463 0.94996 1.578949 0.6943 328 0.148834243 0.94968 1.589458 0.7036 329 0.155491024 0.94816 1.605622 0.6982 330 0.155179396 0.94782 1.613582 0.6952 331 0.146019667 0.95184 1.597742 0.6996 332 0.139302239 0.95458 1.605506 0.6921 333 0.146719351 0.95150 1.560899 0.7039 334 0.144795492 0.95200 1.668529 0.6879 335 0.147545025 0.95144 1.609496 0.6988 336 0.145810306 0.95078 1.571884 0.6969 337 0.145682499 0.95078 1.659166 0.6902 338 0.151341334 0.94866 1.643639 0.6923 339 0.154521286 0.94870 1.623893 0.6939 340 0.148744017 0.95090 1.603120 0.7023 341 0.155419812 0.94862 1.594821 0.6927 342 0.145073339 0.95260 1.639588 0.6942 343 0.141883448 0.95342 1.596513 0.6979 344 0.148682743 0.95096 1.613608 0.6917 345 0.145902410 0.95102 1.620131 0.7011 346 0.138717175 0.95380 1.610826 0.6949 347 0.144152358 0.95216 1.554307 0.7042 348 0.139813170 0.95396 1.625822 0.6982 349 0.148287818 0.95108 1.635874 0.6904 350 0.141597345 0.95280 1.589076 0.7012 351 0.143589795 0.95270 1.641688 0.6880 352 0.150514320 0.94920 1.612702 0.6932 353 0.134994134 0.95512 1.594752 0.6996 354 0.146415249 0.95174 1.618544 0.6921 355 0.147946894 0.95128 1.595780 0.6941 356 0.141507193 0.95332 1.587531 0.6987 357 0.143222481 0.95218 1.642227 0.6977 358 0.136874095 0.95442 1.611352 0.7001 359 0.134477109 0.95490 1.587188 0.7052 360 0.147681177 0.95076 1.618784 0.7007 361 0.151681140 0.94910 1.624903 0.7002 362 0.148775756 0.94986 1.625516 0.6916 363 0.140271589 0.95264 1.695393 0.6960 364 0.152510107 0.94912 1.614601 0.6947 365 0.143475816 0.95292 1.631860 0.6932 366 0.146330699 0.95134 1.674337 0.6927 367 0.145930842 0.95130 1.583898 0.7017 368 0.141588286 0.95262 1.618281 0.6956 369 0.137425229 0.95444 1.687356 0.6918 370 0.145295680 0.95290 1.647281 0.6901 371 0.152795196 0.94978 1.579288 0.7001 372 0.136128083 0.95522 1.563160 0.7031 373 0.137540638 0.95492 1.597466 0.6983 374 0.139822215 0.95386 1.554187 0.6991 375 0.139861718 0.95402 1.588858 0.6996 376 0.143360466 0.95258 1.653458 0.6931 377 0.140761986 0.95272 1.612385 0.6944 378 0.136317670 0.95414 1.629950 0.6937 379 0.137556970 0.95476 1.632674 0.6931 380 0.147856683 0.95210 1.682567 0.6927 381 0.146144703 0.95162 1.601905 0.6996 382 0.142928511 0.95296 1.604671 0.6957 383 0.132159099 0.95612 1.595814 0.7000 384 0.135413557 0.95550 1.583656 0.7009 385 0.142993405 0.95282 1.614029 0.6969 386 0.154201001 0.94866 1.644896 0.6927 387 0.135699034 0.95504 1.590318 0.6913 388 0.142919347 0.95224 1.613929 0.6941 389 0.139648736 0.95394 1.612259 0.6923 390 0.142419606 0.95416 1.542146 0.7025 391 0.144736722 0.95178 1.573453 0.6961 392 0.136304215 0.95480 1.585395 0.7015 393 0.151811987 0.94834 1.635323 0.6952 394 0.145787388 0.95106 1.621509 0.6951 395 0.132778943 0.95624 1.573607 0.7008 396 0.131337613 0.95538 1.629827 0.6917 397 0.145763516 0.95292 1.601867 0.6992 398 0.140917927 0.95296 1.635482 0.7001 399 0.139591441 0.95354 1.623037 0.6961 400 0.130842701 0.95670 1.623863 0.6911 401 0.137896061 0.95490 1.583350 0.6995 402 0.143798321 0.95318 1.701605 0.6870 403 0.147250786 0.95230 1.662524 0.6929 404 0.134439930 0.95474 1.655495 0.6968 405 0.134252042 0.95570 1.677080 0.6918 406 0.150872976 0.95048 1.594711 0.6981 407 0.144244492 0.95178 1.633507 0.6970 408 0.133209661 0.95506 1.668014 0.6935 409 0.135291576 0.95468 1.608640 0.6964 410 0.139853045 0.95466 1.642455 0.6914 411 0.137768671 0.95490 1.600708 0.6923 412 0.137735337 0.95466 1.657257 0.6959 413 0.146236598 0.95182 1.577553 0.7015 414 0.136942446 0.95502 1.586368 0.6981 415 0.137865007 0.95344 1.600882 0.7012 416 0.133690596 0.95608 1.625185 0.6973 417 0.140155673 0.95338 1.634989 0.6879 418 0.147813395 0.95078 1.671531 0.6919 419 0.136237681 0.95542 1.617158 0.7006 420 0.138827220 0.95378 1.563253 0.6991 421 0.139598206 0.95344 1.549036 0.7017 422 0.141333848 0.95292 1.610616 0.7003 423 0.136539862 0.95516 1.617414 0.7006 424 0.134891361 0.95586 1.606521 0.6937 425 0.131976843 0.95620 1.615617 0.6993 426 0.129547238 0.95720 1.579768 0.7001 427 0.141107827 0.95330 1.645673 0.6948 428 0.137960166 0.95500 1.568360 0.7003 429 0.136556715 0.95540 1.613927 0.6992 430 0.137796089 0.95438 1.606424 0.6946 431 0.139588729 0.95360 1.651905 0.6933 432 0.141046390 0.95436 1.593631 0.7004 433 0.135906413 0.95390 1.641060 0.6918 434 0.128570244 0.95612 1.609333 0.7014 435 0.139806166 0.95424 1.659620 0.6971 436 0.134976178 0.95486 1.628250 0.6965 437 0.140654042 0.95310 1.587914 0.6961 438 0.141185418 0.95328 1.636467 0.6909 439 0.137370929 0.95416 1.662184 0.6898 440 0.133600637 0.95536 1.589532 0.7033 441 0.134294435 0.95578 1.646176 0.6974 442 0.147875652 0.95060 1.608860 0.6972 443 0.139603287 0.95392 1.591139 0.7025 444 0.136673346 0.95556 1.612237 0.7019 445 0.132182434 0.95656 1.595953 0.7034 446 0.130647138 0.95746 1.615778 0.6988 447 0.130669087 0.95620 1.639447 0.6942 448 0.130038485 0.95750 1.616603 0.6973 449 0.135308728 0.95486 1.659323 0.6933 450 0.124344736 0.95886 1.441587 0.7229 451 0.050656807 0.98490 1.407676 0.7322 452 0.035060871 0.98980 1.401264 0.7297 453 0.029878315 0.99262 1.383097 0.7313 454 0.026154773 0.99322 1.375369 0.7358 455 0.024405673 0.99418 1.393669 0.7343 456 0.021527730 0.99504 1.378236 0.7383 457 0.020021332 0.99536 1.387158 0.7409 458 0.020443695 0.99490 1.372585 0.7412 459 0.018110137 0.99588 1.374436 0.7407 460 0.017333362 0.99628 1.382423 0.7392 461 0.016203463 0.99682 1.382789 0.7418 462 0.015902476 0.99652 1.392298 0.7403 463 0.015864998 0.99646 1.396140 0.7430 464 0.015451864 0.99640 1.384595 0.7408 465 0.014641048 0.99688 1.395020 0.7395 466 0.013376734 0.99722 1.386763 0.7403 467 0.013458423 0.99720 1.391173 0.7430 468 0.013263275 0.99708 1.395430 0.7436 469 0.012452913 0.99736 1.380908 0.7424 470 0.012391060 0.99736 1.384527 0.7422 471 0.012044979 0.99720 1.386881 0.7453 472 0.011625202 0.99764 1.405479 0.7406 473 0.012104890 0.99742 1.400176 0.7426 474 0.010641610 0.99794 1.398543 0.7415 475 0.010204412 0.99808 1.410941 0.7391 476 0.010913683 0.99772 1.393620 0.7408 477 0.010438056 0.99792 1.416264 0.7423 478 0.010024033 0.99814 1.385298 0.7472 479 0.010866049 0.99752 1.383568 0.7429 480 0.009872962 0.99818 1.418330 0.7423 481 0.009655324 0.99802 1.416580 0.7397 482 0.010174521 0.99770 1.425981 0.7409 483 0.009646485 0.99806 1.397506 0.7432 484 0.010067051 0.99794 1.412241 0.7420 485 0.009147012 0.99840 1.434500 0.7390 486 0.009346766 0.99814 1.407985 0.7428 487 0.009006250 0.99816 1.430214 0.7409 488 0.008698191 0.99822 1.413028 0.7459 489 0.008830434 0.99836 1.407192 0.7406 490 0.008609996 0.99834 1.413055 0.7460 491 0.008674718 0.99834 1.415996 0.7461 492 0.008479036 0.99846 1.419722 0.7431 493 0.008542957 0.99822 1.404585 0.7415 494 0.008159215 0.99854 1.402936 0.7443 495 0.007854393 0.99840 1.385249 0.7453 496 0.008298021 0.99854 1.410055 0.7466 497 0.008261654 0.99844 1.426474 0.7433 498 0.007983575 0.99838 1.407051 0.7437 499 0.007840581 0.99846 1.403839 0.7461 500 0.008229845 0.99834 1.403572 0.7431 501 0.008912349 0.99796 1.414895 0.7471 502 0.007427551 0.99874 1.412331 0.7452 503 0.007547786 0.99842 1.419775 0.7473 504 0.008169206 0.99832 1.401624 0.7436 505 0.007995957 0.99834 1.426146 0.7447 506 0.007517054 0.99860 1.420662 0.7449 507 0.007447371 0.99858 1.416089 0.7463 508 0.007361646 0.99852 1.421918 0.7470 509 0.007002162 0.99874 1.412154 0.7458 510 0.006709813 0.99870 1.423001 0.7422 511 0.007253907 0.99874 1.431654 0.7461 512 0.007252783 0.99858 1.412890 0.7466 513 0.007185285 0.99848 1.416620 0.7460 514 0.007244453 0.99872 1.395788 0.7450 515 0.006935450 0.99872 1.416527 0.7457 516 0.006749317 0.99874 1.405227 0.7472 517 0.007246925 0.99846 1.420532 0.7474 518 0.006517833 0.99878 1.417111 0.7485 519 0.006669237 0.99886 1.441695 0.7451 520 0.006493097 0.99890 1.439675 0.7439 521 0.006711027 0.99888 1.410877 0.7452 522 0.006741743 0.99876 1.420701 0.7459 523 0.006743353 0.99866 1.407158 0.7435 524 0.006978492 0.99870 1.441018 0.7440 525 0.006896419 0.99856 1.432550 0.7452 526 0.006289320 0.99904 1.438000 0.7411 527 0.006299040 0.99884 1.423004 0.7430 528 0.006359863 0.99880 1.429271 0.7439 529 0.006170395 0.99874 1.422184 0.7483 530 0.006173828 0.99888 1.426472 0.7440 531 0.006638598 0.99884 1.405762 0.7456 532 0.006081163 0.99882 1.409936 0.7443 533 0.006174278 0.99884 1.425069 0.7455 534 0.006584657 0.99864 1.401203 0.7492 535 0.006540495 0.99880 1.427776 0.7443 536 0.006161072 0.99894 1.398188 0.7470 537 0.006527155 0.99876 1.416759 0.7476 538 0.006122320 0.99896 1.433329 0.7419 539 0.006534172 0.99880 1.425503 0.7480 540 0.006129253 0.99876 1.417909 0.7467 541 0.006209362 0.99874 1.427651 0.7424 542 0.006163097 0.99906 1.424983 0.7455 543 0.006365490 0.99874 1.411736 0.7498 544 0.005908345 0.99892 1.436467 0.7478 545 0.005721463 0.99904 1.417147 0.7476 546 0.006163025 0.99888 1.404277 0.7475 547 0.005539848 0.99892 1.412556 0.7478 548 0.005734612 0.99894 1.396025 0.7440 549 0.005844006 0.99896 1.433410 0.7478 550 0.005525191 0.99906 1.404340 0.7515 551 0.005642379 0.99910 1.428759 0.7503 552 0.005488868 0.99916 1.408633 0.7469 553 0.006140531 0.99888 1.455970 0.7466 554 0.005134188 0.99920 1.434279 0.7448 555 0.006134112 0.99886 1.426066 0.7467 556 0.005669209 0.99902 1.415759 0.7488 557 0.005865015 0.99870 1.432459 0.7460 558 0.005482425 0.99908 1.402396 0.7495 559 0.005404666 0.99912 1.450093 0.7455 560 0.005720284 0.99886 1.421450 0.7486 561 0.005665937 0.99914 1.424225 0.7473 562 0.005583452 0.99898 1.399133 0.7466 563 0.005681403 0.99896 1.432307 0.7474 564 0.005543076 0.99890 1.406861 0.7475 565 0.005151812 0.99922 1.408505 0.7499 566 0.004985921 0.99928 1.419432 0.7476 567 0.005783788 0.99892 1.413324 0.7448 568 0.005192699 0.99914 1.428858 0.7476 569 0.005511692 0.99906 1.425979 0.7454 570 0.005181483 0.99912 1.404181 0.7457 571 0.005310256 0.99900 1.426186 0.7419 572 0.005155613 0.99908 1.420886 0.7458 573 0.005286245 0.99910 1.410772 0.7471 574 0.005153177 0.99916 1.411543 0.7542 575 0.005183875 0.99900 1.426109 0.7454 576 0.005280633 0.99910 1.395915 0.7494 577 0.004710726 0.99934 1.414770 0.7470 578 0.005187221 0.99894 1.433996 0.7435 579 0.005299417 0.99894 1.408595 0.7457 580 0.005011478 0.99924 1.423391 0.7460 581 0.005053359 0.99924 1.421120 0.7446 582 0.005513133 0.99900 1.398511 0.7466 583 0.005261915 0.99906 1.429351 0.7443 584 0.005426925 0.99902 1.413547 0.7450 585 0.004856053 0.99922 1.419502 0.7482 586 0.004881298 0.99926 1.418158 0.7448 587 0.004796502 0.99940 1.397063 0.7470 588 0.005019376 0.99914 1.395555 0.7448 589 0.005098244 0.99920 1.394117 0.7472 590 0.004879857 0.99912 1.402306 0.7497 591 0.005088651 0.99906 1.420829 0.7463 592 0.004997355 0.99922 1.414816 0.7474 593 0.004906875 0.99918 1.395988 0.7484 594 0.004503739 0.99942 1.394794 0.7464 595 0.004700697 0.99926 1.403609 0.7457 596 0.005166062 0.99902 1.432539 0.7445 597 0.005204814 0.99890 1.406911 0.7462 598 0.004721060 0.99928 1.413863 0.7477 599 0.004791856 0.99932 1.394592 0.7531 600 0.004706315 0.99920 1.403072 0.7470 601 0.004697399 0.99930 1.410332 0.7513 602 0.005670072 0.99886 1.401724 0.7440 603 0.004857319 0.99916 1.393421 0.7516 604 0.004929340 0.99916 1.420501 0.7433 605 0.004938248 0.99920 1.406967 0.7473 606 0.004946361 0.99908 1.411371 0.7483 607 0.004767010 0.99920 1.393423 0.7486 608 0.005068546 0.99906 1.391346 0.7506 609 0.004958915 0.99920 1.445292 0.7450 610 0.004783743 0.99924 1.413274 0.7449 611 0.005048754 0.99906 1.414828 0.7461 612 0.004865469 0.99916 1.391849 0.7481 613 0.005005597 0.99908 1.406553 0.7457 614 0.005072235 0.99920 1.420675 0.7466 615 0.004741145 0.99908 1.416713 0.7429 616 0.004723425 0.99918 1.427763 0.7442 617 0.004865183 0.99926 1.406725 0.7450 618 0.005269920 0.99890 1.392820 0.7402 619 0.004278865 0.99942 1.392555 0.7501 620 0.004817104 0.99924 1.413443 0.7455 621 0.004909006 0.99914 1.396229 0.7517 622 0.004438785 0.99930 1.424255 0.7427 623 0.004656434 0.99914 1.409922 0.7487 624 0.004672626 0.99918 1.387659 0.7445 625 0.004300877 0.99930 1.408186 0.7465 626 0.004252086 0.99952 1.380214 0.7499 627 0.004458995 0.99936 1.391417 0.7508 628 0.004539465 0.99930 1.376831 0.7517 629 0.004472214 0.99936 1.408462 0.7493 630 0.004452997 0.99926 1.395432 0.7463 631 0.004546248 0.99922 1.398501 0.7442 632 0.004568384 0.99924 1.390968 0.7493 633 0.004459475 0.99930 1.387382 0.7483 634 0.004710467 0.99918 1.392331 0.7484 635 0.004839281 0.99906 1.396761 0.7471 636 0.004367354 0.99932 1.372210 0.7478 637 0.004415509 0.99924 1.400341 0.7481 638 0.004599945 0.99926 1.390681 0.7480 639 0.004971456 0.99904 1.391984 0.7458 640 0.004943959 0.99914 1.404574 0.7441 641 0.004728653 0.99926 1.384480 0.7535 642 0.004177971 0.99942 1.359593 0.7519 643 0.004834421 0.99908 1.422060 0.7468 644 0.004372767 0.99942 1.386489 0.7482 645 0.004461181 0.99920 1.412726 0.7462 646 0.004836538 0.99912 1.412337 0.7530 647 0.004467005 0.99922 1.409206 0.7456 648 0.004232721 0.99934 1.394285 0.7444 649 0.004597748 0.99928 1.382640 0.7469 650 0.004212113 0.99946 1.396048 0.7439 651 0.004277823 0.99940 1.374542 0.7491 652 0.004639094 0.99926 1.394459 0.7490 653 0.004335594 0.99928 1.402994 0.7485 654 0.004215192 0.99926 1.398258 0.7511 655 0.004126162 0.99944 1.401571 0.7453 656 0.004210052 0.99932 1.407287 0.7461 657 0.004066105 0.99948 1.385116 0.7486 658 0.004494595 0.99946 1.399248 0.7464 659 0.004397073 0.99920 1.365847 0.7506 660 0.004421169 0.99920 1.383437 0.7484 661 0.004021584 0.99944 1.372819 0.7491 662 0.004440803 0.99918 1.405566 0.7469 663 0.004477837 0.99920 1.397726 0.7472 664 0.004594364 0.99918 1.392519 0.7490 665 0.004321391 0.99928 1.400039 0.7467 666 0.003884707 0.99936 1.388085 0.7477 667 0.004417255 0.99940 1.409185 0.7443 668 0.004120239 0.99942 1.378207 0.7489 669 0.004492601 0.99934 1.371402 0.7470 670 0.004287312 0.99924 1.382235 0.7508 671 0.004465098 0.99918 1.373153 0.7480 672 0.003855312 0.99942 1.401969 0.7481 673 0.004302789 0.99926 1.390971 0.7502 674 0.004210928 0.99932 1.388709 0.7457 675 0.004166720 0.99926 1.399589 0.7466 676 0.004043411 0.99952 1.395130 0.7484 677 0.004054969 0.99942 1.394243 0.7471 678 0.004202967 0.99940 1.386384 0.7494 679 0.004249067 0.99930 1.398514 0.7464 680 0.004161104 0.99932 1.402154 0.7478 681 0.004315372 0.99934 1.401574 0.7471 682 0.004056135 0.99942 1.395551 0.7503 683 0.004266312 0.99934 1.378255 0.7537 684 0.004363045 0.99912 1.384869 0.7490 685 0.004116495 0.99944 1.398063 0.7503 686 0.004099772 0.99950 1.377018 0.7525 687 0.003875785 0.99944 1.376819 0.7491 688 0.004003887 0.99928 1.386256 0.7510 689 0.003963573 0.99942 1.384135 0.7505 690 0.003895512 0.99956 1.377205 0.7516 691 0.003829194 0.99942 1.380257 0.7509 692 0.003802414 0.99940 1.395562 0.7496 693 0.003770972 0.99946 1.373237 0.7514 694 0.003860362 0.99946 1.387800 0.7529 695 0.003928786 0.99940 1.404927 0.7452 696 0.003780503 0.99950 1.383660 0.7499 697 0.003906305 0.99950 1.397846 0.7510 698 0.003901723 0.99950 1.373569 0.7507 699 0.003651343 0.99944 1.371993 0.7515 700 0.003898600 0.99936 1.384891 0.7506 701 0.003819313 0.99948 1.386870 0.7492 702 0.004223661 0.99936 1.384581 0.7481 703 0.003855399 0.99950 1.384949 0.7520 704 0.003507904 0.99966 1.383653 0.7475 705 0.004018778 0.99936 1.366909 0.7540 706 0.003871137 0.99938 1.381888 0.7482 707 0.004155347 0.99932 1.400326 0.7461 708 0.004133416 0.99918 1.397508 0.7484 709 0.004068155 0.99936 1.373228 0.7492 710 0.003763632 0.99942 1.374093 0.7488 711 0.003687674 0.99944 1.393669 0.7473 712 0.003559326 0.99952 1.398147 0.7492 713 0.003739260 0.99956 1.370705 0.7518 714 0.003812938 0.99944 1.386667 0.7515 715 0.003516178 0.99950 1.368693 0.7487 716 0.003728548 0.99938 1.379789 0.7508 717 0.003725690 0.99960 1.379278 0.7489 718 0.003798536 0.99954 1.400110 0.7521 719 0.003879965 0.99942 1.403550 0.7457 720 0.004001467 0.99928 1.389689 0.7494 721 0.003742206 0.99946 1.374500 0.7531 722 0.003800071 0.99942 1.380778 0.7513 723 0.003963613 0.99938 1.396288 0.7499 724 0.003654225 0.99942 1.400507 0.7476 725 0.003773930 0.99944 1.391011 0.7496 726 0.003594571 0.99954 1.387267 0.7499 727 0.003741001 0.99944 1.392733 0.7506 728 0.003641232 0.99952 1.392243 0.7506 729 0.003903184 0.99936 1.396375 0.7469 730 0.003855030 0.99934 1.383650 0.7504 731 0.003584256 0.99948 1.380648 0.7477 732 0.003754498 0.99950 1.382176 0.7498 733 0.003618537 0.99954 1.392480 0.7500 734 0.003623980 0.99952 1.381816 0.7509 735 0.003822622 0.99948 1.391498 0.7515 736 0.003614904 0.99940 1.384733 0.7522 737 0.003754279 0.99932 1.382403 0.7496 738 0.003695032 0.99942 1.402883 0.7452 739 0.003441158 0.99960 1.405548 0.7480 740 0.003841226 0.99950 1.398224 0.7452 741 0.003648196 0.99958 1.398626 0.7470 742 0.003622203 0.99954 1.387685 0.7487 743 0.003837872 0.99944 1.394218 0.7487 744 0.003723893 0.99946 1.393869 0.7480 745 0.003575010 0.99956 1.379912 0.7528 746 0.003627359 0.99956 1.376933 0.7506 747 0.003672693 0.99946 1.383514 0.7488 748 0.003955462 0.99942 1.403958 0.7485 749 0.004004475 0.99946 1.391278 0.7477 750 0.003516122 0.99958 1.379871 0.7500 751 0.003871972 0.99934 1.377972 0.7504 752 0.003995358 0.99930 1.403870 0.7493 753 0.003779306 0.99952 1.381453 0.7508 754 0.003834113 0.99932 1.385096 0.7464 755 0.003847099 0.99940 1.380614 0.7519 756 0.003640553 0.99948 1.382856 0.7503 757 0.003545614 0.99950 1.375053 0.7482 758 0.004023343 0.99934 1.380218 0.7507 759 0.003765266 0.99944 1.377648 0.7493 760 0.003566456 0.99966 1.382731 0.7499 761 0.003774222 0.99950 1.373917 0.7497 762 0.003949451 0.99944 1.381418 0.7525 763 0.004019988 0.99944 1.387758 0.7480 764 0.003585106 0.99946 1.381206 0.7493 765 0.003709744 0.99946 1.382514 0.7488 766 0.003641587 0.99948 1.362586 0.7513 767 0.003836775 0.99940 1.391095 0.7490 768 0.003816722 0.99946 1.394251 0.7469 769 0.003510774 0.99968 1.387712 0.7499 770 0.003702947 0.99954 1.393254 0.7478 771 0.003654839 0.99948 1.371176 0.7450 772 0.003893751 0.99942 1.407273 0.7507 773 0.003593972 0.99950 1.385557 0.7469 774 0.003570807 0.99952 1.383940 0.7504 775 0.003660334 0.99956 1.360500 0.7490 776 0.003710129 0.99936 1.373093 0.7526 777 0.003749571 0.99940 1.394454 0.7482 778 0.003621595 0.99948 1.391579 0.7485 779 0.003916530 0.99946 1.404675 0.7462 780 0.003965732 0.99928 1.407158 0.7472 781 0.003507888 0.99954 1.362682 0.7557 782 0.003921149 0.99920 1.397395 0.7483 783 0.003888350 0.99944 1.388364 0.7484 784 0.004002109 0.99946 1.377400 0.7456 785 0.003597769 0.99958 1.386281 0.7507 786 0.003968587 0.99946 1.372886 0.7479 787 0.003746550 0.99942 1.389609 0.7485 788 0.003471009 0.99950 1.376045 0.7561 789 0.003515029 0.99950 1.391333 0.7492 790 0.003740238 0.99946 1.395389 0.7474 791 0.003662769 0.99952 1.381256 0.7506 792 0.003735510 0.99950 1.385208 0.7501 793 0.003777692 0.99936 1.377771 0.7509 794 0.003426865 0.99962 1.377166 0.7532 795 0.003850501 0.99950 1.381726 0.7488 796 0.003854274 0.99942 1.357744 0.7516 797 0.003896037 0.99942 1.376736 0.7502 798 0.003623167 0.99950 1.373237 0.7514 799 0.003336385 0.99958 1.380837 0.7527 800 0.003677030 0.99936 1.367090 0.7544 801 0.003757585 0.99938 1.381894 0.7539 802 0.003691996 0.99946 1.400380 0.7497 803 0.003496195 0.99952 1.386381 0.7531 804 0.003449135 0.99958 1.383694 0.7471 805 0.003695381 0.99956 1.388623 0.7473 806 0.003624609 0.99942 1.380503 0.7489 807 0.003855928 0.99934 1.399426 0.7477 808 0.003624713 0.99946 1.392185 0.7488 809 0.003710056 0.99944 1.386121 0.7504 810 0.003599518 0.99956 1.388235 0.7457 811 0.003652663 0.99936 1.416115 0.7488 812 0.003694006 0.99932 1.376989 0.7496 813 0.003279331 0.99958 1.370095 0.7499 814 0.003546879 0.99950 1.394469 0.7490 815 0.003811984 0.99938 1.390461 0.7494 816 0.003956810 0.99932 1.386067 0.7486 817 0.003529836 0.99950 1.389819 0.7472 818 0.003880393 0.99950 1.384268 0.7481 819 0.003608852 0.99948 1.376074 0.7542 820 0.003665391 0.99956 1.380117 0.7528 821 0.003703953 0.99942 1.388243 0.7485 822 0.003461364 0.99946 1.401620 0.7469 823 0.003734413 0.99936 1.391865 0.7481 824 0.003790044 0.99940 1.367412 0.7498 825 0.003895878 0.99934 1.378636 0.7477 826 0.003676614 0.99948 1.394193 0.7468 827 0.003633945 0.99954 1.376814 0.7473 828 0.003587150 0.99938 1.366035 0.7526 829 0.003597163 0.99948 1.380416 0.7495 830 0.003318495 0.99962 1.413152 0.7470 831 0.003712012 0.99942 1.371905 0.7471 832 0.003428303 0.99972 1.393037 0.7473 833 0.003723841 0.99938 1.385642 0.7472 834 0.003771480 0.99944 1.398134 0.7466 835 0.003604361 0.99958 1.384968 0.7474 836 0.003694733 0.99952 1.357234 0.7533 837 0.003581399 0.99950 1.394210 0.7483 838 0.003818823 0.99928 1.382098 0.7482 839 0.003807909 0.99934 1.392830 0.7486 840 0.003629315 0.99942 1.379845 0.7494 841 0.003871643 0.99932 1.381863 0.7483 842 0.003787301 0.99952 1.388470 0.7514 843 0.003391408 0.99948 1.392044 0.7521 844 0.003425142 0.99972 1.382607 0.7468 845 0.003776078 0.99942 1.385040 0.7475 846 0.003456869 0.99960 1.382704 0.7475 847 0.003438824 0.99948 1.400075 0.7479 848 0.003668452 0.99948 1.380061 0.7503 849 0.003408735 0.99968 1.403111 0.7484 850 0.003459629 0.99960 1.367557 0.7461 851 0.003659104 0.99944 1.394181 0.7476 852 0.003756752 0.99940 1.380636 0.7505 853 0.003638913 0.99954 1.387339 0.7490 854 0.003841528 0.99934 1.379344 0.7491 855 0.003615711 0.99946 1.382432 0.7498 856 0.003638959 0.99940 1.385617 0.7489 857 0.003685101 0.99940 1.414099 0.7474 858 0.003408466 0.99950 1.380612 0.7508 859 0.003256983 0.99956 1.379402 0.7494 860 0.003420987 0.99956 1.393150 0.7504 861 0.003469174 0.99958 1.391155 0.7528 862 0.003435688 0.99946 1.370455 0.7478 863 0.003549380 0.99956 1.365100 0.7502 864 0.003598281 0.99946 1.376953 0.7494 865 0.003442684 0.99962 1.394220 0.7496 866 0.003403823 0.99956 1.405198 0.7521 867 0.003663774 0.99946 1.386756 0.7482 868 0.003658666 0.99938 1.383487 0.7474 869 0.003466758 0.99950 1.378211 0.7530 870 0.003663639 0.99936 1.400924 0.7450 871 0.003621953 0.99948 1.393332 0.7487 872 0.003522192 0.99950 1.393829 0.7473
ResNet¶
The network inputs are $32 \times 32$ images, with the per-pixel mean subtracted. The first layer is $3 \times 3$ convolutions. Then we use a stack of $6 n$ layers with $3 \times 3$ convolutions on the feature maps of sizes $\{32,16,8\}$ respectively, with $2 n$ layers for each feature map size. The numbers of filters are $\{16,32,64\}$ respectively. The subsampling is performed by convolutions with a stride of 2 . The network ends with a global average pooling, a 10-way fully-connected layer, and softmax. There are totally $6 n+2$ stacked weighted layers. The following table summarizes the architecture:
output map size | $32 \times 32$ | $16 \times 16$ | $8 \times 8$ |
---|---|---|---|
# layers | $1+2 n$ | $2 n$ | $2 n$ |
# filters | 16 | 32 | 64 |
When shortcut connections are used, they are connected to the pairs of $3 \times 3$ layers (totally $3 n$ shortcuts). On this dataset we use identity shortcuts in all cases (i.e., option A), so our residual models have exactly the same depth, width, and number of parameters as the plain counterparts.
We use a weight decay of 0.0001 and momentum of 0.9 , and adopt the weight initialization in [13] and $\mathrm{BN}$ [16] but with no dropout. These models are trained with a minibatch size of 128 on two GPUs. We start with a learning rate of 0.1 , divide it by 10 at $32 \mathrm{k}$ and $48 \mathrm{k}$ iterations, and terminate training at $64 \mathrm{k}$ iterations, which is determined on a $45 \mathrm{k} / 5 \mathrm{k}$ train/val split. We follow the simple data augmentation in [24] for training: 4 pixels are padded on each side, and a $32 \times 32$ crop is randomly sampled from the padded image or its horizontal flip. For testing, we only evaluate the single view of the original $32 \times 32$ image.
Our implementation for ImageNet follows the practice in $[21,41]$. The image is resized with its shorter side randomly sampled in $[256,480]$ for scale augmentation [41]. A $224 \times 224$ crop is randomly sampled from an image or its horizontal flip, with the per-pixel mean subtracted [21]. The standard color augmentation in [21] is used. We adopt batch normalization (BN) [16] right after each convolution and before activation, following [16]. We initialize the weights as in [13] and train all plain/residual nets from scratch. We use SGD with a mini-batch size of 256. The learning rate starts from 0.1 and is divided by 10 when the error plateaus, and the models are trained for up to $60 \times 10^4$ iterations. We use a weight decay of 0.0001 and a momentum of 0.9 . We do not use dropout [14], following the practice in [16]. In testing, for comparison studies we adopt the standard 10-crop testing [21]. For best results, we adopt the fullyconvolutional form as in $[41,13]$, and average the scores at multiple scales (images are resized such that the shorter side is in $\{224,256,384,480,640\}$ ). .... See Table 1 for detailed architectures. ....
He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).
myconv2d = function(x, filters, strides){
if(strides[1]>1)
x %>%
layer_zero_padding_2d(padding=list(list(1L, 0L),list(1L, 0L))) %>%
layer_conv_2d(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'valid', use_bias = FALSE)
else
x %>%
layer_conv_2d(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE)
}
# there are 2 conv2d-bn-relu layers in the bottleneck,
# the 1st layer has a specified `strides`, and the other layer has `strides=1` - so it is pixel-wise dnn)
# there is a `filters` parameters to this bottleneck, specifying thenumber of filters in the 2 layers
# input is added to the output, os if input and output has different size, conv-bn-relu layer is added to macth input and output
conv_block = function(x, filters, strides = 1L){
#print('x')
#print(x)
out = x %>%
layer_conv_2d(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>%
#myconv2d(filters = filters, strides = strides) %>% # conv1
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn1, by default channel is last dimension
layer_activation_relu() %>%
layer_conv_2d(filters = filters, strides = 1L, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>% # conv2
layer_batch_normalization(momentum = .9, epsilon = 1e-5) # bn2
#print('out')
#print(out)
if(dim(x)[4]==filters & strides==1L) shortcut = x # channel dimension 4 for channel last and 2 for channel first
else shortcut = x %>%
layer_conv_2d(filters = filters, strides = strides, kernel_size = 1L, use_bias = FALSE, name = sprintf('shortcut/conv2d_%d',keras$backend$get_uid(prefix='s/c'))) %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5, name=sprintf('shortcut/bn_%d',keras$backend$get_uid(prefix='s/b')))
#print('shortcut')
#print(shortcut)
(out + shortcut) %>% layer_activation_relu()
}
# there are 4 resnet-layers, with number of filters = 64, 128, 256, and 512
# the strides of the 4 resnet-layers are: 1, 2, 2, and 2, so from 2nd resnet-layer image width/depth decrease by 2
# each resnet-layer has a number of blocks/bottlenecks, and that leads to different total depth of resnet
# it starts from a conv2d-bn-relu layer - that outputs the same number of features (64) as resnet-layer 1, uses 3x3 kernel, keeps width and height (padding='same', stride=1)
# and ends with flatten-dense
conv_bottleneck = function(x, filters, strides = 1L){
expansion = 4L
#print('x')
#print(x)
out = x %>%
layer_conv_2d(filters = filters, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn1
layer_activation_relu() %>%
layer_conv_2d(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>% # conv2
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn2
layer_activation_relu() %>%
layer_conv_2d(filters = expansion * filters, kernel_size = c(1L,1L), padding = 'same', use_bias = FALSE) %>% # conv3
layer_batch_normalization(momentum = .9, epsilon = 1e-5) # bn3
#print('out')
#print(out)
if(dim(x)[4]== expansion * filters & strides==1L) shortcut = x # channel dimension 4 for channel last and 2 for channel first
else shortcut = x %>%
layer_conv_2d(filters = expansion * filters, strides = strides, kernel_size = 1L, use_bias = FALSE, name = sprintf('shortcut/conv2d_%d',keras$backend$get_uid(prefix='s/c'))) %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5, name=sprintf('shortcut/bn_%d',keras$backend$get_uid(prefix='s/b')))
#print('shortcut')
#print(shortcut)
(out + shortcut) %>% layer_activation_relu()
}
conv_layer = function(x, block, filters, num_blocks, strides){
layers = x
print(sprintf('new conv_layer:filters = %d, num_blocks = %d, strides = %d', filters, num_blocks, strides))
for(i in 1:num_blocks) {
layers = layers %>% block(., filters, strides)
strides = 1 # only shrink image once at the first block, if strides > 1
}
layers
}
inputs = layer_input(shape = dim(cifar10$train$x)[-1])
predictions = inputs %>%
layer_conv_2d(filters = 16, strides = 1L, kernel_size = 3L, padding = 'same', use_bias = FALSE) %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_relu %>%
conv_layer(., block = conv_block, filters = 16, num_blocks = 18, strides = 1) %>% # layer 1
conv_layer(., block = conv_block, filters = 32, num_blocks = 18, strides = 2) %>% # layer 2
conv_layer(., block = conv_block, filters = 64, num_blocks = 18, strides = 2) %>% # layer 3
layer_average_pooling_2d(pool_size = 8L) %>%
layer_flatten() %>%
layer_dense(activation = 'softmax', units = 10L)
resnet.cifar10 = keras_model(inputs = inputs, outputs = predictions)
[1] "new conv_layer:filters = 16, num_blocks = 18, strides = 1" [1] "new conv_layer:filters = 32, num_blocks = 18, strides = 2" [1] "new conv_layer:filters = 64, num_blocks = 18, strides = 2"
tf$keras$utils$plot_model(resnet.cifar10, to_file='ResNet_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'ResNet_CIFAR10.png')
EfficientNet B0¶
The implementation follows keras reference implementation. We scaled input to $(224, 224)$.
inputs = layer_input(shape = dim(cifar10$train$x)[-1])
efficientNetB0.cifar10 = tf$keras$applications$EfficientNetB0(input_tensor = inputs %>%
layer_lambda(function(img) tf$image$resize(img, c(224L, 224L))), include_top=TRUE, weights=NULL, classes=10L)
efficientNetB0.cifar10$compile(loss='categorical_crossentropy',optimizer= 'adam', metrics=list('accuracy'))
mc = callback_model_checkpoint('efficientNetB0_CIFAR10.ckpt', monitor = 'val_accuracy', save_best_only = TRUE, mode = 'auto', save_weights_only = TRUE)
tf$keras$utils$plot_model(efficientNetB0.cifar10, to_file='efficientNetB0_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'efficientNetB0_CIFAR10.png')
Tan, M., & Le, Q. (2019, May). Efficientnet: Rethinking model scaling for convolutional neural networks. In International conference on machine learning (pp. 6105-6114). PMLR.
c_k_i = initializer_variance_scaling(scale=2.0, mode='fan_out')
convBnAct = function(x, filters, kernel_size=3L, stride=1L, padding='valid', bn=TRUE, act='swish', use_bias=FALSE, groups=1L){
#if(stride==2L & kernel_size == 3L) x = x %>%
# layer_zero_padding_2d(padding=list(list(0L, 1L),list(0L, 1L))) %>%
# layer_conv_2d(filters=filters, kernel_size=kernel_size, stride=stride, padding='valid', use_bias=use_bias, groups = groups, kernel_initializer = c_k_i)
#else
x = layer_conv_2d(x, filters=filters, kernel_size=kernel_size, stride=stride, padding=padding, use_bias=use_bias, groups = groups, kernel_initializer = c_k_i)
if(bn) x = layer_batch_normalization(x, momentum = .9, epsilon = 1e-5)
if(!is.null(act)) x %>% layer_activation(activation = act) else x
}
squeezeExcitation = function(x, filters){
y = x %>%
layer_global_average_pooling_2d(keepdims=TRUE) %>% # is this adaptive average pool 2d with output dim 1x2?
layer_conv_2d(filters=filters, kernel_size=1L, activation='swish',kernel_initializer = c_k_i) %>%
layer_conv_2d(filters=dim(x)[4], kernel_size=1L, activation='sigmoid',kernel_initializer = c_k_i)
x * y
}
stochasticDepth = function(x, survival_prob = .8){
tf$where(tf$random$uniform(list()) < survival_prob, x / survival_prob, 0)
}
#Mobile-net conv Block with expansion factor N
MBConvN = function(x, filters, kernel_size=3L, stride=1L, expansion_factor = 6L, reduction = 4L, survival_prob = .8){
filters_in = dim(x)[4]
residual = x
intermediate_channels = filters_in * expansion_factor
if(expansion_factor!=1L) x = convBnAct(x, filters = intermediate_channels, kernel_size=1L)
x = x %>%
convBnAct(filters = intermediate_channels, kernel_size = kernel_size, stride = stride, padding = 'same', groups = intermediate_channels) %>%
squeezeExcitation(filters = floor(filters_in/reduction)) %>%
convBnAct(filters = filters, kernel_size = 1L, act=NULL, use_bias=FALSE)
if(stride==1L & filters_in==filters){
residual + layer_dropout(x, rate = .2, noise_shape=c(py_none(), 1L, 1L, 1L))
} else x
}
featureExtractor = function(x, width_mult, depth_mult, last_channel){
kernels = c(3, 3, 5, 3, 5, 5, 3)
expansions = c(1, 6, 6, 6, 6, 6, 6)
scaled_num_channels = 4 * ceiling( c(16, 24, 40, 80, 112, 192, 320) * width_mult/4 )
scaled_num_layers = c(1, 2, 2, 3, 3, 4, 1) * depth_mult
strides = c(1, 2, 2, 2, 1, 2, 1)
x = x %>% convBnAct(filters = 4 * ceiling(32 * width_mult/4), kernel_size = 3L, stride = 2L, padding = 'same')
for(i in 1:length(scaled_num_layers)){
for(j in 1:scaled_num_layers[i]){
x = x %>% MBConvN(filters = scaled_num_channels[i], kernel_size = kernels[i], stride = if(j==1) strides[i] else 1L, expansion_factor = expansions[i])
}
}
x %>% convBnAct(filters = last_channel, kernel_size = 1L, stride = 1L, padding='valid')
}
inputs = layer_input(shape = dim(cifar10$train$x)[-1])
predictions = inputs %>%
featureExtractor(width_mult = 1, depth_mult = 1, last_channel = 1280) %>%
layer_global_average_pooling_2d() %>%
layer_dropout(.2) %>%
layer_dense(units = 10L, activation='softmax')
efficientNetB0.cifar10 = keras_model(inputs = inputs, outputs = predictions)
#efficientNetB0.cifar10
tf$keras$utils$plot_model(efficientNetB0.cifar10, to_file='efficientNetB0_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'efficientNetB0_CIFAR10.png')
WideResNet-28-10¶
Compared to the original architecture $[11]$ in $[13]$ the order of batch normalization, activation and convolution in residual block was changed from conv-BN-ReLU to BN-ReLU-conv. As the latter was shown to train faster and achieve better results we don't consider the original version.
The general structure of our residual networks is illustrated in table 1: it consists of an initial convolutional layer conv1 that is followed by 3 groups (each of size $N$ ) of residual blocks conv2, conv3 and conv4, followed by average pooling and final classification layer. The size of conv1 is fixed in all of our experiments, while the introduced widening factor $k$ scales the width of the residual blocks in the three groups conv2-4 (e.g., the original «basic» architecture is equivalent to $k=1$ ).
group name | output size | block type $=B(3,3)$ |
---|---|---|
conv1 | $32 \times 32$ | $[3 \times 3,16]$ |
conv2 | $32 \times 32$ | $\left[\begin{array}{c}3 \times 3,16 \times \mathrm{k} \\3 \times 3,16 \times \mathrm{k}\end{array}\right] \times \mathrm{N}$ |
conv3 | $16 \times 16$ | $\left[\begin{array}{c}3 \times 3,32 \times \mathrm{k} \\3 \times 3,32 \times \mathrm{k}\end{array}\right] \times \mathrm{N}$ |
conv4 | $8 \times 8$ | $\left[\begin{array}{c}3 \times 3,64 \times \mathrm{k} \\3 \times 3,64 \times \mathrm{k}\end{array}\right] \times \mathrm{N}$ |
avg-pool | $1 \times 1$ | $[8 \times 8]$ |
As widening increases the number of parameters we would like to study ways of regularization. Residual networks already have batch normalization that provides a regularization effect, however it requires heavy data augmentation, which we would like to avoid, and it's not always possible. We add a dropout layer into each residual block between convolutions as shown in fig. 1(d) and after ReLU to perturb batch normalization in the next residual block and prevent it from overfitting. In very deep residual networks that should help deal with diminishing feature reuse problem enforcing learning in different residual blocks. ... We trained networks with dropout inserted into residual block between convolutions on all datasets. We used cross-validation to determine dropout probability values, 0.3 on CIFAR and 0.4 on SVHN. Also, we didn't have to increase number of training epochs compared to baseline networks without dropout.
In all our experiments we use SGD with Nesterov momentum and cross-entropy loss. The initial learning rate is set to 0.1 , weight decay to 0.0005 , dampening to 0 , momentum to 0.9 and minibatch size to 128. On CIFAR learning rate dropped by 0.2 at 60, 120 and 160 epochs and we train for total 200 epochs. On SVHN initial learning rate is set to 0.01 and we drop it at 80 and 120 epochs by 0.1 , training for total 160 epochs. Our implementation is based on Torch [6]. We use [21] to reduce memory footprints of all our networks. For ImageNet experiments we used fb. resnet. torch implementation [10]. Our code and models are available at https://github.com/szagoruyko/wide-residual-networks.
some implementations at github by the authors and others
- https://github.com/szagoruyko/wide-residual-networks
- https://github.com/titu1994/Wide-Residual-Networks/
- https://github.com/paradoxysm/wideresnet
Zagoruyko, S., & Komodakis, N. (2016, January). Wide Residual Networks. In British Machine Vision Conference 2016. British Machine Vision Association.
conv_block = function(x, base, k, dropout = 0, strides = c(1L, 1L), shortcut = FALSE){
h = x %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_relu()
out = h %>%
layer_conv_2d(base*k, c(3L,3L), padding = 'same', strides=strides, kernel_initializer='he_normal', use_bias = FALSE) %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_relu() %>%
layer_dropout(rate = dropout) %>% # conv1
layer_conv_2d(base*k, c(3L,3L), padding = 'same', kernel_initializer='he_normal', use_bias = FALSE)
if(shortcut) skip = h %>%
layer_conv_2d(base*k, c(1L, 1L), strides = strides, kernel_initializer='he_normal', use_bias = FALSE)
else skip = x
layer_add(list(out, skip))
}
inputs = layer_input(shape = dim(cifar10$train$x)[-1])
k = 10
predictions = inputs %>%
layer_conv_2d(filters = 16, strides = 1L, kernel_size = 3L, padding = 'same', use_bias = FALSE) %>%
conv_block(., 16, k, shortcut=TRUE) %>% conv_block(., 16, k, dropout=.4) %>% conv_block(., 16, k, dropout=.3) %>%
conv_block(., 32, k, strides = c(2L, 2L), shortcut=TRUE) %>% conv_block(., 32, k, dropout=.4) %>% conv_block(., 32, k, dropout=.4) %>%
conv_block(., 64, k, strides = c(2L, 2L), shortcut=TRUE) %>% conv_block(., 64, k, dropout=.4) %>% conv_block(., 64, k, dropout=.4) %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% layer_activation_relu() %>% layer_average_pooling_2d(pool_size = 8L) %>% layer_flatten() %>%
layer_dense(activation = 'softmax', units = 10L)
wrn_28_10.cifar10 = keras_model(inputs = inputs, outputs = predictions)
tf$keras$utils$plot_model(wrn_28_10.cifar10, to_file='WRN_28_10_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'WRN_28_10_CIFAR10.png')
MLP-Mixer¶
Tolstikhin, Ilya O., Neil Houlsby, Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Thomas Unterthiner, Jessica Yung et al. "Mlp-mixer: An all-mlp architecture for vision." Advances in neural information processing systems 34 (2021): 24261-24272.
layer_patch = Layer(
classname = 'Patch',
inherit = tf$keras$layers$Layer,
initialize = function(patch_size, projection_dim, ...) {
super$initialize()
#print('initialize')
self$patch_size = patch_size
self$projection_dim = projection_dim
},
build = function(input_shape) {
# assert: image is square, patch is square
#print('build')
self$num_patches = as.integer(floor(input_shape[[2]]/self$patch_size)^2)
self$projection = tf$keras$layers$Dense(units = self$projection_dim)
self$position_embedding = tf$keras$layers$Embedding(input_dim = self$num_patches, output_dim = self$projection_dim)
self$built=TRUE
#print('end of build')
return()
},
call = function(images) {
#print('call')
patches = tf$image$extract_patches(
images = images,
sizes = c(1L, self$patch_size, self$patch_size, 1L),
strides = c(1L, self$patch_size, self$patch_size, 1L),
rates = c(1L, 1L, 1L, 1L),
padding = 'VALID'
)
patches = tf$reshape(
patches,
shape=c(tf$shape(images)[1L], -1L, tf$reduce_prod(tf$shape(patches)[4L]))
)
positions = tf$range(start=0, limit = self$num_patches, delta = 1)
self$projection(patches) + self$position_embedding(positions)
}
)
mlp_mixer_block = function(x){
num_patches = dim(x)[2]
projection_dim = dim(x)[3]
mlp1 = x %>%
layer_layer_normalization(epsilon = 1e-6) %>%
tf$linalg$matrix_transpose(.) %>%
layer_dense(units = num_patches, activation='gelu') %>%
layer_dense(units = num_patches) %>%
layer_dropout(rate = .2) %>%
tf$linalg$matrix_transpose(.)
x = layer_add(list(x, mlp1))
mlp2 = x %>%
layer_layer_normalization(epsilon = 1e-6) %>%
layer_dense(units = projection_dim, activation='gelu') %>%
layer_dense(units = projection_dim) %>%
layer_dropout(rate = .2)
layer_add(list(x, mlp2))
}
inputs = layer_input(shape = dim(cifar10$train$x)[-1])
predictions = inputs %>%
layer_resizing(64L, 64L) %>%
layer_patch(patch_size = 8L, projection_dim = 256L) %>%
mlp_mixer_block %>%
mlp_mixer_block %>%
mlp_mixer_block %>%
mlp_mixer_block %>%
layer_global_average_pooling_1d %>%
layer_dropout(rate = .2) %>%
layer_dense(units=10L, activation = 'softmax')
mlp_mixer.cifar10 = keras_model(inputs = inputs, outputs = predictions)
mlp_mixer.cifar10 %>% compile(loss='categorical_crossentropy',optimizer= 'adam',metrics=list('accuracy'))
tf$keras$utils$plot_model(mlp_mixer.cifar10, to_file='MLPMixer_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'MLPMixer_CIFAR10.png')
Bayesian NN architectures for Message Passing¶
Bayesian DenseNet-BC¶
my_divergence_fn = function(q, p, ignore=NULL){
tfp$distributions$kl_divergence(q, p) / 50000 /500
}
conv_bottleneckEP = function(x, growth_rate){
conv = x %>%
#layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn1
layer_activation_Bayesian('relu') %>%
layer_conv_2d_Bayesian(filters = 4 * growth_rate, kernel_size = c(1L,1L), use_bias = FALSE, activation=NULL) %>% # conv1
#layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn2
layer_activation_Bayesian('relu') %>%
layer_conv_2d_Bayesian(filters = growth_rate, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE, activation=NULL) # conv2
# channel dimension 4 for channel last and 2 for channel first, keras default is channel last
layer_lambda(list(x, conv), function(x_conv){
c(x, conv) %<-% x_conv
c(c(conv[[1]], x[[1]]) %>% layer_concatenate(axis=-1L),
c(conv[[2]], x[[2]]) %>% layer_concatenate(axis=-1L) )
})
}
# output width and height are half of the input, output channels = reduction * (input channels + growth rate * number of blocks)
dense_layerEP = function(x, block, growth_rate, num_blocks){
layers = x
for(i in 1:num_blocks) layers = layers %>% block(., growth_rate)
print(sprintf('dense_layer: growth_rate = %d, num_blocks = %d', growth_rate, num_blocks))
print(layers)
layers
}
transitionEP = function(x, reduction = .5){
x %>%
#layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_Bayesian('relu') %>%
layer_conv_2d_Bayesian(filters = floor(reduction*dim(x[[1]])[4]), kernel_size = c(1L,1L), use_bias = FALSE, activation=NULL) %>%
layer_average_pooling_2d_Bayesian(pool_size = 2L)
}
inputsEP = list(layer_input(shape = dim(cifar10$train$x)[-1]),layer_input(shape = dim(cifar10$train$x)[-1]))
g.r. = 12L
predictionsEP = inputsEP %>%
layer_conv_2d_Bayesian(filters = 2*g.r., kernel_size = 3L, padding = 'same', use_bias = FALSE, activation = NULL) %>%
dense_layerEP(., block = conv_bottleneckEP, growth_rate = g.r., num_blocks = 16) %>% transitionEP(.) %>% # layer 1
dense_layerEP(., block = conv_bottleneckEP, growth_rate = g.r., num_blocks = 16) %>% transitionEP(.) %>% # layer 2
dense_layerEP(., block = conv_bottleneckEP, growth_rate = g.r., num_blocks = 16) %>% transitionEP(.) %>% # layer 3
#layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_Bayesian('relu') %>%
layer_average_pooling_2d_Bayesian(pool_size = 4L) %>%
layer_lambda(function(xP){list(tf$reshape(xP[[1]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))), tf$reshape(xP[[2]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))))}) %>% #layer_flatten() %>%
layer_dense_Bayesian(activation = NULL, units = 10L) %>%
layer_lambda(f = function(mv){
tf$reduce_mean(tf$nn$softmax(tfp$distributions$MultivariateNormalDiag(loc = mv[[1]], scale_diag=mv[[2]]^.5)$sample(50L), axis=-1L), axis=0L)
})
densenet.cifar10EP = keras_model(inputs = inputsEP, outputs = predictionsEP)
[1] "dense_layer: growth_rate = 12, num_blocks = 16" [[1]] KerasTensor(type_spec=TensorSpec(shape=(None, 32, 32, 216), dtype=tf.float32, name=None), name='lambda_92/concatenate/concat:0', description="created by layer 'lambda_92'") [[2]] KerasTensor(type_spec=TensorSpec(shape=(None, 32, 32, 216), dtype=tf.float32, name=None), name='lambda_92/concatenate_1/concat:0', description="created by layer 'lambda_92'") [1] "dense_layer: growth_rate = 12, num_blocks = 16" [[1]] KerasTensor(type_spec=TensorSpec(shape=(None, 16, 16, 300), dtype=tf.float32, name=None), name='lambda_108/concatenate/concat:0', description="created by layer 'lambda_108'") [[2]] KerasTensor(type_spec=TensorSpec(shape=(None, 16, 16, 300), dtype=tf.float32, name=None), name='lambda_108/concatenate_1/concat:0', description="created by layer 'lambda_108'") [1] "dense_layer: growth_rate = 12, num_blocks = 16" [[1]] KerasTensor(type_spec=TensorSpec(shape=(None, 8, 8, 342), dtype=tf.float32, name=None), name='lambda_124/concatenate/concat:0', description="created by layer 'lambda_124'") [[2]] KerasTensor(type_spec=TensorSpec(shape=(None, 8, 8, 342), dtype=tf.float32, name=None), name='lambda_124/concatenate_1/concat:0', description="created by layer 'lambda_124'")
tf$keras$utils$plot_model(densenet.cifar10EP, to_file='Bayesian_DenseNetBC_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'Bayesian_DenseNetBC_CIFAR10.png')
densenet.cifar10EP = keras_model(inputs = inputsEP, outputs = predictionsEP)
lr = tf$keras$optimizers$schedules$CosineDecayRestarts(initial_learning_rate = 1, first_decay_steps = 50000/100*100)
opt = tf$keras$mixed_precision$LossScaleOptimizer(tf$keras$optimizers$experimental$SGD(learning_rate = 1, momentum=1-100/50000, nesterov=TRUE, global_clipnorm = 20))
#opt = tf$keras$optimizers$experimental$SGD(learning_rate = lr, momentum= 1-200/50000, nesterov=TRUE, global_clipnorm = 20) # if use `tf$keras$mixed_precision$set_global_policy('mixed_float16')`
densenet.cifar10EP$compile(loss='categorical_crossentropy',optimizer= opt, metrics=list('accuracy'))
mc = callback_model_checkpoint('Bayesian_DenseNet_CIFAR10-{epoch:03d}_{val_loss:.2f}_{val_accuracy:.4f}.ckpt', monitor = 'val_accuracy', save_best_only = FALSE, mode = 'auto', save_weights_only = TRUE, save_freq = 'epoch')
historyEP = densenet.cifar10EP %>% fit(
cifar10.datagen$flow(list(cifar10$train$x, cifar10$train$x*0 + 1e-6), to_categorical(cifar10$train$y, num_classes=10L), batch_size = 100L, shuffle=TRUE),
epochs = 100 , #* (1 + 2 +4),
validation_data = cifar10.datagen$flow(list(cifar10$test$x, cifar10$test$x*0 + 1e-6), to_categorical(cifar10$test$y, num_classes=10L), batch_size = 200L),
callbacks = list(mc, callback_terminate_on_naan(), callback_csv_logger(filename = 'Bayesian_DenseNet_CIFAR10.txt')))
print(data.frame(historyEP$metrics))
plot(historyEP)
loss accuracy val_loss val_accuracy 1 1.507544398 0.44768 1.2507046 0.5612 2 0.991522789 0.64886 0.9443011 0.6739 3 0.740942478 0.74084 0.7780184 0.7346 4 0.612260640 0.78602 0.6886450 0.7636 5 0.536301196 0.81394 0.6576619 0.7827 6 0.482742935 0.83092 0.5292346 0.8158 7 0.436886132 0.84730 0.5240674 0.8228 8 0.400461048 0.86060 0.5130056 0.8269 9 0.376282543 0.86894 0.4926983 0.8323 10 0.350261360 0.87790 0.4437283 0.8458 11 0.325476199 0.88632 0.4471522 0.8512 12 0.308074325 0.89242 0.3922808 0.8646 13 0.289072394 0.89988 0.4368181 0.8588 14 0.268653154 0.90690 0.4473700 0.8478 15 0.264465302 0.90738 0.3601842 0.8793 16 0.241475150 0.91528 0.4337616 0.8576 17 0.240748286 0.91622 0.3829522 0.8738 18 0.226144493 0.91936 0.3759227 0.8750 19 0.211735114 0.92444 0.3631086 0.8839 20 0.203297466 0.92838 0.3736733 0.8778 21 0.195500612 0.93112 0.3687617 0.8828 22 0.188374653 0.93408 0.3488044 0.8862 23 0.176909372 0.93760 0.3591793 0.8847 24 0.165337682 0.94158 0.3468764 0.8890 25 0.163461506 0.94152 0.4085495 0.8701 26 0.154732496 0.94490 0.3538958 0.8904 27 0.146540150 0.94788 0.3708380 0.8867 28 0.137216091 0.95126 0.3410468 0.8954 29 0.136211067 0.95266 0.3348385 0.8994 30 0.130257204 0.95428 0.3426736 0.8979 31 0.123988517 0.95580 0.3463230 0.8968 32 0.124184981 0.95624 0.3239390 0.8974 33 0.112449415 0.96014 0.3321945 0.8988 34 0.106419988 0.96166 0.3460067 0.8952 35 0.102975868 0.96316 0.3502261 0.8966 36 0.096118785 0.96630 0.3444697 0.9002 37 0.096012004 0.96578 0.3461250 0.9011 38 0.092155643 0.96708 0.3436781 0.9029 39 0.089766242 0.96746 0.3207698 0.9059 40 0.080002733 0.97040 0.3379705 0.9060 41 0.077242859 0.97284 0.3454384 0.9043 42 0.075790472 0.97320 0.3285445 0.9069 43 0.072758012 0.97384 0.3520682 0.9058 44 0.066192716 0.97618 0.3450360 0.9033 45 0.066513471 0.97668 0.3412762 0.9071 46 0.062482961 0.97754 0.3346150 0.9068 47 0.056999233 0.97952 0.3461328 0.9122 48 0.053404722 0.98166 0.3443820 0.9072 49 0.054178670 0.98082 0.3488195 0.9090 50 0.048612613 0.98312 0.3341542 0.9093 51 0.045235343 0.98398 0.3151713 0.9159 52 0.042834073 0.98466 0.3431740 0.9116 53 0.040008824 0.98602 0.3264747 0.9134 54 0.037579529 0.98694 0.3391601 0.9130 55 0.037912995 0.98660 0.3395627 0.9142 56 0.033098727 0.98802 0.3570186 0.9143 57 0.033382103 0.98832 0.3471235 0.9137 58 0.031156521 0.98944 0.3390150 0.9171 59 0.027277537 0.99058 0.3605517 0.9136 60 0.024240302 0.99174 0.3551328 0.9167 61 0.025279714 0.99062 0.3415895 0.9172 62 0.021866318 0.99310 0.3632956 0.9138 63 0.017950036 0.99420 0.3465253 0.9197 64 0.016946517 0.99478 0.3460971 0.9181 65 0.017682793 0.99402 0.3401641 0.9201 66 0.016459063 0.99428 0.3403973 0.9194 67 0.015079840 0.99514 0.3495630 0.9186 68 0.013104274 0.99582 0.3509070 0.9170 69 0.013004387 0.99574 0.3425159 0.9215 70 0.012999629 0.99578 0.3469994 0.9202 71 0.012204343 0.99622 0.3681124 0.9188 72 0.010561449 0.99686 0.3632663 0.9216 73 0.010643934 0.99692 0.3487050 0.9224 74 0.009710396 0.99686 0.3552284 0.9206 75 0.009160192 0.99718 0.3517626 0.9245 76 0.008669264 0.99732 0.3502497 0.9210 77 0.008502587 0.99752 0.3495335 0.9249 78 0.007657965 0.99748 0.3379188 0.9254 79 0.006842029 0.99808 0.3321439 0.9261 80 0.007568012 0.99776 0.3456600 0.9210 81 0.006723642 0.99810 0.3463190 0.9230 82 0.006581388 0.99806 0.3491128 0.9238 83 0.006428984 0.99808 0.3509857 0.9234 84 0.006708242 0.99788 0.3510391 0.9259 85 0.006418953 0.99806 0.3576565 0.9235 86 0.005776848 0.99836 0.3454177 0.9255 87 0.005551881 0.99860 0.3389569 0.9263 88 0.005839402 0.99826 0.3447225 0.9235 89 0.005101785 0.99872 0.3424715 0.9250 90 0.005926417 0.99826 0.3483006 0.9235 91 0.005643965 0.99836 0.3434032 0.9282 92 0.004899295 0.99872 0.3490171 0.9240 93 0.005357474 0.99846 0.3457305 0.9259 94 0.004880808 0.99860 0.3442807 0.9244 95 0.005065570 0.99850 0.3454005 0.9250 96 0.005079482 0.99864 0.3439325 0.9253 97 0.005028059 0.99868 0.3427972 0.9254 98 0.004800315 0.99864 0.3581291 0.9262 99 0.005389702 0.99838 0.3636788 0.9205 100 0.005091061 0.99848 0.3469904 0.9242
# cosine decay restarts schedule
lr = tf$keras$optimizers$schedules$CosineDecayRestarts(initial_learning_rate = 1, first_decay_steps = 50000/100*100)
plot( (1:(50000/100*100*(1+2+4)))/(50000/100), lr(1:(50000/100*100*(1+2+4))), type='l', xlab='epoch', ylab='lr schedule', main='cosine decay annealing with warm restart')
abline(v=100*cumsum(c(1,2,4)), h = 0, col='red')
Bayesian ResNet¶
conv_blockEP = function(xP, filters, strides = 1L){
#print('x')
#print(xP)
out = xP %>%
layer_conv_2d_Bayesian(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>%
#layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn1, by default channel is last dimension
layer_activation_Bayesian('relu') %>%
layer_conv_2d_Bayesian(filters = filters, strides = 1L, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) #%>% # conv2
#layer_batch_normalization(momentum = .9, epsilon = 1e-5) # bn2
#print('out')
#print(out)
if(dim(xP[[1]])[4]== filters & strides==1L) shortcut = xP # channel dimension 4 for channel last and 2 for channel first
else shortcut = xP %>%
layer_conv_2d_Bayesian(filters = filters, strides = rep(strides,2), kernel_size = 1L, use_bias = FALSE, name = sprintf('shortcut/conv2d_%d',keras$backend$get_uid(prefix='s/c'))) #%>%
#layer_batch_normalization(momentum = .9, epsilon = 1e-5, name=sprintf('shortcut/bn_%d',keras$backend$get_uid(prefix='s/b')))
#print('shortcut')
#print(shortcut)
layer_lambda(list(out, shortcut), function(out_shortcut){
c(out, shortcut) %<-% out_shortcut
list(out[[1]] + shortcut[[1]], out[[2]]+shortcut[[2]]) %>% layer_activation_Bayesian('relu')
})
}
conv_bottleneckEP = function(xP, filters, strides = 1L){
expansion = 4L
#print('x')
#print(xP)
out = xP %>%
layer_conv_2d_Bayesian(filters = filters, kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>%
#layer_batch_normalization(axis = 1L, momentum = .9, epsilon = 1e-5) %>% # bn1
layer_activation_Bayesian('relu') %>%
layer_conv_2d_Bayesian(filters = filters, strides = rep(strides,2), kernel_size = c(3L,3L), padding = 'same', use_bias = FALSE) %>% # conv2
#layer_batch_normalization(axis = 1L, momentum = .9, epsilon = 1e-5) %>% # bn2
layer_activation_Bayesian('relu') %>%
layer_conv_2d_Bayesian(filters = expansion * filters, kernel_size = c(1L,1L), padding = 'same', use_bias = FALSE, activation = NULL) #%>% # conv3
#layer_batch_normalization(axis = 1L, momentum = .9, epsilon = 1e-5) # bn3
#print("out")
#print(out)
if(dim(xP[[1]])[4]== expansion * filters & strides==1L) shortcut = xP # channel dimension 4 for channel last and 2 for channel first
else shortcut = xP %>%
layer_conv_2d_Bayesian(filters = expansion * filters, strides = rep(strides,2), kernel_size = 1L, use_bias = FALSE, name = sprintf('shortcut/conv2d_%d',keras$backend$get_uid(prefix='s/c'))) #%>%
#layer_batch_normalization(axis = 1L, momentum = .9, epsilon = 1e-5, name=sprintf('shortcut/bn_%d',keras$backend$get_uid(prefix='s/b')))
#print('shortcut')
#print(shortcut)
layer_lambda(list(out, shortcut), function(out_shortcut){
c(out, shortcut) %<-% out_shortcut
list(out[[1]] + shortcut[[1]], out[[2]]+shortcut[[2]]) %>% layer_activation_Bayesian('relu')
})
}
conv_layerEP = function(xP, block, filters, num_blocks, strides){
layers = xP
print(sprintf('new conv_layer:filters = %d, num_blocks = %d, strides = %d', filters, num_blocks, strides))
for(i in 1:num_blocks) {
layers = layers %>% block(., filters, strides)
strides = as.integer(1L) # only shrink image once at the first block, if strides > 1
}
layers
}
inputsEP = list(layer_input(shape = dim(cifar10$train$x)[-1]),layer_input(shape = dim(cifar10$train$x)[-1]))
predictionsEP = inputsEP %>%
layer_conv_2d_Bayesian(filters = 64, strides = 1L, kernel_size = 3L, padding = 'same', use_bias = FALSE) %>%
#layer_batch_normalization(axis = 1L, momentum = .9, epsilon = 1e-5) %>%
layer_activation_Bayesian('relu') %>%
conv_layerEP(., block = conv_blockEP, filters = 16L, num_blocks = 18, strides = 1L) %>% # layer 1
conv_layerEP(., block = conv_blockEP, filters = 32L, num_blocks = 18, strides = 2L) %>% # layer 2
conv_layerEP(., block = conv_blockEP, filters = 64L, num_blocks = 18, strides = 2L) %>% # layer 3
layer_average_pooling_2d_Bayesian(pool_size = 8L) %>%
layer_lambda(function(xP){list(tf$reshape(xP[[1]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))), tf$reshape(xP[[2]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))))}) %>% #layer_flatten() %>%
layer_dense_Bayesian(units = 10L, activation = NULL) %>%
layer_lambda(f = function(mv){
tf$reduce_mean(tf$nn$softmax(tfp$distributions$MultivariateNormalDiag(loc = mv[[1]], scale_diag=mv[[2]]^.5)$sample(50L), axis=-1L), axis=0L)
})
resnet.cifar10EP = keras_model(inputs = inputsEP, outputs = predictionsEP)
resnet.cifar10EP$compile(loss='categorical_crossentropy',optimizer= 'adam',metrics=list('accuracy'))
[1] "new conv_layer:filters = 16, num_blocks = 18, strides = 1" [1] "new conv_layer:filters = 32, num_blocks = 18, strides = 2" [1] "new conv_layer:filters = 64, num_blocks = 18, strides = 2"
tf$keras$utils$plot_model(resnet.cifar10EP, to_file='Bayesian_ResNet_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'Bayesian_ResNet_CIFAR10.png')
Bayesian EfficientNet B0¶
c_k_i = initializer_variance_scaling(scale=2.0, mode='fan_out')
convBnAct = function(xP, filters, kernel_size=3L, strides=1L, padding='valid', bn=TRUE, act='swish', use_bias=FALSE, groups=1L){
#if(stride==2L & kernel_size == 3L) x = x %>%
# layer_zero_padding_2d(padding=list(list(0L, 1L),list(0L, 1L))) %>%
# layer_conv_2d(filters=filters, kernel_size=kernel_size, stride=stride, padding='valid', use_bias=use_bias, groups = groups, kernel_initializer = c_k_i)
#else
xP = layer_conv_2d_Bayesian(xP, filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, use_bias=use_bias, groups = groups, kernel_initializer = c_k_i)
#if(bn) x = layer_batch_normalization(x, momentum = .9, epsilon = 1e-5)
if(!is.null(act)) xP %>% layer_activation_Bayesian(activation = act) else xP
}
squeezeExcitation = function(x, filters){
y = x %>%
layer_average_pooling_2d_Bayesian(pool_size = dim(x[[1]])[2:3]) %>% ## ??? layer_global_average_pooling_2d(keepdims=TRUE) %>% # is this adaptive average pool 2d with output dim 1x2?
layer_conv_2d_Bayesian(filters=filters, kernel_size=1L, activation='swish', kernel_initializer = c_k_i) %>%
layer_conv_2d_Bayesian(filters=dim(x[[1]])[4], kernel_size=1L, activation='sigmoid', kernel_initializer = c_k_i)
#list(x[[1]]*y[[1]], x[[2]]*y[[1]]^2 + x[[1]]^2 * y[[2]])#x * y
layer_lambda(list(x,y), function(xy){
c(x,y) %<-% xy
list(x[[1]]*y[[1]], x[[2]]*y[[1]]^2 + x[[1]]^2 * y[[2]])
})#x * y
}
stochasticDepth = function(x, survival_prob = .8){
tf$where(tf$random$uniform(list()) < survival_prob, x / survival_prob, 0)
}
#Mobile-net conv Block with expansion factor N
MBConvN = function(xP, filters, kernel_size=3L, strides=1L, expansion_factor = 6L, reduction = 4L, survival_prob = .8){
filters_in = dim(xP[[1]])[4]
residual = xP
intermediate_channels = filters_in * expansion_factor
if(expansion_factor!=1L) xP = convBnAct(xP, filters = intermediate_channels, kernel_size=1L)
xP = xP %>%
convBnAct(filters = intermediate_channels, kernel_size = kernel_size, strides = strides, padding = 'same', groups = intermediate_channels) %>%
squeezeExcitation(filters = floor(filters_in/reduction)) %>%
convBnAct(filters = filters, kernel_size = 1L, act=NULL, use_bias=FALSE)
if(strides==1L & filters_in==filters)
layer_lambda(list(xP, residual), function(xP_residual){
c(xP, residual) %<-% xP_residual
list(residual[[1]] + xP[[1]], residual[[2]] + xP[[2]]) #layer_dropout(xP, rate = .2, noise_shape=c(py_none(), 1L, 1L, 1L))
})
else xP
}
featureExtractor = function(x, width_mult, depth_mult, last_channel){
kernels = c(3L, 3L, 5L, 3L, 5L, 5L, 3L)
expansions = c(1L, 6L, 6L, 6L, 6L, 6L, 6L)
scaled_num_channels = as.integer(4 * ceiling( c(16, 24, 40, 80, 112, 192, 320) * width_mult/4 ))
scaled_num_layers = as.integer(c(1, 2, 2, 3, 3, 4, 1) * depth_mult)
strides = c(1L, 2L, 2L, 2L, 1L, 2L, 1L)
x = x %>% convBnAct(filters = 4 * ceiling(32 * width_mult/4), kernel_size = 3L, stride = 2L, padding = 'same')
for(i in 1:length(scaled_num_layers)){
for(j in 1:scaled_num_layers[i]){
x = x %>% MBConvN(filters = scaled_num_channels[i], kernel_size = kernels[i], strides = if(j==1) strides[i] else 1L, expansion_factor = expansions[i])
}
}
x %>% convBnAct(filters = last_channel, kernel_size = 1L, stride = 1L, padding='valid')
}
inputsEP = list(layer_input(shape = dim(cifar10$train$x)[-1]),layer_input(shape = dim(cifar10$train$x)[-1]))
predictionsEP = inputsEP %>%
featureExtractor(width_mult = 1, depth_mult = 1, last_channel = 1280) %>%
layer_lambda(function(xP){list(tf$reshape(xP[[1]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))), tf$reshape(xP[[2]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))))}) %>% #layer_flatten() %>%
layer_dense_Bayesian(units = 10L, activation = NULL) %>%
layer_lambda(f = function(mv){
tf$reduce_mean(tf$nn$softmax(tfp$distributions$MultivariateNormalDiag(loc = mv[[1]], scale_diag=mv[[2]]^.5)$sample(50L), axis=-1L), axis=0L)
})
efficientNetB0.cifar10EP = keras_model(inputs = inputsEP, outputs = predictionsEP)
tf$keras$utils$plot_model(efficientNetB0.cifar10EP, to_file='Bayesian_EfficientNetB0_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'Bayesian_EfficientNetB0_CIFAR10.png')
Bayesian WideResNet-28-10¶
conv_blockEP = function(xP, base, k, strides = c(1L, 1L), shortcut = FALSE){
h = xP %>%
#layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_Bayesian('relu')
out = h %>%
layer_conv_2d_Bayesian(filters = base*k, kernel_size = c(3L,3L), padding = 'same', strides=strides, kernel_initializer='he_normal', use_bias = FALSE) %>%
#layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_Bayesian('relu') %>%
#layer_dropout(rate = dropout) %>% # conv1
layer_conv_2d_Bayesian(filters = base*k, kernel_size = c(3L,3L), padding = 'same', kernel_initializer='he_normal', use_bias = FALSE)
if(shortcut) skip = h %>%
layer_conv_2d_Bayesian(filters = base*k, kernel_size = c(1L, 1L), strides = strides, kernel_initializer='he_normal', use_bias = FALSE)
else skip = xP
layer_lambda(list(out, skip), function(out_skip){
c(out, skip) %<-% out_skip
list(out[[1]] + skip[[1]], out[[2]]+skip[[2]]) #layer_add(list(out, skip))
})
}
inputsEP = list(layer_input(shape = dim(cifar10$train$x)[-1]),layer_input(shape = dim(cifar10$train$x)[-1]))
k = 10L
predictionsEP = inputsEP %>%
layer_conv_2d_Bayesian(filters = 16, strides = 1L, kernel_size = 3L, padding = 'same', use_bias = FALSE) %>%
conv_blockEP(., 16, k, shortcut=TRUE) %>% conv_blockEP(., 16, k) %>% conv_blockEP(., 16, k) %>%
conv_blockEP(., 32, k, strides = c(2L, 2L), shortcut=TRUE) %>% conv_blockEP(., 32, k) %>% conv_blockEP(., 32, k) %>%
conv_blockEP(., 64, k, strides = c(2L, 2L), shortcut=TRUE) %>% conv_blockEP(., 64, k) %>% conv_blockEP(., 64, k) %>%
#layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% layer_activation_relu() %>% layer_average_pooling_2d(pool_size = 8L) %>% layer_flatten() %>%
layer_activation_Bayesian('relu') %>%
layer_average_pooling_2d_Bayesian(pool_size = 4L) %>%
layer_lambda(function(xP){list(tf$reshape(xP[[1]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))), tf$reshape(xP[[2]], shape=as.integer(c(-1L, prod(dim(xP[[1]])[-1])))))}) %>% #layer_flatten() %>%
layer_dense_Bayesian(units = 10L, activation = NULL) %>%
layer_lambda(f = function(mv){
tf$reduce_mean(tf$nn$softmax(tfp$distributions$MultivariateNormalDiag(loc = mv[[1]], scale_diag=mv[[2]]^.5)$sample(50L), axis=-1L), axis=0L)
})
wrn_28_10.cifar10EP = keras_model(inputs = inputsEP, outputs = predictionsEP)
tf$keras$utils$plot_model(wrn_28_10.cifar10EP, to_file='Bayesian_WRN-28-10_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'Bayesian_WRN-28-10_CIFAR10.png')
MLP-Mixer¶
Bayesian NN constructed from tfp.layers
¶
We bring tfp.layers
to competitive performance by
- sampling multiple non-Bayesian nerual networks (
AverageCase
), - wraping around
call
function to implement natural gradient, and - considering adversarial training ('WorseCase`).
AverageCase(tf$keras$Model) %py_class% {
test_step = function(self, data){
x = data[[1]]
y = data[[2]]
if(length(data)==3) sample_weight = data[[3]] else sample_weight = NULL
#y_pred = self$call(x)
y_pred = tf$reduce_mean(tf$stack(lapply(1:8,function(n)self$call(x)), axis=0L), axis=0L)
self$compiled_loss(y, y_pred, regularization_losses=self$losses)
self$compiled_metrics$update_state(y, y_pred, sample_weight=sample_weight)
mapply(assign, sapply(self$metrics, function(m) m$name), lapply(self$metrics, function(m) m$result()))
}
train_step = function(self, data) {
x = data[[1]]
y = data[[2]]
if(length(data)==3) sample_weight = data[[3]] else sample_weight = NULL
with(tf$GradientTape() %as% tape, {
#y_pred = tf$reduce_logsumexp(tf$stack(lapply(1:4,function(n)self$call(x)), axis=0L), axis=0L)
y_pred = tf$reduce_mean(tf$stack(lapply(1:8,function(n)self$call(x)), axis=0L), axis=0L)
loss = self$compiled_loss(y, y_pred, regularization_losses=self$losses)
})
gradients = tape$gradient(loss, self$trainable_variables)
self$optimizer$apply_gradients(mapply(list, gradients, self$trainable_variables, SIMPLIFY =FALSE))
self$compiled_metrics$update_state(y, y_pred, sample_weight=sample_weight)
mapply(assign, sapply(self$metrics, function(m) m$name), lapply(self$metrics, function(m) m$result()))
}
}
WorstCase(tf$keras$Model) %py_class% {
test_step = function(self, data){
x = data[[1]]
y = data[[2]]
if(length(data)==3) sample_weight = data[[3]] else sample_weight = NULL
#y_pred = self$call(x)
y_pred = tf$reduce_mean(tf$stack(lapply(1:8,function(n)self$call(x)), axis=0L), axis=0L)
self$compiled_loss(y, y_pred, regularization_losses=self$losses)
self$compiled_metrics$update_state(y, y_pred, sample_weight=sample_weight)
mapply(assign, sapply(self$metrics, function(m) m$name), lapply(self$metrics, function(m) m$result()))
}
train_step = function(self, data) {
x = data[[1]]
y = data[[2]]
if(length(data)==3) sample_weight = data[[3]] else sample_weight = NULL
with(tf$GradientTape() %as% tape, {
y_pred = tf$stack(lapply(1:8,function(n)self$call(x)), axis=0L)
self$compiled_loss(y, tf$reduce_mean(y_pred, axis=0L), regularization_losses=self$losses) # loss to display
loss_ = tf$losses$categorical_crossentropy(tf$einsum('b,ij->bij',tf$ones_like(y_pred[,1,1]), y), y_pred)
loss_ = tf$reduce_max(tf$reduce_mean(loss_, axis=1L))
loss = loss_ + self$losses # loss to optimize
})
gradients = tape$gradient(loss, self$trainable_variables)
self$optimizer$apply_gradients(mapply(list, gradients, self$trainable_variables, SIMPLIFY =FALSE))
self$compiled_metrics$update_state(y, tf$reduce_mean(y_pred, axis=0L), sample_weight=sample_weight)
mapply(assign, sapply(self$metrics, function(m) m$name), lapply(self$metrics, function(m) m$result()))
}
}
my.posterior = tfp$layers$default_mean_field_normal_fn(untransformed_scale_constraint = function(w){
tf$clip_by_value(w, tfp$math$softplus_inverse(tf$experimental$numpy$finfo('float32')$resolution*10.0), 1.0)
})
layer_dense_reparameterization_NG = Layer(
classname = 'ReparameterizationNGDense',
inherit = tf$keras$layers$Layer,
initialize = function(...) {
super$initialize()
self$denseReparameterization = tfp$layers$DenseReparameterization(...)
},
build = function(input_shape) {
self$denseReparameterization$build(input_shape)
self$`__forward__` = tf$custom_gradient(function(inputs){
with(tf$GradientTape(persistent=TRUE) %as% tape, {
tape$watch(inputs)
outputs = self$denseReparameterization$call(inputs)
})
grad_fn = function(upstream, variables){
c(dx, dv) %<-% tape$gradient(outputs, list(inputs, variables), upstream)
dv2 = lapply(1:length(variables), function(n){
switch(
gsub('.+/(bias_posterior_loc|kernel_posterior_loc|kernel_posterior_untransformed_scale):[0-9]+','\\1', variables[[n]]$name),
bias_posterior_loc = dv[[n]],
kernel_posterior_loc = dv[[n]] * tf$math$softplus(self$denseReparameterization$variables[[2]])^2,
kernel_posterior_untransformed_scale = dv[[n]] * tf$math$softplus(self$denseReparameterization$variables[[2]])^4 / tf$math$sigmoid(self$denseReparameterization$variables[[2]])^2 * 2)
})
list(dx, dv2)
}
list(outputs, grad_fn)
})
# self$`__forward__` = function(inputs){
# self$denseReparameterization$call(inputs)
# }
self$built=TRUE
},
call = function(inputs) {
self$`__forward__`(inputs)
# self$denseReparameterization$call(inputs)
}
)
layer_conv2d_reparameterization_NG = Layer(
classname = 'ReparameterizationNGConv2D',
inherit = tf$keras$layers$Layer,
initialize = function(...) {
super$initialize()
self$conv2dReparameterization = tfp$layers$Convolution2DReparameterization(...)
},
build = function(input_shape) {
self$conv2dReparameterization$build(input_shape)
self$`__forward__` = tf$custom_gradient(function(inputs){
with(tf$GradientTape(persistent=TRUE) %as% tape, {
tape$watch(inputs)
outputs = self$conv2dReparameterization$call(inputs)
})
grad_fn = function(upstream, variables){
c(dx, dv) %<-% tape$gradient(outputs, list(inputs, variables), upstream)
dv2 = lapply(1:length(variables), function(n){
switch(
gsub('.+/(bias_posterior_loc|kernel_posterior_loc|kernel_posterior_untransformed_scale):[0-9]+','\\1', variables[[n]]$name),
bias_posterior_loc = dv[[n]],
kernel_posterior_loc = dv[[n]] * tf$math$softplus(self$conv2dReparameterization$variables[[2]])^2,
kernel_posterior_untransformed_scale = dv[[n]] * tf$math$softplus(self$conv2dReparameterization$variables[[2]])^4 / tf$math$sigmoid(self$conv2dReparameterization$variables[[2]])^2 * 2)
})
list(dx, dv2)
}
list(outputs, grad_fn)
})
self$built=TRUE
},
call = function(inputs) {
self$`__forward__`(inputs)
}
)
DenseNet-BC from ftp.layers
¶
divergence_fn = function(q, p, ignore){
0.0 #tfp$distributions$kl_divergence(q, p) /50000
}
#my.posterior = tfp$layers$default_mean_field_normal_fn(untransformed_scale_initializer=tf$random_normal_initializer(mean=-9.0, stddev=0.05))
my.posterior = tfp$layers$default_mean_field_normal_fn()
# output width and height are the same, output channels = input channels + growth rate
conv_bottleneck = function(x, growth_rate){
conv = x %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn1
layer_activation_relu() %>%
layer_conv_2d_reparameterization(filters = 4 * growth_rate, kernel_size = c(1L,1L), bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>% # conv1
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn2
layer_activation_relu() %>%
layer_conv_2d_reparameterization(filters = growth_rate, kernel_size = c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) # conv2
list(conv, x) %>% layer_concatenate(axis=-1L) # channel dimension 4 for channel last and 2 for channel first, keras default is channel last
}
# output width and height are half of the input, output channels = reduction * (input channels + growth rate * number of blocks)
dense_layer = function(x, block, growth_rate, num_blocks){
layers = x
for(i in 1:num_blocks) layers = layers %>% block(., growth_rate)
print(sprintf('dense_layer: growth_rate = %d, num_blocks = %d', growth_rate, num_blocks))
print(layers)
layers
}
transition = function(x, reduction = .5){
x %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_relu() %>%
layer_conv_2d_reparameterization(filters = floor(reduction*dim(x)[4]), kernel_size = c(1L,1L), bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
layer_average_pooling_2d(pool_size = 2L)
}
g.r. = 12L
predictions = inputs %>%
#layer_lambda(function(img) tf$image$resize(img, c(224L, 224L))) %>%
layer_conv_2d_reparameterization(filters = 2*g.r., kernel_size = 3L, padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>% # layer 1
dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>% # layer 2
dense_layer(., block = conv_bottleneck, growth_rate = g.r., num_blocks = 16) %>% transition(.) %>% # layer 3
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_relu() %>%
layer_average_pooling_2d(pool_size = 4L) %>%
layer_flatten() %>%
layer_dense(activation = 'softmax', units = 10L)
densenet.cifar10 = keras_model(inputs = inputs, outputs = predictions)
[1] "dense_layer: growth_rate = 12, num_blocks = 16" KerasTensor(type_spec=TensorSpec(shape=(None, 32, 32, 216), dtype=tf.float32, name=None), name='concatenate_207/concat:0', description="created by layer 'concatenate_207'") [1] "dense_layer: growth_rate = 12, num_blocks = 16" KerasTensor(type_spec=TensorSpec(shape=(None, 16, 16, 300), dtype=tf.float32, name=None), name='concatenate_223/concat:0', description="created by layer 'concatenate_223'") [1] "dense_layer: growth_rate = 12, num_blocks = 16" KerasTensor(type_spec=TensorSpec(shape=(None, 8, 8, 342), dtype=tf.float32, name=None), name='concatenate_239/concat:0', description="created by layer 'concatenate_239'")
tf$keras$utils$plot_model(densenet.cifar10, to_file='tfp_DenseNetBC_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'tfp_DenseNetBC_CIFAR10.png')
ResNet from tfp.layers
¶
divergence_fn = function(q, p, ignore){
0.0 #tfp$distributions$kl_divergence(q, p) /50000
}
#my.posterior = tfp$layers$default_mean_field_normal_fn(untransformed_scale_initializer=tf$random_normal_initializer(mean=-9.0, stddev=0.05))
my.posterior = tfp$layers$default_mean_field_normal_fn()
myconv2d = function(x, filters, strides){
if(strides[1]>1)
x %>%
layer_zero_padding_2d(padding=list(list(1L, 0L),list(1L, 0L))) %>%
layer_conv_2d_reparameterization(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'valid', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
else
x %>%
layer_conv_2d_reparameterization(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
}
conv_block = function(x, filters, strides = 1L){
#print('x')
#print(x)
out = x %>%
layer_conv_2d_reparameterization(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
#myconv2d(filters = filters, strides = strides) %>% # conv1
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn1, by default channel is last dimension
layer_activation_relu() %>%
layer_conv_2d_reparameterization(filters = filters, strides = 1L, kernel_size = c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>% # conv2
layer_batch_normalization(momentum = .9, epsilon = 1e-5) # bn2
#print('out')
#print(out)
if(dim(x)[4]==filters & strides==1L) shortcut = x # channel dimension 4 for channel last and 2 for channel first
else shortcut = x %>%
layer_conv_2d_reparameterization(filters = filters, strides = strides, kernel_size = 1L, name = sprintf('shortcut/conv2d_%d',keras$backend$get_uid(prefix='s/c')), bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5, name=sprintf('shortcut/bn_%d',keras$backend$get_uid(prefix='s/b')))
#print('shortcut')
#print(shortcut)
(out + shortcut) %>% layer_activation_relu()
}
conv_bottleneck = function(x, filters, strides = 1L){
expansion = 4L
#print('x')
#print(x)
out = x %>%
layer_conv_2d_reparameterization(filters = filters, kernel_size = c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn1
layer_activation_relu() %>%
layer_conv_2d_reparameterization(filters = filters, strides = strides, kernel_size = c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>% # conv2
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% # bn2
layer_activation_relu() %>%
layer_conv_2d_reparameterization(filters = expansion * filters, kernel_size = c(1L,1L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>% # conv3
layer_batch_normalization(momentum = .9, epsilon = 1e-5) # bn3
#print('out')
#print(out)
if(dim(x)[4]== expansion * filters & strides==1L) shortcut = x # channel dimension 4 for channel last and 2 for channel first
else shortcut = x %>%
layer_conv_2d_reparameterization(filters = expansion * filters, strides = strides, kernel_size = 1L, name = sprintf('shortcut/conv2d_%d',keras$backend$get_uid(prefix='s/c')), bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5, name=sprintf('shortcut/bn_%d',keras$backend$get_uid(prefix='s/b')))
#print('shortcut')
#print(shortcut)
(out + shortcut) %>% layer_activation_relu()
}
conv_layer = function(x, block, filters, num_blocks, strides){
layers = x
print(sprintf('new conv_layer:filters = %d, num_blocks = %d, strides = %d', filters, num_blocks, strides))
for(i in 1:num_blocks) {
layers = layers %>% block(., filters, strides)
strides = 1 # only shrink image once at the first block, if strides > 1
}
layers
}
inputs = layer_input(shape = dim(cifar10$train$x)[-1])
predictions = inputs %>%
layer_conv_2d_reparameterization(filters = 16, strides = 1L, kernel_size = 3L, padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_relu %>%
conv_layer(., block = conv_block, filters = 16, num_blocks = 18, strides = 1) %>% # layer 1
conv_layer(., block = conv_block, filters = 32, num_blocks = 18, strides = 2) %>% # layer 2
conv_layer(., block = conv_block, filters = 64, num_blocks = 18, strides = 2) %>% # layer 3
layer_average_pooling_2d(pool_size = 8L) %>%
layer_flatten() %>%
layer_dense_reparameterization(activation = 'softmax', units = 10L, bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
resnet.cifar10 = keras_model(inputs = inputs, outputs = predictions)
[1] "new conv_layer:filters = 16, num_blocks = 18, strides = 1" [1] "new conv_layer:filters = 32, num_blocks = 18, strides = 2" [1] "new conv_layer:filters = 64, num_blocks = 18, strides = 2"
tf$keras$utils$plot_model(densenet.cifar10, to_file='tfp_ResNet_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'tfp_ResNet_CIFAR10.png')
EfficientNet B0 from tfp.layers
¶
divergence_fn = function(q, p, ignore){
0.0 #tfp$distributions$kl_divergence(q, p) /50000
}
my.posterior = tfp$layers$default_mean_field_normal_fn(loc_initializer = initializer_variance_scaling(scale=2.0, mode='fan_out'))
convBnAct = function(x, filters, kernel_size=3L, stride=1L, padding='valid', bn=TRUE, act='swish', use_bias=FALSE, groups=1L){
#if(stride==2L & kernel_size == 3L) x = x %>%
# layer_zero_padding_2d(padding=list(list(0L, 1L),list(0L, 1L))) %>%
# layer_conv_2d(filters=filters, kernel_size=kernel_size, stride=stride, padding='valid', use_bias=use_bias, groups = groups, kernel_initializer = c_k_i)
#else
x = layer_conv_2d_reparameterization(x, filters=filters, kernel_size=kernel_size, stride=stride, padding=padding, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
if(bn) x = layer_batch_normalization(x, momentum = .9, epsilon = 1e-5)
if(!is.null(act)) x %>% layer_activation(activation = act) else x
}
squeezeExcitation = function(x, filters){
y = x %>%
layer_global_average_pooling_2d(keepdims=TRUE) %>% # is this adaptive average pool 2d with output dim 1x2?
layer_conv_2d_reparameterization(filters=filters, kernel_size=1L, activation='swish', kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
layer_conv_2d_reparameterization(filters=dim(x)[4], kernel_size=1L, activation='sigmoid', kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
x * y
}
stochasticDepth = function(x, survival_prob = .8){
tf$where(tf$random$uniform(list()) < survival_prob, x / survival_prob, 0)
}
#Mobile-net conv Block with expansion factor N
MBConvN = function(x, filters, kernel_size=3L, stride=1L, expansion_factor = 6L, reduction = 4L, survival_prob = .8){
filters_in = dim(x)[4]
residual = x
intermediate_channels = filters_in * expansion_factor
if(expansion_factor!=1L) x = convBnAct(x, filters = intermediate_channels, kernel_size=1L)
x = x %>%
convBnAct(filters = intermediate_channels, kernel_size = kernel_size, stride = stride, padding = 'same', groups = intermediate_channels) %>%
squeezeExcitation(filters = floor(filters_in/reduction)) %>%
convBnAct(filters = filters, kernel_size = 1L, act=NULL, use_bias=FALSE)
if(stride==1L & filters_in==filters){
residual + layer_dropout(x, rate = .2, noise_shape=c(py_none(), 1L, 1L, 1L))
} else x
}
featureExtractor = function(x, width_mult, depth_mult, last_channel){
kernels = c(3, 3, 5, 3, 5, 5, 3)
expansions = c(1, 6, 6, 6, 6, 6, 6)
scaled_num_channels = 4 * ceiling( c(16, 24, 40, 80, 112, 192, 320) * width_mult/4 )
scaled_num_layers = c(1, 2, 2, 3, 3, 4, 1) * depth_mult
strides = c(1, 2, 2, 2, 1, 2, 1)
x = x %>% convBnAct(filters = 4 * ceiling(32 * width_mult/4), kernel_size = 3L, stride = 2L, padding = 'same')
for(i in 1:length(scaled_num_layers)){
for(j in 1:scaled_num_layers[i]){
x = x %>% MBConvN(filters = scaled_num_channels[i], kernel_size = kernels[i], stride = if(j==1) strides[i] else 1L, expansion_factor = expansions[i])
}
}
x %>% convBnAct(filters = last_channel, kernel_size = 1L, stride = 1L, padding='valid')
}
inputs = layer_input(shape = dim(cifar10$train$x)[-1])
predictions = inputs %>%
featureExtractor(width_mult = 1, depth_mult = 1, last_channel = 1280) %>%
layer_global_average_pooling_2d() %>%
layer_dropout(.2) %>%
layer_dense_reparameterization(units = 10L, activation='softmax', kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
efficientNetB0.cifar10 = keras_model(inputs = inputs, outputs = predictions)
#efficientNetB0.cifar10
tf$keras$utils$plot_model(efficientNetB0.cifar10, to_file='tfp_EfficientNetB0_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'tfp_EfficientNetB0_CIFAR10.png')
Wide Resnet 28-10 from tfp.layers
¶
divergence_fn = function(q, p, ignore){
0.0 #tfp$distributions$kl_divergence(q, p) /50000
}
my.posterior = tfp$layers$default_mean_field_normal_fn(loc_initializer='he_normal')
conv_block = function(x, base, k, dropout = 0, strides = c(1L, 1L), shortcut = FALSE){
h = x %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_relu()
out = h %>%
layer_conv_2d_reparameterization(base*k, c(3L,3L), padding = 'same', strides=strides, bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>%
layer_activation_relu() %>%
layer_dropout(rate = dropout) %>% # conv1
layer_conv_2d_reparameterization(base*k, c(3L,3L), padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
if(shortcut) skip = h %>%
layer_conv_2d_reparameterization(base*k, c(1L, 1L), strides = strides, bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
else skip = x
layer_add(list(out, skip))
}
inputs = layer_input(shape = dim(cifar10$train$x)[-1])
k = 10
predictions = inputs %>%
layer_conv_2d_reparameterization(filters = 16, strides = 1L, kernel_size = 3L, padding = 'same', bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior) %>%
conv_block(., 16, k, shortcut=TRUE) %>% conv_block(., 16, k, dropout=.4) %>% conv_block(., 16, k, dropout=.3) %>%
conv_block(., 32, k, strides = c(2L, 2L), shortcut=TRUE) %>% conv_block(., 32, k, dropout=.4) %>% conv_block(., 32, k, dropout=.4) %>%
conv_block(., 64, k, strides = c(2L, 2L), shortcut=TRUE) %>% conv_block(., 64, k, dropout=.4) %>% conv_block(., 64, k, dropout=.4) %>%
layer_batch_normalization(momentum = .9, epsilon = 1e-5) %>% layer_activation_relu() %>% layer_average_pooling_2d(pool_size = 8L) %>% layer_flatten() %>%
layer_dense_reparameterization(activation = 'softmax', units = 10L, bias_posterior_fn = NULL, kernel_divergence_fn = divergence_fn, kernel_posterior_fn = my.posterior)
wrn_28_10.cifar10 = keras_model(inputs = inputs, outputs = predictions)
tf$keras$utils$plot_model(wrn_28_10.cifar10, to_file='tfp_WRN-28-10_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'tfp_WRN-28-10_CIFAR10.png')
MLP-Mixer¶
layer_patch = Layer(
classname = 'Patch',
inherit = tf$keras$layers$Layer,
initialize = function(patch_size, projection_dim, ...) {
super$initialize()
self$patch_size = patch_size
self$projection_dim = projection_dim
},
build = function(input_shape) {
# assert: image is square, patch is square
self$num_patches = as.integer(floor(input_shape[[2]]/self$patch_size)^2)
self$projection = layer_dense_reparameterization(units = self$projection_dim)
self$position_embedding = tf$keras$layers$Embedding(input_dim = self$num_patches, output_dim = self$projection_dim)
self$built=TRUE
return()
},
call = function(images) {
patches = tf$image$extract_patches(
images = images,
sizes = c(1L, self$patch_size, self$patch_size, 1L),
strides = c(1L, self$patch_size, self$patch_size, 1L),
rates = c(1L, 1L, 1L, 1L),
padding = 'VALID'
)
patches = tf$reshape(
patches,
shape=c(tf$shape(images)[1L], -1L, tf$reduce_prod(tf$shape(patches)[4L]))
)
positions = tf$range(start=0, limit = self$num_patches, delta = 1)
self$projection(patches) + self$position_embedding(positions)
}
)
mlp_mixer_block = function(x){
num_patches = dim(x)[2]
projection_dim = dim(x)[3]
mlp1 = x %>%
layer_layer_normalization(epsilon = 1e-6) %>%
tf$linalg$matrix_transpose(.) %>%
layer_dense_reparameterization(units = num_patches, activation='gelu') %>%
layer_dense_reparameterization(units = num_patches) %>%
layer_dropout(rate = .2) %>%
tf$linalg$matrix_transpose(.)
x = layer_add(list(x, mlp1))
mlp2 = x %>%
layer_layer_normalization(epsilon = 1e-6) %>%
layer_dense_reparameterization(units = projection_dim, activation='gelu') %>%
layer_dense_reparameterization(units = projection_dim) #%>%
layer_dropout(rate = .2)
layer_add(list(x, mlp2))
}
inputs = layer_input(shape = dim(cifar10$train$x)[-1])
predictions = inputs %>%
layer_resizing(64L, 64L) %>%
layer_patch(patch_size = 8L, projection_dim = 256L) %>%
mlp_mixer_block %>%
mlp_mixer_block %>%
mlp_mixer_block %>%
mlp_mixer_block %>%
layer_global_average_pooling_1d %>%
layer_dropout(rate = .2) %>%
layer_dense_reparameterization(units=10L, activation = 'softmax')
mlp_mixer.cifar10 = keras_model(inputs = inputs, outputs = predictions)
tf$keras$utils$plot_model(mlp_mixer.cifar10, to_file='tfp_MLPMixer_CIFAR10.png', show_shapes=TRUE, show_layer_names=TRUE, show_layer_activations=TRUE)
IRdisplay::display_png(file = 'tfp_MLPMixer_CIFAR10.png')
Datasets involved¶
CIFAR 10¶
cifar10 = dataset_cifar10()
cifar10.datagen = image_data_generator(
featurewise_center=TRUE,
featurewise_std_normalization=TRUE,
##rotation_range=15,
zoom_range=0.1,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=TRUE,
)
cifar10.datagen$fit(cifar10$train$x)
# stream the train/test sets with `tf.keras.preprocessing.image.ImageDataGenerator.flow` in the following way
#cifar10.datagen$flow(cifar10$train$x, to_categorical(cifar10$train$y, num_classes=10L), batch_size = 64L, shuffle=TRUE)
#cifar10.datagen$flow(cifar10$test$x, to_categorical(cifar10$test$y, num_classes=10L), batch_size = 1000L)
CIFAR 100¶
cifar100 = dataset_cifar100()
cifar100.datagen = image_data_generator(
featurewise_center=TRUE,
featurewise_std_normalization=TRUE,
##rotation_range=15,
zoom_range=0.1,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=TRUE,
)
cifar100.datagen$fit(cifar100$train$x)
# stream the train/test sets with `tf.keras.preprocessing.image.ImageDataGenerator.flow` in the following way
#cifar100.datagen$flow(cifar100$train$x, to_categorical(cifar100$train$y, num_classes=100L), batch_size = 64L, shuffle=TRUE)
#cifar100.datagen$flow(cifar100$test$x, to_categorical(cifar100$test$y, num_classes=100L), batch_size = 1000L)
ImageNet resized/64x64¶
httr::GET('https://image-net.org/login.php', authenticate("<username>", "<password>"))
tfds = reticulate::import('tensorflow_datasets')
c(imagenet64x64, info) %<-% tfds$load('imagenet_resized/64x64', shuffle_files=FALSE, with_info=TRUE)
info
tfds.core.DatasetInfo( name='imagenet_resized', full_name='imagenet_resized/64x64/0.1.0', description=""" This dataset consists of the ImageNet dataset resized to fixed size. The images here are the ones provided by Chrabaszcz et. al. using the box resize method. For [downsampled ImageNet](http://image-net.org/download.php) for unsupervised learning see `downsampled_imagenet`. WARNING: The integer labels used are defined by the authors and do not match those from the other ImageNet datasets provided by Tensorflow datasets. See the original [label list](https://github.com/PatrykChrabaszcz/Imagenet32_Scripts/blob/master/map_clsloc.txt), and the [labels used by this dataset](https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/image_classification/imagenet_resized_labels.txt). Additionally, the original authors 1 index there labels which we convert to 0 indexed by subtracting one. """, config_description=""" Images resized to 64x64 """, homepage='https://patrykchrabaszcz.github.io/Imagenet32/', data_path='/root/tensorflow_datasets/imagenet_resized/64x64/0.1.0', file_format=tfrecord, download_size=13.13 GiB, dataset_size=10.29 GiB, features=FeaturesDict({ 'image': Image(shape=(64, 64, 3), dtype=uint8), 'label': ClassLabel(shape=(), dtype=int64, num_classes=1000), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'train': <SplitInfo num_examples=1281167, num_shards=128>, 'validation': <SplitInfo num_examples=50000, num_shards=4>, }, citation="""@article{chrabaszcz2017downsampled, title={A downsampled variant of imagenet as an alternative to the cifar datasets}, author={Chrabaszcz, Patryk and Loshchilov, Ilya and Hutter, Frank}, journal={arXiv preprint arXiv:1707.08819}, year={2017} }""", )
fig = tfds$show_examples(imagenet64x64$train, info)
fig$savefig('imagenet64x64.png')
IRdisplay::display_png(file = 'imagenet64x64.png')
# use keras preprocessing layer for normalization and basic augmentation
#inputs = layer_input(imagenet64x64$train$element_spec$image$shape)
normalization = layer_normalization()
normalization$adapt(imagenet64x64$validation$map(function(xy){
c(x,y) %<-% xy
x
})$batch(1000L))
normalization$mean
normalization$variance^.5
tf.Tensor([[[[120.70578 114.888855 102.38361 ]]]], shape=(1, 1, 1, 3), dtype=float32)
tf.Tensor([[[[67.8931 66.11974 69.63358]]]], shape=(1, 1, 1, 3), dtype=float32)
Tiny ImageNet¶
https://github.com/ksachdeva/tiny-imagenet-tfds/blob/master/tiny_imagenet/_imagenet.py
~/.local/share/r-miniconda/bin/conda init
. ~/.bashrc
conda info --envs
conda activate r-reticulate
pip install git+https://github.com/ksachdeva/tiny-imagenet-tfds.git
python -c "import site; print(''.join(site.getsitepackages()))"
- need to comment out
vi /root/.local/share/r-miniconda/envs/r-reticulate/lib/python3.9/site-packages/tiny_imagenet/_imagenet.py
#urls=["https://tiny-imagenet.herokuapp.com/"],
#num_shards=1,
import os
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tiny_imagenet import TinyImagenetDataset
# optional
tf.compat.v1.enable_eager_execution()
tiny_imagenet_builder = TinyImagenetDataset()
# this call (download_and_prepare) will trigger the download of the dataset
# and preparation (conversion to tfrecords)
#
# This will be done only once and on next usage tfds will
# use the cached version on your host.
#
# You can pass optional argument to this method called
# DownloadConfig (https://www.tensorflow.org/datasets/api_docs/python/tfds/download/DownloadConfig)
# to customize the location where the dataset is downloaded, extracted and processed.
tiny_imagenet_builder.download_and_prepare()
train_dataset = tiny_imagenet_builder.as_dataset(split="train")
validation_dataset = tiny_imagenet_builder.as_dataset(split="validation")
assert(isinstance(train_dataset, tf.data.Dataset))
assert(isinstance(validation_dataset, tf.data.Dataset))
for a_train_example in train_dataset.take(5):
image, label, id = a_train_example["image"], a_train_example["label"], a_train_example["id"]
print(f"Image Shape - {image.shape}")
print(f"Label - {label.numpy()}")
print(f"Id - {id.numpy()}")
# print info about the data
print(tiny_imagenet_builder.info)
tiny_imagenet = reticulate::import('tiny_imagenet')
tiny_imagenet_builder = tiny_imagenet$TinyImagenetDataset()
tiny_imagenet_builder$info
tiny_imagenet_builder$download_and_prepare()
train_dataset = tiny_imagenet_builder$as_dataset(split="train")
validation_dataset = tiny_imagenet_builder$as_dataset(split="validation")
tfds.core.DatasetInfo( name='tiny_imagenet_dataset', full_name='tiny_imagenet_dataset/0.1.0', description=""" Tiny ImageNet Challenge is a similar challenge as ImageNet with a smaller dataset but less image classes. It contains 200 image classes, a training dataset of 100, 000 images, a validation dataset of 10, 000 images, and a test dataset of 10, 000 images. All images are of size 64×64. """, homepage='https://www.tensorflow.org/datasets/catalog/tiny_imagenet_dataset', data_path='/root/tensorflow_datasets/tiny_imagenet_dataset/0.1.0', file_format=tfrecord, download_size=236.61 MiB, dataset_size=210.18 MiB, features=FeaturesDict({ 'id': Text(shape=(), dtype=string), 'image': Image(shape=(64, 64, 3), dtype=uint8), 'label': ClassLabel(shape=(), dtype=int64, num_classes=200), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'train': <SplitInfo num_examples=100000, num_shards=2>, 'validation': <SplitInfo num_examples=10000, num_shards=1>, }, citation="""@article{tiny-imagenet, author = {Li,Fei-Fei}, {Karpathy,Andrej} and {Johnson,Justin}"}""", )
tfds = reticulate::import('tensorflow_datasets')
fig = tfds$show_examples(train_dataset, tiny_imagenet_builder$info)
fig$savefig('tiny_imagenet.png')
IRdisplay::display_png(file = 'tiny_imagenet.png')
str(train_dataset$take(1L)$get_single_element())
List of 3 $ id :<tf.Tensor: shape=(), dtype=string, numpy=b'n02279972'> $ image:<tf.Tensor: shape=(64, 64, 3), dtype=uint8, numpy=…> $ label:<tf.Tensor: shape=(), dtype=int64, numpy=13>
Characterizing BNN loss and noise¶
require(tensorflow)
require(keras)
require(reticulate)
require(tfprobability)
require(coro)
Sys.setenv("TF_XLA_FLAGS" = "--tf_xla_enable_xla_devices")
mnist <- dataset_mnist()
mnist_train_ds = tf$data$Dataset$from_tensor_slices(mnist$train)$shuffle(2048L)$batch(256L)$map(function(xy){
c(x,y) %<-% xy
x <- tf$reshape(tf$cast(x/255.0, dtype=tf$float32), shape=tuple(-1L,784L))
y.onehot = tf$one_hot(y, depth = 10L, dtype=tf$float32)
tuple(x,y.onehot)
})
mnist_test_ds = tf$data$Dataset$from_tensor_slices(mnist$test)$batch(1024L)$map(function(xy){
c(x,y) %<-% xy
x <- tf$reshape(tf$cast(x/255.0, dtype=tf$float32), shape=tuple(-1L,784L))
y.onehot = tf$one_hot(y, depth = 10L, dtype=tf$float32)
tuple(x,y.onehot)
})
Implementation of tfp.layers.DenseReparameterization.¶
According to documentation, it implements the following reparameterization estimator.
- kernel, bias ~ posterior
- outputs = activation(matmul(inputs, kernel) + bias)
"You can access the kernel and/or bias posterior and prior distributions after the layer is built via the kernel_posterior, kernel_prior, bias_posterior and bias_prior properties."
"Upon being built, this layer adds losses (accessible via the losses property) representing the divergences of kernel and/or bias surrogate posteriors and their respective priors. When doing minibatch stochastic optimization, make sure to scale this loss such that it is applied just once per epoch."
tfp.layers.DenseReparameterization
is a subclass oftfp.layers._DenseVariational
, the forward propagation is defined in call method, which in turn callstfp.layers.DenseReparameterization._apply_variational_kernel
,tfp.layers._DenseVariational._apply_variational_bias
,tfp.layers._DenseVariational._apply_divergence
with the providedkernel_divergence_fn
/kernel_posterior
/kernel_prior
,bias_divergence_fn
/bias_posterior
/bias_prior
.
def _apply_variational_kernel(self, inputs):
self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn(
self.kernel_posterior)
self.kernel_posterior_affine = None
self.kernel_posterior_affine_tensor = None
return tf.matmul(inputs, self.kernel_posterior_tensor)
_apply_variational_kernel
callstfp.layers._DenseVariational.kernel_posterior_tensor_fn
, which by defaultsample()
from thetfp.layers._DenseVariational.kernel_posterior
attribute, which in turn, is constructed intfp.layers._DenseVariational.build
bytfp.layers.DenseReparameterization.kernel_posterior_fn
, which by default istfp.layers.default_mean_field_normal_fn
, which in turn callstfp.layers.default_loc_scale_fn
to createVariable
s representing mean and scale. Thescale
is aDeferredTensor
.
scale = tfp_util.DeferredTensor(
untransformed_scale,
lambda x: (np.finfo(dtype.as_numpy_dtype).eps + tf.nn.softplus(x)))
Implementation of tfp.layers.DenseLocalReparameterization¶
def _apply_variational_kernel(self, inputs):
# ... omitted for brevity
self.kernel_posterior_affine = normal_lib.Normal(
loc=tf.matmul(inputs, self.kernel_posterior.distribution.loc),
scale=tf.sqrt(tf.matmul(
tf.square(inputs),
tf.square(self.kernel_posterior.distribution.scale))))
self.kernel_posterior_affine_tensor = (
self.kernel_posterior_tensor_fn(self.kernel_posterior_affine))
self.kernel_posterior_tensor = None
return self.kernel_posterior_affine_tensor
Implementation of tfp.layers.DenseFlipout¶
def _apply_variational_kernel(self, inputs):
# ... omitted for brevity
self.kernel_posterior_affine = normal_lib.Normal(
loc=tf.zeros_like(self.kernel_posterior.distribution.loc),
scale=self.kernel_posterior.distribution.scale)
self.kernel_posterior_affine_tensor = (
self.kernel_posterior_tensor_fn(self.kernel_posterior_affine))
self.kernel_posterior_tensor = None
input_shape = tf.shape(inputs)
batch_shape = input_shape[:-1]
seed_stream = SeedStream(self.seed, salt='DenseFlipout')
sign_input = tfp_random.rademacher(
input_shape,
dtype=inputs.dtype,
seed=seed_stream())
sign_output = tfp_random.rademacher(
tf.concat([batch_shape,
tf.expand_dims(self.units, 0)], 0),
dtype=inputs.dtype,
seed=seed_stream())
perturbed_inputs = tf.matmul(
inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output
outputs = tf.matmul(inputs, self.kernel_posterior.distribution.loc)
outputs += perturbed_inputs
return outputs
divergence_fn = function(q, p, ignore){
tfp$distributions$kl_divergence(q, p) / 60000
}
inputs = layer_input(shape = 784L, name='input')
predictionsR = inputs %>%
layer_dense_reparameterization(units = 40L, activation = tf$nn$sigmoid, kernel_divergence_fn = divergence_fn, bias_posterior_fn = NULL) %>%
layer_dense_reparameterization(units = 40L, activation = tf$nn$sigmoid, kernel_divergence_fn = divergence_fn, bias_posterior_fn = NULL) %>%
layer_dense_reparameterization(units = 10L, activation = 'softmax', kernel_divergence_fn = divergence_fn, , bias_posterior_fn = NULL)
modelR = keras_model(inputs = inputs, outputs = predictionsR)
modelR %>% compile(loss='categorical_crossentropy', optimizer='adam', metrics='accuracy')
modelR$save_weights('model0.ckpt')
modelR$load_weights('model0.ckpt')
modelR %>% fit(mnist_train_ds, epochs=100, validation_data=mnist_test_ds, verbose=0L)
modelR %>% evaluate(mnist_test_ds)
modelR$save_weights('modelR.ckpt')
modelR$load_weights('modelR.ckpt')
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f128c19ef70>
- loss
- 0.402956306934357
- accuracy
- 0.947799980640411
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f128a940d30>
c(x,y) %<-% coro::collect(mnist_test_ds, n = 1)[[1]]
gradR = lapply(1:16, function(n){
with(tf$GradientTape(persistent=TRUE) %as% tape, {
y_pred = modelR$call(x)
loss = modelR$compiled_loss(y_true = y, y_pred = y_pred, regularization_losses = modelR$losses)
})
grad = tape$gradient(loss, modelR$weights)
})
b = lapply(1:length(gradR[[1]]), function(m){
tf$reduce_mean(tf$stack(lapply(gradR, function(n) n[[m]]), axis=0L)^2, axis=0L) - tf$reduce_mean(tf$stack(lapply(gradR, function(n) n[[m]]), axis=0L), axis=0L)^2
})
for(m in 1:length(a)-1){
if(m %%2 ==0){
print(sprintf('mean %d', m/2))
print(summary(c(b[[m+1]]$numpy())))
}
else {
print(sprintf('scale %d', (m-1)/2))
print(summary(c(b[[m+1]]$numpy())))
}
}
b = lapply(1:length(gradR[[1]]), function(m){
tf$reduce_mean(tf$stack(lapply(gradR, function(n) n[[m]]), axis=0L), axis=0L)
})
for(m in 1:length(a)-1){
if(m %%2 ==0){
print(sprintf('mean %d', m/2))
print(summary(c(b[[m+1]]$numpy())))
}
else {
print(sprintf('scale %d', (m-1)/2))
print(summary(c(b[[m+1]]$numpy()))) #tf$math$softplus(
}
}
my_divergence_fn = function(q, p, ignore=NULL){
# regularization loss will dominate the total loss, 15-approx 0, so I scale it by 1e-3
tfp$distributions$kl_divergence(q, p) / 60000 /1e3
}
inputs = layer_input(shape = 784L, name='input')
predictionsR = inputs %>%
layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
layer_dense_reparameterization(units = 10L, activation = 'softmax', kernel_divergence_fn = my_divergence_fn)
modelR = keras_model(inputs = inputs, outputs = predictionsR)
modelR %>% compile(loss='categorical_crossentropy', optimizer='adam', metrics='accuracy')
modelR$save_weights('model0.ckpt')
Below is how to set the Bayesian NN parameter initializers to match vanilla NN parameter initializers. This doesn't seem necessary for MNIST.
my_kernel_posterior_fn = function() tfp$layers$util$default_mean_field_normal_fn(loc_initializer = 'glorot_uniform')
my_bias_posterior_fn = function() tfp$layers$util$default_mean_field_normal_fn(loc_initializer = 'zeros',is_singular=TRUE)
predictionsR = inputs %>%
layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn, kernel_posterior_fn = my_kernel_posterior_fn(), bias_posterior_fn = my_bias_posterior_fn()) %>%
layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn, kernel_posterior_fn = my_kernel_posterior_fn(), bias_posterior_fn = my_bias_posterior_fn()) %>%
layer_dense_reparameterization(units = 10L, activation = 'softmax', kernel_divergence_fn = my_divergence_fn, kernel_posterior_fn = my_kernel_posterior_fn(), bias_posterior_fn = my_bias_posterior_fn())
modelR$load_weights('model0.ckpt')
# there are 60k training example, and I set the batch size to be 300. so 2000 steps is 10 epochs - 60000/300 * 10 = 2000
historyR = modelR %>% fit(mnist_train_ds, steps_per_epoch = 1, epochs=2000, validation_data=mnist_test_ds, verbose=0L, callbacks = list(callback_csv_logger(filename = 'R.txt')))
modelR %>% evaluate(mnist_test_ds)
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f6ed06b88e0>
- loss
- 0.113996714353561
- accuracy
- 0.974399983882904
predictionsL = inputs %>%
layer_dense_local_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
layer_dense_local_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
layer_dense_local_reparameterization(units = 10L, activation = 'softmax', kernel_divergence_fn = my_divergence_fn)
modelL = keras_model(inputs = inputs, outputs = predictionsL)
modelL %>% compile(loss='categorical_crossentropy', optimizer='adam', metrics='accuracy')
modelL$load_weights('model0.ckpt')
historyL = modelL %>% fit(mnist_train_ds, steps_per_epoch = 1, epochs=2000, validation_data=mnist_test_ds, verbose=0L, callbacks = list(callback_csv_logger(filename = 'L.txt')))
modelL %>% evaluate(mnist_test_ds)
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f6ec86c1060>
- loss
- 0.111960597336292
- accuracy
- 0.978500008583069
predictionsF = inputs %>%
layer_dense_flipout(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
layer_dense_flipout(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
layer_dense_flipout(units = 10L, activation = 'softmax', kernel_divergence_fn = my_divergence_fn)
modelF = keras_model(inputs = inputs, outputs = predictionsF)
modelF %>% compile(loss='categorical_crossentropy', optimizer='adam', metrics='accuracy')
modelF$load_weights('model0.ckpt')
historyF = modelF %>% fit(mnist_train_ds, steps_per_epoch = 1, epochs=2000, validation_data=mnist_test_ds, verbose=0L, callbacks = list(callback_csv_logger(filename = 'F.txt')))
modelF %>% evaluate(mnist_test_ds)
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f6ec858e410>
- loss
- 0.116981789469719
- accuracy
- 0.972800016403198
predictions0 = inputs %>%
layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
layer_dense_reparameterization(units = 10L, activation = 'softmax', kernel_divergence_fn = my_divergence_fn)
model0 = keras_model(inputs = inputs, outputs = predictions0)
model0$load_weights('model0.ckpt')
predictions = inputs %>%
layer_dense(units = 512L, activation = 'relu') %>%
layer_dense(units = 512L, activation = 'relu') %>%
layer_dense(units = 10L, activation = 'softmax')
model = keras_model(inputs = inputs, outputs = predictions)
model %>% compile(loss='categorical_crossentropy', optimizer='adam', metrics='accuracy')
invisible({
model$weights[[1]]$assign(model0$weights[[1]]) # layer 1 - kernel
model$weights[[2]]$assign(model0$weights[[3]]) # layer 1 - bias
model$weights[[3]]$assign(model0$weights[[4]]) # layer 2 - kernel
model$weights[[4]]$assign(model0$weights[[6]]) # layer 2 - bias
model$weights[[5]]$assign(model0$weights[[7]]) # layer 3 - kernel
model$weights[[6]]$assign(model0$weights[[9]]) # layer 3 - bias
})
rm(predictions0, model0)
history = model %>% fit(mnist_train_ds, steps_per_epoch = 1, epochs=2000, validation_data=mnist_test_ds, verbose=0L, callbacks = list(callback_csv_logger(filename = 'V.txt')))
model %>% evaluate(mnist_test_ds)
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f6e44783670>
- loss
- 0.0837457105517387
- accuracy
- 0.979200005531311
CE.mvnorm.diag = function(m0, Q0, m1, Q1){
pinv.Q1 = tf$math$divide_no_nan(tf$constant(1.0, dtype=tf$float32), Q1)
(tf$reduce_sum( pinv.Q1 * Q0 + (m1-m0)^2 * pinv.Q1 + tf$math$log(2 * pi * Q1), axis=c(1L))) * (-.5)
}
#my_divergence_fn = function(q, p, ignore=NULL){
# tfp$distributions$kl_divergence(q, p) / 60000 / 1e3
#}
layer_dense_Bayesian = Layer(
classname = 'BayesianDense',
inherit = tf$keras$layers$Layer,
initialize = function(activation=NULL, use_bias = TRUE, divergence_fn = my_divergence_fn, ...) {
super$initialize()
self$divergence_fn = divergence_fn
self$activation = keras$activations$get(activation)
self$use_bias = use_bias
self$dense = tf$keras$layers$Dense(activation=NULL, use_bias=FALSE, ...)
},
build = function(input_shape) {
self$dense$build(input_shape[[1]]) ## input_shape is a list, corresponding to (x, P) tuple
# redefine BNN weights as Gaussian random variables in terms of trainable tf variables
self$kernel_loc = self$add_weight(name='A',shape=self$dense$kernel$shape, initializer=self$dense$kernel_initializer, trainable=TRUE)
self$kernel_scale = self$add_weight(name='B',shape=self$dense$kernel$shape, initializer=tf$random_normal_initializer(mean = -3, stddev = .1), trainable=TRUE)#'uniform'
if(self$use_bias) self$bias = self$add_weight(name='bias', shape = list(self$dense$units), initializer='zeros') else self$bias = NULL
self$built=TRUE
},
call = function(inputs) {
c(x, P0) %<-% inputs
with(tf$GradientTape(persistent=TRUE) %as% g, {
g$watch(list(x)) #g$watch(list(x, self$W))
self$dense$kernel = self$kernel_loc #+ tf$einsum('ij,i->ij',self$B, self$W) # mvnorm with mean and variance
a = self$dense$call(x) # tf$matmul(a=x, b=self$kernel) + self$bias
if(self$use_bias) a = tf$nn$bias_add(a, self$bias)
h = self$activation(a)
`__h__` = self$activation(a)
})
# propagate variance (accept a variance, first order approximation of state transition, output a variance)
#dhdx.sq.sum_x <- tf$einsum("bi,ji,bj->bi", tf$stop_gradient(g$gradient(`__h__`, a)^2), tf$stop_gradient(self$A^2), P0)
self$dense$kernel = tf$stop_gradient(self$kernel_loc^2)
dhdx.sq.sum_x = tf$stop_gradient(g$gradient(`__h__`, a)^2) * self$dense$call(P0)
#dhdw.sq.sum_w <- tf$einsum('bi,ji, bj->bi',g$gradient(h, a)^2, self$B^2, tf$stop_gradient(x^2))
self$dense$kernel = tf$nn$softplus(self$kernel_scale+1e-9)^2
dhdw.sq.sum_w = g$gradient(h, a)^2 * self$dense$call(tf$stop_gradient(x^2))
P = dhdx.sq.sum_x + dhdw.sq.sum_w + 1e-9
## add regularization loss
q_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=self$kernel_loc, scale = tf$nn$softplus(self$kernel_scale+1e-9)), reinterpreted_batch_ndims = 2L)
p_ = tfp$distributions$Independent(tfp$distributions$Normal(loc=tf$zeros_like(self$kernel_loc), scale = 1.0), reinterpreted_batch_ndims = 2L)
self$add_loss(self$divergence_fn(q_, p_))
list(h, P)
}
)
mnist_train_dsEP = tf$data$Dataset$from_tensor_slices(mnist$train)$`repeat`(-1L)$shuffle(2048L)$batch(300L)$map(function(xy){
c(x,y) %<-% xy
x = tf$reshape(tf$cast(x/255.0, dtype=tf$float32), shape=tuple(-1L,784L))
P = tf$zeros_like(x) + 1e-6
y.onehot = tf$one_hot(y, depth = 10L, dtype=tf$float32)
tuple(tuple(x,P),y.onehot)
})
mnist_test_dsEP = tf$data$Dataset$from_tensor_slices(mnist$test)$batch(1024L)$map(function(xy){
c(x,y) %<-% xy
x <- tf$reshape(tf$cast(x/255.0, dtype=tf$float32), shape=tuple(-1L,784L))
P = tf$zeros_like(x) + 1e-6
y.onehot = tf$one_hot(y, depth = 10L, dtype=tf$float32)
tuple(tuple(x,P),y.onehot)
})
mnist_train_dsEP$element_spec
[[1]] [[1]][[1]] TensorSpec(shape=(None, 784), dtype=tf.float32, name=None) [[1]][[2]] TensorSpec(shape=(None, 784), dtype=tf.float32, name=None) [[2]] TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
inputsEP = list(layer_input(shape = 784L), layer_input(shape=tuple(784L)))
predictionsEP = inputsEP %>%
layer_dense_Bayesian(units=512, activation = 'relu') %>%
layer_dense_Bayesian(units=512, activation = 'relu') %>%
layer_dense_Bayesian(units=10, activation = tf$identity) %>%
layer_lambda(f = function(mv){
tf$nn$softmax(tf$reduce_mean(tfp$distributions$MultivariateNormalDiag(loc = mv[[1]], scale_diag=mv[[2]]^.5)$sample(20L), axis=0L))
})
modelEP = keras_model(inputs = inputsEP, outputs = predictionsEP)
modelEP %>% compile(loss='categorical_crossentropy', optimizer='adam', metrics='accuracy')
modelEP$load_weights('model0.ckpt')
historyEP = modelEP %>% fit(mnist_train_dsEP, steps_per_epoch = 1, epochs=2000, validation_data=mnist_test_dsEP, verbose=0L, callbacks = list(callback_csv_logger(filename = 'MRF.txt')))
modelEP %>% evaluate(mnist_test_dsEP)
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f6e44515a80>
- loss
- 0.109516456723213
- accuracy
- 0.979200005531311
Vadam(keras$optimizers$experimental$Optimizer) %py_class% {
initialize <- function(train_set_size, learning_rate=1e-3, beta_1=.9, beta_2 = .999, prior_prec=1.0, prec_init=0.0, num_samples=1, name='Vadam') {
super$initialize(name=name)
self$`_learning_rate` = self$`_build_learning_rate`(learning_rate)
self$beta_1 = beta_1
self$beta_2 = beta_2
self$gamma = prior_prec
self$gamma_0 = prec_init
self$N = train_set_size
self$num_samples = num_samples
}
build = function(var_list){
super$build(var_list)
if(py_has_attr(self, "_built") && self$`_built`) return()
self$m = lapply(var_list, function(var) self$add_variable_from_reference(model_variable=var, variable_name="m"))
self$mu = lapply(var_list, function(var) self$add_variable_from_reference(model_variable=var, variable_name="mu", initial_value=var))
self$s = lapply(var_list, function(var) self$add_variable_from_reference(model_variable=var, variable_name="s"))
self$`_built` = TRUE
}
update_step = function(gradient, variable){
lr = tf$cast(self$learning_rate, variable$dtype)
var_key = self$`_var_key`(variable)
s = self$s[[self$`_index_dict`[var_key]]] #+1L? python index or R index?
m = self$m[[self$`_index_dict`[var_key]]]
mu = self$mu[[self$`_index_dict`[var_key]]]
#v.assign_add((tf.square(gradient) - v) * (1 - self.beta_2))
s$assign_add((gradient^2-s)*(1 - self$beta_2))
#m.assign_add((gradient - m) * (1 - self.beta_1))
#m$assign_add((1-self$beta_1)*(gradient-m)) # Adam
m$assign_add((1-self$beta_1)*(gradient-m+self$gamma/self$N*mu)) # Vadam
#local_step = tf.cast(self.iterations + 1, variable.dtype)
#beta_1_power = tf.pow(tf.cast(self.beta_1, variable.dtype), local_step)
#beta_2_power = tf.pow(tf.cast(self.beta_2, variable.dtype), local_step)
#alpha = lr * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power)
#variable.assign_sub((m * alpha) / (tf.sqrt(v) + self.epsilon))
local_step = tf$cast(self$iterations + 1, variable$dtype)
beta_1_power = tf$pow(tf$cast(self$beta_1, variable$dtype), local_step)
beta_2_power = tf$pow(tf$cast(self$beta_2, variable$dtype), local_step)
alpha = lr * tf$sqrt(1 - beta_2_power) / (1 - beta_1_power)
#mu$assign_sub(tf$math$divide_no_nan(m * alpha, tf$sqrt(s))) # Adam
mu$assign_sub(tf$math$divide_no_nan(m * alpha, tf$sqrt(s)+self$gamma/self$N)) # Vadam
#variable$assign(mu) # Adam
variable$assign(tfp$distributions$Normal(loc=mu, scale=tf$math$divide_no_nan(1.0, (self$N*s+self$gamma)^.5))$sample()) #Vadam
}
get_config = function(){
super$get_config()
}
}
predictionsVadam = inputs %>%
layer_dense(units = 512L, activation = 'relu') %>%
layer_dense(units = 512L, activation = 'relu') %>%
layer_dense(units = 10L, activation = 'softmax')
modelVadam = keras_model(inputs = inputs, outputs = predictionsVadam)
predictions0 = inputs %>%
layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
layer_dense_reparameterization(units = 512L, activation = 'relu', kernel_divergence_fn = my_divergence_fn) %>%
layer_dense_reparameterization(units = 10L, activation = 'softmax', kernel_divergence_fn = my_divergence_fn)
model0 = keras_model(inputs = inputs, outputs = predictions0)
model0$load_weights('model0.ckpt')
invisible({
modelVadam$weights[[1]]$assign(model0$weights[[1]]) # layer 1 - kernel
modelVadam$weights[[2]]$assign(model0$weights[[3]]) # layer 1 - bias
modelVadam$weights[[3]]$assign(model0$weights[[4]]) # layer 2 - kernel
modelVadam$weights[[4]]$assign(model0$weights[[6]]) # layer 2 - bias
modelVadam$weights[[5]]$assign(model0$weights[[7]]) # layer 3 - kernel
modelVadam$weights[[6]]$assign(model0$weights[[9]]) # layer 3 - bias
})
rm(predictions0, model0)
modelVadam %>% compile(loss='categorical_crossentropy', optimizer=Vadam(train_set_size=60000 * 1e4, learning_rate=0.001, beta_1=.9, beta_2 = .999 ), metrics='accuracy')
historyVadam = modelVadam %>% fit(mnist_train_ds$unbatch()$batch(60L), steps_per_epoch = 5, epochs=2000, validation_data=mnist_test_ds, verbose=0L, callbacks = list(callback_csv_logger(filename = 'Vadam.txt')))
modelVadam %>% evaluate(mnist_test_ds)
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f6f4a78f490>
- loss
- 0.124228939414024
- accuracy
- 0.97189998626709
matplot(filter(cbind(
historyR$metrics$loss,
historyR$metrics$val_loss,
historyL$metrics$loss,
historyL$metrics$val_loss,
historyF$metrics$loss,
historyF$metrics$val_loss,
history$metrics$loss,
history$metrics$val_loss,
historyEP$metrics$loss,
historyEP$metrics$val_loss
), method='recursive', filter=.1)/1.1,
lty = rep(c(1,2), times=5), col = rep(1:5, each=2),
type='l', ylab='loss',xlab='epoch', log='xy')
legend('bottomleft',
lty=rep(1:2, times=5),
col=rep(1:5, each=2),
legend=c('reparam.','reparam. val.', 'local reparam.', 'local reparam. val', 'flipout', 'flipout val', 'vanilla', 'vanilla val', 'VBP', 'VBP val'))
matplot(cbind(
historyR$metrics$accuracy,
historyR$metrics$val_accuracy,
historyL$metrics$accuracy,
historyL$metrics$val_accuracy,
historyF$metrics$accuracy,
historyF$metrics$val_accuracy,
history$metrics$accuracy,
history$metrics$val_accuracy,
historyEP$metrics$accuracy,
historyEP$metrics$val_accuracy
), lty = rep(c(1,2), times=5), col = rep(1:5, each=2),
type='l', ylab='accuracy',xlab='epoch', log='xy')
legend('bottomright',
lty=rep(1:2, times=5),
col=rep(1:5, each=2),
legend=c('reparam.','reparam. val.', 'local reparam.', 'local reparam. val', 'flipout', 'flipout val', 'vanilla', 'vanilla val', 'EP', 'EP val'))