# =================================================================================================#
# Description: Produces the experimental results for the comparisons of estimators on the fixed
# coefficient synthetic data
# Author: Ryan Thompson
# =================================================================================================#

import CSV, CUDA, DataFrames, Distributions, Flux, LinearAlgebra, ProgressMeter, Random, Statistics

using Distributed

Distributed.addprocs(2)

Distributed.@sync Distributed.@everywhere begin

include("Estimators/lasso.jl")
include("Estimators/lassonet.jl")
include("Estimators/llspin.jl")
include("Estimators/contextual_lasso.jl")

import CSV, CUDA, DataFrames, Distributions, Flux, LinearAlgebra, ProgressMeter, Random, Statistics

#==================================================================================================#
# Function to generate data
#==================================================================================================#

function gendata(par)

    # Save scenario parameters
    loss, n, p, m, mu_var, rho, s_min, s_max, id = par

    # Generate explanatory features
    Sigma = [rho ^ abs(i - j) for i in 1:p, j in 1:p]
    x_train = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))
    x_valid = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))
    x_test = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))

    # Generate contextual features
    z_train = rand(Distributions.Uniform(- 1, 1), n, m)
    z_valid = rand(Distributions.Uniform(- 1, 1), n, m)
    z_test = rand(Distributions.Uniform(- 1, 1), n, m)

    # Generate coefficients
    beta_train = hcat(ones(n, 10), zeros(n, p - 10))
    beta_valid = hcat(ones(n, 10), zeros(n, p - 10))
    beta_test = hcat(ones(n, 10), zeros(n, p - 10))

    # Generate response
    mu_train = vec(sum(x_train .* beta_train, dims = 2))
    mu_valid = vec(sum(x_valid .* beta_valid, dims = 2))
    mu_test = vec(sum(x_test .* beta_test, dims = 2))

    kappa = sqrt(mu_var) / Statistics.std(mu_train, corrected = false)
    mu_train *= kappa
    mu_valid *= kappa
    mu_test *= kappa

    if loss == Flux.mse
        y_train = rand.(Distributions.Normal.(mu_train, 1))
        y_valid = rand.(Distributions.Normal.(mu_valid, 1))
        y_test = rand.(Distributions.Normal.(mu_test, 1))
    else
        y_train = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_train)))))
        y_valid = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_valid)))))
        y_test = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_test)))))
    end

    # Return generated data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test, beta_test

end

#==================================================================================================#
# Function to evaluate a model
#==================================================================================================#

function evaluate!(result, estimator, y_hat, beta_hat, y_test, beta_test, par)

    # Save scenario parameters
    loss, n, p, m, mu_var, rho, s_min, s_max, id = par

    # Compute prediction metrics
    rel_loss = loss(y_hat, y_test) / loss(zeros(n), y_test)

    # Compute selection metrics
    if !ismissing(beta_hat)
        active_true = findall(abs.(beta_test) .> eps())
        active_hat = findall(abs.(beta_hat) .> eps())
        inactive_true = findall(.!(abs.(beta_test) .> eps()))
        inactive_hat = findall(.!(abs.(beta_hat) .> eps()))
        true_pos = length(intersect(active_hat, active_true)) / (n * p)
        true_neg = length(intersect(inactive_hat, inactive_true)) / (n * p)
        false_pos = length(intersect(active_hat, inactive_true)) / (n * p)
        false_neg = length(intersect(inactive_hat, active_true)) / (n * p)
        f1_score = (2 * true_pos) / (2 * true_pos + false_pos + false_neg)
        sparsity = length(active_hat) / (n * p)
    else
        true_pos = true_neg = false_pos = false_neg = f1_score = sparsity = missing
    end

    # Update results
    push!(result, [estimator, rel_loss, true_pos, true_neg, false_pos, false_neg, f1_score,
        sparsity, loss, n, p, m, mu_var, rho, s_min, s_max, id])

end

#==================================================================================================#
# Function to run a given simulation design
#==================================================================================================#

function runsim(par)

    CUDA.device!((Distributed.myid() - 1) % 2)

    # Set aside space for results
    result = DataFrames.DataFrame(
        estimator = [], rel_loss = [], true_pos = [], true_neg = [], false_pos = [], 
        false_neg = [], f1_score = [], sparsity = [], loss = [], n = [], p = [], m = [], 
        mu_var = [], rho = [], s_min = [], s_max = [], id = []
    )

    # Generate data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test, 
        beta_test = gendata(par)

    # Save scenario parameters
    loss, n, p, m, _, _, _, _, _ = par
    
    # Set network configuration
    n_neuron = round(Int, 1 / 4 * (sqrt((m + p + 3) ^ 2 - 8 * p + 8 * (m * p * 32)) - m - p - 3))
    hidden_layers = repeat([n_neuron], 3)

    # Evaluate deep neural network
    fit = ContextualLasso.classo(
        ones(n, 1), hcat(x_train, z_train), y_train, 
        ones(n, 1), hcat(x_valid, z_valid), y_valid, 
        intercept = false, lambda = Inf, standardise_x = false, verbose = false, 
        loss = loss, hidden_layers = hidden_layers
    )
    beta_hat = missing
    y_hat = ContextualLasso.predict(fit, ones(n, 1), hcat(x_test, z_test))
    evaluate!(result, "Deep neural network", y_hat, beta_hat, y_test, beta_test, par)

    # Evaluate contextual linear model
    fit = ContextualLasso.classo(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        lambda = Inf, verbose = false, intercept = false, 
        loss = loss, hidden_layers = hidden_layers
    )
    beta_hat = ContextualLasso.coef(fit, z_test)
    y_hat = ContextualLasso.predict(fit, x_test, z_test)
    evaluate!(result, "Contextual linear model", y_hat, beta_hat, y_test, beta_test, par)

    # Evaluate lasso
    fit = lasso(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        penalty_factor = vcat(ones(p), zeros(m)), intercept = false, 
        loss = loss
    )
    beta_hat = repeat(fit[1:p]', n)
    y_hat = hcat(x_test, z_test) * fit
    evaluate!(result, "Lasso", y_hat, beta_hat, y_test, beta_test, par)

    # Evaluate lassonet
    beta_hat, y_hat = lassonet(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        hcat(x_test, z_test), 
        verbose = false, seed = par["id"], 
        loss = loss, hidden_layers = hidden_layers
    )
    beta_hat = repeat(beta_hat[1:p]', n)
    evaluate!(result, "Lassonet", y_hat, beta_hat, y_test, beta_test, par)

    # Evaluate LLSPIN
    beta_hat, y_hat = llspin(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        hcat(x_test, z_test), 
        verbose = false, seed = par["id"], 
        loss = loss, hidden_layers = hidden_layers
    )
    beta_hat = beta_hat[:, 1:p]
    evaluate!(result, "LLSPIN", y_hat, beta_hat, y_test, beta_test, par)

    # Evaluate contextual lasso
    fit = ContextualLasso.classo(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        verbose = false, intercept = false, relax = true,
        loss = loss, hidden_layers = hidden_layers
    )
    beta_hat = ContextualLasso.coef(fit, z_test)
    y_hat = ContextualLasso.predict(fit, x_test, z_test)
    evaluate!(result, "Contextual lasso", y_hat, beta_hat, y_test, beta_test, par)

    CUDA.reclaim()

    result

end

end

#==================================================================================================#
# Run simulations
#==================================================================================================#

# Specify simulation parameters
simulations = DataFrames.DataFrame(
    (loss = loss, n = n, p = p, m = m, mu_var = mu_var, rho = rho, s_min = s_min, s_max = s_max, 
        id = id) for
    loss = [Flux.mse],
    n = round.(Int, exp.(range(log(10), log(10000), 10))), # Number of samples
    p = 100, # Number of explanatory features
    m = 5, # Number of contextual features
    mu_var = 5, # Signal-to-noise ratio
    rho = 0.5, # Correlation coefficient
    s_min = 0.05, # Minimal sparsity level
    s_max = 0.15, # Maximal sparsity level
    id = 1:10 # Simulation run ID
    )

# Run all simulations
# CUDA.jl is not reproducible with default rng
Distributed.@sync Distributed.@everywhere begin
    rng = Random.MersenneTwister((Distributed.myid() - 1) % 2)
    Random.default_rng() = rng
end
result = ProgressMeter.@showprogress pmap(runsim, eachrow(simulations))
result = reduce(vcat, result)
CSV.write("Results/fixed.csv", result)

Distributed.rmprocs(Distributed.workers())