# =================================================================================================#
# Description: Implementation of lassonet
# Author: Ryan Thompson
# =================================================================================================#

# ENV["PYTHON"]="/home/ryan/.julia/conda/3/x86_64/bin"
import Flux, PyCall, Statistics

function lassonet(x, y, x_val, y_val, x_test; patience = 30, hidden_layers = [16, 16, 16],
    verbose = true, standardise_x = true, standardise_y = true, type = "1se", loss = Flux.mse, 
    seed = 0, group = nothing)

    hidden_layers = tuple(hidden_layers...)

    # Lassonet expects probabilites not logits
    if loss == Flux.logitbinarycrossentropy
        loss = Flux.binarycrossentropy
    end

    # Standardise features
    if standardise_x
        x_mean = mapslices(Statistics.mean, x, dims = 1)
        x_sd = mapslices(x -> Statistics.std(x, corrected = false), x, dims = 1)
    end
    if any(x_sd .== 0)
        x_sd[x_sd .== 0] .= 1 # Handles constants
    end
    x = (x .- x_mean) ./ x_sd
    x_val = (x_val .- x_mean) ./ x_sd
    x_test = (x_test .- x_mean) ./ x_sd

    # Standardise response
    if standardise_y && loss == Flux.mse
        y_mean = Statistics.mean(y)
        y_sd = Statistics.std(y, corrected = false)
    else
        y_mean = 0
        y_sd = 1
    end
    if y_sd == 0
        y_sd = 1 # Handles constants
    end
    y = (y .- y_mean) ./ y_sd
    y_val = (y_val .- y_mean) ./ y_sd

    # Fit lassonet path
    lnet = PyCall.pyimport("lassonet")
    if loss == Flux.mse
        model = lnet.LassoNetRegressor(patience = patience, hidden_dims = hidden_layers, 
            verbose = verbose, torch_seed = seed, groups = group)
    elseif loss == Flux.binarycrossentropy
        model = lnet.LassoNetClassifier(patience = patience, hidden_dims = hidden_layers, 
            verbose = verbose, torch_seed = seed, groups = group)
    end
    if !verbose
        oldstd = stdout
        redirect_stdout(devnull)
    end
    path = model.path(x, y, X_val = x_val, y_val = y_val)
    if !verbose 
        redirect_stdout(oldstd)
    end

    # Extract best model
    if loss == Flux.mse
        # val_loss = map(x -> x.val_loss, path)
        val_loss = map(path_i -> loss(model.load(path_i).predict(x_val), y_val, 
            agg = Statistics.mean), path)
        val_loss_se = map(path_i -> loss(model.load(path_i).predict(x_val), y_val, 
            agg = x -> Statistics.std(x, corrected = false)) / sqrt(length(x)), path)
    elseif loss == Flux.binarycrossentropy
        val_loss = map(path_i -> loss(model.load(path_i).predict_proba(x_val)[:, 2], y_val, 
            agg = Statistics.mean), path)
        val_loss_se = map(path_i -> loss(model.load(path_i).predict_proba(x_val)[:, 2], y_val, 
            agg = x -> Statistics.std(x, corrected = false)) / sqrt(length(x)), path)
    end
    lambda = map(path_i -> path_i.lambda_, path)
    index_min = argmin(val_loss)
    index_1se = findall(val_loss .<= val_loss[index_min] + val_loss_se[index_min])
    index_1se = sort(index_1se, by = i -> - lambda[i])[1]
    if type == "min"
        best_path = path[index_min]
    elseif type == "1se"
        best_path = path[index_1se]
    end
    best_model = model.load(best_path)

    # Return coefficients and predictions
    beta_hat = best_path.selected[:numpy]()
    if loss == Flux.mse
        y_hat = best_model.predict(x_test)[:, 1] .* y_sd .+ y_mean
    else loss == Flux.binarycrossentropy
        y_hat = best_model.predict_proba(x_test)[:, 2]
        y_hat = log.(y_hat ./ (1 .- y_hat)) # We expect logits
    end
    beta_hat, y_hat

end