using Lasso, Distributions, LinearAlgebra, Random
using DataFrames, CSV
import GLM: glm, coeftable
BLAS.set_num_threads(1)


include("../main/datagen.jl")
include("../main/functions.jl")


case = haskey(ENV, "case") ? parse(Int, ENV["case"]) : 1


N = 500000

k = 50

S = 500

rhos = [0.0025, 0.005, 0.0075, 0.01]

lthr = 1e-8; gammas = [1]

dis = "nor"

ver = "1"
crtn = "P"

#Data generating
par = getcase(zeros(k), case)
name = par.name
true_idx = par.true_idx
beta0 = par.beta0
alpha0 = par.alpha0

Sigma = getSigma(par)

#full data adaptive lasso
Random.seed!(2)
aerr = zeros(S)
berr = zeros(S)
perr = zeros(S)
cover = zeros(S)
overcover = zeros(S)
aucs = zeros(S)
t_plts = zeros(S)
t_ests = zeros(S)
@time for i in 1:S
    X, y = gendat(N, alpha0, beta0, Sigma)
    Xstd = vec(std(X, dims=1))
    Xmean = vec(mean(X, dims=1))
    Xs = (X .- Xmean') ./ Xstd' 
    t_pl= @elapsed fullest = glm([ones(N) Xs], y, Binomial())
    betaMLE = DataFrame(coeftable(fullest))."Coef."[2:end] ./ Xstd
    alphaMLE = DataFrame(coeftable(fullest))."Coef."[1] - sum(Xmean .* betaMLE)
    adp_weight = 1 ./ abs.(betaMLE)
    lambda_max = maximum(abs.(vec(sum(y .* X, dims=1)))) / N
    t_es = @elapsed est = lassoBIC(X, y, penalty_factor = adp_weight,
                                   lambda_max = lambda_max,
                                   nlambda = 100, eps = 0.001, fld = 5,
                                   method = "bic")
    betaest = est.coef[2:end]
    alphaest = est.coef[1]
    t_plts[i] = t_pl
    t_ests[i] = t_es
    aerr[i] = (alphaest - alpha0[1])^2
    berr[i] = sum((betaest - beta0) .^ 2)
    ptrue = 1 .- 1 ./ (1 .+ exp.(alpha0[1] .+ X * beta0))
    pest = 1 .- 1 ./ (1 .+ exp.(alphaest .+ X * betaest))
    perr[i] = sum((ptrue .- pest) .^ 2) ./ N
    auc = AUC(y, 1 .- 1 ./ (1 .+ exp.(alphaest .+ X * betaest)), 2)
    aucs[i] = auc
    scr_idx = findall(!=(0), betaest)
    if issubset(true_idx, scr_idx)
        cover[i] = 1
        if length(true_idx) < length(scr_idx)
            overcover[i] = 1
        end
    end
end
t_plt = sum(t_plts[2:end])
t_est = sum(t_ests[2:end])
amse = median(aerr)
bmse = median(berr)
pmse = median(perr)
covr = mean(cover)
ovcor = mean(overcover)
mauc = mean(aucs)
rst = DataFrame(pmse = pmse, bmse = bmse, cover = covr, overcover = ovcor, auc = mauc, ttotal=t_plt+t_est)
mkpath("results/fulladapt")
path = string("../results/fulladapt/full-adp-", ver, "-", dis, "-", name, "-", crtn, ".csv")
CSV.write(path,rst)
