using LogicCircuits
using ProbabilisticCircuits
using StatsFuns
using CUDA
using LoopVectorization


function decompress_marginal_rans(pbc::CompressParamBitCircuit, code, reuse_d = nothing, reuse = [nothing for _ = 1 : 9]; 
                                  num_vars = num_variables(pbc), num_cats = num_categories(pbc), 
                                  use_gpu = false, Float = Float64, precision::Integer = 28, debug::Bool = false)
    if use_gpu
        if !isgpu(pbc)
            pbc = to_gpu(pbc)
        end
        decompress_marginal_rans_gpu(pbc, code, reuse_d, reuse; num_vars, num_cats, Float, precision, debug)
    else
        decompress_marginal_rans_cpu(pbc, code, reuse_d, reuse; num_vars, num_cats, Float, precision, debug)
    end
end

"""
CPU code
"""

function decompress_marginal_rans_cpu(pbc::CompressParamBitCircuit, codes, reuse_d = nothing, reuse = [nothing for _ = 1 : 9];
                                      num_vars, num_cats, Float = Float64, precision::Integer = 24, debug::Bool = false)
    # Init values
    num_examples = length(codes)
    values = similar!(reuse[1], Matrix{Float}, num_examples, num_nodes(pbc))
    @inbounds @views values[:,:] .= zero(Float)
    @views values[:,LogicCircuits.TRUE_BITS] .= log(one(Float))
    @views values[:,LogicCircuits.FALSE_BITS] .= log(zero(Float))

    # Init target_data
    target_data = similar!(reuse_d, Matrix{UInt32}, num_examples, num_vars)

    # Init reuse buffer(s)
    lls_buffer::Vector{Float} = similar!(reuse[2], Vector{Float}, num_examples)
    cat_idxs_low::Vector{Int} = similar!(reuse[3], Vector{Int}, num_examples)
    cat_idxs_mid::Vector{Int} = similar!(reuse[4], Vector{Int}, num_examples)
    cat_idxs_high::Vector{Int} = similar!(reuse[5], Vector{Int}, num_examples)
    lls_div::Vector{Float} = similar!(reuse[6], Vector{Float}, num_examples)
    @inbounds @views lls_div .= zero(Float)
    lls_ref::Vector{Float} = similar!(reuse[7], Vector{Float}, num_examples)
    lls_cdf::Vector{Float} = similar!(reuse[8], Vector{Float}, num_examples)
    lls_new::Vector{Float} = similar!(reuse[9], Vector{Float}, num_examples)
    
    # Some const vals
    rans_l::UInt64 = UInt64(1) << 31
    tail_bits::UInt64 = (UInt64(1) << 32) - 1
    rans_prec_l::UInt64 = ((UInt64(rans_l) >> precision) << 32)
    prec_tail_bits::UInt64 = (UInt64(1) << precision) - 1
    prec_val = (UInt64(1) << precision)
    log_prec::Float64 = precision * log(Float64(2.0))

    # Convert code to its unflattened form
    @inline bitvectortouint64(c::BitVector) = begin
        if length(c) == 0
            zero(UInt64)
        else
            v::UInt64 = zero(UInt64)
            for idx = length(c) : -1 : 1
                v = (v << 1)
                if c[idx]
                    v |= UInt64(1)
                end
            end
            v
        end
    end
    states_head = UInt64[rans_l for _ = 1 : num_examples] # The numbers at the head of the state stacks
    states_remain = Vector{UInt32}[UInt32[] for _ = 1 : num_examples]
    for i = 1 : num_examples
        siz::Int = @inbounds length(codes[i]) - 1
        h_len::Int = codes[i][end] ? 2 : 1
        s_len::Int = Int(ceil(siz / 32)) - h_len
        states_remain[i] = zeros(UInt32, s_len)
        unsafe_copyto!(pointer(states_remain[i]), reinterpret(Ptr{UInt32}, pointer(codes[i].chunks)), s_len)
        states_head[i] = @inbounds bitvectortouint64(codes[i][(s_len << 5)+1:end-1])
    end
    
    # Decode the variables one by one
    bc::CompressBitCircuit = pbc.bitcircuit
    for (idx, var_idx) in enumerate(bc.var_order)
        # Compute the reference LLs for variable #`var_idx`
        for j = 1 : num_examples
            # @inbounds lls_ref[j] = log((states_head[j] & prec_tail_bits) / prec_val)
            @inbounds lls_ref[j] = log((states_head[j] & prec_tail_bits)) - log_prec
            if var_idx == 330
                println(states_head[j], " ", prec_tail_bits, " ", log_prec)
            end
        end
        
        if debug
            debug_var = 330
            if var_idx == debug_var
                println("Decoding variable $(debug_var)")
                println("State before decoding: $(states_head[1])")
            end
        end
        
        # Use the dichotomy method to locate the target categories
        @inbounds @views cat_idxs_low .= Int(1)
        @inbounds @views cat_idxs_high .= Int(num_cats + 1)
        @inbounds @views lls_cdf .= typemin(Float)
        while any((cat_idxs_low .+ 1) .< cat_idxs_high)
            # cat_idxs_mid = (cat_idxs_low + cat_idxs_high) // 2
            @inbounds @views cat_idxs_mid .= cat_idxs_low
            @inbounds @views cat_idxs_mid .+= cat_idxs_high
            @inbounds @views cat_idxs_mid .÷= Int(2)
            
            assign_leaf_nodes_cat_min1_cpu(values, pbc.lit_to_var, pbc.cumulative_cat_params, 
                bc.layers[var_idx][1]; cat_idxs = cat_idxs_mid)
            compression_marginal_layers_cpu(pbc, values, bc.layers[var_idx])
            record_lls_temp_cpu(values, lls_buffer, bc.layers[var_idx][end], bc.root_node_weights[var_idx])
            
            #=if debug
                if var_idx == debug_var
                    println("=======================")
                    println("($(cat_idxs_low[1]), $(cat_idxs_high[1])) $(cat_idxs_mid[1]):")
                    println("ref_ll  is $(lls_ref[1])")
                    println("code_ll is $(lls_buffer[1] - lls_div[1])")
                end
            end=#
            
            for j = 1 : num_examples
                @inbounds ref_ll = lls_ref[j]
                @inbounds code_ll = lls_buffer[j] - lls_div[j]
                if code_ll < ref_ll
                    @inbounds cat_idxs_low[j] = cat_idxs_mid[j]
                    @inbounds lls_cdf[j] = lls_buffer[j]
                else
                    @inbounds cat_idxs_high[j] = cat_idxs_mid[j]
                end
            end
        end
        
        # Assign decoded categories of variable #`var_idx`
        @inbounds @views target_data[:, var_idx] .= cat_idxs_low
        
        #=if debug
            if var_idx == debug_var
                println("=======================")
                println("Selected category is: $(cat_idxs_low[1])")
            end
        end=#
        
        # P(X_1=x_1,...,X_k=x_k)
        assign_leaf_nodes_cpu(values, target_data, pbc.lit_to_var, pbc.cat_params, pbc.cumulative_cat_params, 
                              bc.layers[var_idx][1]; cumulative = false)
        compression_marginal_layers_cpu(pbc, values, bc.layers[var_idx])
        record_lls_temp_cpu(values, lls_new, bc.layers[var_idx][end], bc.root_node_weights[var_idx])
        
        # Update the rANS states
        for j = 1 : num_examples
            ref_low_int = @inbounds UInt64(ceil(exp(lls_cdf[j] - lls_div[j] + log_prec)))
            ref_ll_int = @inbounds UInt64(ceil(exp(lls_new[j] - lls_div[j] + log_prec)))
            if ref_low_int % ref_ll_int == 0 # Handle this special case to avoid numerical errors
                ref_low_int -= one(UInt64)
            end
            
            # rANS decoding
            cf = (states_head[j] & prec_tail_bits)
            @inbounds states_head[j] = (states_head[j] >> precision) * ref_ll_int + cf - ref_low_int
            if states_head[j] < rans_l
                if length(states_remain[j]) != 0
                    states_head[j] = (states_head[j] << 32) | UInt64(pop!(states_remain[j]))
                else
                    states_head[j] = (states_head[j] << 32)
                end
            end
        end
        
        @inbounds @views lls_div[:] .= lls_new[:]
    end
    
    target_data, (values, lls_buffer, cat_idxs_low, cat_idxs_mid, cat_idxs_high, lls_div, lls_ref, lls_cdf, lls_new)
end

"""
GPU code
"""

function decompress_marginal_rans_gpu(pbc::CompressParamBitCircuit, codes, reuse_d = nothing, reuse = [nothing for _ = 1 : 9];
                                      num_vars, num_cats, Float = Float64, precision::Integer = 24, debug::Bool = false)
    # Init values
    num_examples::Int = length(codes)
    values::CuMatrix{Float} = similar!(reuse[1], CuMatrix{Float}, num_examples, num_nodes(pbc))
    @inbounds @views values[:,:] .= zero(Float)
    @views values[:,LogicCircuits.TRUE_BITS] .= log(one(Float))
    @views values[:,LogicCircuits.FALSE_BITS] .= log(zero(Float))

    # Init target_data
    target_data::CuMatrix{UInt32} = similar!(reuse_d, CuMatrix{UInt32}, num_examples, num_vars)

    # Init reuse buffer(s)
    lls_buffer::CuVector{Float} = similar!(reuse[2], CuVector{Float}, num_examples)
    cat_idxs_low::Vector{Int} = similar!(reuse[3], Vector{Int}, num_examples)
    cat_idxs_mid::Vector{Int} = similar!(reuse[4], Vector{Int}, num_examples)
    cat_idxs_high::Vector{Int} = similar!(reuse[5], Vector{Int}, num_examples)
    lls_div::Vector{Float} = similar!(reuse[6], Vector{Float}, num_examples)
    @inbounds @views lls_div .= zero(Float)
    lls_ref::Vector{Float} = similar!(reuse[7], Vector{Float}, num_examples)
    lls_cdf::Vector{Float} = similar!(reuse[8], Vector{Float}, num_examples)
    lls_new::CuVector{Float} = similar!(reuse[9], CuVector{Float}, num_examples)
    
    # Some const vals
    rans_l::UInt64 = UInt64(1) << 31
    tail_bits::UInt64 = (UInt64(1) << 32) - 1
    rans_prec_l::UInt64 = ((UInt64(rans_l) >> precision) << 32)
    prec_tail_bits::UInt64 = (UInt64(1) << precision) - 1
    prec_val = (UInt64(1) << precision)
    log_prec::Float64 = precision * log(Float64(2.0))

    # Convert code to its unflattened form
    @inline bitvectortouint64(c::BitVector) = begin
        if length(c) == 0
            zero(UInt64)
        else
            v::UInt64 = zero(UInt64)
            for idx = length(c) : -1 : 1
                v = (v << 1)
                if c[idx]
                    v |= UInt64(1)
                end
            end
            v
        end
    end
    states_head::Vector{UInt64} = UInt64[rans_l for _ = 1 : num_examples] # The numbers at the head of the state stacks
    states_remain::Vector{Vector{UInt32}} = Vector{UInt32}[UInt32[] for _ = 1 : num_examples]
    for i = 1 : num_examples
        siz::Int = @inbounds length(codes[i]) - 1
        h_len::Int = codes[i][end] ? 2 : 1
        s_len::Int = Int(ceil(siz / 32)) - h_len
        states_remain[i] = zeros(UInt32, s_len)
        unsafe_copyto!(pointer(states_remain[i]), reinterpret(Ptr{UInt32}, pointer(codes[i].chunks)), s_len)
        states_head[i] = @inbounds bitvectortouint64(codes[i][(s_len << 5)+1:end-1])
    end
    
    # Decode the variables one by one
    bc::CompressBitCircuit = pbc.bitcircuit
    for (idx, var_idx) in enumerate(bc.var_order)
        # Compute the reference LLs for variable #`var_idx`
        @inbounds @views lls_ref .= log.((states_head .& prec_tail_bits) ./ prec_val)
        
        # Use the dichotomy method to locate the target categories
        @inbounds @views cat_idxs_low .= Int(1)
        @inbounds @views cat_idxs_high .= Int(num_cats + 1)
        @inbounds @views lls_cdf .= typemin(Float)
        while any((cat_idxs_low .+ 1) .< cat_idxs_high)
            # cat_idxs_mid = (cat_idxs_low + cat_idxs_high) // 2
            @inbounds @views cat_idxs_mid .= cat_idxs_low
            @inbounds @views cat_idxs_mid .+= cat_idxs_high
            @inbounds @views cat_idxs_mid .÷= Int(2)
            
            assign_leaf_nodes_cat_min1_gpu(values, pbc.lit_to_var, pbc.cumulative_cat_params, 
                bc.layers[var_idx][1]; cat_idxs = to_gpu(cat_idxs_mid))
            compression_marginal_layers_gpu(pbc, values, bc.layers[var_idx]; 
                dec_per_thread = 4, log2_threads_per_block = 5)
            record_lls_temp_gpu(values, lls_buffer, bc.layers[var_idx][end], bc.root_node_weights[var_idx])
            lls_buffer_cpu = to_cpu(lls_buffer)
            
            # ref_ll = lls_ref[j]; code_ll = lls_buffer_cpu[j] - lls_div[j]; cond = (code_ll < ref_ll)
            @inbounds @views cond = lls_buffer_cpu .- lls_div .< lls_ref
            # if code_ll < ref_ll then cat_idxs_low[j] = cat_idxs_mid[j]
            @inbounds @views cat_idxs_low .= ifelse.(cond, cat_idxs_mid, cat_idxs_low)
            # if code_ll < ref_ll then lls_cdf[j] = lls_buffer_cpu[j]
            @inbounds @views lls_cdf .= ifelse.(cond, lls_buffer_cpu, lls_cdf)
            # if code_ll >= ref_ll then cat_idxs_high[j] = cat_idxs_mid[j]
            @inbounds @views cat_idxs_high .= ifelse.(cond, cat_idxs_high, cat_idxs_mid)
        end
        
        # Assign decoded categories of variable #`var_idx`
        @inbounds @views target_data[:, var_idx] .= to_gpu(cat_idxs_low)
        
        # P(X_1=x_1,...,X_k=x_k)
        assign_leaf_nodes_gpu(values, target_data, pbc.lit_to_var, pbc.cat_params, pbc.cumulative_cat_params, 
                              bc.layers[var_idx][1]; cumulative = false)
        compression_marginal_layers_gpu(pbc, values, bc.layers[var_idx]; 
            dec_per_thread = 4, log2_threads_per_block = 5)
        record_lls_temp_gpu(values, lls_new, bc.layers[var_idx][end], bc.root_node_weights[var_idx])
        lls_new_cpu = to_cpu(lls_new)
        
        # Update the rANS states
        ref_low_int = UInt64.(ceil.(exp.(lls_cdf .- lls_div .+ log_prec)))
        ref_ll_int = UInt64.(ceil.(exp.(lls_new_cpu .- lls_div .+ log_prec)))
        cond = (ref_low_int .% ref_ll_int) .== zero(UInt64) # Handle this special case to avoid numerical errors
        ref_low_int = ifelse.(cond, ref_low_int .- one(UInt64), ref_low_int)
        
        cf = (states_head .& prec_tail_bits)
        states_head .= (states_head .>> precision) .* ref_ll_int .+ cf .- ref_low_int
        for j = 1 : num_examples
            if states_head[j] < rans_l
                if length(states_remain[j]) != 0
                    states_head[j] = (states_head[j] << 32) | UInt64(pop!(states_remain[j]))
                else
                    states_head[j] = (states_head[j] << 32)
                end
            end
        end
        
        @inbounds @views lls_div[:] .= lls_new_cpu[:]
    end
    
    target_data, (values, lls_buffer, cat_idxs_low, cat_idxs_mid, cat_idxs_high, lls_div, lls_ref, lls_cdf, lls_new)
end