# =================================================================================================#
# Description: Produces the experimental results for the comparisons of estimators on the synthetic
# classification data
# Author: Ryan Thompson
# =================================================================================================#

include("Estimators/lasso.jl")
include("Estimators/lassonet.jl")
include("Estimators/llspin.jl")
include("Estimators/contextual_lasso.jl")

import CSV, DataFrames, Distributions, Flux, LinearAlgebra, ProgressMeter, Random, Statistics

#==================================================================================================#
# Function to generate data
#==================================================================================================#

function gendata(par)

    # Save scenario parameters
    loss, n, p, m, mu_var, rho, s_min, s_max, id = par

    # Generate explanatory features
    Sigma = [rho ^ abs(i - j) for i in 1:p, j in 1:p]
    x_train = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))
    x_valid = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))
    x_test = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))

    # Generate contextual features
    z_train = rand(Distributions.Uniform(- 1, 1), n, m)
    z_valid = rand(Distributions.Uniform(- 1, 1), n, m)
    z_test = rand(Distributions.Uniform(- 1, 1), n, m)

    # Generate coefficients
    s = rand(Distributions.Uniform(s_min, s_max), p)
    c = rand(Distributions.Uniform(- 1, 1), p, m)
    r = map(j -> Statistics.quantile([LinearAlgebra.norm(z_train[i, :] - c[j, :], 2)
        for i in 1:n], s[j]), 1:p)
    beta(z, c, r) = (1 - 0.5 * LinearAlgebra.norm(z - c, 2) / r) * 
        (LinearAlgebra.norm(z - c, 2) <= r)
    beta_train = [beta(z_train[i, :], c[j, :], r[j]) for i in 1:n, j in 1:p]
    beta_valid = [beta(z_valid[i, :], c[j, :], r[j]) for i in 1:n, j in 1:p]
    beta_test = [beta(z_test[i, :], c[j, :], r[j]) for i in 1:n, j in 1:p]

    # Generate response
    mu_train = vec(sum(x_train .* beta_train, dims = 2))
    mu_valid = vec(sum(x_valid .* beta_valid, dims = 2))
    mu_test = vec(sum(x_test .* beta_test, dims = 2))

    kappa = sqrt(mu_var) / Statistics.std(mu_train, corrected = false)
    mu_train *= kappa
    mu_valid *= kappa
    mu_test *= kappa

    if loss == Flux.mse
        y_train = rand.(Distributions.Normal.(mu_train, 1))
        y_valid = rand.(Distributions.Normal.(mu_valid, 1))
        y_test = rand.(Distributions.Normal.(mu_test, 1))
    else
        y_train = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_train)))))
        y_valid = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_valid)))))
        y_test = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_test)))))
    end

    # Return generated data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test, beta_test

end

#==================================================================================================#
# Function to evaluate a model
#==================================================================================================#

function evaluate!(result, estimator, y_hat, beta_hat, y_test, beta_test, par)

    # Save scenario parameters
    loss, n, p, m, mu_var, rho, s_min, s_max, id = par

    # Compute prediction metrics
    rel_loss = loss(y_hat, y_test) / loss(zeros(n), y_test)

    # Compute selection metrics
    if !ismissing(beta_hat)
        active_true = findall(abs.(beta_test) .> eps())
        active_hat = findall(abs.(beta_hat) .> eps())
        inactive_true = findall(.!(abs.(beta_test) .> eps()))
        inactive_hat = findall(.!(abs.(beta_hat) .> eps()))
        true_pos = length(intersect(active_hat, active_true)) / (n * p)
        true_neg = length(intersect(inactive_hat, inactive_true)) / (n * p)
        false_pos = length(intersect(active_hat, inactive_true)) / (n * p)
        false_neg = length(intersect(inactive_hat, active_true)) / (n * p)
        f1_score = (2 * true_pos) / (2 * true_pos + false_pos + false_neg)
        sparsity = length(active_hat) / (n * p)
    else
        true_pos = true_neg = false_pos = false_neg = f1_score = sparsity = missing
    end

    # Update results
    push!(result, [estimator, rel_loss, true_pos, true_neg, false_pos, false_neg, f1_score,
        sparsity, loss, n, p, m, mu_var, rho, s_min, s_max, id])

end

#==================================================================================================#
# Function to run a given simulation design
#==================================================================================================#

function runsim(par)

    # Set aside space for results
    result = DataFrames.DataFrame(
        estimator = [], rel_loss = [], true_pos = [], true_neg = [], false_pos = [], 
        false_neg = [], f1_score = [], sparsity = [], loss = [], n = [], p = [], m = [], 
        mu_var = [], rho = [], s_min = [], s_max = [], id = []
    )

    # Generate data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test, 
        beta_test = gendata(par)

    # Save scenario parameters
    loss, n, p, m, _, _, _, _, _ = par
    
    # Set network configuration
    n_neuron = round(Int, 1 / 4 * (sqrt((m + p + 3) ^ 2 - 8 * p + 8 * (m * p * 32)) - m - p - 3))
    hidden_layers = repeat([n_neuron], 3)

    # Evaluate deep neural network
    fit = ContextualLasso.classo(
        ones(n, 1), hcat(x_train, z_train), y_train, 
        ones(n, 1), hcat(x_valid, z_valid), y_valid, 
        intercept = false, lambda = Inf, standardise_x = false, verbose = false, 
        loss = loss, hidden_layers = hidden_layers
    )
    beta_hat = missing
    y_hat = ContextualLasso.predict(fit, ones(n, 1), hcat(x_test, z_test))
    evaluate!(result, "Deep neural network", y_hat, beta_hat, y_test, beta_test, par)

    # Evaluate contextual linear model
    fit = ContextualLasso.classo(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        lambda = Inf, verbose = false, intercept = false, 
        loss = loss, hidden_layers = hidden_layers
    )
    beta_hat = ContextualLasso.coef(fit, z_test)
    y_hat = ContextualLasso.predict(fit, x_test, z_test)
    evaluate!(result, "Contextual linear model", y_hat, beta_hat, y_test, beta_test, par)

    # Evaluate lasso
    fit = lasso(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        penalty_factor = vcat(ones(p), zeros(m)), intercept = false, 
        loss = loss
    )
    beta_hat = repeat(fit[1:p]', n)
    y_hat = hcat(x_test, z_test) * fit
    evaluate!(result, "Lasso", y_hat, beta_hat, y_test, beta_test, par)

    # Evaluate lassonet
    beta_hat, y_hat = lassonet(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        hcat(x_test, z_test), 
        verbose = false, seed = par["id"], 
        loss = loss, hidden_layers = hidden_layers
    )
    beta_hat = repeat(beta_hat[1:p]', n)
    evaluate!(result, "Lassonet", y_hat, beta_hat, y_test, beta_test, par)

    # Evaluate LLSPIN
    beta_hat, y_hat = llspin(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        hcat(x_test, z_test), 
        verbose = false, seed = par["id"], 
        loss = loss, hidden_layers = hidden_layers
    )
    beta_hat = beta_hat[:, 1:p]
    evaluate!(result, "LLSPIN", y_hat, beta_hat, y_test, beta_test, par)

    # Evaluate contextual lasso
    fit = ContextualLasso.classo(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        verbose = false, intercept = false, relax = true,
        loss = loss, hidden_layers = hidden_layers
    )
    beta_hat = ContextualLasso.coef(fit, z_test)
    y_hat = ContextualLasso.predict(fit, x_test, z_test)
    evaluate!(result, "Contextual lasso", y_hat, beta_hat, y_test, beta_test, par)

    # Evaluate contextual lasso (non-relaxed)
    index_min = argmin(fit.val_loss)
    index_1se = findall(fit.val_loss .<= fit.val_loss[index_min] + fit.val_loss_se[index_min])
    index_1se = sort(index_1se, by = i -> fit.lambda[i])[1]
    lambda_min = fit.lambda[index_min]
    lambda_1se = fit.lambda[index_1se]
    gamma_min = 0.0
    gamma_1se = 0.0
    beta_hat = ContextualLasso.coef(fit, z_test, lambda = lambda_1se, gamma = gamma_1se)
    y_hat = ContextualLasso.predict(fit, x_test, z_test, lambda = lambda_1se, gamma = gamma_1se)
    evaluate!(result, "Contextual lasso (non-relaxed)", y_hat, beta_hat, y_test, beta_test, par)

end

#==================================================================================================#
# Run simulations
#==================================================================================================#

# Specify simulation parameters
simulations = 
vcat(
    DataFrames.DataFrame(
        (loss = loss, n = n, p = p, m = m, mu_var = mu_var, rho = rho, s_min = s_min, s_max = s_max, 
            id = id) for
        loss = [Flux.logitbinarycrossentropy],
        n = round.(Int, exp.(range(log(100), log(100000), 10))), # Number of samples
        p = 10, # Number of explanatory features
        m = 2, # Number of contextual features
        mu_var = 5, # Signal-to-noise ratio
        rho = 0.5, # Correlation coefficient
        s_min = 0.05, # Minimal sparsity level
        s_max = 0.15, # Maximal sparsity level
        id = 1:10 # Simulation run ID
    ),
    DataFrames.DataFrame(
        (loss = loss, n = n, p = p, m = m, mu_var = mu_var, rho = rho, s_min = s_min, s_max = s_max, 
            id = id) for
        loss = [Flux.logitbinarycrossentropy],
        n = round.(Int, exp.(range(log(100), log(100000), 10))), # Number of samples
        p = 50, # Number of explanatory features
        m = [2, 5], # Number of contextual features
        mu_var = 5, # Signal-to-noise ratio
        rho = 0.5, # Correlation coefficient
        s_min = 0.05, # Minimal sparsity level
        s_max = 0.15, # Maximal sparsity level
        id = 1:10 # Simulation run ID
    )
)

# Run all simulations
# CUDA.jl is not reproducible with default rng
rng = Random.MersenneTwister(2023); Random.default_rng() = rng # Random.seed!(2023)
result = ProgressMeter.@showprogress map(runsim, eachrow(simulations))
result = reduce(vcat, result)
CSV.write("Results/classification.csv", result)