using LogicCircuits
using ProbabilisticCircuits
using StatsFuns


function convert_to_sd_pc(pc::ProbCircuit)
    vtree_cache = Dict{String,Vtree}()
    vtree_str_cache = Dict{Vtree,String}()
    pc2vtree = Dict{ProbCircuit,Vtree}()
    vtree2pc = Dict{Vtree,Vector{ProbCircuit}}()
    
    f_con(n)::Tuple{ProbCircuit,Vtree} = error("Do not support constant node.")
    f_lit(n)::Tuple{ProbCircuit,Vtree} = begin
        new_n = PlainProbCategoricalNode(n.variable, n.literal, n.num_cats, deepcopy(n.log_probs))
        vtree_str = "($(n.variable))"
        v = get!(vtree_cache, vtree_str) do 
            vtree = PlainVtreeLeafNode(Var(n.variable))
            vtree_str_cache[vtree] = vtree_str
            vtree
        end
        pc2vtree[new_n] = v
        if v in keys(vtree2pc)
            push!(vtree2pc[v], new_n)
        else
            vtree2pc[v] = [new_n]
        end
        new_n, v
    end
    f_a(n, cns)::Tuple{ProbCircuit,Vtree} = begin
        convert_to_sum(n, v) = begin
            if is⋀gate(n)
                new_n = summate(n)
                pc2vtree[new_n] = v
                vtree2pc[v] = new_n
                new_n, v
            else
                n, v
            end
        end
        
        m, vm = cns[1]
        m, vm = convert_to_sum(m, vm)
        for (idx, (cm, vcm)) in enumerate(cns[2:end])
            cm, vcm = convert_to_sum(cm, vcm)
            vm_len = (vm isa PlainVtreeLeafNode) ? 1 : length(vm.variables)
            vcm_len = (vcm isa PlainVtreeLeafNode) ? 1 : length(vcm.variables)
            if vcm_len > vm_len
                m, cm = cm, m
                vm, vcm = vcm, vm
            end
            
            if idx < length(cns) - 1
                m = summate(multiply(m, cm))
                vtree_str = "($(vtree_str_cache[vm])|$(vtree_str_cache[vcm]))"
                vm = get!(vtree_cache, vtree_str) do
                    vtree = PlainVtreeInnerNode(vm, vcm)
                    vtree_str_cache[vtree] = vtree_str
                    vtree
                end
                pc2vtree[m] = vm
                pc2vtree[m.children[1]] = vm
                if vm in keys(vtree2pc)
                    push!(vtree2pc[vm], m)
                    push!(vtree2pc[vm], m.children[1])
                else
                    vtree2pc[vm] = [m, m.children[1]]
                end
            else
                m = multiply(m, cm)
                vtree_str = "($(vtree_str_cache[vm])|$(vtree_str_cache[vcm]))"
                vm = get!(vtree_cache, vtree_str) do
                    vtree = PlainVtreeInnerNode(vm, vcm)
                    vtree_str_cache[vtree] = vtree_str
                    vtree
                end
                pc2vtree[m] = vm
                if vm in keys(vtree2pc)
                    push!(vtree2pc[vm], m)
                else
                    vtree2pc[vm] = [m]
                end
            end
        end
        m, vm
    end
    f_o(n, cns)::Tuple{ProbCircuit,Vtree} = begin
        new_n = summate([item[1] for item in cns]...)
        new_n.log_probs = n.log_probs
        v = pc2vtree[cns[1][1]]
        pc2vtree[new_n] = v
        push!(vtree2pc[v], new_n)
        new_n, v
    end
    pc, vtree = foldup_aggregate(pc, f_con, f_lit, f_a, f_o, Tuple{ProbCircuit,Vtree})
    
    pc, vtree, pc2vtree, vtree2pc
end


function pc_add_pseudo_category(pc::ProbCircuit; w = 1.0)
    f_con(n)::ProbCircuit = error("Do not support constant node.")
    f_lit(n)::ProbCircuit = begin
        new_n = PlainProbCategoricalNode(n.variable, n.literal, n.num_cats + 1)
        @inbounds @views new_n.log_probs[1:end-1] .= n.log_probs
        @inbounds new_n.log_probs[end] = w / n.num_cats
        @inbounds @views new_n.log_probs .-= logsumexp(new_n.log_probs)
        new_n
    end
    f_a(n, cns)::ProbCircuit = multiply(cns...)
    f_o(n, cns)::ProbCircuit = begin
        new_n = summate(cns...)
        @inbounds @views new_n.log_probs .= n.log_probs
        new_n
    end
    foldup_aggregate(pc, f_con, f_lit, f_a, f_o, ProbCircuit)
end


function smooth_pc(pc::ProbCircuit; w = 1.0, leaf_only = true)
    f_con(n)::ProbCircuit = error("Do not support constant node.")
    f_lit(n)::ProbCircuit = begin
        new_n = PlainProbCategoricalNode(n.variable, n.literal, n.num_cats, n.log_probs)
        @inbounds @views new_n.log_probs .= log.(exp.(new_n.log_probs) .+ (w / n.num_cats))
        @inbounds @views new_n.log_probs .-= logsumexp(new_n.log_probs)
        new_n
    end
    f_a(n, cns)::ProbCircuit = multiply(cns...)
    f_o(n, cns)::ProbCircuit = begin
        new_n = summate(cns...)
        @inbounds @views new_n.log_probs .= n.log_probs
        if !leaf_only
            @inbounds @views new_n.log_probs .= log.(exp.(new_n.log_probs) .+ (w / length(n.log_probs)))
            @inbounds @views new_n.log_probs .-= logsumexp(new_n.log_probs)
        end
        new_n
    end
    foldup_aggregate(pc, f_con, f_lit, f_a, f_o, ProbCircuit)
end