using Distributed
using DataFrames
using Combinatorics
using SpecialFunctions
using Dates

function score(parents::Set{Int64}, i::Int64, D::DataFrame, ESS::Float64)
    num_i_states = length(Set(D[:,i]))
    state_counts = combine(DataFrames.groupby(D, names(D)[[j for j in parents]]), names(D)[i] => x -> [sum(x.==s) for s in Set(D[:,i])])
    state_counts = convert(Vector{Int64}, state_counts[:,ncol(state_counts)])
    num_parents_states = Int64(length(state_counts) / num_i_states)
    state_counts = reshape(state_counts, (num_i_states, num_parents_states))
    alpha = ESS / Float64(num_parents_states)
    beta = ESS / (Float64(num_parents_states) * Float64(num_i_states))
    score_ = 0.0
    score_ = score_ - sum([loggamma(state_counts[k,j] + beta) for j in 1:num_parents_states for k in 1:num_i_states])
    score_ = score_ + sum([loggamma(sum(state_counts[:,j]) + alpha) for j in 1:num_parents_states])
    score_ = score_ - Float64(num_parents_states) * loggamma(alpha) 
    score_ = score_ + Float64(num_parents_states) * Float64(num_i_states) * loggamma(beta) 
    return score_
end

function CPSD(W_list::Set{Set{Int64}}, i::Int64, D::DataFrame)
    Y = Set{Int64}([W_ for W in W_list for W_ in W])
    lambda_pq = length(W_list) - 1
    U_list = copy(W_list)
    V_list = Set{Set{Int64}}([Set{Int64}([])])
    Z = Set{Int64}([])
    Z_best = Set{Int64}([])
    while length(setdiff(Y, Z)) > 0
        for Z_ in setdiff(Y, Z)
            ZZ_ = Z ∪ Z_
            U_list_ = Set{Set{Int64}}([W ∩ ZZ_ for W in W_list])
            V_list_ = Set{Set{Int64}}([W ∩ setdiff(Y, ZZ_) for W in W_list])
            UV_list = Set{Set{Int64}}([U ∪ V for U in U_list_ for V in V_list_])
            if (issubset(W_list, UV_list) == true) & ((length(U_list_) + length(V_list_) - 2) < lambda_pq)
                lambda_pq = length(U_list_) + length(V_list_) - 2
                U_list = copy(U_list_)
                V_list = copy(V_list_)
                Z_best = copy(ZZ_)
            end
        end
        if Z_best == Z
            break
        else 
            Z = copy(Z_best)
        end
    end
    return U_list, V_list
end

function worker(W_list::Set{Set{Int64}}, i::Int64, D::DataFrame)
    U_list, V_list = CPSD(W_list, i, D)
    println("i=", i, " completed (lambda_pq=", length(U_list) + length(V_list) - 2, ")")
    return [U_list, V_list]
end

function fit(D_W_list::DataFrame, D::DataFrame, ESS::Float64=1.0)
    sort!(D)
    time0 = now()
    D_W_list_ = Vector{Set{Set{Int64}}}([Set{Set{Int64}}([Set{Int64}([ii for ii in 1:ncol(D) if D_W_list[num,1+ii] == 1]) for num in 1:nrow(D_W_list) if D_W_list[num,1] == i]) for i in 1:ncol(D)])
    result_list = pmap(i -> worker(D_W_list_[i], i, D), 1:ncol(D))
    running_time = (now() - time0).value / 1000
    col = vcat(vcat(vcat(["i"], [D_ * "_" * string(ii) for ii in 1:2 for D_ in names(D)]), [D_ for D_ in names(D)]), ["score"])
    output_list = DataFrame([(if (i == (1 + 3 * ncol(D) + 1)) Float64[] else Int64[] end) for i in 1:(1 + 3 * ncol(D) + 1)], col)
    for (i, result) in enumerate(result_list)
        for U in result[1]
            hoge = vcat([i], [(if j in U 1 else 0 end) for j in 1:ncol(D)])
            for V in result[2]
                push!(output_list, vcat(vcat(vcat(hoge, [(if j in V 1 else 0 end) for j in 1:ncol(D)]), [(if j in (U ∪ V) 1 else 0 end) for j in 1:ncol(D)]), [score(U ∪ V, i, D, ESS)]))
            end
        end
    end
    return output_list, running_time
end