using LogicCircuits
using ProbabilisticCircuits


## Extend plain_prob_nodes.jl

"A probabilistic categorical node"
mutable struct PlainProbCategoricalNode <: PlainProbLeafNode
    variable::Var
    literal::Lit # An universal leaf node identifier
    num_cats::UInt32
    log_probs::Vector{Float64}
    PlainProbCategoricalNode(v, l, n_cats) = begin
        new(v, l, n_cats, log.(ones(Float64, n_cats) / n_cats))
    end
    PlainProbCategoricalNode(v, l, n_cats, p) = begin
        new(v, l, n_cats, p)
    end
end

import ProbabilisticCircuits.num_parameters

@inline num_parameters_cat(c::PlainProbCategoricalNode) = c.num_cats
@inline num_parameters_cat(c::ProbCircuit) = 
    num_parameters(c) + sum(n -> num_parameters_cat(n), literal_nodes(c))

import ProbabilisticCircuits.GateType

@inline GateType(::Type{<:PlainProbCategoricalNode}) = LiteralGate()

import ProbabilisticCircuits.compile

compile(::Type{<:PlainProbCircuit}, v::Var, l::Lit, n_cats::Integer) =
    PlainProbCategoricalNode(v, l, UInt32(n_cats))

function compile_cat(::Type{<:ProbCircuit}, circuit::ProbCircuit)
    f_con(n) = error("")
    f_lit(n) = begin
        pc = compile(ProbCircuit, n.variable, n.literal, n.num_cats)
        pc.log_probs .= n.log_probs
        pc
    end
    f_a(_, cns) = multiply(cns...)
    f_o(n, cns) = begin
        pc = summate(cns...)
        pc.log_probs .= n.log_probs
        pc
    end
    foldup_aggregate(circuit, f_con, f_lit, f_a, f_o, ProbCircuit)
end

function compile_cat_to_bin(pc_cat::ProbCircuit)
    leaves = categorical_leafs(num_cat_vars(pc_cat), num_categories(pc_cat));
    f_con(n) = error("")
    f_lit(n) = begin
        pc = summate(leaves[n.variable,:]...)
        pc.log_probs .= n.log_probs
        pc
    end
    f_a(_, cns) = multiply(cns...)
    f_o(n, cns) = begin
        pc = summate(cns...)
        pc.log_probs .= n.log_probs
        pc
    end
    foldup_aggregate(pc_cat, f_con, f_lit, f_a, f_o, ProbCircuit)
end

function clone(pc::ProbCircuit)
    f_con(n) = error("")
    f_lit(n) = begin
        PlainProbCategoricalNode(
            n.variable,
            n.literal,
            n.num_cats,
            deepcopy(n.log_probs)
        )
    end
    f_a(_, cns) = multiply(cns...)
    f_o(n, cns) = begin
        pc = summate(cns...)
        pc.log_probs .= n.log_probs
        pc
    end
    foldup_aggregate(pc, f_con, f_lit, f_a, f_o, ProbCircuit)
end

"Generate categorical leaf nodes"
categorical_leafs(::Type{T}, num_leafs::Integer, n_cats::Integer; 
                  var_idx_offset = 0, lit_idx_offset = 0) where {T<:LogicCircuit} = 
    map(v -> compile(T, Var(v + var_idx_offset), Lit(v + lit_idx_offset), n_cats), 1 : num_leafs)
categorical_leafs(::Type{T}, num_leafs::Integer, num_groups::Integer, n_cats::Integer; 
                  var_idx_offset = 0, lit_idx_offset = 0) where {T<:LogicCircuit} = begin
    cat_leaf(var, group) = begin
        group_offset = (group - 1) * num_leafs
        compile(T, Var(var + var_idx_offset), Lit(var + group_offset + lit_idx_offset), n_cats)
    end

    cat_leaf.(1:num_leafs, (1:num_groups)')
end

num_categories(pc::ProbCircuit) = begin
    f_con(n) = UInt32(0)
    f_lit(n) = n.num_cats
    f_a(_, cns) = maximum(cns)
    f_o(_, cns) = maximum(cns)
    foldup_aggregate(pc, f_con, f_lit, f_a, f_o, UInt32)
end

cat_vars(pc::ProbCircuit) = begin
    f_con(n) = BitSet()
    f_lit(n) = BitSet(n.variable)
    f_a(_, cns) = union(cns...)
    f_o(_, cns) = union(cns...)
    vars = foldup_aggregate(pc, f_con, f_lit, f_a, f_o, BitSet)
end

num_cat_vars(pc::ProbCircuit) = length(cat_vars(pc))