using LogicCircuits
using ProbabilisticCircuits


abstract type NodeLine end

struct LeafNodeLine <: NodeLine
    node_idx::UInt32
    lit::Lit
end

struct CatNodeLine <: NodeLine
    node_idx::UInt32
    lit::Lit
    var::Var
    log_probs::Vector{Float64}
end

struct ProductNodeLine <: NodeLine
    node_idx::UInt32
    children::Vector{UInt32}
end

struct SumNodeLine <: NodeLine
    node_idx::UInt32
    children::Vector{UInt32}
    log_probs::Vector{Float64}
end

function compile_line(node_line::LeafNodeLine)::String
    "L " * string(node_line.node_idx) * " " * string(node_line.lit)
end
function compile_line(node_line::CatNodeLine)::String
    str = "C " * string(node_line.node_idx) * " " * string(node_line.var) * " " * string(node_line.lit)
    for log_prob in node_line.log_probs
        str *= " " * string(log_prob)
    end
    
    str
end
function compile_line(node_line::ProductNodeLine)::String
    str = "P " * string(node_line.node_idx)
    for child_idx in node_line.children
        str *= " " * string(child_idx)
    end
    
    str
end
function compile_line(node_line::SumNodeLine)::String
    str = "S " * string(node_line.node_idx)
    for (child_idx, log_prob) in zip(node_line.children, node_line.log_probs)
        str *= " " * string(child_idx) * " " * string(log_prob)
    end
    
    str
end

function save_pc(file_name::String, pc::ProbCircuit)
    node2idx = Dict{ProbCircuit, UInt32}()
    curr_idx::UInt32 = one(UInt32)
    node_lines = Vector{NodeLine}()
    foreach(pc) do n
        node2idx[n] = curr_idx
        
        if GateType(n) isa ⋁Gate
            children_idxs = map(x -> node2idx[x], n.children)
            push!(node_lines, SumNodeLine(curr_idx, children_idxs, n.log_probs))
        elseif GateType(n) isa ⋀Gate
            children_idxs = map(x -> node2idx[x], n.children)
            push!(node_lines, ProductNodeLine(curr_idx, children_idxs))
        elseif n isa PlainProbCategoricalNode
            push!(node_lines, CatNodeLine(curr_idx, n.literal, n.variable, n.log_probs))
        elseif GateType(n) isa LeafGate
            push!(node_lines, LeafNodeLine(curr_idx, literal(n)))
        else
            @assert false, "Unexpected GateType $(GateType(n))"
        end
        
        curr_idx += one(UInt32)
    end
    
    open(file_name, "w") do f
        for node_line in node_lines
            println(f, compile_line(node_line))
        end
    end
end

function load_pc(file_name::String)
    idx2node = Dict{UInt32, ProbCircuit}()
    node_lines = open(file_name, "r") do f
        split(read(f, String), "\n")
    end
    circuit = nothing
    for node_line in node_lines
        tokens = split(node_line)
        if startswith(node_line, 'L')
            lit = parse(Lit, tokens[3])
            n = compile(ProbCircuit, lit)
        elseif startswith(node_line, 'C')
            var = parse(Var, tokens[3])
            lit = parse(Lit, tokens[4])
            log_probs = Vector{Float64}()
            for idx = 5 : length(tokens)
                push!(log_probs, parse(Float64, tokens[idx]))
            end
            num_cats = length(log_probs)
            n = PlainProbCategoricalNode(var, lit, num_cats, log_probs)
        elseif startswith(node_line, 'P')
            children = Vector{ProbCircuit}()
            for idx = 3 : length(tokens)
                push!(children, idx2node[parse(UInt32, tokens[idx])])
            end
            n = multiply(children...)
        elseif startswith(node_line, 'S')
            children = Vector{ProbCircuit}()
            log_probs = Vector{Float64}()
            for idx = 3 : 2 : length(tokens)
                push!(children, idx2node[parse(UInt32, tokens[idx])])
                push!(log_probs, parse(Float64, tokens[idx + 1]))
            end
            n = summate(children...)
            n.log_probs .= log_probs
        else
            continue
        end
        idx2node[parse(UInt32, tokens[2])] = n
        circuit = n
    end
    
    circuit
end