# =================================================================================================#
# Description: Produces the experimental results for the timings
# Author: Ryan Thompson
# =================================================================================================#

include("Estimators/lasso.jl")
include("Estimators/lassonet.jl")
include("Estimators/llspin.jl")
include("Estimators/contextual_lasso.jl")

import BenchmarkTools, CSV, DataFrames, Distributions, Flux, LinearAlgebra, ProgressMeter, Random, 
    Statistics

#==================================================================================================#
# Function to generate data
#==================================================================================================#

function gendata(par)

    # Save scenario parameters
    loss, n, p, m, mu_var, rho, s_min, s_max, id = par

    # Generate explanatory features
    Sigma = [rho ^ abs(i - j) for i in 1:p, j in 1:p]
    x_train = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))
    x_valid = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))
    x_test = permutedims(rand(Distributions.MvNormal(zeros(p), Sigma), n))

    # Generate contextual features
    z_train = rand(Distributions.Uniform(- 1, 1), n, m)
    z_valid = rand(Distributions.Uniform(- 1, 1), n, m)
    z_test = rand(Distributions.Uniform(- 1, 1), n, m)

    # Generate coefficients
    s = rand(Distributions.Uniform(s_min, s_max), 10)
    c = rand(Distributions.Uniform(- 1, 1), 10, m)
    r = map(j -> Statistics.quantile([LinearAlgebra.norm(z_train[i, :] - c[j, :], 2)
        for i in 1:n], s[j]), 1:10)
    beta(z, c, r) = (1 - 0.5 * LinearAlgebra.norm(z - c, 2) / r) * 
        (LinearAlgebra.norm(z - c, 2) <= r)
    beta_train = [beta(z_train[i, :], c[j, :], r[j]) for i in 1:n, j in 1:10]
    beta_valid = [beta(z_valid[i, :], c[j, :], r[j]) for i in 1:n, j in 1:10]
    beta_test = [beta(z_test[i, :], c[j, :], r[j]) for i in 1:n, j in 1:10]

    beta_train = hcat(beta_train, zeros(n, p - 10))
    beta_valid = hcat(beta_valid, zeros(n, p - 10))
    beta_test = hcat(beta_test, zeros(n, p - 10))

    # Generate response
    mu_train = vec(sum(x_train .* beta_train, dims = 2))
    mu_valid = vec(sum(x_valid .* beta_valid, dims = 2))
    mu_test = vec(sum(x_test .* beta_test, dims = 2))

    kappa = sqrt(mu_var) / Statistics.std(mu_train, corrected = false)
    mu_train *= kappa
    mu_valid *= kappa
    mu_test *= kappa

    if loss == Flux.mse
        y_train = rand.(Distributions.Normal.(mu_train, 1))
        y_valid = rand.(Distributions.Normal.(mu_valid, 1))
        y_test = rand.(Distributions.Normal.(mu_test, 1))
    else
        y_train = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_train)))))
        y_valid = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_valid)))))
        y_test = float.(rand.(Distributions.Bernoulli.(1 ./ (1 .+ exp.(- mu_test)))))
    end

    # Return generated data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test, beta_test

end

#==================================================================================================#
# Function to evaluate a model
#==================================================================================================#

function evaluate!(result, estimator, time, par)

    # Save scenario parameters
    loss, n, p, m, mu_var, rho, s_min, s_max, id = par

    # Update results
    push!(result, [estimator, time, loss, n, p, m, mu_var, rho, s_min, s_max, id])

end

#==================================================================================================#
# Function to run a given simulation design
#==================================================================================================#

function runsim(par)

    # Set aside space for results
    result = DataFrames.DataFrame(
        estimator = [], time = [], loss = [], n = [], p = [], m = [], mu_var = [], rho = [], 
        s_min = [], s_max = [], id = []
    )

    # Generate data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test, 
        beta_test = gendata(par)

    # Save scenario parameters
    loss, n, p, m, _, _, _, _, _ = par
    
    # Set network configuration
    # n_neuron = round(Int, 1 / 4 * (sqrt((m + p + 3) ^ 2 - 8 * p + 8 * (m * p * 32)) - m - p - 3))
    # hidden_layers = repeat([n_neuron], 3)
    hidden_layers = [100, 100, 100]

    # # Evaluate deep neural network
    # time = BenchmarkTools.@belapsed ContextualLasso.classo(
    #     ones($n, 1), hcat($x_train, $z_train), $y_train, 
    #     ones($n, 1), hcat($x_valid, $z_valid), $y_valid, 
    #     intercept = false, lambda = Inf, standardise_x = false, verbose = false, 
    #     loss = $loss, hidden_layers = $hidden_layers
    # )
    # evaluate!(result, "Deep neural network", time, par)

    # # Evaluate contextual linear model
    # time = BenchmarkTools.@belapsed ContextualLasso.classo(
    #     $x_train, $z_train, $y_train, 
    #     $x_valid, $z_valid, $y_valid, 
    #     lambda = Inf, verbose = false, intercept = false, 
    #     loss = $loss, hidden_layers = $hidden_layers
    # )
    # evaluate!(result, "Contextual linear model", time, par)

    # # Evaluate lasso
    # time = BenchmarkTools.@belapsed lasso(
    #     hcat($x_train, $z_train), $y_train, 
    #     hcat($x_valid, $z_valid), $y_valid, 
    #     penalty_factor = vcat(ones($p), zeros($m)), intercept = false, 
    #     loss = $loss
    # )
    # evaluate!(result, "Lasso", time, par)

    # # Evaluate lassonet
    # time = BenchmarkTools.@belapsed lassonet(
    #     hcat($x_train, $z_train), $y_train, 
    #     hcat($x_valid, $z_valid), $y_valid, 
    #     hcat($x_test, $z_test), 
    #     verbose = false, seed = $par["id"], 
    #     loss = $loss, hidden_layers = $hidden_layers
    # )
    # evaluate!(result, "Lassonet", time, par)

    # # Evaluate LLSPIN
    # time = BenchmarkTools.@belapsed llspin(
    #     hcat($x_train, $z_train), $y_train, 
    #     hcat($x_valid, $z_valid), $y_valid, 
    #     hcat($x_test, $z_test), 
    #     verbose = false, seed = $par["id"], 
    #     loss = $loss, hidden_layers = $hidden_layers
    # )
    # evaluate!(result, "LLSPIN", time, par)

    # Evaluate contextual lasso
    time = BenchmarkTools.@belapsed ContextualLasso.classo(
        $x_train, $z_train, $y_train, 
        $x_valid, $z_valid, $y_valid, 
        verbose = false, intercept = false, relax = true,
        loss = $loss, hidden_layers = $hidden_layers
    )
    evaluate!(result, "Contextual lasso", time, par)

    # # Evaluate contextual lasso (non-relaxed)
    # time = BenchmarkTools.@belapsed ContextualLasso.classo(
    #     $x_train, $z_train, $y_train, 
    #     $x_valid, $z_valid, $y_valid, 
    #     verbose = false, intercept = false, relax = false,
    #     loss = $loss, hidden_layers = $hidden_layers
    # )
    # evaluate!(result, "Contextual lasso (non-relaxed)", time, par)

end

#==================================================================================================#
# Run simulations
#==================================================================================================#

# Specify simulation parameters
simulations = 
vcat(
    DataFrames.DataFrame(
        (loss = loss, n = n, p = p, m = m, mu_var = mu_var, rho = rho, s_min = s_min, s_max = s_max, 
            id = id) for
        loss = [Flux.mse],
        n = Int.(range(1000, 10000, 10)), # Number of samples
        p = 50, # Number of explanatory features
        m = 5, # Number of contextual features
        mu_var = 5, # Signal-to-noise ratio
        
        rho = 0.5, # Correlation coefficient
        s_min = 0.05, # Minimal sparsity level
        s_max = 0.15, # Maximal sparsity level
        id = 1:10 # Simulation run ID
    ),
    DataFrames.DataFrame(
        (loss = loss, n = n, p = p, m = m, mu_var = mu_var, rho = rho, s_min = s_min, s_max = s_max, 
            id = id) for
        loss = [Flux.mse],
        n = 1000,#Int.(range(1000, 10000, 10)), # Number of samples
        p = Int.(range(100, 1000, 10)), # Number of explanatory features
        m = 5, # Number of contextual features
        mu_var = 5, # Signal-to-noise ratio
        rho = 0.5, # Correlation coefficient
        s_min = 0.05, # Minimal sparsity level
        s_max = 0.15, # Maximal sparsity level
        id = 1:10 # Simulation run ID
    )
)

# Run all simulations
# CUDA.jl is not reproducible with default rng
rng = Random.MersenneTwister(2023); Random.default_rng() = rng # Random.seed!(2023)
result = ProgressMeter.@showprogress map(runsim, eachrow(simulations))
result = reduce(vcat, result)
CSV.write("Results/timings.csv", result)