using LogicCircuits
using ProbabilisticCircuits

using LogicCircuits: NodeId, NodeIds, ⋁NodeIds, ⋀NodeIds, TRUE_BITS, FALSE_BITS

"""
A bit circuit is a low-level representation of a logical circuit structure.
They are a "flat" representation of a circuit, essentially a bit string,
that can be processed by lower level code (e.g., GPU kernels)
The wiring of the circuit is captured by two matrices: nodes and elements.
  * Nodes are either leafs or decision (disjunction) nodes in the circuit.
  * Elements are conjunction nodes in the circuit.
  * In addition, there is a vector of layers, where each layer is a list of node ids.
    Layer 1 is the leaf/input layer. Layer end is the circuit root.
  * And there is a vector of parents, pointing to element id parents of decision nodes.
Nodes are represented as a 4xN matrix where
  * nodes[1,:] is the first element id belonging to this decision
  * nodes[2,:] is the last element id belonging to this decision
  * nodes[3,:] is the first parent index belonging to this decision
  * nodes[4,:] is the last parent index belonging to this decision
  Elements belonging to node `i` are `elements[:, nodes[1,i]:nodes[2,i]]`
  Parents belonging to node `i` are `parents[nodes[3,i]:nodes[4,i]]`
Elements are represented by a 3xE matrix, where 
  * elements[1,:] is the decision node id (parents of the element),
  * elements[2,:] is the prime node id (child of the element) 
  * elements[3,:] is the sub node id (child of the element)
"""
struct CatBitCircuit{V,M}
    layers::Vector{V}
    nodes::M
    elements::M
    parents::V
    node2id
end

function CatBitCircuit(circuit::LogicCircuit, data; on_decision=noop)
    CatBitCircuit(circuit, num_features(data); on_decision)
end

"construct a new `CatBitCircuit` accomodating the given number of features"
function CatBitCircuit(circuit::LogicCircuit, num_features::Int; on_decision=noop)
    #TODO: consider not using foldup_aggregate and instead calling twice to ensure order but save allocations
    #TODO add inbounds annotations
    
    f_con(n) = ⋁NodeIds(one(NodeId), istrue(n) ? TRUE_BITS : FALSE_BITS)

    f_lit(n) = begin
        ⋁NodeIds(one(NodeId), NodeId(2+n.literal))
    end
      
    # store data in vectors to facilitate push!
    num_leafs = 2+2*num_features
    layers::Vector{Vector{NodeId}} = Vector{NodeId}[collect(1:num_leafs)]
    nodes::Vector{NodeId} = zeros(NodeId, 4*num_leafs)
    elements::Vector{NodeId} = NodeId[]
    parents::Vector{Vector{NodeId}} = Vector{NodeId}[NodeId[] for i = 1:num_leafs]
    last_dec_id::NodeId = 2*num_features+2
    last_el_id::NodeId = zero(NodeId)

    to⋁NodeIds(c::⋁NodeIds) = c
    to⋁NodeIds(c::⋀NodeIds) = begin
        # need to add a dummy decision node in between AND nodes
        last_dec_id += one(NodeId)
        last_el_id += one(NodeId)
        push!(elements, last_dec_id, c.prime_id, c.sub_id)
        push!(parents[c.prime_id], last_el_id)
        push!(parents[c.sub_id], last_el_id)
        layer_id = c.layer_id + one(NodeId)
        push!(nodes, last_el_id, last_el_id, zero(NodeId), zero(NodeId))
        push!(parents, NodeId[])
        length(layers) < layer_id && push!(layers, NodeId[])
        push!(layers[layer_id], last_dec_id)
        on_decision(nothing, c, layer_id, last_dec_id, last_el_id, last_el_id)
        ⋁NodeIds(layer_id, last_dec_id)
    end

    f_and(n, cs) = begin
        @assert length(cs) > 1 "CatBitCircuits only support AND gates with at least two children"
        a12 = ⋀NodeIds(to⋁NodeIds(cs[1]), to⋁NodeIds(cs[2]))
        if length(cs) == 2
            a12
        elseif length(cs) > 1000 # Avoid using recursion here since it will cause StackOverflowError
            ⋀node_idx = cs[end]
            for idx = length(cs) - 1 : -1 : 1
                ⋀node_idx = ⋀NodeIds(to⋁NodeIds(cs[idx]), to⋁NodeIds(⋀node_idx))
            end
            ⋀node_idx
        else
            f_and(n, [a12, cs[3:end]...])
        end
    end
    
    f_or(n, cs) = begin
        first_el_id::NodeId = last_el_id + one(NodeId)
        layer_id::NodeId = zero(NodeId)
        last_dec_id::NodeId += one(NodeId)

        f_or_child(c::⋀NodeIds) = begin
            layer_id = max(layer_id, c.layer_id)
            last_el_id += one(NodeId)
            push!(elements, last_dec_id, c.prime_id, c.sub_id)
            @inbounds push!(parents[c.prime_id], last_el_id)
            @inbounds push!(parents[c.sub_id], last_el_id)
        end

        f_or_child(c::⋁NodeIds) = begin
            layer_id = max(layer_id, c.layer_id)
            last_el_id += one(NodeId)
            push!(elements, last_dec_id, c.node_id, TRUE_BITS)
            @inbounds push!(parents[c.node_id], last_el_id)
            @inbounds push!(parents[TRUE_BITS], last_el_id)
        end

        foreach(f_or_child, cs)

        layer_id += one(NodeId)
        length(layers) < layer_id && push!(layers, NodeId[])
        push!(nodes, first_el_id, last_el_id, zero(NodeId), zero(NodeId))
        push!(parents, NodeId[])
        push!(layers[layer_id], last_dec_id)
        on_decision(n, cs, layer_id, last_dec_id, first_el_id, last_el_id)
        ⋁NodeIds(layer_id, last_dec_id)
    end
    
    node2id = Dict{LogicCircuit,NodeIds}()
    r = foldup_aggregate(circuit, f_con, f_lit, f_and, f_or, NodeIds, node2id)
    to⋁NodeIds(r)
    
    nodes_m = reshape(nodes, 4, :)
    elements_m = reshape(elements, 3, :)
    parents_m = Vector{NodeId}(undef, size(elements_m,2)*2)
    last_parent = zero(NodeId)
    @assert last_dec_id == size(nodes_m,2) == size(parents,1)
    @assert sum(length, parents) == length(parents_m)
    for i in 1:last_dec_id-1
        if !isempty(parents[i])
            nodes_m[3,i] = last_parent + one(NodeId)
            parents_m[last_parent + one(NodeId):last_parent + length(parents[i])] .= parents[i] 
            last_parent += length(parents[i])
            nodes_m[4,i] = last_parent
        else
            @assert i <= num_leafs "Only root and leaf nodes can have no parents: $i"
        end
    end
    
    return CatBitCircuit(layers, nodes_m, elements_m, parents_m, node2id)
end
