# =================================================================================================#
# Description: Produces the experimental results for the synthetic nonlinear data
# Author: Ryan Thompson
# =================================================================================================#

using Distributed

Distributed.addprocs(2)

Distributed.@sync Distributed.@everywhere begin

cd("/Experiments")
include("pro_dag.jl")
include("metrics.jl")

import CSV, CUDA, DataFrames, Distributions, Flux, Graphs, JSON, LinearAlgebra, ProgressMeter, 
    PyCall, Random, Statistics

#==================================================================================================#
# Function to generate data
#==================================================================================================#

function gendata(par)

    # Save scenario parameters
    n, p, s, id = par

    # Generate graph
    g = Graphs.erdos_renyi(p, s)
    g = Graphs.random_orientation_dag(g)

    # Generate features
    a = Matrix(Graphs.adjacency_matrix(g))
    w = zeros(p, p)
    x = zeros(n, p)
    x_val = zeros(round(Int, 0.1 * n), p)
    ε = randn(n, p)
    ε_val = randn(round(Int, 0.1 * n), p)
    ω = zeros((p + 1) * 10, p)
    for k in Graphs.topological_sort(g)
        ω_h = rand(Distributions.Uniform(0.3, 0.7), p, 10) .* rand([- 1, 1], p, 10)
        ω_h = ω_h .* a[:, k]
        ω_o = rand(Distributions.Uniform(0.3, 0.7), 1, 10) .* rand([- 1, 1], 1, 10)
        x[:, k] = Flux.relu(x * ω_h) * transpose(ω_o) + ε[:, k]
        x_val[:, k] = Flux.relu(x_val * ω_h) * transpose(ω_o) + ε_val[:, k]
        w[:, k] = sqrt.(sum(ω_h .^ 2, dims = 2))
        ω[:, k] = vcat(vec(transpose(ω_h)), vec(ω_o))
    end
    ω = vec(ω)

    # Return generated data
    x, x_val, w, ω

end

#==================================================================================================#
# Function to evaluate a model
#==================================================================================================#

function evaluate!(result, estimator, ŵ, w, par)

    # Save scenario parameters
    n, p, s, id = par

    # Reshape as array if matrix
    if length(size(ŵ)) == 2
        ŵ = reshape(ŵ, size(ŵ, 1), size(ŵ, 2), 1)
    end

    # Save data dimension
    n_sample = size(ŵ, 3)

    # Compute Brier score
    bscore = brier_score(w, ŵ)

    # Compute structural Hamming distance
    shd = sum(map(i -> struct_hamming_dist(w, ŵ[:, :, i]), 1:n_sample)) / n_sample

    # Compute F1 score
    f1score = sum(map(i -> f1_score(w, ŵ[:, :, i]), 1:n_sample)) / n_sample

    # Compute AUROC
    auroc = binary_auroc(w, ŵ)

    # Compute sparsity levels
    sparsity = sum(ŵ .≠ 0) / n_sample

    # Check if all graphs are DAGs
    dag_rate = sum(map(w -> !Graphs.is_cyclic(Graphs.SimpleDiGraph(w)), eachslice(ŵ, dims = 3))) / 
        n_sample

    # Update results
    push!(result, [estimator, bscore, shd, f1score, auroc, sparsity, dag_rate, n, p, s, id])

end

#==================================================================================================#
# Function to run a given simulation design
#==================================================================================================#

function runsim(par)

    gpu_id = (Distributed.myid() - 1) % 2
    CUDA.device!(gpu_id)
    CUDA.seed!(hash(par))
    Random.seed!(hash(par))

    # Set aside space for results
    result = DataFrames.DataFrame(
        estimator = [], bscore = [], shd = [], f1score = [], auroc = [], sparsity = [], 
        dag_rate = [], n = [], p = [], s = [], id = []
    )

    # Generate data
    x, x_val, w, ω = gendata(par)

    # Evaluate ProDAG
    fit = ProDAG.fit_mlp(x, prior_α = 3 * par.p, verbose = false, bias = false)
    init_μ = fit.μ
    init_σ = fit.σ
    ŵ = Vector{Any}(undef, 10)
    model = Vector{Any}(undef, length(ŵ))
    ŵ[1], model[1] = ProDAG.sample(fit)
    α = range(sum(abs.(ŵ[1])) / size(ŵ[1], 3), 0, length(ŵ))
    for i in 2:length(ŵ) - 1
        fit = ProDAG.fit_mlp(x, prior_α = α[i], init_μ = init_μ, init_σ = init_σ, 
            verbose = false, bias = false)
        init_μ = fit.μ
        init_σ = fit.σ
        ŵ[i], model[i] = ProDAG.sample(fit)
    end
    fit = ProDAG.fit_mlp(x, prior_α = 0, verbose = false, bias = false)
    ŵ[length(ŵ)], model[length(ŵ)] = ProDAG.sample(fit)
    best_i = argmin([Statistics.mean([sum((x_val -  model[i][j](x_val')') .^ 2) 
        for j in 1:size(ŵ[i], 3)]) for i in 1:length(model)])
    evaluate!(result, "ProDAG", ŵ[best_i], w, par)

    # Evaluate DAGMA
    λ = exp.(range(log(1e-3), log(1), 10))
    ŵ = Vector{Any}(undef, length(λ))
    nn = Vector{Any}(undef, length(λ))
    for i in 1:length(λ)
        dagma = PyCall.pyimport("dagma.nonlinear")
        eq_model = dagma.DagmaMLP(dims = [par.p, 10, 1], bias = false)
        model = dagma.DagmaNonlinear(eq_model)
        ŵ[i], nn[i] = model.fit(x, lambda1 = λ[i], w_threshold = 0.1)
    end
    numpy = PyCall.pyimport("numpy")
    torch = PyCall.pyimport("torch")
    best_i = argmin([sum((x_val - numpy.array(torch.detach(nn[i].to(torch.device("cpu"))(
        torch.tensor(x_val))))) .^ 2) for i in 1:length(nn)])
    evaluate!(result, "DAGMA", ŵ[best_i], w, par)

    if par.p == 10

        # Evaluate DiBS
        numpy = PyCall.pyimport("numpy")
        dibs = PyCall.pyimport("dibs")
        PyCall.pyimport("dibs.models")
        PyCall.pyimport("dibs.inference")
        jax = PyCall.pyimport("jax")
        jax.config.update("jax_default_device", jax.devices()[gpu_id + 1])
        graph_model = dibs.models.ErdosReniDAGDistribution(n_vars = par.p)
        likelihood_model = dibs.models.DenseNonlinearGaussian(n_vars = par.p, hidden_layers = [10], 
            obs_noise = 1.0)
        fit = dibs.inference.JointDiBS(x = x, graph_model = graph_model, 
            likelihood_model = likelihood_model)
        key = jax.random.PRNGKey(Int64(hash(par) % Int64))
        g, theta = fit.sample(key = key, n_particles = 50, steps = 5000)
        ŵ = numpy.stack(g, axis = 2)
        evaluate!(result, "DiBS", ŵ, w, par)

        # Evaluate DiBS+
        log_weights = numpy.array(fit.get_mixture(g, theta)[1])
        log_weights_normalized = log_weights .- maximum(log_weights)
        weights = exp.(log_weights_normalized)
        weights /= sum(weights)
        ind = rand(Distributions.Categorical(weights), 1000)
        ŵ = (numpy.stack(g, axis = 2))[:, :, ind]
        evaluate!(result, "DiBS+", ŵ, w, par)

    end

    # Evaluate BayesDAG
    λ = exp.(range(log(10), log(1000), 10))
    ŵ = Vector{Any}(undef, length(λ))
    nnz = Vector{Any}(undef, length(λ))
    for i in 1:length(λ)
        cd("/Experiments/bayes_dag/src")
        foreach(x -> rm(joinpath("results/$gpu_id", x), force = true, recursive = true), 
            readdir("results/$gpu_id"))
        python_path = "/.cache/pypoetry/virtualenvs/causica-eW3dY-JD-py3.8/bin/python3"
        model_config_path = "configs/bayesdag_nonlinear_$gpu_id.json"
        json_data = JSON.parsefile(model_config_path)
        json_data["model_hyperparams"]["lambda_sparse"] = λ[i]
        json_data["model_hyperparams"]["random_seed"] = Int64(hash(par) % Int64(2 ^ 32))
        open(model_config_path,"w") do f
            JSON.print(f, json_data)
        end
        dataset_config_path = "configs/dataset_config_causal_dataset.json"
        CSV.write("data/$gpu_id/all.csv", DataFrames.DataFrame(x, :auto), header = false)
        run(`$python_path -m causica.run_experiment data/$gpu_id \
            -d ./ \
            --output_dir results/$gpu_id \
            --model_config $model_config_path \
            --dataset_config $dataset_config_path \
            --model_type bayesdag_nonlinear \
            --device $gpu_id`)
        torch = PyCall.pyimport("torch")
        numpy = PyCall.pyimport("numpy")
        dag_path = first(filter(f -> endswith(f, "bayesdag_dags.pt"), [joinpath(root, file) for 
            (root, dirs, files) in walkdir("results/$gpu_id") for file in files]))
        ŵ[i] = numpy.stack(torch.load(dag_path, weights_only = false), axis = 2)
        nnz[i] = Statistics.mean(sum(ŵ[i] .≠ 0, dims = (1, 2)))
        foreach(x -> rm(joinpath("results/$gpu_id", x), force = true, recursive = true), 
            readdir("results/$gpu_id"))
        cd("/Experiments")
    end
    best_i = argmin(abs.(nnz .- par.s))
    evaluate!(result, "BayesDAG", ŵ[best_i], w, par)

    CUDA.reclaim()

    result

end

end

#==================================================================================================#
# Run simulations
#==================================================================================================#

# Specify simulation parameters
simulations = DataFrames.DataFrame(
        (n = n, p = p, s = s, id = id) for
        n = round.(Int, exp.(range(log(10), log(1000), 5))), # Number of samples
        (p, s) = [(10, 20), (20, 40)], # Number of variables and nonzeros
        id = 1:10 # Simulation run ID
    )

# Run simulations in parallel
result = ProgressMeter.@showprogress pmap(runsim, eachrow(simulations))
result = reduce(vcat, result)
CSV.write("Results/synthetic_nonlinear.csv", result)

Distributed.rmprocs(Distributed.workers())