using LogicCircuits
using ProbabilisticCircuits
using StatsFuns: logaddexp, logsumexp, logsubexp


function compress(pc::ProbCircuit, x::Vector{Bool}; precision = 512)
    setprecision(BigFloat, precision)
    
    newx = Vector{Union{Bool,Missing}}()
    append!(newx, x)
    n_samples = 1
    for i = 1 : length(x)
        if x[i]
            start_idx = length(newx)
            append!(newx, x)
            newx[start_idx+i] = false
            newx[start_idx+i+1:end] .= missing
            n_samples += 1
        end
    end
    newx = reshape(newx, length(x), :)
    newx = convert(Matrix{Union{Bool,Missing}}, newx)
    
    mars = marginals(pc, newx)
    lower = n_samples == 1 ? BigFloat(-Inf) : logsumexp(mars[2:end])
    upper = logaddexp(lower, mars[1])
    
    expected_num_bits = logsubexp(upper, lower) * log(ℯ) / log(2.0)
    
    lower, upper = exp(lower), exp(upper)
    
    code = Vector{Bool}()
    num = BigFloat("0.0")
    bit_num = BigFloat("0.5")
    flag = true
    while flag
        if num >= lower && num < upper
            flag = false
        elseif num < lower
            if num + bit_num < upper
                num += bit_num
                push!(code, true)
            else
                push!(code, false)
            end
        else
            error("Encountered num >= upper.")
        end
        println(length(code))
        println(code)
        println(num)
        bit_num *= BigFloat("0.5")
    end
    
    code = BitVector(code)

    if length(code) > 2 + expected_num_bits
        println("Warning: got LL $(expected_num_bits) but the true bit-length is $(length(code)).")
    end
    
    code
end

function decompress(pc::ProbCircuit, n_features::Integer, code::BitVector; precision = 512)
    setprecision(BigFloat, precision)
    
    x = Matrix{Union{Bool,Missing}}(undef, n_features, 1)
    x[:] .= missing
    
    code_to_float(c::BitVector)::BigFloat = begin
        num = BigFloat("0.0")
        bit_num = BigFloat("0.5")
        for i = 1 : length(c)
            if c[i]
                num += bit_num
            end
            bit_num *= BigFloat("0.5")
        end
        num
    end
    
    plus_one(a::Matrix{Union{Bool,Missing}}) = begin
        b = deepcopy(a)
        b = coalesce.(b, false)
        for i = size(a, 1) : -1 : 1
            if !b[i]
                b[i] = true
                break
            else
                b[i] = false
            end
        end
        convert(Matrix{Union{Bool,Missing}}, b)
    end
    plus_one(a::Vector{Bool}) = begin
        b = deepcopy(a)
        for i = size(a, 1) : -1 : 1
            if !b[i]
                b[i] = true
                break
            else
                b[i] = false
            end
        end
        convert(Vector{Bool}, b)
    end
    
    val::BigFloat = code_to_float(code)
    
    decoded_x = Vector{Bool}()
    curr_lower = BigFloat("0.0")
    flag = true
    i = 1
    while flag
        x[i] = false
        
        x_plus1 = deepcopy(x)
        x_plus1[i] = true
        x_plus1 = coalesce.(x_plus1, false)
        
        mar, mar_1 = exp.(marginals(pc, hcat(x, x_plus1)))
        if curr_lower + mar <= val
            push!(decoded_x, true)
            x[i] = true
            if curr_lower + mar + mar_1 > val
                flag = false
            end
            curr_lower += mar
        else # curr_lower + mar > val
            push!(decoded_x, false)
            # x[i] = false
        end
        if i >= n_features
            flag = false
        end
        i += 1
    end
    while length(decoded_x) < n_features
        push!(decoded_x, false)
    end
    
    decoded_x
end

function marginals(pc::ProbCircuit, x::Matrix{Union{Bool,Missing}})
    m = size(x, 2)
    
    f_con(n)::Vector{BigFloat} = error("Do not support constant nodes.")
    f_lit(n)::Vector{BigFloat} = begin
        v = Vector{BigFloat}(undef, m)
        if literal(n) > Lit(0)
            v .= log.(coalesce.(x[variable(n),:], true))
        else
            v .= log.(coalesce.(.!x[variable(n),:], true))
        end
        v
    end
    f_a(n, cs)::Vector{BigFloat} = begin
        v = deepcopy(cs[1])
        for i = 2 : num_children(n)
            v .+= cs[i]
        end
        v
    end
    f_o(n, cs)::Vector{BigFloat} = begin
        v = zeros(BigFloat, m)
        v .= BigFloat(-Inf)
        for i = 1 : num_children(n)
            v .= logaddexp.(v, BigFloat(n.log_probs[i]) .+ cs[i])
        end
        v
    end
    
    foldup_aggregate(pc, f_con, f_lit, f_a, f_o, Vector{BigFloat})
end