using LogicCircuits
using ProbabilisticCircuits
using CUDA


"A `BitCircuit` with parameters attached to the elements"
mutable struct CatParamBitCircuit{V,M,W,S,T}
    bitcircuit::CatBitCircuit{V,M}
    params::W
    cat_params::S
    lit_to_var::T
end

function CatParamBitCircuit(pc::ProbCircuit, num_features)
    logprobs::Vector{Float64} = Vector{Float64}()
    cat_probs::Matrix{Float64} = Matrix{Float64}(undef, num_features, num_categories(pc))
    lit_to_var::Vector{Var} = Vector{Var}(undef, num_features)
    sizehint!(logprobs, num_edges(pc))
    
    @inbounds @views cat_probs[:,:] .= log(1.0 / num_categories(pc))
    @inbounds @views lit_to_var[:] .= Var(1)
    
    on_decision(n, cs, layer_id, decision_id, first_element, last_element) = begin
        if isnothing(n) # this decision node is not part of the PC
            # @assert first_element == last_element
            push!(logprobs, 0.0)
        else
            # @assert last_element-first_element+1 == length(n.log_probs) 
            append!(logprobs, n.log_probs)
        end
    end
    bc = CatBitCircuit(pc, num_features; on_decision)
    
    foreach(pc) do n
        if n isa PlainProbCategoricalNode
            @inbounds @views cat_probs[n.literal, 1:n.num_cats] .= n.log_probs
            @inbounds @views lit_to_var[n.literal] = n.variable
        end
    end
    
    CatParamBitCircuit(bc, logprobs, cat_probs, lit_to_var)
end

import ProbabilisticCircuits: num_nodes, num_elements, num_features, num_leafs, nodes, elements

num_nodes(c::CatParamBitCircuit) = num_nodes(c.bitcircuit)
num_nodes(c::CatBitCircuit) = size(c.nodes, 2)
num_elements(c::CatParamBitCircuit) = num_elements(c.bitcircuit)
num_elements(c::CatBitCircuit) = size(c.elements, 2)
num_features(c::CatParamBitCircuit) = num_features(c.bitcircuit)
num_features(c::CatBitCircuit) = num_features(num_leafs(c))
num_leafs(c::CatParamBitCircuit) = num_leafs(c.bitcircuit)
num_leafs(c::CatBitCircuit) = num_leafs(length(c.layers[1]))

nodes(c::CatParamBitCircuit) = nodes(c.bitcircuit)
nodes(c::CatBitCircuit) = c.nodes
elements(c::CatParamBitCircuit) = elements(c.bitcircuit)
elements(c::CatBitCircuit) = c.elements

import ProbabilisticCircuits: to_gpu, to_cpu, isgpu #extend

to_gpu(c::CatParamBitCircuit) = 
    CatParamBitCircuit(to_gpu(c.bitcircuit), to_gpu(c.params), to_gpu(c.cat_params), to_gpu(c.lit_to_var))

to_cpu(c::CatParamBitCircuit) = 
    CatParamBitCircuit(to_cpu(c.bitcircuit), to_cpu(c.params), to_cpu(c.cat_params), to_cpu(c.lit_to_var))

isgpu(c::CatParamBitCircuit) = 
    isgpu(c.bitcircuit) && isgpu(c.params) && isgpu(c.cat_params) && isgpu(c.lit_to_var)

to_gpu(c::CatBitCircuit) = 
    CatBitCircuit(map(to_gpu, c.layers), to_gpu(c.nodes), to_gpu(c.elements), to_gpu(c.parents), c.node2id)

to_cpu(c::CatBitCircuit) = 
    CatBitCircuit(map(to_cpu, c.layers), to_cpu(c.nodes), to_cpu(c.elements), to_cpu(c.parents), c.node2id)

isgpu(c::CatBitCircuit{<:CuArray,<:CuArray}) = true
isgpu(c::CatBitCircuit{<:Array,<:Array}) = false