# =================================================================================================#
# Description: Produces the experimental results for the comparisons of estimators on the 
# Parkinson's telemonitoring data
# Author: Ryan Thompson
# =================================================================================================#

include("Estimators/lasso.jl")
include("Estimators/glasso.jl")
include("Estimators/lassonet.jl")
include("Estimators/contextual_lasso.jl")

import CSV, DataFrames, Flux, ProgressMeter, Random, RCall, Statistics, StatsBase

#==================================================================================================#
# Function to generate data
#==================================================================================================#

function gendata(par)

    # Save scenario parameters
    perc_train, perc_valid, id = par

    # Load data
    data = CSV.read("Data/parkinsons_updrs.data", DataFrames.DataFrame)
    DataFrames.rename!(data, "subject#" => :subject)

    # Save data dimension
    n = size(data, 1)

    # Extract explanatory features, contextual features, and response
    x = DataFrames.select(data, DataFrames.Not([:subject, :age, :sex, :test_time, :motor_UPDRS, 
        :total_UPDRS]))
    z = DataFrames.select(data, :age, :sex)
    z = Float64.(z)
    y = DataFrames.select(data, :motor_UPDRS)

    # Generate indices of training, validation, and testing sets
    id = 1:n
    train_id = StatsBase.sample(id, round(Int64, n * perc_train), replace = false)
    id = setdiff(id, train_id)
    valid_id = StatsBase.sample(id, round(Int64, n * perc_valid), replace = false)
    id = setdiff(id, valid_id)
    test_id = id

    # Generate training, validation, and testing sets
    x_train = Matrix(x[train_id, :])
    z_train = Matrix(z[train_id, :])
    y_train = Matrix(y[train_id, :])[:]
    x_valid = Matrix(x[valid_id, :])
    z_valid = Matrix(z[valid_id, :])
    y_valid = Matrix(y[valid_id, :])[:]
    x_test = Matrix(x[test_id, :])
    z_test = Matrix(z[test_id, :])
    y_test = Matrix(y[test_id, :])[:]

    # Return generated data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test

end

#==================================================================================================#
# Function to evaluate a model
#==================================================================================================#

function evaluate!(result, estimator, y_hat, beta_hat, y_test, y_train, par)

    # Save scenario parameters
    perc_train, perc_valid, id = par

    # Compute prediction metrics
    y_bench = Statistics.mean(y_train)
    test_loss = Flux.mse(y_hat, y_test) / Flux.mse(y_bench, y_test)

    # Compute selection metrics
    if !ismissing(beta_hat)
        active_hat = findall(abs.(beta_hat[:, 2:end]) .> eps())
        test_sparsity = length(active_hat) / size(beta_hat, 1)
    else
        test_sparsity = missing
    end

    # Update results
    push!(result, [estimator, test_loss, test_sparsity, perc_train, perc_valid, id])

end

#==================================================================================================#
# Function to run a given simulation design
#==================================================================================================#

function runsim(par)

    # Set aside space for results
    result = DataFrames.DataFrame(
        estimator = [], test_loss = [], test_sparsity = [], perc_train = [], perc_valid = [], 
        id = []
    )

    # Generate data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test = gendata(par)

    # Save data dimensions
    n_train, n_valid, n_test = length(y_train), length(y_valid), length(y_test)
    m, p = size(z_train, 2), size(x_train, 2)

    # Set network configuration
    n_neuron = round(Int, 1 / 4 * (sqrt((m + p + 3) ^ 2 - 8 * p + 8 * (m * p * 32)) - m - p - 3))
    hidden_layers = repeat([n_neuron], 3)

    # Evaluate deep neural network
    fit = ContextualLasso.classo(
        ones(n_train, 1), hcat(x_train, z_train), y_train, 
        ones(n_valid, 1), hcat(x_valid, z_valid), y_valid, 
        intercept = false, lambda = Inf, standardise_x = false, verbose = false, 
        hidden_layers = hidden_layers
    )
    y_hat = ContextualLasso.predict(fit, ones(n_test, 1), hcat(x_test, z_test))
    beta_hat = missing
    evaluate!(result, "Deep neural network", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate lasso
    fit = lasso(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        penalty_factor = vcat(ones(p), zeros(m))
    )
    beta_hat = repeat(fit[1:p + 1]', n_test)
    y_hat = hcat(ones(n_test), x_test, z_test) * fit
    evaluate!(result, "Lasso", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate group lasso
    xz_train, xz_valid, xz_test = hcat(x_train, z_train[:, 1]), hcat(x_valid, z_valid[:, 1]), hcat(x_test, z_test[:, 1])
    knots = RCall.rcopy(RCall.R"apply($xz_train, 2, \(x) quantile(x, seq(0, 1, l = 4)))")
    xz_train_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:($p + $m - 1), \(j) npreg::basis_tps($xz_train[, j], $knots[, j], rk = F)))")
    xz_valid_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:($p + $m - 1), \(j) npreg::basis_tps($xz_valid[, j], $knots[, j], rk = F)))")
    xz_test_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:($p + $m - 1), \(j) npreg::basis_tps($xz_test[, j], $knots[, j], rk = F)))")
    xz_train_spline, xz_valid_spline, xz_test_spline = hcat(xz_train_spline, z_train[:, 2]), hcat(xz_valid_spline, z_valid[:, 2]), hcat(xz_test_spline, z_test[:, 2])
    group = vcat(repeat(1:p + m - 1, inner = 5), p + m)
    fit = glasso(
        xz_train_spline, y_train, 
        xz_valid_spline, y_valid, 
        group = group
    )
    beta_hat = repeat(hcat(fit[1], sum(reshape(fit[2:p * 5 + 1], 5, :), dims = 1)), n_test)
    y_hat = hcat(ones(n_test), xz_test_spline) * fit
    evaluate!(result, "Group lasso", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate lassonet
    beta_hat, y_hat = lassonet(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        hcat(x_test, z_test), 
        verbose = false, seed = par["id"], 
        hidden_layers = hidden_layers
    )
    beta_hat = repeat(vcat(0, beta_hat[1:p])', n_test)
    evaluate!(result, "Lassonet", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate contextual group lasso
    knots = RCall.rcopy(RCall.R"apply($x_train, 2, \(x) quantile(x, seq(0, 1, l = 4)))")
    x_train_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:$p, \(j) npreg::basis_tps($x_train[, j], $knots[, j], rk = F)))")
    x_valid_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:$p, \(j) npreg::basis_tps($x_valid[, j], $knots[, j], rk = F)))")
    x_test_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:$p, \(j) npreg::basis_tps($x_test[, j], $knots[, j], rk = F)))")
    group = repeat(1:p, inner = 5)
    fit = ContextualLasso.classo(
        x_train_spline, z_train, y_train, 
        x_valid_spline, z_valid, y_valid, 
        verbose = false, relax = true, 
        hidden_layers = hidden_layers, group = group
    )
    y_hat = ContextualLasso.predict(fit, x_test_spline, z_test)
    beta_hat = ContextualLasso.coef(fit, z_test)
    beta_hat = hcat(beta_hat[:, 1], reduce(vcat, map(i -> sum(reshape(beta_hat[i, 2:p * 5 + 1], 5, :), dims = 1), 1:n_test)))
    evaluate!(result, "Contextual group lasso", y_hat, beta_hat, y_test, y_train, par)

    state = copy(Random.default_rng())

    # Evaluate deep neural network (spline)
    knots = RCall.rcopy(RCall.R"apply($x_train, 2, \(x) quantile(x, seq(0, 1, l = 4)))")
    x_train_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:$p, \(j) npreg::basis_tps($x_train[, j], $knots[, j], rk = F)))")
    x_valid_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:$p, \(j) npreg::basis_tps($x_valid[, j], $knots[, j], rk = F)))")
    x_test_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:$p, \(j) npreg::basis_tps($x_test[, j], $knots[, j], rk = F)))")
    fit = ContextualLasso.classo(
        ones(n_train, 1), hcat(x_train_spline, z_train), y_train, 
        ones(n_valid, 1), hcat(x_valid_spline, z_valid), y_valid, 
        intercept = false, lambda = Inf, standardise_x = false, verbose = false, 
        hidden_layers = hidden_layers
    )
    y_hat = ContextualLasso.predict(fit, ones(n_test, 1), hcat(x_test_spline, z_test))
    beta_hat = missing
    evaluate!(result, "Deep neural network (spline)", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate lassonet (spline)
    knots = RCall.rcopy(RCall.R"apply($x_train, 2, \(x) quantile(x, seq(0, 1, l = 4)))")
    x_train_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:$p, \(j) npreg::basis_tps($x_train[, j], $knots[, j], rk = F)))")
    x_valid_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:$p, \(j) npreg::basis_tps($x_valid[, j], $knots[, j], rk = F)))")
    x_test_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:$p, \(j) npreg::basis_tps($x_test[, j], $knots[, j], rk = F)))")
    group = vcat(repeat(1:p, inner = 5), 1 + p:m + p)
    group = [findall(x -> x == i, group) for i in 1:(p + m)]
    group = [g .- 1 for g in group]
    beta_hat, y_hat = lassonet(
        hcat(x_train_spline, z_train), y_train, 
        hcat(x_valid_spline, z_valid), y_valid, 
        hcat(x_test_spline, z_test), 
        verbose = false, seed = par["id"], 
        hidden_layers = hidden_layers, group = group
    )
    beta_hat = repeat(vcat(0, beta_hat[1:p * 5])', n_test)
    beta_hat = hcat(beta_hat[:, 1], reduce(vcat, map(i -> sum(reshape(beta_hat[i, 2:p * 5 + 1], 5, :), dims = 1), 1:n_test)))
    evaluate!(result, "Lassonet (spline)", y_hat, beta_hat, y_test, y_train, par)

    copy!(Random.default_rng(), state)

    result

end

#==================================================================================================#
# Run simulations
#==================================================================================================#

# Specify simulation parameters
simulations = DataFrames.DataFrame(
    (perc_train, perc_valid, id = id) for
    perc_train = 0.6, # Portion of samples to use for training set
    perc_valid = 0.2, # Portion of samples to use for validation set
    id = 1:10 # Scenario ID
)

# Run all simulations
# CUDA.jl is not reproducible with default rng
rng = Random.MersenneTwister(2023); Random.default_rng() = rng # Random.seed!(2023)
result = ProgressMeter.@showprogress map(runsim, eachrow(simulations))
result = reduce(vcat, result)
CSV.write("Results/parkinsons.csv", result)