using LogicCircuits
using ProbabilisticCircuits
using StatsFuns


marginal_log_likelihood_avg_cat(pc, data; use_gpu::Bool = isgpu(data), n_variables::Integer = 0) = begin
    if !isbatched(data)
        data = [data]
    end
    
    n_variables = n_variables == 0 ? num_variables(pc) : n_variables
    pbc::CatParamBitCircuit = CatParamBitCircuit(pc, n_variables)
    
    if use_gpu
        pbc = to_gpu(pbc)
    end
    
    total_ll::Float64 = 0.0
    v = nothing
    for idx = 1 : length(data)
        d = use_gpu ? to_gpu(data[idx]) : data[idx]
        v = marginal_all(pbc, d, v)
        likelihoods = v[:,end]
        total_ll += sum(isgpu(likelihoods) ? to_cpu(likelihoods) : likelihoods)
    end
    
    total_ll / num_examples(data)
end


log_likelihood_cat(pc, data::Vector{UInt32}) = begin
    f_con(n) = error("")
    f_lit(n) = n.log_probs[data[n.variable]]
    f_a(_, cns) = sum(cns)
    f_o(n, cns) = logsumexp(cns .+ n.log_probs)
    foldup_aggregate(pc, f_con, f_lit, f_a, f_o, Float64)
end


marginal_log_likelihood_cat(pc, data; use_gpu::Bool = isgpu(data), n_variables::Integer = 0) = begin
    if !isbatched(data)
        data = [data]
    end
    
    n_variables = n_variables == 0 ? num_variables(pc) : n_variables
    pbc::CatParamBitCircuit = CatParamBitCircuit(pc, n_variables)
    
    if use_gpu
        pbc = to_gpu(pbc)
    end
    
    lls = Vector{Float64}(undef, num_examples(data))
    ll_idx = 1
    v = nothing
    for idx = 1 : length(data)
        d = use_gpu ? to_gpu(data[idx]) : data[idx]
        v = marginal_all(pbc, d, v)
        likelihoods = v[:,end]
        if isgpu(likelihoods)
            lls[ll_idx:ll_idx+num_examples(d)-1] .= to_cpu(likelihoods)
        else
            lls[ll_idx:ll_idx+num_examples(d)-1] .= likelihoods
        end
        ll_idx += num_examples(d)
    end
    
    lls
end