# =================================================================================================#
# Description: Produces the experimental results for the Sachs data
# Author: Ryan Thompson
# =================================================================================================#

using Distributed

Distributed.addprocs(2)

Distributed.@sync Distributed.@everywhere begin

cd("/Experiments")
include("pro_dag.jl")
include("metrics.jl")

import CSV, CUDA, CausalInference, DataFrames, Distributions, Graphs, JSON, LinearAlgebra, PyCall, 
    ProgressMeter, Random, Statistics

#==================================================================================================#
# Function to generate data
#==================================================================================================#

function gendata(par)

    # Save scenario parameters
    id, = par

    # Create full dataset
    cdt = PyCall.pyimport("cdt")
    networkx = PyCall.pyimport("networkx")
    numpy = PyCall.pyimport("numpy")
    x, g = cdt.data.load_dataset("sachs")
    label_order = [findfirst(==(name), collect(g.nodes)) for name in collect(x.columns)]
    x = numpy.array(x)
    w = numpy.array(networkx.adjacency_matrix(g).todense())[label_order, label_order]
    
    # Split into training and validation sets
    val_ind = Random.randperm(size(x, 1))[1:round(Int, 0.4 * size(x, 1))]
    train_ind = setdiff(1:size(x, 1), val_ind)
    x_val = x[val_ind, :]
    x = x[train_ind, :]

    # Remove intercepts
    x_mean = mapslices(Statistics.mean, x, dims = 1)
    x .-= x_mean
    x_val .-= x_mean

    # Remove units of scale
    x_std = mapslices(Statistics.std, x, dims = 1)
    x ./= x_std
    x_val ./= x_std
    
    x, x_val, w

end

#==================================================================================================#
# Function to evaluate a model
#==================================================================================================#

function evaluate!(result, estimator, ŵ, w, par)

    # Save scenario parameters
    id, = par

    # Reshape as array if matrix
    if length(size(ŵ)) == 2
        ŵ = reshape(ŵ, size(ŵ, 1), size(ŵ, 2), 1)
    end

    # Save data dimension
    n_sample = size(ŵ, 3)

    # Compute Brier score
    bscore = brier_score(w, ŵ)

    # Compute structural Hamming distance
    shd = sum(map(i -> struct_hamming_dist(w, ŵ[:, :, i]), 1:n_sample)) / n_sample

    # Compute F1 score
    f1score = sum(map(i -> f1_score(w, ŵ[:, :, i]), 1:n_sample)) / n_sample

    # Compute AUROC
    auroc = binary_auroc(w, ŵ)

    # Compute sparsity levels
    sparsity = sum(ŵ .≠ 0) / n_sample

    # Check if all graphs are DAGs
    dag_rate = sum(map(w -> !Graphs.is_cyclic(Graphs.SimpleDiGraph(w)), eachslice(ŵ, dims = 3))) / 
        n_sample

    # Update results
    push!(result, [estimator, bscore, shd, f1score, auroc, sparsity, dag_rate, id])

end

#==================================================================================================#
# Function to run a given simulation design
#==================================================================================================#

function runsim(par)

    process_id = Distributed.myid()
    gpu_id = (process_id - 1) % 2
    CUDA.device!(gpu_id)
    CUDA.seed!(hash(par))
    Random.seed!(hash(par))

    # Set aside space for results
    result = DataFrames.DataFrame(
        estimator = [], bscore = [], shd = [], f1score = [], auroc = [], sparsity = [], 
        dag_rate = [], id = []
    )

    x, x_val, w = gendata(par)

    # Evaluate ProDAG
    fit = ProDAG.fit_linear(x, prior_α = Inf, verbose = false, prior_σ = sqrt(10))
    init_μ = fit.μ
    init_σ = fit.σ
    ŵ = Vector{Any}(undef, 10)
    ŵ[1] = ProDAG.sample(fit)
    α = range(sum(abs.(ŵ[1])) / size(ŵ[1], 3), 0, length(ŵ))
    for i in 2:length(ŵ) - 1
        fit = ProDAG.fit_linear(x, prior_α = α[i], init_μ = init_μ, init_σ = init_σ, 
            verbose = false, prior_σ = sqrt(10))
        init_μ = fit.μ
        init_σ = fit.σ
        ŵ[i] = ProDAG.sample(fit)
    end
    fit = ProDAG.fit_linear(x, prior_α = 0, verbose = false, prior_σ = sqrt(10))
    ŵ[length(ŵ)] = ProDAG.sample(fit)
    best_i = argmin([Statistics.mean([sum((x_val - x_val * ŵ[i][:, :, j]) .^ 2) 
        for j in 1:size(ŵ[i], 3)]) for i in 1:length(ŵ)])
    evaluate!(result, "ProDAG", ŵ[best_i], w, par)

    # Evaluate DAGMA
    λ = exp.(range(log(1e-3), log(1), 10))
    ŵ = Vector{Any}(undef, length(λ))
    for i in 1:length(λ)
        dagma = PyCall.pyimport("dagma.linear")
        model = dagma.DagmaLinear(loss_type = "l2")
        ŵ[i] = model.fit(x, lambda1 = λ[i], w_threshold = 0.1)
    end
    best_i = argmin([sum((x_val - x_val * ŵ[i]) .^ 2) for i in 1:length(ŵ)])
    evaluate!(result, "DAGMA", ŵ[best_i], w, par)

    # Evaluate DiBS
    numpy = PyCall.pyimport("numpy")
    dibs = PyCall.pyimport("dibs")
    PyCall.pyimport("dibs.models")
    PyCall.pyimport("dibs.inference")
    jax = PyCall.pyimport("jax")
    jax.config.update("jax_default_device", jax.devices()[gpu_id + 1])
    graph_model = dibs.models.ErdosReniDAGDistribution(n_vars = size(x, 2))
    likelihood_model = dibs.models.LinearGaussian(n_vars = size(x, 2), min_edge = 0.1, 
        obs_noise = 1.0)
    fit = dibs.inference.JointDiBS(x = x, graph_model = graph_model, 
        likelihood_model = likelihood_model)
    key = jax.random.PRNGKey(Int64(hash(par) % Int64))
    g, theta = fit.sample(key = key, n_particles = 50, steps = 5000)
    ŵ = numpy.stack(g, axis = 2) .* numpy.stack(theta, axis = 2)
    evaluate!(result, "DiBS", ŵ, w, par)

    # Evaluate DiBS+
    log_weights = numpy.array(fit.get_mixture(g, theta)[1])
    log_weights_normalized = log_weights .- maximum(log_weights)
    weights = exp.(log_weights_normalized)
    weights /= sum(weights)
    ind = rand(Distributions.Categorical(weights), 1000)
    ŵ = (numpy.stack(g, axis = 2) .* numpy.stack(theta, axis = 2))[:, :, ind]
    evaluate!(result, "DiBS+", ŵ, w, par)

    # Evaluate BayesDAG
    λ = exp.(range(log(10), log(1000), 10))
    ŵ = Vector{Any}(undef, length(λ))
    nnz = Vector{Any}(undef, length(λ))
    for i in 1:length(λ)
        cd("bayes_dag/src")
        foreach(x -> rm(joinpath("results/$gpu_id", x), force = true, recursive = true), 
            readdir("results/$gpu_id"))
        python_path = "/.cache/pypoetry/virtualenvs/causica-WXFkYpeo-py3.8/bin/python3"
        model_config_path = "configs/bayesdag_linear_$gpu_id.json"
        json_data = JSON.parsefile(model_config_path)
        json_data["model_hyperparams"]["lambda_sparse"] = λ[i]
        json_data["model_hyperparams"]["random_seed"] = Int64(hash(par) % Int64(2 ^ 32))
        open(model_config_path,"w") do f
            JSON.print(f, json_data)
        end
        dataset_config_path = "configs/dataset_config_causal_dataset.json"
        CSV.write("data/$gpu_id/all.csv", DataFrames.DataFrame(x, :auto), header = false)
        run(`$python_path -m causica.run_experiment data/$gpu_id \
            -d ./ \
            --output_dir results/$gpu_id \
            --model_config $model_config_path \
            --dataset_config $dataset_config_path \
            --model_type bayesdag_linear \
            --device $gpu_id`)
        torch = PyCall.pyimport("torch")
        numpy = PyCall.pyimport("numpy")
        dag_path = first(filter(f -> endswith(f, "bayesdag_dags.pt"), [joinpath(root, file) for 
            (root, dirs, files) in walkdir("results/$gpu_id") for file in files]))
        ŵ[i] = numpy.stack(torch.load(dag_path, weights_only = false), axis = 2)
        nnz[i] = Statistics.mean(sum(ŵ[i] .≠ 0, dims = (1, 2)))
        foreach(x -> rm(joinpath("results/$gpu_id", x), force = true, recursive = true), 
            readdir("results/$gpu_id"))
        cd("Experiments")
    end
    best_i = argmin(abs.(nnz .- sum(w .!= 0)))
    evaluate!(result, "BayesDAG", ŵ[best_i], w, par)

    # BOSS
    CSV.write("boss/data_$process_id.csv", DataFrames.DataFrame(x, :auto))
    try
        run(`/.julia/conda/3/x86_64/bin/python boss/boss.py $process_id`)
        ŵ = Matrix(CSV.read("boss/adjacency_matrix_$process_id.csv", DataFrames.DataFrame))
        rm("boss/adjacency_matrix_$process_id.csv")
        evaluate!(result, "BOSS", ŵ, w, par)
    catch
    end
    rm("boss/data_$process_id.csv")

    CUDA.reclaim()

    result

end

end

#==================================================================================================#
# Run simulations
#==================================================================================================#

# Specify simulation parameters
simulations = DataFrames.DataFrame(
    id = 1:10 # Simulation run ID
)

# Run simulations in parallel
result = ProgressMeter.@showprogress pmap(runsim, eachrow(simulations))
result = reduce(vcat, result)
CSV.write("Results/sachs.csv", result)

Distributed.rmprocs(Distributed.workers())