# =================================================================================================#
# Description: Produces the sparsity levels for the Parkinson's telemonitoring data
# Author: Ryan Thompson
# =================================================================================================#

include("Estimators/contextual_lasso.jl")

import Cairo, ColorSchemes, CSV, DataFrames, Dates, Fontconfig, Gadfly, Random, RCall, StatsBase

# CUDA.jl is not reproducible with default rng
rng = Random.MersenneTwister(2023); Random.default_rng() = rng # Random.seed!(2023)

# Load data
data = CSV.read("Data/parkinsons_updrs.data", DataFrames.DataFrame)
DataFrames.rename!(data, "subject#" => :subject)

# Save data dimension
n = size(data, 1)

# Extract explanatory features, contextual features, and response
x = DataFrames.select(data, DataFrames.Not([:subject, :age, :sex, :test_time, :motor_UPDRS, :total_UPDRS]))
z = DataFrames.select(data, :age, :sex)
z = Float64.(z)
y = DataFrames.select(data, :motor_UPDRS)

# Generate indices of training, validation, and testing sets
id = 1:n
train_id = StatsBase.sample(id, round(Int64, n * 0.6), replace = false)
id = setdiff(id, train_id)
valid_id = StatsBase.sample(id, round(Int64, n * 0.2), replace = false)
id = setdiff(id, valid_id)
test_id = id

# Generate training, validation, and testing sets
x_train = Matrix(x[train_id, :])
z_train = Matrix(z[train_id, :])
y_train = Matrix(y[train_id, :])[:]
x_valid = Matrix(x[valid_id, :])
z_valid = Matrix(z[valid_id, :])
y_valid = Matrix(y[valid_id, :])[:]
x_test = Matrix(x[test_id, :])
z_test = Matrix(z[test_id, :])

# Set network configuration
m, p = size(z_train, 2), size(x_train, 2)
n_neuron = round(Int, 1 / 4 * (sqrt((m + p + 3) ^ 2 - 8 * p + 8 * (m * p * 32)) - m - p - 3))
hidden_layers = repeat([n_neuron], 3)

# Fit contextual lasso
knots = RCall.rcopy(RCall.R"apply($x_train, 2, \(x) quantile(x, seq(0, 1, l = 4)))")
x_train_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:$p, \(j) npreg::basis_tps($x_train[, j], $knots[, j], rk = F)))")
x_valid_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:$p, \(j) npreg::basis_tps($x_valid[, j], $knots[, j], rk = F)))")
x_test_spline = RCall.rcopy(RCall.R"do.call(cbind, lapply(1:$p, \(j) npreg::basis_tps($x_test[, j], $knots[, j], rk = F)))")
group = repeat(1:p,inner = 5)
fit = ContextualLasso.classo(
    x_train_spline, z_train, y_train, 
    x_valid_spline, z_valid, y_valid, 
    verbose = false, relax = true, 
    hidden_layers = hidden_layers, group = group
)

j = 15
j1 = (j - 1) * 5 + 1:j * 5
j2 = (j - 1) * 5 + 2:j * 5 + 1

df = vcat(
    DataFrames.DataFrame(
        x = x_test[:, j],
        fhat = x_test_spline[:, j1] * ContextualLasso.coef(fit, [70.0 1.0])[1, j2],
        z = "70-year-old"
    ),
    DataFrames.DataFrame(
        x = x_test[:, j],
        fhat = x_test_spline[:, j1] * ContextualLasso.coef(fit, [75.0 1.0])[1, j2],
        z = "75-year-old"
    ),
    DataFrames.DataFrame(
        x = x_test[:, j],
        fhat = x_test_spline[:, j1] * ContextualLasso.coef(fit, [80.0 1.0])[1, j2],
        z = "80-year-old"
    )
)

Gadfly.plot(
    df, xgroup = :z, x = :x, y = :fhat,
    Gadfly.Guide.xlabel("Detrended flucation analysis"),
    Gadfly.Guide.ylabel("f̂(x)"),
    Gadfly.Geom.subplot_grid(Gadfly.Geom.line()),
    Gadfly.Theme(default_color = "black", line_width = 4Gadfly.pt, plot_padding = [0Gadfly.mm]),
) |> 
Gadfly.PDF("Figures/parkinsons.pdf", 7Gadfly.inch, 2.5Gadfly.inch)