using LogicCircuits
using ProbabilisticCircuits
using CUDA
using Statistics: mean, std
using StatsFuns: logaddexp
using Clustering
using DataFrames


function uniform_parameters_cat(pc::ProbCircuit; perturbation::Float64 = 0.0)
    foreach(pc) do pn
        if is⋁gate(pn)
            if num_children(pn) == 1
                pn.log_probs .= 0.0
            else
                if perturbation < 1e-8
                    pn.log_probs .= log.(ones(Float64, num_children(pn)) ./ num_children(pn))
                else
                    unnormalized_probs = map(x -> 1.0 - perturbation + x * 2 * perturbation, rand(num_children(pn)))
                    pn.log_probs .= log.(unnormalized_probs ./ sum(unnormalized_probs))
                end
            end
        elseif pn isa PlainProbCategoricalNode
            unnormalized_probs = map(x -> 1.0 - perturbation + x * 2 * perturbation, rand(pn.num_cats))
            pn.log_probs .= log.(unnormalized_probs ./ sum(unnormalized_probs))
        end
    end
end


function kmeans_params_initialization_cat(pc::ProbCircuit, data::DataFrame; same_cluster_prob = 0.8)
    
    apply_kmeans(data::DataFrame; num_clusters = 10) = begin
        data = Float32.(permutedims(convert(Matrix{UInt32}, data), [2, 1])) ./ 256.0

        results = kmeans(data, num_clusters)
        results.assignments
    end
    
    get_marginal_probs(data::DataFrame, assignments, var_idx, cluster_idx; num_cats) = begin
        probs = Vector{Float64}(undef, num_cats)
        @inbounds @views probs[:] .= 1.0
        for i = 1 : length(assignments)
            if assignments[i] == cluster_idx
                probs[data[i,var_idx]] += 1.0
            end
        end
        log.(probs ./ sum(probs))
    end
    
    num_clusters = length(pc.log_probs)
    num_cats = num_categories(pc)
    
    assignments = apply_kmeans(data; num_clusters)
    
    node_flag = Dict{ProbCircuit,Bool}()
    f_rec(n::ProbCircuit; cluster_idx = 0) = begin
        if n in keys(node_flag)
            return
        end
        node_flag[n] = true
        if is⋁gate(n)
            if length(n.children) == 1
                f_rec(n.children[1])
            else
                if cluster_idx == 0
                    n.log_probs .= log(1.0 / num_clusters)
                else
                    n.log_probs .= log((1.0 - same_cluster_prob) / num_clusters)
                    n.log_probs[cluster_idx] = logaddexp(n.log_probs[cluster_idx], log(same_cluster_prob))
                end
                @assert length(n.children) == num_clusters "$(length(n.children)) != $(num_clusters)"
                for i = 1 : num_clusters
                    f_rec(n.children[i]; cluster_idx = i)
                end
            end
        elseif is⋀gate(n)
            for i = 1 : length(n.children)
                f_rec(n.children[i]; cluster_idx)
            end
        elseif isliteralgate(n)
            n.log_probs .= get_marginal_probs(data, assignments, n.variable, cluster_idx; num_cats)
        else
            error("Unexpected gate type encountered.")
        end
        nothing
    end
    f_rec(pc)
end


function estimate_parameters_em_cat(pc::ProbCircuit, data; pseudocount::Float64, use_gpu::Bool = isgpu(data),
                                    exp_update_factor::Float64 = 0.0, update_per_batch::Bool = false,
                                    leaf_node_distribution::String = "categorical", std_max::AbstractFloat = 1e5,
                                    n_variables::Integer = 0, frac::AbstractFloat = 0.01)
    
    n_variables = n_variables == 0 ? num_variables(pc) : n_variables
    pbc::CatParamBitCircuit = CatParamBitCircuit(pc, n_variables)
    
    if isgpu(data)
        use_gpu = true
    end
    
    if update_per_batch && isbatched(data)
        if use_gpu
            pbc = to_gpu(pbc)
        end
        
        reuse_v, reuse_f = nothing, nothing
        reuse_counts = use_gpu ? (nothing, nothing, nothing) : (nothing, nothing, nothing, nothing, nothing)
        
        for idx = 1 : length(data)
            params, leaf_params, reuse_v, reuse_f, reuse_counts = if use_gpu
                estimate_parameters_em_gpu(pbc, to_gpu(data[idx]); pseudocount,
                    reuse = true, reuse_v, reuse_f, reuse_counts)
            else
                estimate_parameters_em_cpu(pbc, data[idx]; pseudocount,
                    reuse = true, reuse_v, reuse_f, reuse_counts)
            end
            
            # Update the parameters to `pbc`
            if use_gpu # GPU
                @inbounds @views pbc.params .+= log(exp_update_factor)
                @inbounds @views params .+= log(1.0 - exp_update_factor)
                delta = @inbounds @views @. CUDA.ifelse(pbc.params == params, CUDA.zero(params), CUDA.abs(pbc.params - params))
                @inbounds @views @. pbc.params = CUDA.max(pbc.params, params) + CUDA.log1p(CUDA.exp(-delta))
                
                @inbounds @views pbc.cat_params .+= log(exp_update_factor)
                @inbounds @views leaf_params .+= log(1.0 - exp_update_factor)
                delta = @inbounds @views @. CUDA.ifelse(pbc.cat_params == leaf_params, CUDA.zero(leaf_params), CUDA.abs(
                        pbc.cat_params - leaf_params))
                @inbounds @views @. pbc.cat_params = CUDA.max(pbc.cat_params, leaf_params) + CUDA.log1p(CUDA.exp(-delta))
                # pbc.cat_params .-= logsumexp(pbc.cat_params; dims = 2) # normalize away any leftover error
                
                CUDA.unsafe_free!(params)
                CUDA.unsafe_free!(leaf_params)
                CUDA.unsafe_free!(delta)
            else # CPU
                @inbounds @views pbc.params .= logaddexp.(pbc.params .+ log(exp_update_factor), params .+ log(1.0 - exp_update_factor))
                @inbounds @views pbc.cat_params .= logaddexp.(pbc.cat_params .+ log(exp_update_factor), leaf_params .+ log(1.0 - exp_update_factor))
            end
            
            pbc = compute_cat_params(pbc, leaf_node_distribution; std_max, frac)
        end
        
        my_cache_parameters!(pc, pbc.bitcircuit, pbc.params, pbc.cat_params)
    else
        params, leaf_params = if use_gpu
            estimate_parameters_em_gpu(to_gpu(pbc), data; pseudocount)
        else
            estimate_parameters_em_cpu(pbc, data; pseudocount)
        end
        
        leaf_params = compute_cat_params(leaf_params, leaf_node_distribution; std_max, frac)
        
        my_cache_parameters!(pc, pbc.bitcircuit, params, leaf_params)
    end
    
    nothing
end

function compute_cat_params(pbc::CatParamBitCircuit, leaf_node_distribution::String; std_max::AbstractFloat = 1e4,
                            frac::AbstractFloat = 0.01)
    cat_params = compute_cat_params(pbc.cat_params, leaf_node_distribution; std_max, frac)
    @inbounds @views pbc.cat_params .= cat_params
    pbc
end
function compute_cat_params(cat_params::Matrix, leaf_node_distribution::String; std_max::AbstractFloat = 1e4,
                            frac::AbstractFloat = 0.01)
    if leaf_node_distribution == "categorical"
        cat_params
    elseif leaf_node_distribution == "gaussian"
        @inbounds @views cat_params .= exp.(cat_params)
        vec = reshape(Vector(range(1, stop = size(cat_params,2), length = size(cat_params,2))), 1, size(cat_params,2))
        @inbounds @views mean_vec = sum(vec .* cat_params; dims = 2)
        @inbounds @views std_vec = sqrt.(sum(((vec .- mean_vec).^2) .* cat_params; dims = 2))
        @inbounds @views std_vec .= min.(std_vec, std_max)
        
        Threads.@threads for i = 1 : size(cat_params, 1)
            for j = 1 : size(cat_params, 2)
                cat_params[i,j] = exp(-(j - mean_vec[i,1])^2 / (2.0 * std_vec[i,1]^2)) / std_vec[i,1] / sqrt(2.0 * π)
            end
        end
        
        @inbounds @views cat_params ./= sum(cat_params; dims = 2)
        @inbounds @views cat_params .= log.(cat_params)
        @assert all(isapprox.(sum(exp.(Array(cat_params)); dims = 2), 1.0, atol=1e-3)) "Parameters do not sum to one locally."
        
        cat_params
    elseif leaf_node_distribution == "sparse_val"
        @inbounds @views cat_params .= exp.(cat_params)
        maxval = maximum(cat_params; dims = 2)
        
        Threads.@threads for i = 1 : size(cat_params, 1)
            for j = 1 : size(cat_params, 2)
                if cat_params[i,j] / maxval[i] < frac
                    @inbounds cat_params[i,j] = maxval[i] * frac
                end
            end
        end
        
        @inbounds @views cat_params ./= sum(cat_params; dims = 2)
        @inbounds @views cat_params .= log.(cat_params)
        @assert all(isapprox.(sum(exp.(Array(cat_params)); dims = 2), 1.0, atol=1e-3)) "Parameters do not sum to one locally."
        
        cat_params
    elseif leaf_node_distribution == "sparse_frac"
        error("Not implemented")
    else
        error("Unknown leaf distribution")
    end
end
function compute_cat_params(cat_params::CuMatrix, leaf_node_distribution::String; std_max::AbstractFloat = 1e4,
                            frac::AbstractFloat = 0.01)
    if leaf_node_distribution == "categorical"
        cat_params
    elseif leaf_node_distribution == "gaussian"
        @inbounds @views cat_params .= exp.(cat_params)
        vec = CuArray(reshape(Vector(range(1, stop = size(cat_params,2), length = size(cat_params,2))), 1, size(cat_params,2)))
        @inbounds @views mean_vec = sum(vec .* cat_params; dims = 2)
        @inbounds @views std_vec = sqrt.(sum(((vec .- mean_vec).^2) .* cat_params; dims = 2))
        @inbounds @views std_vec .= min.(std_vec, std_max)
        
        num_leaf_nodes = size(cat_params, 1)
        num_cats_sets = size(cat_params, 2) / 8
        num_threads = balance_threads(num_leaf_nodes, num_cats_sets, 7)
        num_blocks = (ceil(Int, num_leaf_nodes/num_threads[1]), 
                      ceil(Int, num_cats_sets/num_threads[2])) 
        @cuda threads=num_threads blocks=num_blocks compute_gaussian_cuda(cat_params, mean_vec, std_vec)
        
        @inbounds @views cat_params ./= sum(cat_params; dims = 2)
        @inbounds @views cat_params .= log.(cat_params)
        @assert all(isapprox.(sum(exp.(Array(cat_params)); dims = 2), 1.0, atol=1e-3)) "Parameters do not sum to one locally."
        
        cat_params
    elseif leaf_node_distribution == "sparse_val"
        @inbounds @views cat_params .= exp.(cat_params)
        maxval = maximum(cat_params; dims = 2)
        
        num_leaf_nodes = size(cat_params, 1)
        num_cats_sets = size(cat_params, 2) / 8
        num_threads = balance_threads(num_leaf_nodes, num_cats_sets, 7)
        num_blocks = (ceil(Int, num_leaf_nodes/num_threads[1]), 
                      ceil(Int, num_cats_sets/num_threads[2])) 
        @cuda threads=num_threads blocks=num_blocks compute_sparse_cat_params_cuda(cat_params, maxval, frac)
        
        @inbounds @views cat_params ./= sum(cat_params; dims = 2)
        @inbounds @views cat_params .= log.(cat_params)
        @assert all(isapprox.(sum(exp.(Array(cat_params)); dims = 2), 1.0, atol=1e-3)) "Parameters do not sum to one locally."
        
        cat_params
    elseif leaf_node_distribution == "sparse_frac"
        @inbounds @views cat_params .= exp.(cat_params)
        sorted_vals = CUDA.sort(cat_params; dims = 2)
        frac_idx = min(max(Int(floor(size(cat_params, 2) * (1.0 - frac))), Int(1)), size(cat_params, 2))
        target_val = sorted_vals[:,frac_idx]
        
        num_leaf_nodes = size(cat_params, 1)
        num_cats_sets = size(cat_params, 2) / 8
        num_threads = balance_threads(num_leaf_nodes, num_cats_sets, 7)
        num_blocks = (ceil(Int, num_leaf_nodes/num_threads[1]), 
                      ceil(Int, num_cats_sets/num_threads[2])) 
        @cuda threads=num_threads blocks=num_blocks compute_sparse_cat_params_cuda2(cat_params, target_val)
        
        @inbounds @views cat_params ./= sum(cat_params; dims = 2)
        @inbounds @views cat_params .= log.(cat_params)
        @assert all(isapprox.(sum(exp.(Array(cat_params)); dims = 2), 1.0, atol=1e-3)) "Parameters do not sum to one locally."
        
        cat_params
    else
        error("Unknown leaf distribution")
    end
end

function compute_gaussian_cuda(cat_params, mean_vec, std_vec)
    index_x = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    index_y = (blockIdx().y - 1) * blockDim().y + threadIdx().y
    stride_x = blockDim().x * gridDim().x
    stride_y = blockDim().y * gridDim().y
    for i = index_x:stride_x:size(cat_params,1)
        for j = index_y:stride_y:size(cat_params,2)
            @inbounds cat_params[i,j] = exp(-(j - mean_vec[i,1])^2 / (2.0 * std_vec[i,1]^2)) / std_vec[i,1] / sqrt(2.0 * π)
        end
    end
    return nothing
end

function compute_sparse_cat_params_cuda(cat_params, maxval, frac)
    index_x = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    index_y = (blockIdx().y - 1) * blockDim().y + threadIdx().y
    stride_x = blockDim().x * gridDim().x
    stride_y = blockDim().y * gridDim().y
    for i = index_x:stride_x:size(cat_params,1)
        for j = index_y:stride_y:size(cat_params,2)
            if cat_params[i,j] / maxval[i] < frac
                @inbounds cat_params[i,j] = maxval[i] * frac
            end
        end
    end
    return nothing
end

function compute_sparse_cat_params_cuda2(cat_params, target_val)
    index_x = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    index_y = (blockIdx().y - 1) * blockDim().y + threadIdx().y
    stride_x = blockDim().x * gridDim().x
    stride_y = blockDim().y * gridDim().y
    for i = index_x:stride_x:size(cat_params,1)
        for j = index_y:stride_y:size(cat_params,2)
            if cat_params[i,j] < target_val[i]
                @inbounds cat_params[i,j] = target_val[i]
            end
        end
    end
    return nothing
end

function my_cache_parameters!(pc::ProbCircuit, bc::CatBitCircuit, params, leaf_params; exp_update_factor = 0.0)
    if isgpu(bc)
        bc = to_cpu(bc)
    end
    if isgpu(params)
        params = to_cpu(params)
    end
    if isgpu(leaf_params)
        leaf_params = to_cpu(leaf_params)
    end
    
    log_exp_factor = log(exp_update_factor)
    log_1_exp_factor = log(1.0 - exp_update_factor)
    
    foreach(pc) do pn
        if is⋁gate(pn)
            if num_children(pn) == 1
                pn.log_probs .= zero(Float64)
            else
                id = (bc.node2id[pn]::⋁NodeIds).node_id
                @inbounds els_start = bc.nodes[1,id]
                @inbounds els_end = bc.nodes[2,id]
                @inbounds @views pn.log_probs .= logaddexp.(log_exp_factor .+ pn.log_probs, log_1_exp_factor .+ params[els_start:els_end])
                @assert isapprox(sum(exp.(pn.log_probs)), 1.0, atol=1e-3) "Parameters do not sum to one locally: $(sum(exp.(pn.log_probs))); $(pn.log_probs)"
                pn.log_probs .-= logsumexp(pn.log_probs) # normalize away any leftover error
            end
        elseif isliteralgate(pn)
            @inbounds @views pn.log_probs .= logaddexp.(log_exp_factor .+ pn.log_probs, log_1_exp_factor .+ leaf_params[pn.literal,:])
            @assert isapprox(sum(exp.(pn.log_probs)), 1.0, atol=1e-3) "Leaf parameters do not sum to one locally: $(sum(exp.(pn.log_probs))); $(pn.log_probs)"
            pn.log_probs .-= logsumexp(pn.log_probs) # normalize away any leftover error
        end
    end
    
    nothing
end

function estimate_parameters_em_cpu(pbc::CatParamBitCircuit, data; pseudocount::Float64,
                                    reuse::Bool = false, reuse_v = nothing, reuse_f = nothing, 
                                    reuse_counts = (nothing, nothing, nothing, nothing, nothing))
    bc::CatBitCircuit = pbc.bitcircuit
    params = pbc.params
    
    node_counts::Vector{Float64} = similar!(reuse_counts[1], Vector{Float64}, num_nodes(bc))
    edge_counts::Vector{Float64} = similar!(reuse_counts[2], Vector{Float64}, num_elements(bc))
    parent_node_counts::Vector{Float64} = similar!(reuse_counts[3], Vector{Float64}, num_elements(bc))
    leaf_counts::Matrix{Float64} = similar!(reuse_counts[4], Matrix{Float64}, size(pbc.cat_params)...)
    @inbounds @views node_counts[:] .= typemin(Float64)
    @inbounds @views edge_counts[:] .= zero(Float64)
    @inbounds @views parent_node_counts[:] .= zero(Float64)
    @inbounds @views leaf_counts[:,:] .= typemin(Float64)
    
    # Buffer to save some allocations
    buffer::Vector{Float64} = similar!(reuse_counts[4], Vector{Float64}, 
        isbatched(data) ? num_examples(data[1]) : num_examples(data))
    
    if !isbatched(data)
        data = [data]
    end
    
    # For batched dataset, we want to enable 'estimate' for parent_node_counts only in the last minibatch
    estimate_flag = true
    
    @inline function on_node(flows, values, dec_id, weight::Nothing, d)
        i::UInt32 = dec_id - 2
        @inbounds @views buffer .= flows[:, dec_id]
        node_counts[dec_id] = logaddexp(node_counts[dec_id], logsumexp(buffer))
        for j = 1 : length(d)
            leaf_counts[i, d[j]] = logaddexp(leaf_counts[i, d[j]], buffer[j])
        end
    end
    @inline function on_node(flows, values, dec_id, weights::Nothing, d::Nothing)
        @inbounds @views buffer .= flows[:, dec_id]
        node_counts[dec_id] = logaddexp(node_counts[dec_id], logsumexp(buffer))
    end
    
    @inline function estimate(element, decision, edge_count)
        edge_counts[element] += exp(edge_count)
        if estimate_flag # For batched dataset, we only accumulate parent_node_counts after all node_counts have been cumulated
            parent_node_counts[element] += exp(node_counts[decision])
        end
    end
    
    @inline function on_edge(flows, values, prime, sub, element, grandpa, 
                             single_child, weight::Nothing)
        θ = eltype(flows)(params[element])
        if !single_child
            @inbounds @views buffer .= values[:, prime] .+ values[:, sub] .- values[:, grandpa] .+ flows[:, grandpa] .+ θ
            @inbounds @views buffer .= ifelse.(isnan.(buffer[:]), typemin(eltype(flows)), buffer)
            
            edge_count = logsumexp(buffer)
            
            estimate(element, grandpa, edge_count)
        end # no need to estimate single child params, they are always prob 1
    end
    
    v, f = reuse_v, reuse_f
    for idx = 1 : length(data)    
        # Resize buffer if the current minibatch has a different size
        if size(buffer, 1) != num_examples(data[idx])
            resize!(buffer, num_examples(data[idx]))
        end
        
        estimate_flag = (idx == length(data))
        v, f = marginal_flows(pbc, data[idx], v, f; on_node, on_edge, weights = nothing)
    end
    
    # `edge_counts` now becomes "params"
    bc = pbc.bitcircuit
    @simd for i = 1 : num_elements(pbc)
        num_els = num_elements(bc.nodes, bc.elements[1, i])
        if num_els == 1
            @inbounds edge_counts[i] = zero(eltype(edge_counts)) # log(1)
        else
            @inbounds edge_counts[i] = log((edge_counts[i] + pseudocount / num_elements(bc.nodes, bc.elements[1, i])) / (parent_node_counts[i] + pseudocount))
        end
    end
    
    n_cats = size(pbc.cat_params, 2)
    @inbounds leaf_params = log.((exp.(leaf_counts) .+ (pseudocount ./ n_cats)) ./ (sum(exp.(leaf_counts); dims = 2) .+ pseudocount))
    
    if reuse
        edge_counts, leaf_params, v, f, (node_counts, edge_counts, parent_node_counts, leaf_counts)
    else
        edge_counts, leaf_params
    end
end

function estimate_parameters_em_gpu(pbc::CatParamBitCircuit, data; pseudocount::Float64,
                                    reuse::Bool = false, reuse_v = nothing, reuse_f = nothing, 
                                    reuse_counts = (nothing, nothing, nothing))
    bc::CatBitCircuit = pbc.bitcircuit
    
    node_counts::CuVector{Float64} = similar!(reuse_counts[1], CuVector{Float64}, num_nodes(bc))
    edge_counts::CuVector{Float64} = similar!(reuse_counts[2], CuVector{Float64}, num_elements(bc))
    leaf_c_size = size(pbc.cat_params)
    leaf_counts::CuMatrix{Float64} = similar!(reuse_counts[3], CuMatrix{Float64}, leaf_c_size[1], leaf_c_size[2])
    @inbounds @views node_counts[:] .= zero(Float64)
    @inbounds @views edge_counts[:] .= zero(Float64)
    @inbounds @views leaf_counts[:,:] .= zero(Float64)
    
    # need to manually cudaconvert closure variables
    node_counts_device = CUDA.cudaconvert(node_counts)
    edge_counts_device = CUDA.cudaconvert(edge_counts)
    leaf_counts_device = CUDA.cudaconvert(leaf_counts)
    
    if !isbatched(data)
        data = [data]
    end
    
    @inline function on_node(flows, values, dec_id, bit_idx, flow, weight::Nothing, d::UInt32)
        c::Float64 = exp(flow) # cast for @atomic to be happy
        i::UInt32 = dec_id - 2
        CUDA.@atomic leaf_counts_device[i, d] += c
        CUDA.@atomic node_counts_device[dec_id] += c
    end
    @inline function on_node(flows, values, dec_id, bit_idx, flow, weight::Nothing, d::Nothing)
        c::Float64 = exp(flow) # cast for @atomic to be happy
        CUDA.@atomic node_counts_device[dec_id] += c
    end
    
    @inline function on_edge(flows, values, prime, sub, element, grandpa, sample_idx, 
                             edge_flow::AbstractFloat, single_child, weight::Nothing)
        if !single_child
            c::Float64 = exp(edge_flow) # cast for @atomic to be happy
            CUDA.@atomic edge_counts_device[element] += c
        end # no need to estimate single child params, they are always prob 1
    end
    
    v, f = reuse_v, reuse_f
    for idx = 1 : length(data) 
        v, f = marginal_flows(pbc, to_gpu(data[idx]), v, f; on_node, on_edge, weights = nothing)
    end
    
    @inbounds parents = bc.elements[1,:]
    @inbounds parent_counts = node_counts[parents]
    par_a = bc.nodes[1,parents]
    par_b = bc.nodes[2,parents]
    @inbounds @views parent_elcount = par_b .- par_a
    @inbounds @views parent_elcount .= parent_elcount .+ UInt32(1)
    params = log.((edge_counts .+ (pseudocount ./ parent_elcount)) 
                    ./ (parent_counts .+ pseudocount))
    params = ifelse.(parent_elcount .== 1, zero(params), params)
    
    n_cats = size(pbc.cat_params, 2)
    leaf_params = log.((leaf_counts .+ (pseudocount ./ n_cats)) ./ (sum(leaf_counts; dims = 2) .+ pseudocount))
    
    if reuse
        params, leaf_params, v, f, (node_counts, edge_counts, leaf_counts)
    else
        CUDA.unsafe_free!(v) # save the GC some effort
        CUDA.unsafe_free!(f) # save the GC some effort
        CUDA.unsafe_free!(node_counts) # save the GC some effort
        CUDA.unsafe_free!(edge_counts) # save the GC some effort
        
        params, leaf_params
    end
end


"Stochastic gradient descent"
function sgd_cat(pc::ProbCircuit, data; lr::Float64, use_gpu::Bool = isgpu(data))
    
    pbc::CatParamBitCircuit = CatParamBitCircuit(pc, num_variables(pc))
    
    if isgpu(data)
        use_gpu = true
    end
    
    if !isbatched(data)
        data = [data]
    end
    
    if use_gpu
        pbc = to_gpu(pbc)
    end
        
    reuse_v, reuse_f = nothing, nothing
    reuse = (nothing, nothing, nothing)

    for idx = 1 : length(data)
        param_grads, leaf_param_grads, reuse_v, reuse_f, reuse_counts = if use_gpu
            # Forward pass
            reuse_v = marginal_all(pbc, to_gpu(data[idx]), reuse_v)
            
            # Compute gradient at output
            log_grads = similar!(reuse[1], typeof(pbc.params), num_examples(data[idx]))
            @inbounds @views log_grads[:] .= zero(eltype(reuse_v)) .- reuse_v[:, num_nodes(pbc.bitcircuit)]
    
            param_grads = similar!(reuse[2], typeof(pbc.params), size(pbc.params)...)
            leaf_param_grads = similar!(reuse[3], typeof(pbc.cat_params), size(pbc.cat_params)...)
            
            d = isgpu(reuse_v) ? to_gpu(convert(Matrix{UInt32}, to_cpu(data[idx]))) : convert(Matrix{UInt32}, to_cpu(data[idx]))
            reuse_f, param_grads, leaf_param_grads = backprop_flows_down_gpu(
                pbc, d, log_grads, reuse_v, param_grads, leaf_param_grads, reuse_f)
            
            param_grads, leaf_param_grads, reuse_v, reuse_f, (log_grads, param_grads, leaf_param_grads)
        else
            error("Not implemented.")
        end

        # Update the parameters to `pbc`
        if use_gpu # GPU
            apply_gradients_gpu(pbc, param_grads, leaf_param_grads; lr)
        else # CPU
            error("Not implemented.")
        end
    end
        
    my_cache_parameters!(pc, pbc.bitcircuit, pbc.params, pbc.cat_params)
    
    nothing
end