using LogicCircuits
using ProbabilisticCircuits
using StatsFuns


function compress_data(pc::ProbCircuit, data::DataFrame, values = nothing, target_lls = nothing; use_gpu::Bool = false)
    pbc::CompressParamBitCircuit = CompressParamBitCircuit(pc)
    
    if use_gpu && !isgpu(pbc)
        pbc = to_gpu(pbc)
    end
    compress_marginal(pbc, data, values = nothing, target_lls = nothing; Float = Float64)
end

function compress_data(pbc::CompressParamBitCircuit, data::DataFrame, values = nothing, target_lls = nothing; 
                       use_gpu::Bool = false, ll_interval::Integer = 5)
    if use_gpu && !isgpu(pbc)
        pbc = to_gpu(pbc)
        data = to_gpu(data)
    end
    
    values, target_lls = compress_marginal(pbc, data, values, target_lls; Float = Float64, ll_interval)
    
    values, target_lls
end


function lls_to_code(lls::Matrix, ::Type{I} = UInt64; num_vars::Integer, ll_interval::Integer = 5, Float = Float64,
                     sanity_check::Bool = false)::Matrix{I} where I <: Integer
    
    @inline mylogaddexp(x::AbstractFloat, y::AbstractFloat) = begin
        tmp = x - y
        if tmp > 0
            x + log1p(exp(-tmp))
        elseif tmp <= 0
            y + log1p(exp(tmp))
        else
            tmp
        end
    end
    
    @inline encode(ref_low, ref_high) = begin
        msb = ~(typemax(I) >> 1)
        cum_v = typemin(Float)
        v = log(Float(0.5))
        c::I = I(1)
        if ref_high < 0.0
            while (c & msb) == I(0)
                new_v = mylogaddexp(cum_v, v)
                if new_v < ref_high
                    cum_v = new_v
                    c = (c << 1) | I(1)
                    if new_v >= ref_low
                        break
                    end
                else
                    c = c << 1
                end

                v += log(Float(0.5))
            end
        else
            c = typemax(I)
        end
        c
    end
    
    @inline decode(c::I) = begin
        if c == typemax(I)
            return Float(0.0)
        end
        msb = ~(typemax(I) >> 1)
        while (c & msb) == I(0)
            c = (c << 1)
        end
        c = (c << 1)
        ll::Float = typemin(Float)
        v::Float = log(Float(0.5))
        while c != I(0)
            if (c & msb) != I(0)
                ll = mylogaddexp(ll, v)
            end
            c = (c << 1)
            v += log(Float(0.5))
        end
        ll
    end
    
    code = Matrix{I}(undef, size(lls, 1), size(lls, 2) - num_vars)
    buffer = Vector{Float}(undef, ll_interval)
    for i in collect(1 : size(lls, 1))
        mar_idx1 = 1
        mar_idx2 = min(ll_interval, num_vars)
        ll_idx = num_vars + 1
        while mar_idx1 <= num_vars
            if mar_idx1 > 1
                @inbounds div_val = lls[i, ll_idx-1]
            else
                div_val = 0.0
            end
            @inbounds @views buffer[1:mar_idx2-mar_idx1+1] .= lls[i, mar_idx1:mar_idx2]
            @inbounds @views buffer[1:mar_idx2-mar_idx1+1] .-= div_val
            @inbounds @views ref_cdf = logsumexp(buffer[1:mar_idx2-mar_idx1+1])
            ref_ll = lls[i, ll_idx] - div_val
            ref_low, ref_high = ref_cdf, logaddexp(ref_cdf, ref_ll)
            
            @inbounds code[i, ll_idx-num_vars] = encode(ref_low, ref_high)
            
            if sanity_check
                ll = decode(code[i, ll_idx-num_vars])
                @assert (ll >= ref_low && ll < ref_high) || (code[i, ll_idx-num_vars] == typemax(I)) "($(i), $(mar_idx1), $(mar_idx2), $(ll_idx)); ll: $(ll); ref_low: $(ref_low); ref_high: $(ref_high); ref_ll: $(ref_ll)"
            end
            
            mar_idx1 += ll_interval
            mar_idx2 = min(mar_idx2 + ll_interval, num_vars)
            ll_idx += 1
        end
    end
    code
end