using LogicCircuits
using ProbabilisticCircuits
using SimpleWeightedGraphs: SimpleWeightedGraph, SimpleWeightedEdge
using LightGraphs
using MetaGraphs
using Graphs: num_vertices
using DataFrames
using DataStructures
using Random
using Distributions
using StatsFuns: logsumexp, logaddexp
using CUDA: CUDA, @cuda


function chow_liu_circuit(num_vars, num_cats = 2, ::Type{T} = ProbCircuit; data::DataFrame, 
                          data_for_mi = nothing) where T
    # Chow-Liu Tree (CLT) given data
    clt = chow_liu_tree(data, num_vars, num_cats; num_trees = 1, dropout_prob = 0.0, data_for_mi)[1]
    
    observed_leafs = categorical_leafs(num_vars, num_cats, T)
    
    pc = chow_liu_circuit(clt, num_vars, num_cats, observed_leafs; data)

    pc
end

function chow_liu_circuit(clt::CLT, num_vars, num_cats, observed_leafs, 
                          ::Type{T} = ProbCircuit; data::DataFrame) where T
    # Construct the CLT circuit bottom-up
    node_seq = bottom_up_traverse_node_seq(clt)
    for curr_node in node_seq
        out_neighbors = outneighbors(clt, curr_node)
        
        # meaning: `circuits' of leaf CLT nodes refer to a collection of marginal distribution Pr(X);
        #          `circuits' of an inner CLT node (corr. var Y) is a collection of joint distributions
        #              over itself and its child vars (corr. var X_1, ..., X_k): Pr(Y)Pr(X_1|Y)...Pr(X_k|Y)
        
        if length(out_neighbors) == 0
            # Leaf node
            set_prop!(clt, curr_node, :circuits, summate.(observed_leafs[curr_node, :]))
        else
            # Inner node
            
            c_nodes = [[summate(observed_leafs[curr_node, c_idx])] for c_idx = 1 : num_cats]
            for child_node in out_neighbors
                curr_cs = get_prop(clt, child_node, :circuits)
                if is⋁gate(curr_cs[1])
                    curr_cs = [children(n)[1] for n in curr_cs]
                end
                for c_idx = 1 : num_cats
                    push!(c_nodes[c_idx], summate(curr_cs...))
                end
            end
            
            # Pr(X_1)...Pr(X_k) -> Pr(Y)Pr(X_1|Y)...Pr(X_k|Y)
            circuits = [summate(multiply(c_nodes[c_idx]...)) for c_idx = 1 : num_cats]
            set_prop!(clt, curr_node, :circuits, circuits)
        end
    end
    
    curr_cs = get_prop(clt, node_seq[end], :circuits)
    if is⋁gate(curr_cs[1])
        curr_cs = [children(n)[1] for n in curr_cs]
    end
    summate(curr_cs...)
end