using Distributed
using DataFrames
using Combinatorics
using SpecialFunctions
using Dates

function score(parents::Set{Int64}, i::Int64, D::DataFrame, ESS::Float64)
    num_i_states = length(Set(D[:,i]))
    state_counts = combine(DataFrames.groupby(D, names(D)[[j for j in parents]]), names(D)[i] => x -> [sum(x.==s) for s in Set(D[:,i])])
    state_counts = convert(Vector{Int64}, state_counts[:,ncol(state_counts)])
    num_parents_states = Int64(length(state_counts) / num_i_states)
    state_counts = reshape(state_counts, (num_i_states, num_parents_states))
    alpha = ESS / Float64(num_parents_states)
    beta = ESS / (Float64(num_parents_states) * Float64(num_i_states))
    score_ = 0.0
    score_ = score_ - sum([loggamma(state_counts[k,j] + beta) for j in 1:num_parents_states for k in 1:num_i_states])
    score_ = score_ + sum([loggamma(sum(state_counts[:,j]) + alpha) for j in 1:num_parents_states])
    score_ = score_ - Float64(num_parents_states) * loggamma(alpha) 
    score_ = score_ + Float64(num_parents_states) * Float64(num_i_states) * loggamma(beta) 
    return score_
end

function CPSI(i::Int64, D::DataFrame, m::Int64, ESS::Float64, greedy::Bool, ipshi::Float64=1e-10)
    X_ = Vector{Int64}([j for j in 1:ncol(D) if j != i])
    if greedy == false
        W_list = Vector{Set{Int64}}([Set{Int64}(W) for m_ in 0:m for W in collect(combinations(X_, m_))])
        W_list_ = copy(W_list)
        Score_list = Vector{Float64}([score(W, i, D, ESS) for W in W_list_])
        Score_list_ = copy(Score_list)
        space = length(W_list)
        for (h_, W_) in enumerate(W_list_)
            for (h, W) in Iterators.reverse(enumerate(W_list))
                if (issubset(W_, W) == true) & (issetequal(W_, W) == false) & (Score_list_[h_] <= (Score_list[h] + ipshi))
                    deleteat!(W_list, h)
                    deleteat!(Score_list, h)
                end
            end
        end        
    else
        W_list = Vector{Set{Int64}}([Set{Int64}([])])
        Score_list = Vector{Float64}([score(Set{Int64}([]), i, D, ESS)])
        W_old = Vector{Set{Int64}}([Set{Int64}([])])
        space = 1
        for m_ in 1:m
            W_new_ = Set{Set{Int64}}([])
            for W in W_old
                for W1 in setdiff(X_, W)
                    W_new_ = W_new_ ∪ Set{Set{Int64}}([W ∪ Set{Int64}([W1])])
                end
            end
            W_new_ = Vector{Set{Int64}}([W for W in W_new_])
            W_new = Vector{Set{Int64}}([])
            Score_new = Vector{Float64}([])
            space = space + length(W_new_)
            for W_ in W_new_
                Score = score(W_, i, D, ESS)
                for (h, W) in enumerate(W_list)
                    if (issubset(W, W_) == true) & (Score_list[h] <= (Score + ipshi))
                        break
                    elseif h == length(W_list)
                        push!(W_new, W_)
                        push!(Score_new, Score)
                    end
                end
            end
            append!(W_list, W_new)
            append!(Score_list, Score_new)
            W_old = copy(W_new)
        end
        for (h, W) in Iterators.reverse(enumerate(W_list))
            for W_ in Vector{Set{Int64}}([Set{Int64}(W_) for m_ in 0:(length(W)-1) for W_ in collect(combinations(Vector{Int64}([W1 for W1 in W]), m_))])
                if score(W_, i, D, ESS) <= (Score_list[h] + ipshi)
                    deleteat!(W_list, h)
                    break
                end
            end
        end
    end
    return Set{Set{Int64}}(W_list)
end

function worker(i::Int64, D::DataFrame, m::Int64, ESS::Float64, greedy::Bool)
    W_list = CPSI(i, D, m, ESS, greedy)
    println("i=", i, " completed (lambda=", length(W_list) - 1, ")")
    return [W_list]
end

function fit(D::DataFrame, m::Int64=ncol(D)-1, ESS::Float64=1.0, greedy::Bool=false)
    sort!(D)
    time0 = now()
    result_list = pmap(i -> worker(i, D, m, ESS, greedy), 1:ncol(D))
    running_time = (now() - time0).value / 1000
    col = vcat(vcat(["i"], [D_ for D_ in names(D)]), ["score"])
    output_list = DataFrame([(if (i == (1 + 1 * ncol(D) + 1)) Float64[] else Int64[] end) for i in 1:(1 + 1 * ncol(D) + 1)], col)
    for (i, result) in enumerate(result_list)
        for W in result[1]
            push!(output_list, vcat(vcat([i], [(if j in W 1 else 0 end) for j in 1:ncol(D)]), [score(W, i, D, ESS)]))
        end
    end
    return output_list, running_time
end