# =================================================================================================#
# Description: Produces the experimental results for the Gadget comparisons
# Author: Ryan Thompson
# =================================================================================================#

using Distributed

Distributed.addprocs(2)

Distributed.@sync Distributed.@everywhere begin

cd("/Experiments")
include("pro_dag.jl")
include("metrics.jl")

import CSV, CUDA, DataFrames, Distributions, Graphs, JSON, LinearAlgebra, ProgressMeter, PyCall, 
    Random, Statistics

#==================================================================================================#
# Function to generate data
#==================================================================================================#

function gendata(par)

    # Save scenario parameters
    n, p, s, graph_type, dist, id = par

    # Generate graph
    if graph_type == "erdos_renyi"
        g = Graphs.erdos_renyi(p, s)
    elseif graph_type == "scale_free_2"
        g = Graphs.static_scale_free(p, s, 2)
    elseif graph_type == "scale_free_3"
        g = Graphs.static_scale_free(p, s, 3)
    end
    g = Graphs.random_orientation_dag(g)

    # Create weighted adjacency matrix from graph
    w_values = rand(Distributions.Uniform(0.3, 0.7), p, p) .* rand([- 1, 1], p, p)
    w = Matrix(Graphs.adjacency_matrix(g)) .* w_values

    # Generate features
    if dist == "gaussian"
        ε = rand(Distributions.Normal(0, 1), n, p)
        ε_val = rand(Distributions.Normal(0, 1), round(Int, 0.1 * n), p)
    elseif dist == "gumbel"
        ε = rand(Distributions.Gumbel(0, 1), n, p)
        ε_val = rand(Distributions.Gumbel(0, 1), round(Int, 0.1 * n), p)
    elseif dist == "exponential"
        ε = rand(Distributions.Exponential(1), n, p)
        ε_val = rand(Distributions.Exponential(1), round(Int, 0.1 * n), p)
    end
    x = ε * LinearAlgebra.inv(LinearAlgebra.I - w)
    x_val = ε_val * LinearAlgebra.inv(LinearAlgebra.I - w)
    x_mean = mapslices(Statistics.mean, x, dims = 1)
    x_val .-= x_mean
    x .-= x_mean

    # Return generated data
    x, x_val, w

end

#==================================================================================================#
# Function to evaluate a model
#==================================================================================================#

function evaluate!(result, estimator, ŵ, w, par)

    # Save scenario parameters
    n, p, s, graph_type, dist, id = par

    # Reshape as array if matrix
    if length(size(ŵ)) == 2
        ŵ = reshape(ŵ, size(ŵ, 1), size(ŵ, 2), 1)
    end

    # Save data dimension
    n_sample = size(ŵ, 3)

    # Compute Brier score
    bscore = brier_score(w, ŵ)

    # Compute structural Hamming distance
    shd = sum(map(i -> struct_hamming_dist(w, ŵ[:, :, i]), 1:n_sample)) / n_sample

    # Compute F1 score
    f1score = sum(map(i -> f1_score(w, ŵ[:, :, i]), 1:n_sample)) / n_sample

    # Compute AUROC
    auroc = binary_auroc(w, ŵ)

    # Compute sparsity levels
    sparsity = sum(ŵ .≠ 0) / n_sample

    # Check if all graphs are DAGs
    dag_rate = sum(map(w -> !Graphs.is_cyclic(Graphs.SimpleDiGraph(w)), eachslice(ŵ, dims = 3))) / 
        n_sample

    # Update results
    push!(result, [estimator, bscore, shd, f1score, auroc, sparsity, dag_rate, n, p, s, graph_type, 
        dist, id])

end

#==================================================================================================#
# Function to run a given simulation design
#==================================================================================================#

function runsim(par)

    process_id = Distributed.myid()
    gpu_id = (process_id - 1) % 2
    CUDA.device!(gpu_id)
    CUDA.seed!(hash(par))
    Random.seed!(hash(par))

    # Set aside space for results
    result = DataFrames.DataFrame(
        estimator = [], bscore = [], shd = [], f1score = [], auroc = [], sparsity = [], 
        dag_rate = [], n = [], p = [], s = [], graph_type = [], dist = [], id = []
    )

    # Generate data
    x, x_val, w = gendata(par)

    # Evaluate ProDAG
    fit = ProDAG.fit_linear(x, prior_α = Inf, verbose = false)
    init_μ = fit.μ
    init_σ = fit.σ
    ŵ = Vector{Any}(undef, 10)
    ŵ[1] = ProDAG.sample(fit)
    α = range(sum(abs.(ŵ[1])) / size(ŵ[1], 3), 0, length(ŵ))
    for i in 2:length(ŵ) - 1
        fit = ProDAG.fit_linear(x, prior_α = α[i], init_μ = init_μ, init_σ = init_σ, 
            verbose = false)
        init_μ = fit.μ
        init_σ = fit.σ
        ŵ[i] = ProDAG.sample(fit)
    end
    fit = ProDAG.fit_linear(x, prior_α = 0, verbose = false)
    ŵ[length(ŵ)] = ProDAG.sample(fit)
    best_i = argmin([Statistics.mean([sum((x_val - x_val * ŵ[i][:, :, j]) .^ 2) 
        for j in 1:size(ŵ[i], 3)]) for i in 1:length(ŵ)])
    evaluate!(result, "ProDAG", ŵ[best_i], w, par)

    # Evaluate Gadget
    numpy = PyCall.pyimport("numpy")
    numpy.random.seed(Int64(hash(par) % Int64(2 ^ 32)))
    sumu = PyCall.pyimport("sumu")
    data = sumu.Data(PyCall.PyObject(x))
    dags, _ = sumu.Gadget(data = data, 
        run_mode = Dict("name" => "normal", "params" => Dict("n_target_chain_iters" => 1000000)), 
        metropolis_coupling = Dict("name" => "static", "params" => Dict("M" => 16))).sample()
    ŵ = [sumu.bnet.family_sequence_to_adj_mat(PyCall.PyObject([(d[1], d[2]) for d in dags[i, :]])) 
        for i in 1:size(dags, 1)]
    ŵ = reduce((a, b) -> cat(a, b; dims = 3), ŵ)
    evaluate!(result, "Gadget", ŵ, w, par)

    CUDA.reclaim()

    result

end

end

#==================================================================================================#
# Run simulations
#==================================================================================================#

# Specify simulation parameters
simulations = DataFrames.DataFrame(
        (n = n, p = p, s = s, graph_type = graph_type, dist = dist, id = id) for
        n = 100, # Number of samples
        (p, s) = [(20, 40)], # Number of variables and nonzeros
        graph_type = ["erdos_renyi"], # Type of graph
        dist = ["gaussian"], # Distribution of the noise
        id = 1:10 # Simulation run ID
    )

# Run simulations in parallel
result = ProgressMeter.@showprogress pmap(runsim, eachrow(simulations))
result = reduce(vcat, result)
CSV.write("Results/gadget.csv", result)

Distributed.rmprocs(Distributed.workers())