# =================================================================================================#
# Description: Produces the run time results across Bayesian estimators
# Author: Ryan Thompson
# =================================================================================================#

using Distributed

Distributed.addprocs(2)

Distributed.@sync Distributed.@everywhere begin

cd("/Experiments")
include("pro_dag.jl")

import CSV, CUDA, DataFrames, Distributions, Graphs, LinearAlgebra, ProgressMeter, Random, 
    Statistics, PyCall, JSON, Flux

#==================================================================================================#
# Function to generate data
#==================================================================================================#

function gendata(par)

    # Save scenario parameters
    n, p, s, sem_type, id = par

    # Generate graph
    g = Graphs.erdos_renyi(p, s)
    g = Graphs.random_orientation_dag(g)

    if sem_type == "linear"

        # Create weighted adjacency matrix from graph
        w_values = rand(Distributions.Uniform(0.3, 0.7), p, p) .* rand([- 1, 1], p, p)
        w = Matrix(Graphs.adjacency_matrix(g)) .* w_values

        # Generate features
        ε = randn(n, p)
        x = ε * LinearAlgebra.inv(LinearAlgebra.I - w)

    elseif sem_type == "nonlinear"

        # Generate features
        a = Matrix(Graphs.adjacency_matrix(g))
        w = zeros(p, p)
        x = zeros(n, p)
        ε = randn(n, p)
        ω = zeros((p + 1) * 10, p)
        for k in Graphs.topological_sort(g)
            ω_h = rand(Distributions.Uniform(0.3, 0.7), p, 10) .* rand([- 1, 1], p, 10)
            ω_h = ω_h .* a[:, k]
            ω_o = rand(Distributions.Uniform(0.3, 0.7), 1, 10) .* rand([- 1, 1], 1, 10)
            x[:, k] = Flux.relu(x * ω_h) * transpose(ω_o) + ε[:, k]
            w[:, k] = sqrt.(sum(ω_h .^ 2, dims = 2))
            ω[:, k] = vcat(vec(transpose(ω_h)), vec(ω_o))
        end
        ω = vec(ω)

    end 

    # Return generated data
    x, w

end

#==================================================================================================#
# Function to evaluate a model
#==================================================================================================#

function evaluate!(result, estimator, time, par)

    # Save scenario parameters
    n, p, s, sem_type, id = par

    # Update results
    push!(result, [estimator, time, n, p, s, sem_type, id])

end

#==================================================================================================#
# Function to run a given simulation design
#==================================================================================================#

function runsim(par)

    gpu_id = (Distributed.myid() - 1) % 2
    CUDA.device!(gpu_id)
    CUDA.seed!(hash(par))
    Random.seed!(hash(par))

    # Set aside space for results
    result = DataFrames.DataFrame(
        estimator = [], time = [], n = [], p = [], s = [], sem_type = [], id = []
    )

    # Generate data
    x, w = gendata(par)

    if par.sem_type == "linear"

        # Evaluate ProDAG
        eval_time = @elapsed ProDAG.fit_linear(x, prior_α = sum(abs.(w)), verbose = false)
        evaluate!(result, "ProDAG", eval_time, par)

        # Evaluate DiBS
        dibs = PyCall.pyimport("dibs")
        PyCall.pyimport("dibs.models")
        PyCall.pyimport("dibs.inference")
        jax = PyCall.pyimport("jax")
        jax.config.update("jax_default_device", jax.devices()[gpu_id + 1])
        graph_model = dibs.models.ErdosReniDAGDistribution(n_vars = par.p)
        likelihood_model = dibs.models.LinearGaussian(n_vars = par.p, min_edge = 0.1, 
            obs_noise = 1.0)
        fit = dibs.inference.JointDiBS(x = x, graph_model = graph_model, 
            likelihood_model = likelihood_model)
        key = jax.random.PRNGKey(Int64(hash(par) % Int64))
        eval_time = @elapsed fit.sample(key = key, n_particles = 50, steps = 5000)
        evaluate!(result, "DiBS", eval_time, par)

        # Evaluate BayesDAG
        cd("/Experiments/bayes_dag/src")
        foreach(x -> rm(joinpath("results/$gpu_id", x), force = true, recursive = true), 
            readdir("results/$gpu_id"))
        python_path = "/.cache/pypoetry/virtualenvs/causica-eW3dY-JD-py3.8/bin/python3"
        model_config_path = "configs/bayesdag_linear_$gpu_id.json"
        json_data = JSON.parsefile(model_config_path)
        json_data["model_hyperparams"]["random_seed"] = Int64(hash(par) % Int64(2 ^ 32))
        open(model_config_path, "w") do f
            JSON.print(f, json_data)
        end
        dataset_config_path = "configs/dataset_config_causal_dataset.json"
        CSV.write("data/$gpu_id/all.csv", DataFrames.DataFrame(x, :auto), header = false)
        eval_time = @elapsed run(`$python_path -m causica.run_experiment data/$gpu_id \
            -d ./ \
            --output_dir results/$gpu_id \
            --model_config $model_config_path \
            --dataset_config $dataset_config_path \
            --model_type bayesdag_linear \
            --device $gpu_id`)
        foreach(x -> rm(joinpath("results/$gpu_id", x), force = true, recursive = true), 
            readdir("results/$gpu_id"))
        cd("/Experiments")
        evaluate!(result, "BayesDAG", eval_time, par)
    
    elseif par.sem_type == "nonlinear"

        # Evaluate ProDAG
        eval_time = @elapsed ProDAG.fit_mlp(x, prior_α = sum(abs.(w)), verbose = false, 
            bias = false)
        evaluate!(result, "ProDAG", eval_time, par)

        # Evaluate DiBS
        dibs = PyCall.pyimport("dibs")
        PyCall.pyimport("dibs.models")
        PyCall.pyimport("dibs.inference")
        jax = PyCall.pyimport("jax")
        jax.config.update("jax_default_device", jax.devices()[gpu_id + 1])
        graph_model = dibs.models.ErdosReniDAGDistribution(n_vars = par.p)
        likelihood_model = dibs.models.DenseNonlinearGaussian(n_vars = par.p, hidden_layers = [10], 
            obs_noise = 1.0)
        fit = dibs.inference.JointDiBS(x = x, graph_model = graph_model, 
            likelihood_model = likelihood_model)
        key = jax.random.PRNGKey(Int64(hash(par) % Int64))
        eval_time = @elapsed fit.sample(key = key, n_particles = 50, steps = 5000)
        evaluate!(result, "DiBS", eval_time, par)

        # Evaluate BayesDAG
        cd("/Experiments/bayes_dag/src")
        foreach(x -> rm(joinpath("results/$gpu_id", x), force = true, recursive = true), 
            readdir("results/$gpu_id"))
        python_path = "/.cache/pypoetry/virtualenvs/causica-eW3dY-JD-py3.8/bin/python3"
        model_config_path = "configs/bayesdag_nonlinear_$gpu_id.json"
        json_data = JSON.parsefile(model_config_path)
        json_data["model_hyperparams"]["random_seed"] = Int64(hash(par) % Int64(2 ^ 32))
        open(model_config_path, "w") do f
            JSON.print(f, json_data)
        end
        dataset_config_path = "configs/dataset_config_causal_dataset.json"
        CSV.write("data/$gpu_id/all.csv", DataFrames.DataFrame(x, :auto), header = false)
        eval_time = @elapsed run(`$python_path -m causica.run_experiment data/$gpu_id \
            -d ./ \
            --output_dir results/$gpu_id \
            --model_config $model_config_path \
            --dataset_config $dataset_config_path \
            --model_type bayesdag_nonlinear \
            --device $gpu_id`)
        foreach(x -> rm(joinpath("results/$gpu_id", x), force = true, recursive = true), 
            readdir("results/$gpu_id"))
        cd("/Experiments")
        evaluate!(result, "BayesDAG", eval_time, par)

    end

    CUDA.reclaim()

    result

end

end

#==================================================================================================#
# Run simulations
#==================================================================================================#

# Specify simulation parameters
simulations = vcat(

    # Linear experiments ==========================================================================#
    DataFrames.DataFrame(
        (n = n, p = p, s = s, sem_type = sem_type, id = id) for
        n = 100, # Number of samples
        (p, s) = [(20, 40)], # Number of variables and nonzeros
        sem_type = ["linear"], # Type of SEM
        id = 1:10 # Simulation run ID
    ),
    #==============================================================================================#

    # Nonlinear experiments =======================================================================#
    DataFrames.DataFrame(
        (n = n, p = p, s = s, sem_type = sem_type, id = id) for
        n = 100, # Number of samples
        (p, s) = [(10, 20)], # Number of variables and nonzeros
        sem_type = ["nonlinear"], # Type of SEM
        id = 1:10 # Simulation run ID
    ),
    #==============================================================================================#

)

# Run simulations in parallel
result = ProgressMeter.@showprogress pmap(runsim, eachrow(simulations))
result = reduce(vcat, result)
CSV.write("Results/timings_comparisons.csv", result)

Distributed.rmprocs(Distributed.workers())