# =================================================================================================#
# Description: Produces the experimental results for the comparisons of estimators on the house
# pricing data
# Author: Ryan Thompson
# =================================================================================================#

include("Estimators/lasso.jl")
include("Estimators/lassonet.jl")
include("Estimators/llspin.jl")
include("Estimators/contextual_lasso.jl")

import CSV, DataFrames, Flux, ProgressMeter, Random, Statistics, StatsBase

#==================================================================================================#
# Function to generate data
#==================================================================================================#

function gendata(par)

    # Save scenario parameters
    n_train, n_valid, n_test, id = par

    # Load data
    data = CSV.read("Data/new.csv", DataFrames.DataFrame)

    # Remove corrupted observations
    data = filter(:bathRoom => x -> in(x, string.(0:7)), data)
    data.bathRoom = parse.(Float64, data.bathRoom)
    data.livingRoom = parse.(Float64, data.livingRoom)

    # Save data dimension
    n = size(data, 1)

    # Extract explanatory features, contextual features, and response
    x = DataFrames.select(
        data, 
        :elevator, 
        :floor => DataFrames.ByRow(x -> parse(Float64, x[end - 1:end])), 
        :renovationCondition => DataFrames.ByRow(x -> max(x - 2.0, 0.0)) => :renovationCondition,
        # :buildingType => DataFrames.ByRow(x -> float(x == 1)) => :buildingType, 
        :livingRoom,
        :bathRoom
    )
    z = DataFrames.select(
        data, 
        :Lng => :Longitude, 
        :Lat => :Latitude
        )
    y = DataFrames.select(data, [:totalPrice, :square] => DataFrames.ByRow(/) => :price)

    # Remove missing type from DataFrame
    DataFrames.disallowmissing!(x)

    # Generate indices of training, validation, and testing sets
    id = 1:n
    train_id = StatsBase.sample(id, n_train, replace = false)
    id = setdiff(id, train_id)
    valid_id = StatsBase.sample(id, n_valid, replace = false)
    test_id = setdiff(id, valid_id)
    test_id = StatsBase.sample(id, n_test, replace = false)

    # Generate training, validation, and testing sets
    x_train = Matrix(x[train_id, :])
    z_train = Matrix(z[train_id, :])
    y_train = Matrix(y[train_id, :])[:]
    x_valid = Matrix(x[valid_id, :])
    z_valid = Matrix(z[valid_id, :])
    y_valid = Matrix(y[valid_id, :])[:]
    x_test = Matrix(x[test_id, :])
    z_test = Matrix(z[test_id, :])
    y_test = Matrix(y[test_id, :])[:]

    # Return generated data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test

end

#==================================================================================================#
# Function to evaluate a model
#==================================================================================================#

function evaluate!(result, estimator, y_hat, beta_hat, y_test, y_train, par)

    # Save scenario parameters
    n_train, n_valid, n_test, id = par

    # Compute prediction metrics
    y_bench = Statistics.mean(y_train)
    test_loss = Flux.mse(y_hat, y_test) / Flux.mse(y_bench, y_test)

    # Compute selection metrics
    if !ismissing(beta_hat)
        active_hat = findall(abs.(beta_hat[:, 2:end]) .> eps())
        test_sparsity = length(active_hat) / size(beta_hat, 1)
    else
        test_sparsity = missing
    end

    # Update results
    push!(result, [estimator, test_loss, test_sparsity, n_train, n_valid, n_test, id])

end

#==================================================================================================#
# Function to run a given simulation design
#==================================================================================================#

function runsim(par)

    # Set aside space for results
    result = DataFrames.DataFrame(
        estimator = [], test_loss = [], test_sparsity = [], n_train = [], n_valid = [], 
        n_test = [], id = []
    )

    # Generate data
    x_train, z_train, y_train, x_valid, z_valid, y_valid, x_test, z_test, y_test = gendata(par)

    # Save data dimensions
    n_train, n_valid, n_test = length(y_train), length(y_valid), length(y_test)
    m, p = size(z_train, 2), size(x_train, 2)

    # Set network configuration
    n_neuron = round(Int, 1 / 4 * (sqrt((m + p + 3) ^ 2 - 8 * p + 8 * (m * p * 32)) - m - p - 3))
    hidden_layers = repeat([n_neuron], 3)

    # Evaluate deep neural network
    fit = ContextualLasso.classo(
        ones(n_train, 1), hcat(x_train, z_train), y_train, 
        ones(n_valid, 1), hcat(x_valid, z_valid), y_valid, 
        intercept = false, lambda = Inf, standardise_x = false, verbose = false, 
        hidden_layers = hidden_layers
    )
    y_hat = ContextualLasso.predict(fit, ones(n_test, 1), hcat(x_test, z_test))
    beta_hat = missing
    evaluate!(result, "Deep neural network", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate contextual linear model
    fit = ContextualLasso.classo(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        lambda = Inf, verbose = false, 
        hidden_layers = hidden_layers
    )
    y_hat = ContextualLasso.predict(fit, x_test, z_test)
    beta_hat = ContextualLasso.coef(fit, z_test)
    evaluate!(result, "Contextual linear model", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate lasso
    fit = lasso(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        penalty_factor = vcat(ones(p), zeros(m)), type = "min"
    )
    beta_hat = repeat(fit[1:p + 1]', n_test)
    y_hat = hcat(ones(n_test), x_test, z_test) * fit
    evaluate!(result, "Lasso", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate lassonet
    beta_hat, y_hat = lassonet(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        hcat(x_test, z_test), 
        verbose = false, seed = par["id"], 
        hidden_layers = hidden_layers, type = "min"
    )
    beta_hat = repeat(vcat(0, beta_hat[1:p])', n_test)
    evaluate!(result, "Lassonet", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate LLSPIN
    beta_hat, y_hat = llspin(
        hcat(x_train, z_train), y_train, 
        hcat(x_valid, z_valid), y_valid, 
        hcat(x_test, z_test), 
        verbose = false, seed = par["id"], 
        hidden_layers = hidden_layers
    )
    beta_hat = hcat(zeros(n_test), beta_hat[:, 1:p])
    evaluate!(result, "LLSPIN", y_hat, beta_hat, y_test, y_train, par)

    # Evaluate contextual lasso
    fit = ContextualLasso.classo(
        x_train, z_train, y_train, 
        x_valid, z_valid, y_valid, 
        verbose = false, relax = true, 
        hidden_layers = hidden_layers
    )
    y_hat = ContextualLasso.predict(fit, x_test, z_test, lambda = "lambda_min", gamma = "gamma_min")
    beta_hat = ContextualLasso.coef(fit, z_test, lambda = "lambda_min", gamma = "gamma_min")
    evaluate!(result, "Contextual lasso", y_hat, beta_hat, y_test, y_train, par)

end

#==================================================================================================#
# Run simulations
#==================================================================================================#

# Specify simulation parameters
simulations = DataFrames.DataFrame(
    (n_train, n_valid, n_test, id = id) for
    n_train = 15000, # Portion of samples to use for training set
    n_valid = 15000, # Portion of samples to use for validation set
    n_test = 15000,
    id = 1:10 # Scenario ID
)

# Run all simulations
# CUDA.jl is not reproducible with default rng
rng = Random.MersenneTwister(2023); Random.default_rng() = rng # Random.seed!(2023)
result = ProgressMeter.@showprogress map(runsim, eachrow(simulations))
result = reduce(vcat, result)
CSV.write("Results/house.csv", result)