
using SparseArrays

Turing.@model function nnmf(Y_batch, users_batch,
                            I, U, α₀, λ₀, ::Type{T}=Float64) where {T <: Real}
    #=
    Mnih, Andriy, and Russ R. Salakhutdinov.
    "Probabilistic matrix factorization."
    Advances in neural information processing systems 20 (2007).
    =##
    K = length(α₀)
    β ~ Turing.filldist(Exponential(λ₀), I, K)
    θ ~ Turing.arraydist([Dirichlet(α₀) for u = 1:U ])

    λ = β*θ[:,users_batch]
    @. Y_batch ~ Poisson(λ)
end

function movielens_dataset()
    # y[1] is the user
    # y[2] is the item
    # y[3] is the rating

    y_entries = readdlm(datadir("datasets", "movielens-100k", "u.data"), Int)[:,1:3]
    #I = 1682
    #U = 943
    sparse(y_entries[:,2], y_entries[:,1], y_entries[:,3])
end

function model_with_dataset(::Val{:nnmf}, dataset, n_batch; rng=Random.GLOBAL_RNG)
    Y = movielens_dataset()
    I = size(Y, 1)
    U = size(Y, 2)

    n_data  = U
    n_batch = min(n_batch, n_data)

    data_itr    = Iterators.partition(shuffle(1:n_data), n_batch)
    users_batch = first(data_itr)
    Y_batch     = Matrix(Y[:,users_batch])

    K  = 10
    α₀ = fill(0.3, K)
    λ₀ = 1.0

    context = DynamicPPL.MiniBatchContext(; batch_size=n_batch, npoints=n_data)
    model   = nnmf(Y_batch, users_batch, I, U, α₀, λ₀)

    function sample_batch!(prob)
        if isempty(data_itr)
            data_itr  = Iterators.partition(shuffle(1:U), n_batch)
        end
        batch_idx, data_itr = Iterators.peel(data_itr)
        prob = @set prob.model.args.Y_batch     = Matrix(Y[:,batch_idx])
        prob = @set prob.model.args.users_batch = users_batch
        prob = @set prob.context.loglike_scalar = n_data/length(batch_idx)
        prob 
    end

    function prepare_full_pass!()
        #data_itr = Iterators.partition(data_idx, n_batch)
        #length(data_itr)
    end

    varinfo = DynamicPPL.VarInfo(model)
    b       = Bijectors.bijector(model)
    b⁻¹     = inverse(b)
    prob    = DynamicPPL.LogDensityFunction(model, varinfo, context)

    prob, b⁻¹, sample_batch!, prepare_full_pass!
end
