using LogicCircuits
using ProbabilisticCircuits
using StatsFuns: logsumexp


function prune_hclt(pc::ProbCircuit; atol = 1e-3, prune_structure = false)
    if prune_structure
        f_con1(n)::ProbCircuit = error("")
        f_lit1(n)::ProbCircuit = PlainProbCategoricalNode(n.variable, n.literal, n.num_cats, deepcopy(n.log_probs))
        f_a1(_, cns)::ProbCircuit = multiply(cns...)
        f_o1(n, cns)::ProbCircuit = begin
            if n === pc
                m = summate(cns...)
                m.log_probs .= n.log_probs
                return m
            end
            p = exp.(n.log_probs)
            p_max = maximum(p)
            p = p ./ p_max
            log_probs = Vector{Float64}()
            chs = Vector{ProbCircuit}()
            for i = 1 : length(cns)
                if p[i] > atol
                    push!(log_probs, n.log_probs[i])
                    push!(chs, cns[i])
                end
            end
            @assert length(chs) > 0
            log_probs .-= logsumexp(log_probs)
            m = summate(chs...)
            m.log_probs .= log_probs
            m
        end
        foldup_aggregate(pc, f_con1, f_lit1, f_a1, f_o1, ProbCircuit)
    else
        f_con2(n)::ProbCircuit = error("")
        f_lit2(n)::ProbCircuit = PlainProbCategoricalNode(n.variable, n.literal, n.num_cats, deepcopy(n.log_probs))
        f_a2(_, cns)::ProbCircuit = multiply(cns...)
        f_o2(n, cns)::ProbCircuit = begin
            if n === pc
                m = summate(cns...)
                m.log_probs .= n.log_probs
                return m
            end
            p = exp.(n.log_probs)
            p_max = maximum(p)
            p = p ./ p_max
            log_probs = Vector{Float64}()
            for i = 1 : length(cns)
                if p[i] > atol
                    push!(log_probs, n.log_probs[i])
                else
                    push!(log_probs, typemin(Float64))
                end
            end
            log_probs .-= logsumexp(log_probs)
            m = summate(cns...)
            m.log_probs .= log_probs
            m
        end
        foldup_aggregate(pc, f_con2, f_lit2, f_a2, f_o2, ProbCircuit)
    end
end