using LogicCircuits
using ProbabilisticCircuits
using CUDA
using StatsFuns: logaddexp
using LoopVectorization


function decompress_marginal(pbc::CompressParamBitCircuit, code, values = nothing, target_data = nothing, 
                             reuse = (nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing); 
                             num_vars, num_cats, Float = Float64, ll_interval::Integer = 5)
    @assert isgpu(code) == isgpu(pbc) "CompressParamBitCircuit and code need to be on the same device"
    
    if isgpu(code)
        # decompress_marginal_gpu(pbc, code, values, target_data, reuse; num_vars, num_cats, Float, ll_interval)
    else
        decompress_marginal_cpu(pbc, code, values, target_data, reuse; num_vars, num_cats, Float, ll_interval)
    end
end

"""
CPU code
"""

function decompress_marginal_cpu(pbc::CompressParamBitCircuit, code, values = nothing, target_data = nothing, 
                                 reuse = (nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing);
                                 num_vars, num_cats, Float = Float64, ll_interval::Integer = 5)
    # Init values
    values = similar!(values, Matrix{Float}, size(code, 1), num_nodes(pbc))
    @inbounds @views values[:,:] .= zero(Float)
    @views values[:,LogicCircuits.TRUE_BITS] .= log(one(Float))
    @views values[:,LogicCircuits.FALSE_BITS] .= log(zero(Float))

    # Init target_data
    target_data = similar!(target_data, Matrix{UInt32}, size(code, 1), num_vars)

    # Init reuse buffer(s)
    lls_buffer::Vector{Float} = similar!(reuse[1], Vector{Float}, size(code, 1))
    cat_idxs_low::Vector{Int} = similar!(reuse[2], Vector{Int}, size(code, 1))
    cat_idxs_mid::Vector{Int} = similar!(reuse[3], Vector{Int}, size(code, 1))
    cat_idxs_high::Vector{Int} = similar!(reuse[4], Vector{Int}, size(code, 1))
    lls_div::Vector{Float} = similar!(reuse[5], Vector{Float}, size(code, 1))
    @inbounds @views lls_div .= zero(Float)
    lls_cdf::Vector{Float} = similar!(reuse[6], Vector{Float}, size(code, 1))
    ref_lls::Matrix{Float} = similar!(reuse[7], Matrix{Float}, size(code, 1), size(code, 2))
    lls_to_sub::Vector{Float} = similar!(reuse[8], Vector{Float}, size(code, 1))
    @inbounds @views lls_to_sub .= typemin(Float)

    code_to_ll_cpu(code, ref_lls; Float)
    
    bc = pbc.bitcircuit
    ref_ll_idx::Int = 1
    for (idx, var_idx) in enumerate(bc.var_order)
        @inbounds @views cat_idxs_low .= Int(1)
        @inbounds @views cat_idxs_high .= Int(num_cats + 1)
        @inbounds @views lls_cdf .= typemin(Float)

        while any((cat_idxs_low .+ 1) .< cat_idxs_high)
            @inbounds @views cat_idxs_mid .= cat_idxs_low
            @inbounds @views cat_idxs_mid .+= cat_idxs_high
            @inbounds @views cat_idxs_mid .÷= Int(2)
            
            assign_leaf_nodes_cat_min1_cpu(values, pbc.lit_to_var, pbc.cumulative_cat_params, 
                bc.layers[var_idx][1]; cat_idxs = cat_idxs_mid)
            compression_marginal_layers_cpu(pbc, values, bc.layers[var_idx])
            record_lls_temp_cpu(values, lls_buffer, bc.layers[var_idx][end], bc.root_node_weights[var_idx])
            
            ref_ll::Float = zero(Float)
            code_ll::Float = zero(Float)
            Threads.@threads for j = 1 : size(code, 1)
                @inbounds ref_ll = logsubexp(ref_lls[j,ref_ll_idx], lls_to_sub[j])
                @inbounds code_ll = lls_buffer[j] - lls_div[j]
                if code_ll < ref_ll
                    @inbounds cat_idxs_low[j] = cat_idxs_mid[j]
                    @inbounds lls_cdf[j] = code_ll
                else
                    @inbounds cat_idxs_high[j] = cat_idxs_mid[j]
                end
            end
        end
        
        if idx % ll_interval == 0 || idx == num_vars
            ref_ll_idx += 1
        end
        
        @inbounds @views target_data[:, var_idx] .= cat_idxs_low
        @inbounds @views lls_to_sub .= logaddexp.(lls_to_sub, lls_cdf)
        
        # P(X_1=x_1,...,X_k=x_k)
        assign_leaf_nodes_cpu(values, target_data, pbc.lit_to_var, pbc.cat_params, pbc.cumulative_cat_params, 
                              bc.layers[var_idx][1]; cumulative = false)
        compression_marginal_layers_cpu(pbc, values, bc.layers[var_idx])

        if idx % ll_interval == 0 || idx == num_vars
            record_lls_temp_cpu(values, lls_div, bc.layers[var_idx][end], bc.root_node_weights[var_idx])
            @inbounds @views lls_to_sub .= typemin(Float)
        end
    end
    
    values, target_data, (lls_buffer, cat_idxs_low, cat_idxs_mid, cat_idxs_high, lls_div, lls_cdf, ref_lls, lls_to_sub)
end

function assign_leaf_nodes_cat_min1_cpu(values, lit_to_var, cumulative_cat_params, layer; cat_idxs)
    n_literals = length(lit_to_var)
    for dec_id in layer
        lit_idx = dec_id - 2
        vals_pos = map(1 : size(values, 1)) do j
            if cat_idxs[j] == 1
                typemin(eltype(values))
            else
                @inbounds cumulative_cat_params[lit_idx, cat_idxs[j]-1]
            end
        end
        @inbounds @views values[:,dec_id] .= vals_pos
        @inbounds @views values[:,dec_id+n_literals] .= typemin(eltype(values))
    end
end
function assign_leaf_nodes_cat_cpu(values, lit_to_var, cumulative_cat_params, layer; cat_idxs)
    n_literals = length(lit_to_var)
    for dec_id in layer
        lit_idx = dec_id - 2
        vals_pos = map(1 : size(values, 1)) do j
            @inbounds cumulative_cat_params[lit_idx, cat_idxs[j]]
        end
        @inbounds @views values[:,dec_id] .= vals_pos
        @inbounds @views values[:,dec_id+n_literals] .= typemin(eltype(values))
    end
end

function record_lls_temp_cpu(values, lls_buffer, layer, weights)
    # Compute the LL corresponds to P(X_1=x_1,...,X_k=Any)
    Threads.@threads for j in collect(1:size(values,1))
        val = typemin(eltype(lls_buffer))
        for i = 1 : length(layer)
            @inbounds dec_id = layer[i]
            @inbounds w = weights[i]
            @inbounds val = logaddexp(val, values[j,dec_id] + w)
        end
        @inbounds lls_buffer[j] = val
    end
end

function code_to_ll_cpu(code::Matrix{UInt64}, lls::Matrix; Float = Float64)
    log_0_5 = log(Float(0.5))
    
    @inline decode(c::UInt64) = begin
        if c == typemax(UInt64)
            return Float(0.0)
        end
        while (c & UInt64(0x8000000000000000)) == UInt64(0)
            c = (c << 1)
        end
        c = (c << 1)
        ll::Float = typemin(Float)
        v::Float = log_0_5
        while c != UInt64(0)
            if (c & UInt64(0x8000000000000000)) != UInt64(0)
                ll = logaddexp(ll, v)
            end
            c = (c << 1)
            v += log_0_5
        end
        ll
    end
    
    Threads.@threads for i = 1 : size(code, 1)
        for j = 1 : size(code, 2)
            @inbounds lls[i, j] = decode(code[i, j])
        end
    end
end

"""
GPU code
"""

function decompress_marginal_gpu(pbc::CompressParamBitCircuit, code, values = nothing, target_data = nothing,
                                 reuse = (nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing);
                                 num_vars, num_cats, Float = Float64, ll_interval::Integer = 5)
    # Init values
    values = similar!(values, CuMatrix{Float}, size(code, 1), num_nodes(pbc))
    @inbounds @views values[:,:] .= zero(Float)
    @views values[:,LogicCircuits.TRUE_BITS] .= log(one(Float))
    @views values[:,LogicCircuits.FALSE_BITS] .= log(zero(Float))
    
    # Init target_data
    target_data = similar!(target_data, CuMatrix{UInt32}, size(code, 1), num_vars)
    
    # Init reuse buffer(s)
    lls_buffer::CuVector{Float} = similar!(reuse[1], CuVector{Float}, size(code, 1))
    cat_idxs_low::CuVector{Int} = similar!(reuse[2], CuVector{Int}, size(code, 1))
    cat_idxs_mid::CuVector{Int} = similar!(reuse[3], CuVector{Int}, size(code, 1))
    cat_idxs_high::CuVector{Int} = similar!(reuse[4], CuVector{Int}, size(code, 1))
    lls_div::CuVector{Float} = similar!(reuse[5], CuVector{Float}, size(code, 1))
    @inbounds @views lls_div .= zero(Float)
    lls_cdf::CuVector{Float} = similar!(reuse[6], CuVector{Float}, size(code, 1))
    ref_lls::CuMatrix{Float} = similar!(reuse[7], CuMatrix{Float}, size(code, 1), size(code, 2))
    lls_to_sub::CuVector{Float} = similar!(reuse[8], CuVector{Float}, size(code, 1))
    @inbounds @views lls_to_sub .= typemin(Float)
    
    code_to_ll_gpu(code, ref_lls; Float)
    
    bc = pbc.bitcircuit
    ref_ll_idx::Int = 1
    for (idx, var_idx) in enumerate(bc.var_order)
        @inbounds @views cat_idxs_low .= Int(1)
        @inbounds @views cat_idxs_high .= Int(num_cats + 1)
        @inbounds @views lls_cdf .= typemin(Float)
        
        while any((cat_idxs_low .+ 1) .< cat_idxs_high)
            @inbounds @views cat_idxs_mid .= cat_idxs_low
            @inbounds @views cat_idxs_mid .+= cat_idxs_high
            @inbounds @views cat_idxs_mid .÷= Int(2)
            
            assign_leaf_nodes_cat_min1_gpu(values, pbc.lit_to_var, pbc.cumulative_cat_params, 
                bc.layers[var_idx][1]; cat_idxs = cat_idxs_mid)
            compression_marginal_layers_gpu(pbc, values, bc.layers[var_idx])
            update_cat_idxs_and_lls_cdf_gpu(values, bc.layers[var_idx][end], bc.root_node_weights[var_idx],
                ref_lls, lls_to_sub, lls_div, cat_idxs_low, cat_idxs_mid, cat_idxs_high, lls_cdf, ref_ll_idx)
        end
        
        if idx % ll_interval == 0 || idx == num_vars
            ref_ll_idx += 1
        end
        
        @inbounds @views target_data[:, var_idx] .= cat_idxs_low
        @inbounds @views lls_to_sub .= logaddexp.(lls_to_sub, lls_cdf) # TODO: convert to gpu code
        
        # P(X_1=x_1,...,X_k=x_k)
        assign_leaf_nodes_gpu(values, target_data, pbc.lit_to_var, pbc.cat_params, pbc.cumulative_cat_params, 
                              bc.layers[var_idx][1]; cumulative = false)
        compression_marginal_layers_gpu(pbc, values, bc.layers[var_idx])

        if idx % ll_interval == 0 || idx == num_vars
            record_lls_temp_gpu(values, lls_div, bc.layers[var_idx][end], bc.root_node_weights[var_idx])
            @inbounds @views lls_to_sub .= typemin(Float)
        end
    end
end

function assign_leaf_nodes_cat_min1_gpu(values, lit_to_var, cumulative_cat_params, layer; cat_idxs,
                                        dec_per_thread = 7, log2_threads_per_block = 8)
    n_examples = size(values, 1)
    num_literal_sets = length(lit_to_var) / dec_per_thread
    num_threads = balance_threads(n_examples, num_literal_sets, log2_threads_per_block)
    num_blocks = (ceil(Int, n_examples/num_threads[1]), 
                  ceil(Int, num_literal_sets/num_threads[2]))
    
    @cuda threads=num_threads blocks=num_blocks assign_leaf_nodes_cat_min1_cuda(values, lit_to_var, 
            cumulative_cat_params, layer, cat_idxs)
end

function assign_leaf_nodes_cat_min1_cuda(values, lit_to_var, cumulative_cat_params, layer, cat_idxs)
    index_x = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    index_y = (blockIdx().y - 1) * blockDim().y + threadIdx().y
    stride_x = blockDim().x * gridDim().x
    stride_y = blockDim().y * gridDim().y
    n_literals = length(lit_to_var)
    for j = index_x:stride_x:size(values,1)
        for i = index_y:stride_y:length(layer)
            @inbounds dec_id = layer[i]
            lit_idx = dec_id - 2
            @inbounds cat_idx = cat_idxs[j]
            
            if cat_idx == 1
                marg_pos = typemin(eltype(values))
            else
                @inbounds marg_pos = cumulative_cat_params[lit_idx, cat_idx-1]
            end
            
            @inbounds values[j, dec_id] = marg_pos
            @inbounds values[j, dec_id+n_literals] = typemin(eltype(values))
        end
    end
    return nothing
end

function update_cat_idxs_and_lls_cdf_gpu(values, layer, weights, ref_lls, lls_to_sub, lls_div, 
                                         cat_idxs_low, cat_idxs_mid, cat_idxs_high, lls_cdf, ref_ll_idx; 
                                         dec_per_thread = 7, log2_threads_per_block = 8)
    n_examples = size(values, 1)
    num_literal_sets = length(layer) / dec_per_thread
    num_threads = balance_threads(n_examples, num_literal_sets, log2_threads_per_block)[1]
    num_blocks = ceil(Int, n_examples/num_threads[1])
    
    @cuda threads=num_threads blocks=num_blocks update_cat_idxs_and_lls_cdf_cuda(values, layer, weights, ref_lls, lls_to_sub, 
        lls_div, cat_idxs_low, cat_idxs_mid, cat_idxs_high, lls_cdf, ref_ll_idx)
end

function update_cat_idxs_and_lls_cdf_cuda(values, layer, weights, ref_lls, lls_to_sub, lls_div, cat_idxs_low, 
                                          cat_idxs_mid, cat_idxs_high, lls_cdf, ref_ll_idx)
    index_x = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    stride_x = blockDim().x * gridDim().x
    for j = index_x:stride_x:size(values,1)
        val = typemin(eltype(target_lls))
        for i = 1 : length(layer)
            @inbounds dec_id = layer[i]
            @inbounds w = weights[i]
            
            @inbounds y = values[j,dec_id] + w
            Δ = ifelse(val == y, zero(eltype(values)), CUDA.abs(val - y))
            val = max(val, y) + CUDA.log1p(CUDA.exp(-Δ))
        end
        @inbounds lls_buffer[j] = val
        
        @inbounds x = ref_lls[j, ref_ll_idx]
        @inbounds y = lls_to_sub[j]
        Δ = ifelse(x == y, zero(x - y), CUDA.abs(x, y))
        ref_ll = max(x, y) + my_log1mexp(-Δ) # logsubexp(x, y)
        
        @inbounds code_ll = val - lls_div[j]
        if code_ll < ref_ll
            @inbounds cat_idxs_low[j] = cat_idxs_mid[j]
            @inbounds lls_cdf[j] = code_ll
        else
            @inbounds cat_idxs_high[j] = cat_idxs_mid[j]
        end
    end
    return nothing
end

@inline my_log1mexp(v) = begin
    ifelse(v < loghalf, CUDA.log1p(-CUDA.exp(v)), CUDA.log(-CUDA.expm1(v)))
end

function record_lls_temp_gpu(values, lls_buffer, layer, weights; dec_per_thread = 7, log2_threads_per_block = 8)
    n_examples = size(values, 1)
    num_literal_sets = length(layer) / dec_per_thread
    num_threads = balance_threads(n_examples, num_literal_sets, log2_threads_per_block)[1]
    num_blocks = ceil(Int, n_examples/num_threads[1])
    
    @cuda threads=num_threads blocks=num_blocks record_lls_temp_cuda(values, lls_buffer, layer, weights)
end

function record_lls_temp_cuda(values, lls_buffer, layer, weights)
    index_x = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    stride_x = blockDim().x * gridDim().x
    for j = index_x:stride_x:size(values,1)
        val = typemin(eltype(values))
        for i = 1 : length(layer)
            @inbounds dec_id = layer[i]
            @inbounds w = weights[i]
            
            @inbounds y = values[j,dec_id] + w
            Δ = ifelse(val == y, zero(eltype(values)), CUDA.abs(val - y))
            val = max(val, y) + CUDA.log1p(CUDA.exp(-Δ))
        end
        @inbounds lls_buffer[j] = val
    end
    return nothing
end

function code_to_ll_gpu(code::CuMatrix{UInt64}, lls::CuMatrix; Float = Float64,
                        dec_per_thread = 7, log2_threads_per_block = 8)
    n_examples = size(code, 1)
    num_lls_sets = size(code, 2) / dec_per_thread
    num_threads = balance_threads(n_examples, num_lls_sets, log2_threads_per_block)
    num_blocks = (ceil(Int, n_examples/num_threads[1]), 
                  ceil(Int, num_literal_sets/num_threads[2]))
    
    @cuda threads=num_threads blocks=num_blocks code_to_ll_cuda(code, lls)
end

function code_to_ll_cuda(code::CuMatrix, lls::CuMatrix)
    index_x = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    index_y = (blockIdx().y - 1) * blockDim().y + threadIdx().y
    stride_x = blockDim().x * gridDim().x
    stride_y = blockDim().y * gridDim().y
    log_0_5 = log(Float(0.5))
    for j = index_x:stride_x:size(code,1)
        for i = index_y:stride_y:size(code,2)
            if c == typemax(UInt64)
                ll = Float(0.0)
            else
                while (c & UInt64(0x8000000000000000)) == UInt64(0)
                    c = (c << 1)
                end
                c = (c << 1)
                ll = typemin(Float)
                v::Float = log_0_5
                while c != UInt64(0)
                    if (c & UInt64(0x8000000000000000)) != UInt64(0)
                        Δ = ifelse(ll == v, zero(eltype(lls)), CUDA.abs(ll - v))
                        ll = max(ll, v) + CUDA.log1p(CUDA.exp(-Δ)) # ll = logaddexp(ll, v)
                    end
                    c = (c << 1)
                    v += log_0_5
                end
                ll
            end
            @inbounds lls[i, j] = ll
        end
    end
    return nothing
end