
BernoulliLogitF(x) = BernoulliLogit(x)

Turing.@model function horseshoe(X, y)
    #=
    Piironen, Juho, and Aki Vehtari.
    "Sparsity information and regularization in the horseshoe and other shrinkage priors."
    Bayesian Analysis (2017): 5018-5051.
    =##

    n, d = size(X)
    
    s   = 2
    ν   = 4
    m₀  = 3

    σ   = 1.0
    τ₀  = m₀ / (d - m₀) * σ / sqrt(n)

    λ   ~ Turing.filldist(Turing.Truncated(Cauchy(0.0, 1.0), 0.0, Inf), d)
    τ   ~ Turing.Truncated(Cauchy(0.0, τ₀), 0.0, Inf)
    τ²  = τ.*τ
    λ²  = λ.*λ
    c²  ~ InverseGamma(ν/2, 1/(ν/2*s^2))

    α   ~ Normal(0, 2) 

    λ_tilde = @. sqrt(c²)*λ / sqrt(c² + τ²*λ²)

    β  ~ MvNormal(zeros(d), τ*λ_tilde)

    p = X*β .+ α
    @. y ~ Turing.BernoulliLogit(p)
end

function high_dimensional_classification(dataset)
    # https://jundongl.github.io/scikit-feature/datasets.html
    data = if dataset == :colon
        MAT.matread(datadir("datasets", "highdim", "colon.mat"))
    elseif dataset == :allaml
        MAT.matread(datadir("datasets", "highdim", "ALLAML.mat"))
    elseif dataset == :prostate
        MAT.matread(datadir("datasets", "highdim", "Prostate_GE.mat"))
    end
    y    = dropdims(data["Y"] .== 1, dims=2)
    X    = data["X"]
    X, y
end

function model_with_dataset(::Val{:horseshoe}, dataset, batchsize; rng=Random.GLOBAL_RNG)
    X, y      = high_dimensional_classification(dataset)
    n_data    = length(y)
    batchsize = min(batchsize, n_data)

    X_sub = similar(X, batchsize, size(X, 2))
    y_sub = similar(y, batchsize)

    data_idx = 1:n_data
    data_itr = Iterators.partition(data_idx, batchsize)

    model    = horseshoe(X_sub, y_sub)
    context  = DynamicPPL.MiniBatchContext(; batch_size=batchsize, npoints=n_data)

    varinfo = DynamicPPL.VarInfo(model)
    b       = Bijectors.bijector(model)
    b⁻¹     = inverse(b)
    prob    = DynamicPPL.LogDensityFunction(model, varinfo, context)

    function sample_batch!(_prob)
        if isempty(data_itr)
            data_itr = Iterators.partition(shuffle(rng, data_idx), batchsize)
        end
        batch_idx, data_itr = Iterators.peel(data_itr)

        _prob = @set _prob.model.args.X = X[batch_idx,:]
        _prob = @set _prob.model.args.y = y[batch_idx]
        _prob = @set _prob.context.loglike_scalar = n_data/length(batch_idx)
        _prob 
    end

    function prepare_full_pass!()
        data_itr = Iterators.partition(data_idx, batchsize)
        length(data_itr), 1
    end

    function validate(z)
        vi_new = DynamicPPL.unflatten(prob.varinfo, prob.context, z)
        α = vi_new.metadata.α.vals
        β = vi_new.metadata.β.vals

        y_pred = logistic.(X*β .+ α) .> 0.5
        mean(y .== y_pred)
    end

    prob, b⁻¹, sample_batch!, prepare_full_pass!, validate
end
