using LogicCircuits
using ProbabilisticCircuits
using SimpleWeightedGraphs: SimpleWeightedGraph, SimpleWeightedEdge
using LightGraphs
using MetaGraphs
using Graphs: num_vertices
using DataFrames
using DataStructures
using Random
using Distributions
using StatsFuns: logsumexp, logaddexp
using CUDA: CUDA, @cuda


function dual_hidden_chow_liu_circuit(num_vars, num_cats = 2, num_channels = 3, ::Type{T} = ProbCircuit; data::DataFrame, 
                                      num_hidden_cats::Tuple{Integer,Integer} = (16, 8), num_trees::Integer = 1, 
                                      num_tree_candidates::Integer = 1, tree_sample_type::String = "fixed_interval",
                                      dropout_prob::Float64 = 0.0, data_for_mi = nothing, debug = false) where T
    # Chow-Liu Tree (CLT) given data
    clts = chow_liu_tree(data, num_vars, num_cats; num_trees = num_tree_candidates, dropout_prob, data_for_mi)
    
    # Sample `num_trees` trees from the `num_tree_candidates` candidates
    if tree_sample_type == "random"
        clts = clts[randperm(num_tree_candidates)[1:num_trees]]
    elseif tree_sample_type == "fixed_interval"
        clts = clts[Int.(round.(LinRange(1, num_tree_candidates, num_trees)))]
    end
    
    observed_leafs = map(1:num_channels) do c_idx
        categorical_leafs(num_vars, num_cats, T; var_idx_offset = (c_idx - 1) * num_vars * num_bits_for_cats(num_cats))
    end
    
    circuits = map(clts) do clt
        pc = dual_hidden_chow_liu_circuit(clt, num_vars, num_cats, observed_leafs, num_channels; 
                                          data = data, num_hidden_cats = num_hidden_cats, debug)
        pc
    end
    
    children::Array{T} = Array{T}(undef, 0)
    for circuit in circuits
        append!(children, circuit.children)
    end
    summate(children...)
end

function dual_hidden_chow_liu_circuit(clt::CLT, num_vars, num_cats, observed_leafs, num_channels, 
                                      ::Type{T} = ProbCircuit; data::DataFrame, var_idx_offset::Integer = 0, 
                                      num_hidden_cats::Tuple{Integer,Integer} = (16, 8), debug = false) where T
    @assert num_channels > 1
    
    # meaning: `joined_leafs[i,j]` is a distribution of the hidden variable `i` having value `j` 
    # conditioned on the observed variable `i`
    gen_joined_leaf(var_idx, hidden_cat_idx) = begin
        pcs = map(1:num_hidden_cats[2]) do idx
            channel_leafs = map(1:num_channels) do c_idx
                summate(observed_leafs[c_idx][var_idx, :])
            end
            multiply(channel_leafs...)
        end
        summate(pcs...)
    end
    joined_leafs = gen_joined_leaf.(1:num_vars, (1:num_hidden_cats[1])')
    
    # Construct the CLT circuit bottom-up
    node_seq = bottom_up_traverse_node_seq(clt; debug)
    for curr_node in node_seq
        out_neighbors = outneighbors(clt, curr_node)
        
        # meaning: `circuits' of leaf CLT nodes refer to a collection of marginal distribution Pr(X);
        #          `circuits' of an inner CLT node (corr. var Y) is a collection of joint distributions
        #              over itself and its child vars (corr. var X_1, ..., X_k): Pr(Y)Pr(X_1|Y)...Pr(X_k|Y)
        
        if length(out_neighbors) == 0
            # Leaf node
            set_prop!(clt, curr_node, :circuits, joined_leafs[curr_node, :])
        else
            # Inner node
            
            # Each element in `child_circuits' represents the joint distribution of the child nodes, 
            # i.e., Pr(X_1)...Pr(X_k)
            child_circuits = [get_prop(clt, child_node, :circuits) for child_node in out_neighbors]
            if length(out_neighbors) > 1
                child_circuits = [summate(multiply([child_circuit[cat_idx] for child_circuit in child_circuits])) for cat_idx = 1 : num_hidden_cats[1]]
            else
                child_circuits = child_circuits[1]
            end
            # Pr(X_1)...Pr(X_k) -> Pr(Y)Pr(X_1|Y)...Pr(X_k|Y)
            circuits = [summate(multiply.(child_circuits, joined_leafs[curr_node, :])) for cat_idx = 1 : num_hidden_cats[1]]
            set_prop!(clt, curr_node, :circuits, circuits)
        end
    end
    
    get_prop(clt, node_seq[end], :circuits)[1] # A ProbCircuit node
end