# =================================================================================================#
# Description: Implementation of the (standard) relaxed group lasso
# Author: Ryan Thompson
# =================================================================================================#

import Flux, GLMNet, RCall, Statistics

function glasso(x_train, y_train, x_valid, y_valid; group = collect(1:size(x_train, 2)), 
    gamma = range(0, 1, 11), nlambda = 50, type = "1se")

    # Train model
    fit = RCall.R"grpreg::grpreg($x_train, $y_train, $group, nlambda = $nlambda)"
    beta = RCall.rcopy(RCall.R"$fit$beta")
    lambda = RCall.rcopy(RCall.R"$fit$lambda")  

    # Polish model
    function polish_fit(beta)
        p = length(beta)
        beta_polish = zeros(p)
        active = beta[2:end] .!= 0
        if sum(active) > 0
            fit = GLMNet.glmnet(x_train[:, active], y_train, lambda = [0.0], maxit = 10000)
            if length(fit.betas) == sum(active) # In case doesn't converge
                beta_polish[1] = fit.a0[1]
                beta_polish[findall(active) .+ 1] = fit.betas
            end
        else
            beta_polish[1] = GLMNet.glmnet(x_train, y_train, lambda = [Inf]).a0[1]
        end
        beta_polish
    end
    beta_polish = hcat(map(polish_fit, eachcol(beta))...)

    # Validate models
    function val(gamma; agg = Statistics.mean)
        beta_relax = (1 - gamma) * beta + gamma * beta_polish
        y_hat = map(beta -> beta[1] .+ x_valid * beta[2:end], eachcol(beta_relax))
        map(y_hat -> Flux.mse(y_hat, y_valid, agg = agg), y_hat)
    end
    val_loss = hcat(map(val, gamma)...)
    val_loss_se = hcat(map(gamma -> val(gamma, agg = x -> Statistics.std(x, corrected = false) / 
        sqrt(length(x))), gamma)...)

    # Return best model
    index_min = argmin(val_loss)
    index_1se = findall(val_loss .<= val_loss[index_min] + val_loss_se[index_min])
    index_1se = sort(index_1se, by = i -> (- lambda[i[1]], gamma[i[2]]))[1]
    if type == "min"
        (1 - gamma[index_min[2]]) * beta[:, index_min[1]] + gamma[index_min[2]] * 
            beta_polish[:, index_min[1]]
    elseif type == "1se"
        (1 - gamma[index_1se[2]]) * beta[:, index_1se[1]] + gamma[index_1se[2]] * 
            beta_polish[:, index_1se[1]]
    end

end