
using DistributionsAD

@eval Bijectors begin
function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
    #N = length(sb.bs)
    #yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges[1]])
    #logjac = sum(linit)
    # HACK: Return early to avoid `mapreduce` over empty-collection which Zygote.jl doesn't like.
    #N == 1 && return (yinit, logjac)
    
    # ys = mapreduce(vcat, sb.bs[2:end], sb.ranges[2:end]; init=yinit) do b, r
    #     y, l = with_logabsdet_jacobian(b, x[r])
    #     logjac += sum(l)
    #     y
    # end

    ys = mapreduce(vcat, sb.bs, sb.ranges) do b, r
        y, _ = with_logabsdet_jacobian(b, x[r])
        y
    end
    logjac = mapreduce(+, sb.bs, sb.ranges) do b, r
        _, l = with_logabsdet_jacobian(b, x[r])
        first(l)
    end
    return (ys, logjac)
end
end
