using Distributions, LinearAlgebra, Random
using DataFrames, CSV
using Lasso

BLAS.set_num_threads(1)

include("../main/datagen.jl")
include("../main/functions.jl")


case = haskey(ENV, "case") ? parse(Int, ENV["case"]) : 1


Np = 500; N = 500000

k = 50

S = 500

rhos = [0.0025, 0.005, 0.0075, 0.01]

lthr = 1e-8; gammas = [1]

dis = "nor"

ver = "1"
crtn = "P"

#Data generating
par = getcase(zeros(k), case)
name = par.name
true_idx = par.true_idx
beta0 = par.beta0
alpha0 = par.alpha0

Sigma = getSigma(par)

# Uniform experiments
Random.seed!(2)
rst = Matrix{Float64}(undef,0,16)
for gamma in gammas, rho in rhos
    aerr = zeros(S)
    berr = zeros(S)
    perr = zeros(S)
    cover = zeros(S)
    fcover = zeros(S)
    overcover = zeros(S)
    fvnum = zeros(S)
    svnum = zeros(S)
    aucs = zeros(S)
    lambdas = zeros(S)
    iter_ins = zeros(S)
    iter_outs = zeros(S)
    t_plts = zeros(S)
    t_ests = zeros(S)
    @time for i in 1:S
        X, y = gendat(N, alpha0, beta0, Sigma)
        t_pl = @elapsed plt = PilotEst2(X, y, Np, criterion = string(crtn,"-opt"),
                                        standardize=true)
        t_es = @elapsed est = UniEst(X, y, plt, gamma, rho, lthr,
                                     nlambda = 100, eps = 0.001,
                                     method = "bic", standardize=false)
        t_plts[i] = t_pl
        t_ests[i] = t_es
        est.adpbetas .= est.adpbetas
        aerr[i] =  (est.alpha - alpha0[1])^2
        berr[i] = sum((est.adpbetas - beta0) .^ 2)
        ptrue = 1 .- 1 ./ (1 .+ exp.(alpha0[1] .+ X * beta0))
        pest = 1 .- 1 ./ (1 .+ exp.(est.alpha .+ X * est.adpbetas))
        perr[i] = sum((ptrue .- pest) .^ 2) ./ N
        fvnum[i] = length(est.fscr_idx)
        svnum[i] = length(est.sscr_idx)
        aucs[i] = est.auc
        lambdas[i] = est.lambda[1]
        if issubset(true_idx, est.fscr_idx)
            fcover[i] = 1
        end
        if issubset(true_idx, est.sscr_idx)
            cover[i] = 1
            if length(true_idx) < length(est.sscr_idx)
                overcover[i] = 1
            end
        end
    end
    t_plt = sum(t_plts[2:end])
    t_est = sum(t_ests[2:end])
    amse = median(aerr)
    bmse = median(berr)
    pmse = median(perr)
    covr = 1 - mean(cover)
    vcovr = std(1 .- cover)/sqrt(S)
    fcovr = mean(fcover)
    ovcor = mean(cover) - mean(overcover)
    vovcor = std(cover .- overcover)/sqrt(S)
    mfvnum = mean(fvnum)
    vfvnum = std(fvnum)/sqrt(S)
    msvnum = mean(svnum)
    vsvnum = std(svnum)/sqrt(S)
    auc = median(aucs)
    lmbdbst = mean(lambdas)
    global rst = [rst; [gamma rho bmse pmse covr vcovr ovcor vovcor mfvnum vfvnum msvnum vsvnum auc lmbdbst t_est t_plt+t_est]]
end
rst = DataFrame(rst, [:gamma, :rho, :bmse, :pmse, :covr, :vcovr, :ovcor, :vovcor, :mfvnum, :vfvnum, :msvnum, :vsvnum, :auc, :lambda, :tmest, :ttotal])
mkpath("results/uniform")
path = string("results/uniform/Uniform-", ver, "-", dis, "-", name, "-", crtn, ".csv")
CSV.write(path, rst)
