
function dense_forward(X, ξ, η, σ)
    _, d = size(X)

    X = hcat(X, ones(size(X, 1)))
    W = σ*(reshape(η, (:,1)).*reshape(ξ, (d + 1, :)))
    X*W / sqrt(size(X,2))
end

Turing.@model function DenseLayer(::Type{ArrayType}, n_in, n_out) where {ArrayType <: AbstractArray}
    ξ ~ MvNormal(zeros((n_in+1)*n_out), I)
    η ~ Turing.filldist(InverseGamma(2, 1), (n_in + 1))
    σ ~ Turing.TruncatedNormal(0, 3, 0, Inf)
    ξ, η, σ
end

Turing.@model function bnn(X, y, n_hidden1)
    #=
    Cui, Tianyu, et al.
    "Informative Bayesian neural network priors for weak signals."
    Bayesian Analysis 17.4 (2022).
    =##

    _, d = size(X)

    σ⁻² ~ Turing.TruncatedNormal(0, 3, 0, Inf)

    Turing.@submodel prefix="layer1" ξ₁, η₁, σ₁ = DenseLayer(typeof(X), d,         n_hidden1)
    Turing.@submodel prefix="layer2" ξ₂, η₂, σ₂ = DenseLayer(typeof(X), n_hidden1, 1)
    #Turing.@submodel prefix="layer3" ξ₃, η₃, σ₃ = DenseLayer(typeof(X), n_hidden2, 1)
    #Turing.@submodel prefix="layer4" ξ₄, η₄, σ₄ = DenseLayer(typeof(X), n_hidden3, 1)

    Y₁ = dense_forward(X,  ξ₁, η₁, σ₁)
    X₂ = tanh.(Y₁)

    Y₂ = dense_forward(X₂, ξ₂, η₂, σ₂)
    #X₃ = tanh.(Y₂)

    #Y₃ = dense_forward(X₃, ξ₃, η₃, σ₃)

    #X₄ = max.(Y₃, 0)

    #Y₄ = dense_forward(X₄, ξ₄, η₄, σ₄)

    y_pred = reshape(Y₂, :)

    y ~ MvNormal(y_pred, σ⁻²*I)
end

function model_with_dataset(::Val{:bnn}, dataset, batchsize; rng=Random.GLOBAL_RNG)
    X, y, σ_y = uci_regression_dataset(dataset)
    n_data    = length(y)
    batchsize = min(batchsize, n_data)

    X_sub = similar(X, batchsize, size(X, 2))
    y_sub = similar(y, batchsize)

    n_valid   = round(Int, n_data*0.1)
    data_idx  = 1:n_data
    valid_idx = sample(rng, data_idx, n_valid, replace=false)
    train_idx = setdiff(data_idx, valid_idx)
    n_train   = length(train_idx)

    train_itr = Iterators.partition(train_idx, batchsize)
    valid_itr = Iterators.partition(valid_idx, batchsize)

    model    = bnn(X_sub, y_sub, 100)#, 50)
    context  = DynamicPPL.MiniBatchContext(; batch_size=batchsize, npoints=n_data)

    varinfo = DynamicPPL.VarInfo(model)
    b       = Bijectors.bijector(model)
    b⁻¹     = inverse(b)
    prob    = DynamicPPL.LogDensityFunction(model, varinfo, context)

    function sample_train_batch!(_prob)
        if isempty(train_itr)
            train_itr = Iterators.partition(shuffle(rng, train_idx), batchsize)
        end
        batch_idx, train_itr = Iterators.peel(train_itr)

        _prob = @set _prob.model.args.X = X[batch_idx,:]
        _prob = @set _prob.model.args.y = y[batch_idx]
        _prob = @set _prob.context.loglike_scalar = n_train/length(batch_idx)
        _prob 
    end

    function prepare_valid_pass!()
        valid_itr = Iterators.partition(valid_idx, batchsize)
        length(valid_itr)
    end

    function sample_valid_batch!(_prob)
        batch_idx, valid_itr = Iterators.peel(valid_itr)
        _prob = @set _prob.model.args.X = X[batch_idx,:]
        _prob = @set _prob.model.args.y = y[batch_idx]
        _prob = @set _prob.context.loglike_scalar = n_valid/length(batch_idx)
        _prob 
    end

    function validate(q, b⁻¹)
        n_samples = 1000
        y_pred = mapreduce(+, 1:n_samples) do _
            z      = b⁻¹(rand(rng, q))
            vi_new = DynamicPPL.unflatten(prob.varinfo, prob.context, z)
            
            metadata = vi_new.metadata

            ξ₁ = metadata[Symbol("layer1.ξ")].vals
            η₁ = metadata[Symbol("layer1.η")].vals
            σ₁ = metadata[Symbol("layer1.σ")].vals[1]
            
            ξ₂ = metadata[Symbol("layer2.ξ")].vals
            η₂ = metadata[Symbol("layer2.η")].vals
            σ₂ = metadata[Symbol("layer2.σ")].vals[1]

            #ξ₃ = metadata[Symbol("layer3.ξ")].vals
            #η₃ = metadata[Symbol("layer3.η")].vals
            #σ₃ = metadata[Symbol("layer3.σ")].vals[1]

            #ξ₄ = metadata[Symbol("layer4.ξ")].vals
            #η₄ = metadata[Symbol("layer4.η")].vals
            #σ₄ = metadata[Symbol("layer4.σ")].vals[1]

            Y₁ = dense_forward(X[valid_idx,:], ξ₁, η₁, σ₁)
            X₂ = tanh.(Y₁)

            Y₂ = dense_forward(X₂, ξ₂, η₂, σ₂)
            #X₃ = tanh.(Y₂)

            #Y₃ = dense_forward(X₃, ξ₃, η₃, σ₃)
            #X₄ = max.(Y₃, 0)

            #Y₄ = dense_forward(X₄, ξ₄, η₄, σ₄)

            #y_pred = reshape(Y₄, :)
            y_pred = reshape(Y₂, :)
        end / n_samples

        sqrt(mean(@. (y[valid_idx] - y_pred)^2*σ_y^2))
    end

    prob, b⁻¹, sample_train_batch!, prepare_valid_pass!, sample_valid_batch!, validate
end
