using LogicCircuits
using ProbabilisticCircuits
using CUDA
using StatsFuns: logaddexp


"A `BitCircuit` with parameters attached to the elements"
mutable struct CompressParamBitCircuit{V,M,W,S,T,F}
    bitcircuit::CompressBitCircuit{V,M,F}
    params::W
    cat_params::S
    cumulative_cat_params::S
    lit_to_var::T
end

function CompressParamBitCircuit(pc::ProbCircuit)
    CompressParamBitCircuit(pc, num_variables(pc))
end

function CompressParamBitCircuit(pc::ProbCircuit, num_features)
    logprobs::Vector{Float64} = Vector{Float64}()
    cat_probs::Matrix{Float64} = Matrix{Float64}(undef, num_features, num_categories(pc))
    cumulative_cat_probs::Matrix{Float64} = Matrix{Float64}(undef, num_features, num_categories(pc))
    lit_to_var::Vector{Var} = Vector{Var}(undef, num_features)
    sizehint!(logprobs, num_edges(pc))
    
    @inbounds @views cat_probs[:,:] .= log(1.0 / num_categories(pc))
    @inbounds @views lit_to_var[:] .= Var(1)
    
    on_decision(n, cs, layer_id, decision_id, first_element, last_element) = begin
        if isnothing(n) # this decision node is not part of the PC
            # @assert first_element == last_element
            push!(logprobs, 0.0)
        else
            # @assert last_element-first_element+1 == length(n.log_probs) 
            append!(logprobs, n.log_probs)
        end
    end
    bc = CompressBitCircuit(pc, num_features; on_decision)
    
    foreach(pc) do n
        if n isa PlainProbCategoricalNode
            @inbounds @views cat_probs[n.literal, 1:n.num_cats] .= n.log_probs
            @inbounds @views lit_to_var[n.literal] = n.variable
        end
    end
    
    @inbounds @views cumulative_cat_probs[:,1] .= cat_probs[:,1]
    for j = 2 : num_categories(pc)
        @inbounds @views cumulative_cat_probs[:,j] .= logaddexp.(cumulative_cat_probs[:,j-1], cat_probs[:,j])
    end
    
    CompressParamBitCircuit(bc, logprobs, cat_probs, cumulative_cat_probs, lit_to_var)
end

import ProbabilisticCircuits: num_nodes, num_elements, num_features, num_leafs, nodes, elements, num_variables

num_nodes(c::CompressParamBitCircuit) = num_nodes(c.bitcircuit)
num_nodes(c::CompressBitCircuit) = size(c.nodes, 2)
num_elements(c::CompressParamBitCircuit) = num_elements(c.bitcircuit)
num_elements(c::CompressBitCircuit) = size(c.elements, 2)
num_features(c::CompressParamBitCircuit) = num_features(c.bitcircuit)
num_features(c::CompressBitCircuit) = num_features(num_leafs(c))
num_leafs(c::CompressParamBitCircuit) = num_leafs(c.bitcircuit)
num_leafs(c::CompressBitCircuit) = begin
    n_leafs = 0
    for idx = 1 : length(c.layers)
        n_leafs += num_leafs(length(c.layers[idx][1]))
    end
    n_leafs
end

nodes(c::CompressParamBitCircuit) = nodes(c.bitcircuit)
nodes(c::CompressBitCircuit) = c.nodes
elements(c::CompressParamBitCircuit) = elements(c.bitcircuit)
elements(c::CompressBitCircuit) = c.elements

num_variables(c::CompressParamBitCircuit) = num_variables(c.bitcircuit)
num_variables(c::CompressBitCircuit) = length(c.var_order)

num_categories(c::CompressParamBitCircuit) = size(c.cat_params, 2)

import ProbabilisticCircuits: to_gpu, to_cpu, isgpu #extend

to_gpu(c::CompressParamBitCircuit) = 
    CompressParamBitCircuit(to_gpu(c.bitcircuit), to_gpu(c.params), to_gpu(c.cat_params), 
                            to_gpu(c.cumulative_cat_params), to_gpu(c.lit_to_var))

to_cpu(c::CompressParamBitCircuit) = 
    CompressParamBitCircuit(to_cpu(c.bitcircuit), to_cpu(c.params), to_cpu(c.cat_params), 
                            to_cpu(c.cumulative_cat_params), to_cpu(c.lit_to_var))

isgpu(c::CompressParamBitCircuit) = 
    isgpu(c.bitcircuit) && isgpu(c.params) && isgpu(c.cat_params) && isgpu(c.cumulative_cat_params) && isgpu(c.lit_to_var)

to_gpu(c::CompressBitCircuit) = begin
    layers_gpu = map(c.layers) do layer
        map(to_gpu, layer)
    end
    CompressBitCircuit(layers_gpu, map(to_gpu, c.root_node_weights), to_gpu(c.nodes), to_gpu(c.elements), c.var_order, c.node2id)
end

to_cpu(c::CompressBitCircuit) = begin
    layers_cpu = map(c.layers) do layer
        map(to_cpu, layer)
    end
    CompressBitCircuit(layers_cpu, map(to_cpu, c.root_node_weights), to_cpu(c.nodes), to_cpu(c.elements), c.var_order, c.node2id)
end

isgpu(c::CompressBitCircuit{<:CuArray,<:CuArray,<:Any}) = true
isgpu(c::CompressBitCircuit{<:Array,<:Array,<:Any}) = false