# =================================================================================================#
# Description: Produces the experimental results for the comparisons of estimators on the news
# popularity data
# Author: Ryan Thompson
# =================================================================================================#

include("Estimators/lasso.jl")
include("Estimators/lassonet.jl")
include("Estimators/llspin.jl")
include("Estimators/contextual_lasso.jl")

import CSV, DataFrames, Flux, ProgressMeter, Random, Statistics, StatsBase

#==================================================================================================#
# Function to generate data
#==================================================================================================#

function gendata(par)

    # Save scenario parameters
    perc_train, perc_valid, id = par

    # Load data
    data = CSV.read("Data/OnlineNewsPopularity.csv", DataFrames.DataFrame)
    DataFrames.rename!(data, strip.(names(data)))

    # Remove corrupted observations
    data = filter(:n_unique_tokens => x -> x <= 1, data)
    data = filter(:n_non_stop_words => x -> x <= 1, data)
    data = filter(:n_non_stop_unique_tokens => x -> x <= 1, data)

    # Save data dimension
    n = size(data, 1)

    # Dichotomise response
    DataFrames.transform!(data, :shares => DataFrames.ByRow(x -> x > 1400) => :popular)

    # Extract explanatory features, contextual features, and response
    x = DataFrames.select(
        data, DataFrames.Not([
        :url, # Not predictive
        :timedelta, # Not predictive
        :shares,
        :popular,
        :data_channel_is_lifestyle,
        :data_channel_is_entertainment, 
        :data_channel_is_bus, 
        :data_channel_is_socmed, 
        :data_channel_is_tech, 
        :data_channel_is_world,
        :weekday_is_sunday
        ]))
    z = DataFrames.select(
        data, 
        :data_channel_is_lifestyle,
        :data_channel_is_entertainment, 
        :data_channel_is_bus, 
        :data_channel_is_socmed, 
        :data_channel_is_tech,
        :data_channel_is_world
        )
    y = DataFrames.select(data, :popular)

    # Generate indices of training, validation, and testing sets
    id = 1:n
    train_id = StatsBase.sample(id, round(Int64, n * perc_train), replace = false)
    id = setdiff(id, train_id)
    valid_id = StatsBase.sample(id, round(Int64, n * perc_valid), replace = false)
    id = setdiff(id, valid_id)
    test_id = id

    # Generate training, validation, and testing sets
    x_train = Matrix(x[train_id, :])
    z_train = Matrix(z[train_id, :])
    y_train = Matrix(y[train_id, :])[:]
    x_valid = Matrix(x[valid_id, :])
    z_valid = Matrix(z[valid_id, :])
    y_valid = Matrix(y[valid_id, :])[:]
    x_test = Matrix(x[test_id, :])
    z_test = Matrix(z[test_id, :])
    y_test = Matrix(y[test_id, :])[:]

    # Return generated data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test

end

#==================================================================================================#
# Function to evaluate a model
#==================================================================================================#

function evaluate!(result, estimator, y_hat, beta_hat, y_test, y_train, par)

    # Save scenario parameters
    perc_train, perc_valid, id = par

    # Compute prediction metrics
    y_bench = fill(log(Statistics.mean(y_train) / (1 - Statistics.mean(y_train))), length(y_test))
    test_loss = Flux.logitbinarycrossentropy(y_hat, y_test) / 
    Flux.logitbinarycrossentropy(y_bench, y_test)

    # Compute selection metrics
    if !ismissing(beta_hat)
        active_hat = findall(abs.(beta_hat[:, 2:end]) .> eps())
        test_sparsity = length(active_hat) / size(beta_hat, 1)
    else
        test_sparsity = missing
    end

    # Update results
    push!(result, [estimator, test_loss, test_sparsity, perc_train, perc_valid, id])

end

#==================================================================================================#
# Function to run a given simulation design
#==================================================================================================#

function runsim(par)

    # Set aside space for results
    result = DataFrames.DataFrame(
        estimator = [], test_loss = [], test_sparsity = [], perc_train = [], perc_valid = [], 
        id = []
    )

    # Generate data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test = gendata(par)

    # Save data dimensions
    n_train, n_valid, n_test = length(y_train), length(y_valid), length(y_test)
    m, p = size(z_train, 2), size(x_train, 2)

    # Set network configuration
    n_neuron = round(Int, 1 / 4 * (sqrt((m + p + 3) ^ 2 - 8 * p + 8 * (m * p * 32)) - m - p - 3))
    hidden_layers = repeat([n_neuron], 3)

    # Evaluate deep neural network
    fit = ContextualLasso.classo(
        ones(n_train, 1), hcat(x_train, z_train), y_train, 
        ones(n_valid, 1), hcat(x_valid, z_valid), y_valid, 
        intercept = false, lambda = Inf, standardise_x = false, verbose = false, 
        loss = Flux.logitbinarycrossentropy, hidden_layers = hidden_layers
    )
    y_hat = ContextualLasso.predict(fit, ones(n_test, 1), hcat(x_test, z_test))
    beta_hat = missing
    evaluate!(result, "Deep neural network", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate contextual linear model
    fit = ContextualLasso.classo(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        lambda = Inf, verbose = false, 
        loss = Flux.logitbinarycrossentropy, hidden_layers = hidden_layers
    )
    y_hat = ContextualLasso.predict(fit, x_test, z_test)
    beta_hat = ContextualLasso.coef(fit, z_test)
    evaluate!(result, "Contextual linear model", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate lasso
    fit = lasso(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        penalty_factor = vcat(ones(p), zeros(m)), 
        loss = Flux.logitbinarycrossentropy
    )
    beta_hat = repeat(fit[1:p + 1]', n_test)
    y_hat = hcat(ones(n_test), x_test, z_test) * fit
    evaluate!(result, "Lasso", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate lassonet
    beta_hat, y_hat = lassonet(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        hcat(x_test, z_test), 
        verbose = false, seed = par["id"], 
        loss = Flux.logitbinarycrossentropy, hidden_layers = hidden_layers
    )
    beta_hat = repeat(vcat(0, beta_hat[1:p])', n_test)
    evaluate!(result, "Lassonet", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate LLSPIN
    beta_hat, y_hat = llspin(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        hcat(x_test, z_test), 
        verbose = false, seed = par["id"], 
        loss = Flux.logitbinarycrossentropy, hidden_layers = hidden_layers
    )
    beta_hat = hcat(zeros(n_test), beta_hat[:, 1:p])
    evaluate!(result, "LLSPIN", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate contextual lasso
    fit = ContextualLasso.classo(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        verbose = false, relax = true, 
        loss = Flux.logitbinarycrossentropy, hidden_layers = hidden_layers
    )
    y_hat = ContextualLasso.predict(fit, x_test, z_test)
    beta_hat = ContextualLasso.coef(fit, z_test)
    evaluate!(result, "Contextual lasso", y_hat, beta_hat, y_test, y_train, par)

end

#==================================================================================================#
# Run simulations
#==================================================================================================#

# Specify simulation parameters
simulations = DataFrames.DataFrame(
    (perc_train, perc_valid, id = id) for
    perc_train = 0.6, # Portion of samples to use for training set
    perc_valid = 0.2, # Portion of samples to use for validation set
    id = 1:10 # Scenario ID
)

# Run all simulations
# CUDA.jl is not reproducible with default rng
rng = Random.MersenneTwister(2023); Random.default_rng() = rng # Random.seed!(2023)
result = ProgressMeter.@showprogress map(runsim, eachrow(simulations))
result = reduce(vcat, result)
CSV.write("Results/news.csv", result)