using LogicCircuits
using ProbabilisticCircuits
using StatsFuns: logaddexp, logsumexp

using LogicCircuits: NodeId, TRUE_BITS, FALSE_BITS

abstract type CompressNodeIds end
mutable struct ⋁CompressNodeIds <: CompressNodeIds
    layer_id::Vector{NodeId}
    node_id::NodeId
end
mutable struct ⋀CompressNodeIds <: CompressNodeIds
    layer_id::Vector{NodeId}
    prime_id::NodeId
    sub_id::NodeId
    ⋀CompressNodeIds(p, s) = begin
        l = max.(p.layer_id, s.layer_id)
        new(l, p.node_id, s.node_id)
    end 
end


"""
BitCircuit designed for compression/decompression.
"""

struct CompressBitCircuit{V,M,F}
    layers::Vector{Vector{V}}
    root_node_weights::Vector{F}
    nodes::M
    elements::M
    var_order
    node2id
end

function CompressBitCircuit(pc::ProbCircuit, data; on_decision=noop)
    CompressBitCircuit(pc, num_features(data); on_decision)
end

function CompressBitCircuit(pc::ProbCircuit; on_decision=noop)
    CompressBitCircuit(pc, num_variables(pc); on_decision)
end

"construct a new `CompressBitCircuit` accomodating the given number of features"
function CompressBitCircuit(pc::ProbCircuit, num_features::Int; on_decision=noop)
    pc, vtree, pc2vtree, vtree2pc = convert_to_sd_pc(pc)
    var_order::Vector{Var} = get_flattened_vtree_var_order(vtree)
    vtree_order::Vector{Set{Vtree}} = get_vtree_computation_order(vtree, var_order)
    
    num_vars = length(var_order)
    
    # store data in vectors to facilitate push!
    num_leafs = 2+2*num_features
    nodes::Vector{NodeId} = zeros(NodeId, 2*num_leafs)
    elements::Vector{NodeId} = NodeId[]
    last_dec_id::NodeId = 2*num_features+2
    last_el_id::NodeId = zero(NodeId)
    layers::Vector{Vector{Vector{NodeId}}} = Vector{Vector{NodeId}}[Vector{NodeId}[] for i = 1 : num_vars]
    layers_pc::Vector{Vector{Vector{ProbCircuit}}} = Vector{Vector{ProbCircuit}}[Vector{ProbCircuit}[] for i = 1 : num_vars]
    
    f_con(n) = ⋁CompressNodeIds(ones(NodeId, num_vars), istrue(n) ? TRUE_BITS : FALSE_BITS)
    f_lit(n) = begin
        layer_id::Vector{NodeId} = zeros(NodeId, num_vars)
        dec_id::NodeId = NodeId(2+n.literal)
        n_vtree = pc2vtree[n]
        for var_idx = 1 : num_vars
            if n_vtree in vtree_order[var_idx]
                layer_id[var_idx] = 1
                length(layers[var_idx]) < layer_id[var_idx] && push!(layers[var_idx], NodeId[])
                push!(layers[var_idx][layer_id[var_idx]], dec_id)
                length(layers_pc[var_idx]) < layer_id[var_idx] && push!(layers_pc[var_idx], ProbCircuit[])
                push!(layers_pc[var_idx][layer_id[var_idx]], n)
            else
                layer_id[var_idx] = 0
            end
        end
        ⋁CompressNodeIds(layer_id, dec_id)
    end
    f_and(n, cs) = begin
        @assert length(cs) == 2 "Should have exactly two children"
        ⋀CompressNodeIds(cs[1], cs[2])
    end
    f_or(n, cs) = begin
        first_el_id::NodeId = last_el_id + one(NodeId)
        layer_id::Vector{NodeId} = zeros(NodeId, num_vars)
        last_dec_id::NodeId += one(NodeId)
        
        f_or_child(c::⋀CompressNodeIds) = begin
            layer_id = max.(layer_id, c.layer_id)
            last_el_id += one(NodeId)
            push!(elements, last_dec_id, c.prime_id, c.sub_id)
        end
        f_or_child(c::⋁CompressNodeIds) = begin
            layer_id = max.(layer_id, c.layer_id)
            last_el_id += one(NodeId)
            push!(elements, last_dec_id, c.node_id, TRUE_BITS)
        end
        foreach(f_or_child, cs)
        
        n_vtree = pc2vtree[n]
        for var_idx = 1 : num_vars
            if n_vtree in vtree_order[var_idx]
                @inbounds layer_id[var_idx] += one(NodeId)
                length(layers[var_idx]) < layer_id[var_idx] && push!(layers[var_idx], NodeId[])
                push!(layers[var_idx][layer_id[var_idx]], last_dec_id)
                length(layers_pc[var_idx]) < layer_id[var_idx] && push!(layers_pc[var_idx], ProbCircuit[])
                push!(layers_pc[var_idx][layer_id[var_idx]], n)
            else
                @inbounds layer_id[var_idx] = zero(NodeId)
            end
        end
        
        push!(nodes, first_el_id, last_el_id)
        on_decision(n, cs, layer_id, last_dec_id, first_el_id, last_el_id)
        ⋁CompressNodeIds(layer_id, last_dec_id)
    end
    
    node2id = Dict{ProbCircuit,CompressNodeIds}()
    r = foldup_aggregate(pc, f_con, f_lit, f_and, f_or, CompressNodeIds, node2id)
    
    # Finalize CompressBitCircuit
    nodes_m = reshape(nodes, 2, :)
    elements_m = reshape(elements, 3, :)
    
    top_down_probs = get_top_down_probs(pc)
    root_node_weights::Vector{Vector{Float64}} = Vector{Float64}[Float64[] for i = 1 : num_vars]
    for var_idx = 1 : num_vars
        for n in layers_pc[var_idx][end]
            push!(root_node_weights[var_idx], top_down_probs[n])
        end
        @assert isapprox(sum(exp.(root_node_weights[var_idx])), 1.0, atol=1e-3) "Top-down parameters do not sum to one locally: $(sum(exp.(root_node_weights[var_idx])))."
        root_node_weights[var_idx] .-= logsumexp(root_node_weights[var_idx]) # Normalize out left-over errors
    end
    
    CompressBitCircuit(layers, root_node_weights, nodes_m, elements_m, var_order, node2id)
end


### Helper functions ###

function get_flattened_vtree_var_order(vtree::Vtree; cache::Set{Vtree} = Set{Vtree}())
    order = Vector{Var}()
    
    dfs(v::Vtree) = begin
        if v in cache
            return
        end
        push!(cache, v)
        
        if v isa PlainVtreeLeafNode
            push!(order, v.var)
        else
            dfs(v.left)
            dfs(v.right)
        end
        
        nothing
    end
    
    dfs(vtree)
    
    order
end


function get_vtree_computation_order(vtree::Vtree, var_order::Vector{Var})
    num_vars = length(var_order)
    vtree_order::Vector{Set{Vtree}} = Set{Vtree}[Set{Vtree}() for i = 1 : num_vars]
    leaf_nodes::Dict{Var,Vtree} = Dict{Var,Vtree}()
    
    f(v) = begin
        if v isa PlainVtreeLeafNode
            leaf_nodes[v.var] = v
        end
    end
    foreach(f, vtree)
    
    for i = 1 : num_vars
        var_idx = var_order[i]
        v = leaf_nodes[var_idx]
        s = BitSet(var_order[1:i])
        push!(vtree_order[var_idx], v)
        flag = true
        while length(setdiff(s, (v isa PlainVtreeLeafNode ? [v.var] : v.variables))) > 0
            v = v.parent
            push!(vtree_order[var_idx], v)
        end
    end
    
    vtree_order
end


function get_top_down_probs(pc::ProbCircuit)
    top_down_prob = Dict{ProbCircuit, Float64}()
    
    process(n::ProbCircuit) = begin
        n_td_prob = top_down_prob[n]
        if is⋁gate(n)
            for (c, p) in zip(n.children, n.log_probs)
                logprob = n_td_prob + p
                c_td_prob = get(top_down_prob, c, -Inf)
                top_down_prob[c] = logaddexp(c_td_prob, logprob)
            end
        elseif is⋀gate(n)
            for c in n.children
                c_td_prob = get(top_down_prob, c, -Inf)
                top_down_prob[c] = logaddexp(c_td_prob, n_td_prob)
            end
        end
    end
    
    top_down_prob[pc] = zero(Float64)
    foreach_down(process, pc)
    
    top_down_prob
end