
using CUDA

Turing.@model function linearreg(X, y)
    _, d = size(X)

    σ_α ~ Turing.TruncatedNormal(0., 10., 0., Inf)
    σ_β ~ Turing.TruncatedNormal(0., 10., 0., Inf)
    σ   ~ Turing.TruncatedNormal(0., 0.3, 0., Inf)

    α ~ Normal(0, σ_α)
    β ~ MvNormal(zeros(d), σ_β*σ_β*I)
    y ~ MvNormal(X*β .+ α, σ*σ*I)
end

function model_with_dataset(::Val{:linearreg}, dataset, batchsize; rng=Random.GLOBAL_RNG)
    X, y, σ_y = uci_regression_dataset(dataset)
    n_data    = length(y)
    batchsize = min(batchsize, n_data)

    gpu_batchsize = 1_000

    data_idx  = 1:n_data
    data_itr  = Iterators.partition(data_idx, batchsize)
    batch_idx = first(data_itr)

    full_batch_itr = Iterators.Generator(
        Iterators.partition(data_idx, gpu_batchsize)) do _batch_idx
        (X[_batch_idx,:] |> Array{Float32},
         y[_batch_idx]   |> Array{Float32})
    end |> CUDA.CuIterator

    model   = linearreg(X[batch_idx,:], y[batch_idx])

    varinfo = DynamicPPL.VarInfo(model)
    b       = Bijectors.bijector(model)
    b⁻¹     = inverse(b)

    minibatch_ctxt = DynamicPPL.MiniBatchContext(; batch_size=batchsize, npoints=n_data)
    prob           = DynamicPPL.LogDensityFunction(model, varinfo, minibatch_ctxt)

    prior_ctxt = DynamicPPL.PriorContext()
    prior_prob = DynamicPPL.LogDensityFunction(model, varinfo, prior_ctxt)

    function sample_train_batch!(_prob)
        if isempty(data_itr)
            data_itr = Iterators.partition(shuffle(rng, data_idx), batchsize)
        end
        batch_idx, data_itr = Iterators.peel(data_itr)

        _prob = @set _prob.model.args.X = X[batch_idx,:]
        _prob = @set _prob.model.args.y = y[batch_idx]
        _prob = @set _prob.context.loglike_scalar = n_data/length(batch_idx)
        _prob 
    end

    function compute_full_elbo!(q, b⁻¹, M)
        zs, ∑logdetjac = rand_and_logjac(rng, q, b⁻¹, M)

        𝔼ℓprior = mapreduce(+, 1:M) do m
            LogDensityProblems.logdensity(prior_prob, view(zs, :, m))
        end  / M

        samples = unflatten_and_stack(zs, prob, [:β, :α, :σ])
        β_dev   = samples[:β] |> Array{Float32} |> CuArray
        α_dev   = samples[:α] |> Array{Float32} |> CuArray
        σ_dev   = samples[:σ] |> Array{Float32} |> CuArray

        ∑ℓZ = n_data*(-mapreduce(log, +, σ_dev) - log(2*π)/2*M)
        ∑ℓlike_unnorm = mapreduce(+, full_batch_itr) do (X_batch_dev, y_batch_dev)
            μ_batch_dev = X_batch_dev*β_dev .+ α_dev
            z_batch_dev = (μ_batch_dev .- reshape(y_batch_dev, (:,1))) ./ σ_dev
            -sum(z_batch_dev.^2)/2
        end
        𝔼ℓlike = ∑ℓlike_unnorm/M + ∑ℓZ/M
        𝔼ℓlike + 𝔼ℓprior + entropy(q) + ∑logdetjac/M
    end

    function validate(z)
        vi_new = DynamicPPL.unflatten(prob.varinfo, prob.context, z)
        α = vi_new.metadata.α.vals
        β = vi_new.metadata.β.vals
        y_pred = X*β .+ α
        sqrt(mean(@. (y - y_pred)^2))
    end

    prob, b⁻¹, sample_train_batch!, compute_full_elbo!, validate
end
