using LogicCircuits
using ProbabilisticCircuits
using CUDA


function compress_marginal(pc::ProbCircuit, data, values = nothing, target_lls = nothing)
    pbc = same_device(CompressParamBitCircuit(pc), data)
    compress_marginal(pbc, data, values, target_lls)
end

function compress_marginal(pbc::CompressParamBitCircuit, data, values = nothing, target_lls = nothing; 
                           Float = Float64, ll_interval::Integer = 5)
    @assert isgpu(data) == isgpu(pbc) "CompressParamBitCircuit and data need to be on the same device"
    
    if isgpu(data)
        d = to_gpu(convert(Matrix{UInt32}, to_cpu(data)))
        compress_marginal_gpu(pbc, d, values, target_lls; Float, ll_interval)
    else
        d = convert(Matrix{UInt32}, data)
        compress_marginal_cpu(pbc, d, values, target_lls; Float, ll_interval)
    end
end

"""
CPU code
"""

function compress_marginal_cpu(pbc::CompressParamBitCircuit, data, values = nothing, target_lls = nothing; 
                               Float = Float64, ll_interval::Integer = 5)
    # Init values
    values = similar!(values, Matrix{Float}, num_examples(data), num_nodes(pbc))
    @inbounds @views values[:,:] .= zero(Float)
    @views values[:,LogicCircuits.TRUE_BITS] .= log(one(Float))
    @views values[:,LogicCircuits.FALSE_BITS] .= log(zero(Float))
    
    # Init target_lls
    num_vars = num_variables(pbc)
    num_extra_lls = Int(ceil(num_vars / ll_interval))
    target_lls = similar!(target_lls, Matrix{Float}, num_examples(data), num_vars + num_extra_lls)
    @inbounds @views target_lls[:,:] .= typemin(Float)
    
    bc = pbc.bitcircuit
    extra_ll_idx = num_vars + 1
    for (idx, var_idx) in enumerate(bc.var_order)
        # Pr(X_1 = x_1, ..., X_k = 1···x_k)
        assign_leaf_nodes_cpu(values, data, pbc.lit_to_var, pbc.cat_params, pbc.cumulative_cat_params, 
                              bc.layers[var_idx][1]; cumulative = true)
        compression_marginal_layers_cpu(pbc, values, bc.layers[var_idx])
        
        record_lls_cpu(values, target_lls, idx, bc.layers[var_idx][end], bc.root_node_weights[var_idx])
        
        # Pr(X_1 = x_1, ..., X_k = x_k)
        assign_leaf_nodes_cpu(values, data, pbc.lit_to_var, pbc.cat_params, pbc.cumulative_cat_params, 
                              bc.layers[var_idx][1]; cumulative = false)
        compression_marginal_layers_cpu(pbc, values, bc.layers[var_idx])
        
        if idx % ll_interval == 0 || idx == num_vars
            record_lls_cpu(values, target_lls, extra_ll_idx, bc.layers[var_idx][end], bc.root_node_weights[var_idx])
            extra_ll_idx += 1
        end
    end
    
    values, target_lls
end

function assign_leaf_nodes_cpu(values, data, lit_to_var, cat_params, cumulative_cat_params, layer; cumulative::Bool)
    n_literals = length(lit_to_var)
    for dec_id in layer
        lit_idx = dec_id - 2
        @inbounds var_idx = lit_to_var[lit_idx]
        @inbounds cat_idxs = data[:,var_idx]
        if cumulative
            # When in cumulative mode, always record 1 category less 
            vals_pos = map(1 : num_examples(data)) do j
                if cat_idxs[j] == 1
                    typemin(eltype(values))
                else
                    @inbounds cumulative_cat_params[lit_idx, cat_idxs[j]-1]
                end
            end
        else
            vals_pos = map(1 : num_examples(data)) do j
                @inbounds cat_params[lit_idx, cat_idxs[j]]
            end
        end
        @inbounds @views values[:,dec_id] .= vals_pos
        @inbounds @views values[:,dec_id+n_literals] .= typemin(eltype(values))
    end
end

function compression_marginal_layers_cpu(pbc::CompressParamBitCircuit, values::Matrix, layers)
    bc::CompressBitCircuit = pbc.bitcircuit
    els = bc.elements
    pars = pbc.params
    for layer in layers[2:end]
        Threads.@threads for dec_id in layer
            j = @inbounds bc.nodes[1,dec_id]
            els_end = @inbounds bc.nodes[2,dec_id]
            if j == els_end
                assign_marginal(values, dec_id, els[2,j], els[3,j], pars[j])
                j += 1
            else
                assign_marginal(values, dec_id, els[2,j], els[3,j], els[2,j+1], els[3,j+1], pars[j], pars[j+1])
                j += 2
            end
            while j <= els_end
                accum_marginal(values, dec_id, els[2,j], els[3,j], pars[j])
                j += 1
            end
        end
    end
end

assign_marginal(v::Matrix{<:AbstractFloat}, i, e1p, e1s, p1) =
    @views @. @avx v[:,i] = v[:,e1p] + v[:,e1s] + p1

accum_marginal(v::Matrix{<:AbstractFloat}, i, e1p, e1s, p1) = begin
    @avx for j=1:size(v,1)
        @inbounds x = v[j,i]
        @inbounds y = v[j,e1p] + v[j,e1s] + p1
        Δ = ifelse(x == y, zero(eltype(v)), abs(x - y))
        @inbounds v[j,i] = max(x, y) + log1p(exp(-Δ))
    end
end

assign_marginal(v::Matrix{<:AbstractFloat}, i, e1p, e1s, e2p, e2s, p1, p2) = begin
    @avx for j=1:size(v,1)
        @inbounds x = v[j,e1p] + v[j,e1s] + p1
        @inbounds y = v[j,e2p] + v[j,e2s] + p2
        Δ = ifelse(x == y, zero(eltype(v)), abs(x - y))
        @inbounds v[j,i] = max(x, y) + log1p(exp(-Δ))
    end
end

function record_lls_cpu(values::Matrix, target_lls::Matrix, var_idx, layer, weights)
    Threads.@threads for j in collect(1:size(values,1))
        val = typemin(eltype(target_lls))
        for i = 1 : length(layer)
            @inbounds dec_id = layer[i]
            @inbounds w = weights[i]
            @inbounds val = logaddexp(val, values[j,dec_id] + w)
        end
        @inbounds target_lls[j,var_idx] = val
    end
end

"""
GPU code
"""

function compress_marginal_gpu(pbc::CompressParamBitCircuit, data, values = nothing, target_lls = nothing; 
                               Float = Float64, ll_interval::Integer = 5)
    # Init values
    values = similar!(values, CuMatrix{Float}, num_examples(data), num_nodes(pbc))
    @inbounds @views values[:,:] .= zero(Float)
    @views values[:,LogicCircuits.TRUE_BITS] .= log(one(Float))
    @views values[:,LogicCircuits.FALSE_BITS] .= log(zero(Float))
    
    # Init target_lls
    num_vars = num_variables(pbc)
    num_extra_lls = Int(ceil(num_vars / ll_interval))
    target_lls = similar!(target_lls, CuMatrix{Float}, num_examples(data), num_vars + num_extra_lls)
    @inbounds @views target_lls[:,:] .= typemin(Float)
    
    bc = pbc.bitcircuit
    extra_ll_idx = num_vars + 1
    for (idx, var_idx) in enumerate(bc.var_order)
        # Pr(X_1 = x_1, ..., X_k = 1···x_k)
        assign_leaf_nodes_gpu(values, data, pbc.lit_to_var, pbc.cat_params, pbc.cumulative_cat_params, 
                              bc.layers[var_idx][1]; cumulative = true)
        compression_marginal_layers_gpu(pbc, values, bc.layers[var_idx]; dec_per_thread = 4, log2_threads_per_block = 5)
        
        record_lls_gpu(values, target_lls, idx, bc.layers[var_idx][end], bc.root_node_weights[var_idx])
        
        # Pr(X_1 = x_1, ..., X_k = x_k)
        assign_leaf_nodes_gpu(values, data, pbc.lit_to_var, pbc.cat_params, pbc.cumulative_cat_params, 
                              bc.layers[var_idx][1]; cumulative = false)
        compression_marginal_layers_gpu(pbc, values, bc.layers[var_idx]; dec_per_thread = 4, log2_threads_per_block = 5)
        
        if idx % ll_interval == 0 || idx == num_vars
            record_lls_gpu(values, target_lls, extra_ll_idx, bc.layers[var_idx][end], bc.root_node_weights[var_idx])
            extra_ll_idx += 1
        end
    end
    
    values, target_lls
end

function assign_leaf_nodes_gpu(values, data, lit_to_var, cat_params, cumulative_cat_params, layer; cumulative::Bool,
                               dec_per_thread = 7, log2_threads_per_block = 8)
    n_examples = size(values, 1)
    num_literal_sets = length(lit_to_var) / dec_per_thread
    num_threads = balance_threads(n_examples, num_literal_sets, log2_threads_per_block)
    num_blocks = (ceil(Int, n_examples/num_threads[1]), 
                  ceil(Int, num_literal_sets/num_threads[2]))
    
    if cumulative
        @cuda threads=num_threads blocks=num_blocks assign_leaf_nodes_min1_cuda(values, data, lit_to_var, 
            cumulative_cat_params, layer)
    else
        @cuda threads=num_threads blocks=num_blocks assign_leaf_nodes_cuda(values, data, lit_to_var, cat_params, layer)
    end
end

function assign_leaf_nodes_min1_cuda(values, data, lit_to_var, params, layer)
    index_x = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    index_y = (blockIdx().y - 1) * blockDim().y + threadIdx().y
    stride_x = blockDim().x * gridDim().x
    stride_y = blockDim().y * gridDim().y
    n_literals = length(lit_to_var)
    for j = index_x:stride_x:size(values,1)
        for i = index_y:stride_y:length(layer)
            @inbounds dec_id = layer[i]
            lit_idx = dec_id - 2
            @inbounds var_idx = lit_to_var[lit_idx]
            @inbounds cat_idx = data[j,var_idx]
            
            if cat_idx == 1
                marg_pos = typemin(eltype(values))
            else
                @inbounds marg_pos = params[lit_idx, cat_idx-1]
            end
            
            @inbounds values[j, dec_id] = marg_pos
            @inbounds values[j, dec_id+n_literals] = typemin(eltype(values))
        end
    end
    return nothing
end
function assign_leaf_nodes_cuda(values, data, lit_to_var, params, layer)
    index_x = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    index_y = (blockIdx().y - 1) * blockDim().y + threadIdx().y
    stride_x = blockDim().x * gridDim().x
    stride_y = blockDim().y * gridDim().y
    n_literals = length(lit_to_var)
    for j = index_x:stride_x:size(values,1)
        for i = index_y:stride_y:length(layer)
            @inbounds dec_id = layer[i]
            lit_idx = dec_id - 2
            @inbounds var_idx = lit_to_var[lit_idx]
            @inbounds cat_idx = data[j,var_idx]
            
            @inbounds marg_pos = params[lit_idx, cat_idx]
            
            @inbounds values[j, dec_id] = marg_pos
            @inbounds values[j, dec_id+n_literals] = typemin(eltype(values))
        end
    end
    return nothing
end

function compression_marginal_layers_gpu(pbc::CompressParamBitCircuit, values::CuMatrix, layers; 
                             dec_per_thread = 8, log2_threads_per_block = 8)
    bc::CompressBitCircuit = pbc.bitcircuit
    CUDA.@sync for layer in layers[2:end]
        num_examples = size(values, 1)
        num_decision_sets = length(layer) / dec_per_thread
        num_threads = balance_threads(num_examples, num_decision_sets, log2_threads_per_block)
        num_blocks = (ceil(Int, num_examples / num_threads[1]), 
                      ceil(Int, num_decision_sets / num_threads[2]))
        @cuda threads=num_threads blocks=num_blocks compression_marginal_layers_cuda(layer, bc.nodes, bc.elements, 
            pbc.params, values)
    end
end

function compression_marginal_layers_cuda(layer, nodes, elements, params, values)
    index_x = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    index_y = (blockIdx().y - 1) * blockDim().y + threadIdx().y
    stride_x = blockDim().x * gridDim().x
    stride_y = blockDim().y * gridDim().y
    for j = index_x:stride_x:size(values,1)
        for i = index_y:stride_y:length(layer)
            decision_id = @inbounds layer[i]
            k = @inbounds nodes[1,decision_id]
            els_end = @inbounds nodes[2,decision_id]
            left_ele_id = @inbounds elements[2,k]
            right_ele_id = @inbounds elements[3,k]
            @inbounds x = values[j, left_ele_id] + values[j, right_ele_id] + params[k]
            while k < els_end
                k += 1
                left_ele_id = @inbounds elements[2,k]
                right_ele_id = @inbounds elements[3,k]
                @inbounds y = values[j, left_ele_id] + values[j, right_ele_id] + params[k]
                Δ = ifelse(x == y, zero(eltype(values)), CUDA.abs(x - y))
                x = max(x, y) + CUDA.log1p(CUDA.exp(-Δ))
            end
            @inbounds values[j, decision_id] = x
        end
    end
    return nothing
end

function record_lls_gpu(values::CuMatrix, target_lls::CuMatrix, var_idx, layer, weights; 
                        dec_per_thread = 7, log2_threads_per_block = 8)
    n_examples = size(values, 1)
    num_literal_sets = length(layer) / dec_per_thread
    num_threads = balance_threads(n_examples, num_literal_sets, log2_threads_per_block)[1]
    num_blocks = ceil(Int, n_examples/num_threads[1])
    
    @cuda threads=num_threads blocks=num_blocks record_lls_cuda(values, target_lls, var_idx, layer, weights)
end

function record_lls_cuda(values, target_lls, var_idx, layer, weights)
    index_x = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    stride_x = blockDim().x * gridDim().x
    for j = index_x:stride_x:size(values,1)
        val = typemin(eltype(target_lls))
        for i = 1 : length(layer)
            @inbounds dec_id = layer[i]
            @inbounds w = weights[i]
            
            @inbounds y = values[j,dec_id] + w
            Δ = ifelse(val == y, zero(eltype(values)), CUDA.abs(val - y))
            val = max(val, y) + CUDA.log1p(CUDA.exp(-Δ))
        end
        @inbounds target_lls[j, var_idx] = val
    end
    return nothing
end