# =================================================================================================#
# Description: Produces the experimental results for the comparisons of estimators on the energy
# consumption data
# Author: Ryan Thompson
# =================================================================================================#

include("Estimators/lasso.jl")
include("Estimators/lassonet.jl")
include("Estimators/llspin.jl")
include("Estimators/contextual_lasso.jl")

import CSV, DataFrames, Dates, Flux, ProgressMeter, Random, Statistics, StatsBase

#==================================================================================================#
# Function to generate data
#==================================================================================================#

function gendata(par)

    # Save scenario parameters
    perc_train, perc_valid, id = par

    # Load data
    data = CSV.read("Data/energydata_complete.csv", DataFrames.DataFrame)
    DataFrames.transform!(data, :date => DataFrames.ByRow(x -> Dates.DateTime(x, 
        Dates.dateformat"yyyy-mm-dd HH:MM:SS")) => :date)

    # Save data dimension
    n = size(data, 1)

    # Create functions to extract year, week, and day from time stamp
    yearly(x) = (x - floor(x, Dates.Year)).value / 3.154e+10
    weekly(x) = (x - floor(x, Dates.Week)).value / 6.048e+8
    daily(x) = (x - floor(x, Dates.Day)).value / 8.64e+7

    # Extract explanatory features, contextual features, and response
    x = DataFrames.select(data, DataFrames.Not([:date, :Appliances, :rv1, :rv2]))
    z = DataFrames.select(
        data, 
        :date => DataFrames.ByRow(x -> float(weekly(x) > 5 / 7)) => :weekend,
        :date => DataFrames.ByRow(x -> cospi(2 * yearly(x))) => :monthcos,
        :date => DataFrames.ByRow(x -> sinpi(2 * yearly(x))) => :monthsin,
        :date => DataFrames.ByRow(x -> cospi(2 * weekly(x))) => :daycos,
        :date => DataFrames.ByRow(x -> sinpi(2 * weekly(x))) => :daysin,
        :date => DataFrames.ByRow(x -> cospi(2 * daily(x))) => :hourcos,
        :date => DataFrames.ByRow(x -> sinpi(2 * daily(x))) => :hoursin
        )
    y = DataFrames.select(data, :Appliances => DataFrames.ByRow(log) => :Appliances)

    # Generate indices of training, validation, and testing sets
    id = 1:n
    train_id = StatsBase.sample(id, round(Int64, n * perc_train), replace = false)
    id = setdiff(id, train_id)
    valid_id = StatsBase.sample(id, round(Int64, n * perc_valid), replace = false)
    id = setdiff(id, valid_id)
    test_id = id

    # Generate training, validation, and testing sets
    x_train = Matrix(x[train_id, :])
    z_train = Matrix(z[train_id, :])
    y_train = Matrix(y[train_id, :])[:]
    x_valid = Matrix(x[valid_id, :])
    z_valid = Matrix(z[valid_id, :])
    y_valid = Matrix(y[valid_id, :])[:]
    x_test = Matrix(x[test_id, :])
    z_test = Matrix(z[test_id, :])
    y_test = Matrix(y[test_id, :])[:]

    # Return generated data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test

end

#==================================================================================================#
# Function to evaluate a model
#==================================================================================================#

function evaluate!(result, estimator, y_hat, beta_hat, y_test, y_train, par)

    # Save scenario parameters
    perc_train, perc_valid, id = par

    # Compute prediction metrics
    y_bench = Statistics.mean(y_train)
    test_loss = Flux.mse(y_hat, y_test) / Flux.mse(y_bench, y_test)

    # Compute selection metrics
    if !ismissing(beta_hat)
        active_hat = findall(abs.(beta_hat[:, 2:end]) .> eps())
        test_sparsity = length(active_hat) / size(beta_hat, 1)
    else
        test_sparsity = missing
    end

    # Update results
    push!(result, [estimator, test_loss, test_sparsity, perc_train, perc_valid, id])

end

#==================================================================================================#
# Function to run a given simulation design
#==================================================================================================#

function runsim(par)

    # Set aside space for results
    result = DataFrames.DataFrame(
        estimator = [], test_loss = [], test_sparsity = [], perc_train = [], perc_valid = [], 
        id = []
    )

    # Generate data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test = gendata(par)

    # Save data dimensions
    n_train, n_valid, n_test = length(y_train), length(y_valid), length(y_test)
    m, p = size(z_train, 2), size(x_train, 2)

    # Set network configuration
    n_neuron = round(Int, 1 / 4 * (sqrt((m + p + 3) ^ 2 - 8 * p + 8 * (m * p * 32)) - m - p - 3))
    hidden_layers = repeat([n_neuron], 3)

    # Evaluate deep neural network
    fit = ContextualLasso.classo(
        ones(n_train, 1), hcat(x_train, z_train), y_train, 
        ones(n_valid, 1), hcat(x_valid, z_valid), y_valid, 
        intercept = false, lambda = Inf, standardise_x = false, verbose = false, 
        hidden_layers = hidden_layers
    )
    y_hat = ContextualLasso.predict(fit, ones(n_test, 1), hcat(x_test, z_test))
    beta_hat = missing
    evaluate!(result, "Deep neural network", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate contextual linear model
    fit = ContextualLasso.classo(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        lambda = Inf, verbose = false, 
        hidden_layers = hidden_layers
    )
    y_hat = ContextualLasso.predict(fit, x_test, z_test)
    beta_hat = ContextualLasso.coef(fit, z_test)
    evaluate!(result, "Contextual linear model", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate lasso
    fit = lasso(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        penalty_factor = vcat(ones(p), zeros(m))
    )
    beta_hat = repeat(fit[1:p + 1]', n_test)
    y_hat = hcat(ones(n_test), x_test, z_test) * fit
    evaluate!(result, "Lasso", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate lassonet
    beta_hat, y_hat = lassonet(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        hcat(x_test, z_test), 
        verbose = false, seed = par["id"], 
        hidden_layers = hidden_layers
    )
    beta_hat = repeat(vcat(0, beta_hat[1:p])', n_test)
    evaluate!(result, "Lassonet", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate LLSPIN
    beta_hat, y_hat = llspin(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        hcat(x_test, z_test), 
        verbose = false, seed = par["id"], 
        hidden_layers = hidden_layers
    )
    beta_hat = hcat(zeros(n_test), beta_hat[:, 1:p])
    evaluate!(result, "LLSPIN", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate contextual lasso
    fit = ContextualLasso.classo(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        verbose = false, relax = true, 
        hidden_layers = hidden_layers
    )
    y_hat = ContextualLasso.predict(fit, x_test, z_test)
    beta_hat = ContextualLasso.coef(fit, z_test)
    evaluate!(result, "Contextual lasso", y_hat, beta_hat, y_test, y_train, par)

    result

end

#==================================================================================================#
# Run simulations
#==================================================================================================#

# Specify simulation parameters
simulations = DataFrames.DataFrame(
    (perc_train, perc_valid, id = id) for
    perc_train = 0.6, # Portion of samples to use for training set
    perc_valid = 0.2, # Portion of samples to use for validation set
    id = 1:10 # Scenario ID
)

# Run all simulations
# CUDA.jl is not reproducible with default rng
rng = Random.MersenneTwister(2023); Random.default_rng() = rng # Random.seed!(2023)
result = ProgressMeter.@showprogress map(runsim, eachrow(simulations))
result = reduce(vcat, result)
CSV.write("Results/energy.csv", result)