# =================================================================================================#
# Description: Implementation of the (standard) relaxed lasso
# Author: Ryan Thompson
# =================================================================================================#

import Distributions, Flux, GLMNet, Statistics

function lasso(x_train, y_train, x_valid, y_valid; intercept = true, loss = Flux.mse, 
    penalty_factor = ones(size(x_train, 2)), gamma = range(0, 1, 11), nlambda = 50, type = "1se")

    # Determine loss function
    if loss == Flux.mse
        dist = Distributions.Normal()
    elseif loss == Flux.logitbinarycrossentropy
        dist = Distributions.Binomial()
        y_train = hcat(1 .- y_train, y_train)
    end

    # Train model
    fit = GLMNet.glmnet(x_train, y_train, dist, intercept = intercept, nlambda = nlambda, 
    penalty_factor = penalty_factor)
    if intercept
        beta = vcat(fit.a0', fit.betas)
    else
        beta = fit.betas
    end
    lambda = fit.lambda

    # Polish model
    function polish_fit(beta)
        p = length(beta)
        beta_polish = zeros(p)
        active = beta[1 + intercept:end] .!= 0
        if sum(active) > 0
            fit = GLMNet.glmnet(x_train[:, active], y_train, dist, intercept = intercept, 
            lambda = [0.0], maxit = 10000)
            if length(fit.betas) == sum(active) # In case doesn't converge
                if intercept
                    beta_polish[1] = fit.a0[1]
                    beta_polish[findall(active) .+ 1] = fit.betas
                else
                    beta_polish[active] = fit.betas
                end
            end
        elseif intercept
            beta_polish[1] = GLMNet.glmnet(x_train, y_train, dist, lambda = [Inf]).a0[1]
        end
        beta_polish
    end
    beta_polish = hcat(map(polish_fit, eachcol(beta))...)

    # Validate models
    function val(gamma; agg = Statistics.mean)
        beta_relax = (1 - gamma) * beta + gamma * beta_polish
        if intercept
            y_hat = map(beta -> beta[1] .+ x_valid * beta[2:end], eachcol(beta_relax))
        else
            y_hat = map(beta -> x_valid * beta, eachcol(beta_relax))
        end
        map(y_hat -> loss(y_hat, y_valid, agg = agg), y_hat)
    end
    val_loss = hcat(map(val, gamma)...)
    val_loss_se = hcat(map(gamma -> val(gamma, agg = x -> Statistics.std(x, corrected = false) / 
        sqrt(length(x))), gamma)...)

    # Return best model
    index_min = argmin(val_loss)
    index_1se = findall(val_loss .<= val_loss[index_min] + val_loss_se[index_min])
    index_1se = sort(index_1se, by = i -> (- lambda[i[1]], gamma[i[2]]))[1]
    if type == "min"
        (1 - gamma[index_min[2]]) * beta[:, index_min[1]] + gamma[index_min[2]] * 
            beta_polish[:, index_min[1]]
    elseif type == "1se"
        (1 - gamma[index_1se[2]]) * beta[:, index_1se[1]] + gamma[index_1se[2]] * 
            beta_polish[:, index_1se[1]]
    end

end