using LogicCircuits
using ProbabilisticCircuits
using StatsFuns

"""
A good blog post for rANS: https://kedartatwawadi.github.io/post--ANS/
"""


function compress_data_rans(pbc::CompressParamBitCircuit, data::DataFrame, reuse = (nothing, nothing); 
                            use_gpu::Bool = false, precision::Integer = 28, debug::Bool = false, Float = Float64)
    if use_gpu && !isgpu(pbc)
        pbc = to_gpu(pbc)
        data = to_gpu(data)
    end
    
    # Get target_lls
    values, target_lls = compress_marginal(pbc, data, reuse[1], reuse[2]; Float, ll_interval = 1)
    
    if use_gpu
        lls = to_cpu(target_lls)
    else
        lls = target_lls
    end
    
    # Some const vals
    rans_l::UInt64 = UInt64(1) << 31
    tail_bits::UInt64 = (UInt64(1) << 32) - 1
    rans_prec_l::UInt64 = ((UInt64(rans_l) >> precision) << 32)
    log_prec::Float64 = precision * log(Float64(2.0))
    
    # Encode
    num_examples = size(lls, 1)
    num_vars = num_features(data)
    states_head = UInt64[rans_l for _ = 1 : num_examples] # The numbers at the head of the state stacks
    states_remain = Vector{UInt32}[UInt32[] for _ = 1 : num_examples]
    Threads.@threads for i = 1 : num_examples
        # Encode the variables in reverse order so that the decoding can be done in the right order
        for cdf_idx = num_vars : -1 : 1
            # Get (ref_low, ref_ll), which correspond to an interval [ref_low, ref_ll)
            ll_idx = cdf_idx + num_vars
            if cdf_idx > 1
                @inbounds ll_div = lls[i, ll_idx - 1]
                @inbounds ref_low = lls[i, cdf_idx] - ll_div
                @inbounds ref_ll = lls[i, ll_idx] - ll_div
            else
                @inbounds ref_low = lls[i, cdf_idx]
                @inbounds ref_ll = lls[i, ll_idx]
            end
            
            # Convert (ref_low, ref_ll) to integers with precision `precision`
            ref_low_int = UInt64(ceil(exp(ref_low + log_prec)))
            ref_ll_int = UInt64(ceil(exp(ref_ll + log_prec)))
            if ref_low_int % ref_ll_int == 0 # Handle this special case to avoid numerical errors
                ref_low_int -= one(UInt64)
            end
            
            #=if debug
                debug_var = 330
                if pbc.bitcircuit.var_order[cdf_idx] == debug_var
                    println("Encoding variable $(debug_var)")
                    println("True category is $(data[1,debug_var])")
                    println("State before encoding: $(states_head[1])")
                    println("(ref_low, ref_ll): ($(ref_low), $(ref_ll))")
                    println("(ref_low_int, ref_ll_int): ($(ref_low_int), $(ref_ll_int))")
                    prec_tail_bits::UInt64 = (UInt64(1) << precision) - 1
                    println("log((states_head[1] & prec_tail_bits)) - log_prec should be in the range [$(exp(ref_low)), $(exp(logaddexp(ref_low, ref_ll))))")
                    println("It is $(exp(log((states_head[1] & prec_tail_bits)) - log_prec)) - $(log((states_head[1] & tail_bits)) - log_prec)")
                end
            end=#
            
            @assert ref_ll_int > 0 "Please increase encoding precision (i.e., `precision`)."
            
            # rANS encoding
            @inbounds if states_head[i] >= rans_prec_l * ref_ll_int
                push!(states_remain[i], UInt32(states_head[i] & tail_bits))
                states_head[i] = (states_head[i] >> 32)
            end
            @inbounds states_head[i] = ((states_head[i] ÷ ref_ll_int) << precision) + (states_head[i] % ref_ll_int) + ref_low_int
            
            #=prec_tail_bits::UInt64 = (UInt64(1) << precision) - 1
            dec_val = log(states_head[i] & prec_tail_bits) - log_prec
            if !(dec_val >= ref_low && dec_val < logaddexp(ref_low, ref_ll))
                println(exp(dec_val), " ", exp(ref_low), " ", exp(logaddexp(ref_low, ref_ll)))
                println(dec_val, " ", ref_low)
                println(ref_low_int, " ", ref_ll_int)
            end=#
            
            #=if debug
                if pbc.bitcircuit.var_order[cdf_idx] == debug_var
                    println("State after encoding: $(states_head[1])")
                    println("It is $(exp(log((states_head[1] & prec_tail_bits)) - log_prec)) - $(log((states_head[1] & prec_tail_bits)) - log_prec)")
                end
            end=#
        end
    end
    
    # merge `states_head` and `states_remain` into a BitVector
    @inline uint64tobitvector(v::UInt64) = begin
        bv = falses(64)
        idx = 1
        while v > UInt64(0)
            bv[idx] = Bool(v & UInt64(1))
            v = (v >> 1)
            idx += 1
        end
        bv[1:idx-1]
    end
    codes = BitVector[BitVector() for _ = 1 : num_examples]
    Threads.@threads for i = 1 : num_examples
        head_bv = uint64tobitvector(states_head[i])
        v = @inbounds states_remain[i]
        siz = size(v, 1)
        bv = falses(siz << 5 + size(head_bv, 1) + 1)
        unsafe_copyto!(reinterpret(Ptr{UInt32}, pointer(bv.chunks)), pointer(v), siz)
        @inbounds @views bv[(siz << 5)+1:end-1] = head_bv[:]
        @inbounds bv[end] = (states_head[i] > tail_bits)
        @inbounds codes[i] = bv
    end
    
    codes, (values, target_lls)
end