# =================================================================================================#
# Description: Produces the experimental results for the comparisons with the localized lasso
# Author: Ryan Thompson
# =================================================================================================#

import CSV, CUDA, DataFrames, Distributions, Flux, LinearAlgebra, ProgressMeter, Random, Statistics

using Distributed

Distributed.addprocs(2)

Distributed.@sync Distributed.@everywhere begin

include("Estimators/localized_lasso.jl")
include("Estimators/contextual_lasso.jl")

import CSV, CUDA, DataFrames, Distributions, Flux, LinearAlgebra, ProgressMeter, Random, Statistics

#==================================================================================================#
# Function to generate data
#==================================================================================================#

function gendata(par)

    # Save scenario parameters
    loss, n, p, m, mu_var, rho, s_min, s_max, id = par

    # Generate explanatory features
    Sigma = [rho ^ abs(i - j) for i in 1:p, j in 1:p]
    x_train = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))
    x_valid = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))
    x_test = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))

    # Generate contextual features
    z_train = rand(Distributions.Uniform(- 1, 1), n, m)
    z_valid = rand(Distributions.Uniform(- 1, 1), n, m)
    z_test = rand(Distributions.Uniform(- 1, 1), n, m)

    # Generate coefficients
    s = rand(Distributions.Uniform(s_min, s_max), p)
    c = rand(Distributions.Uniform(- 1, 1), p, m)
    r = map(j -> Statistics.quantile([LinearAlgebra.norm(z_train[i, :] - c[j, :], 2)
        for i in 1:n], s[j]), 1:p)
    beta(z, c, r) = (1 - 0.5 * LinearAlgebra.norm(z - c, 2) / r) * 
        (LinearAlgebra.norm(z - c, 2) <= r)
    beta_train = [beta(z_train[i, :], c[j, :], r[j]) for i in 1:n, j in 1:p]
    beta_valid = [beta(z_valid[i, :], c[j, :], r[j]) for i in 1:n, j in 1:p]
    beta_test = [beta(z_test[i, :], c[j, :], r[j]) for i in 1:n, j in 1:p]

    # Generate response
    mu_train = vec(sum(x_train .* beta_train, dims = 2))
    mu_valid = vec(sum(x_valid .* beta_valid, dims = 2))
    mu_test = vec(sum(x_test .* beta_test, dims = 2))

    kappa = sqrt(mu_var) / Statistics.std(mu_train, corrected = false)
    mu_train *= kappa
    mu_valid *= kappa
    mu_test *= kappa

    if loss == Flux.mse
        y_train = rand.(Distributions.Normal.(mu_train, 1))
        y_valid = rand.(Distributions.Normal.(mu_valid, 1))
        y_test = rand.(Distributions.Normal.(mu_test, 1))
    else
        y_train = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_train)))))
        y_valid = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_valid)))))
        y_test = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_test)))))
    end

    # Return generated data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test, beta_test

end

#==================================================================================================#
# Function to evaluate a model
#==================================================================================================#

function evaluate!(result, estimator, y_hat, beta_hat, y_test, beta_test, par)

    # Save scenario parameters
    loss, n, p, m, mu_var, rho, s_min, s_max, id = par

    # Compute prediction metrics
    rel_loss = loss(y_hat, y_test) / loss(zeros(n), y_test)

    # Compute selection metrics
    if !ismissing(beta_hat)
        active_true = findall(abs.(beta_test) .> eps())
        active_hat = findall(abs.(beta_hat) .> eps())
        inactive_true = findall(.!(abs.(beta_test) .> eps()))
        inactive_hat = findall(.!(abs.(beta_hat) .> eps()))
        true_pos = length(intersect(active_hat, active_true)) / (n * p)
        true_neg = length(intersect(inactive_hat, inactive_true)) / (n * p)
        false_pos = length(intersect(active_hat, inactive_true)) / (n * p)
        false_neg = length(intersect(inactive_hat, active_true)) / (n * p)
        f1_score = (2 * true_pos) / (2 * true_pos + false_pos + false_neg)
        sparsity = length(active_hat) / (n * p)
    else
        true_pos = true_neg = false_pos = false_neg = f1_score = sparsity = missing
    end

    # Update results
    push!(result, [estimator, rel_loss, true_pos, true_neg, false_pos, false_neg, f1_score,
        sparsity, loss, n, p, m, mu_var, rho, s_min, s_max, id])

end

#==================================================================================================#
# Function to run a given simulation design
#==================================================================================================#

function runsim(par)

    CUDA.device!((Distributed.myid() - 1) % 2)

    # Set aside space for results
    result = DataFrames.DataFrame(
        estimator = [], rel_loss = [], true_pos = [], true_neg = [], false_pos = [], 
        false_neg = [], f1_score = [], sparsity = [], loss = [], n = [], p = [], m = [], 
        mu_var = [], rho = [], s_min = [], s_max = [], id = []
    )

    # Generate data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test, 
        beta_test = gendata(par)

    # Save scenario parameters
    loss, n, p, m, _, _, _, _, _ = par
    
    # Set network configuration
    n_neuron = round(Int, 1 / 4 * (sqrt((m + p + 3) ^ 2 - 8 * p + 8 * (m * p * 32)) - m - p - 3))
    hidden_layers = repeat([n_neuron], 3)

    # Evaluate localized lasso
    beta_hat, y_hat = llasso(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        x_test, z_test, 
        intercept = false
    )
    evaluate!(result, "Localized lasso", y_hat, beta_hat, y_test, beta_test, par)

    # Evaluate contextual lasso
    fit = ContextualLasso.classo(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        verbose = false, intercept = false, relax = true,
        loss = loss, hidden_layers = hidden_layers
    )
    beta_hat = ContextualLasso.coef(fit, z_test)
    y_hat = ContextualLasso.predict(fit, x_test, z_test)
    evaluate!(result, "Contextual lasso", y_hat, beta_hat, y_test, beta_test, par)

    CUDA.reclaim()

    result

end

end

#==================================================================================================#
# Run simulations
#==================================================================================================#

# Specify simulation parameters
simulations = DataFrames.DataFrame(
        (loss = loss, n = n, p = p, m = m, mu_var = mu_var, rho = rho, s_min = s_min, s_max = s_max, 
            id = id) for
        loss = [Flux.mse],
        n = round.(Int, exp.(range(log(10), log(1000), 10))), # Number of samples
        p = 10, # Number of explanatory features
        m = 2, # Number of contextual features
        mu_var = 5, # Signal-to-noise ratio
        rho = 0.5, # Correlation coefficient
        s_min = 0.05, # Minimal sparsity level
        s_max = 0.15, # Maximal sparsity level
        id = 1:10 # Simulation run ID
        )

# Run all simulations
# CUDA.jl is not reproducible with default rng
Distributed.@sync Distributed.@everywhere begin
    rng = Random.MersenneTwister((Distributed.myid() - 1) % 2)
    Random.default_rng() = rng
end
result = ProgressMeter.@showprogress pmap(runsim, eachrow(simulations))
result = reduce(vcat, result)
CSV.write("Results/localized.csv", result)

Distributed.rmprocs(Distributed.workers())