using Plots;pythonplot()
using DataFrames, CSV
using LaTeXStrings
using Statistics
using ProgressBars

# myplot() is a function to plot the results of simulated data 
function myplot(attr, ver, name, dis, 
                rst_A, rst_L, rst_pA,
                rst_uni, rst_adp, rst_las)
    rho = Vector(rst_pA[rst_pA.gamma .== 1, :].rho)
    if attr == "bmse"
      ylab = "log(eMSE)"
      A = log.(Vector(rst_A[rst_A.gamma .== 1, attr]))
      L = log.(Vector(rst_L[rst_L.gamma .== 1, attr]))
      pA = log.(Vector(rst_pA[rst_pA.gamma .== 1, attr]))
      #spA = log.(Vector(rst_spA[rst_spA.gamma .== 1, attr]))
      uni = log.(Vector(rst_uni[rst_uni.gamma .== 1, attr]))
      adp = log.(rst_adp[1, attr] .* ones(length(rho)))
      las = log.(rst_las[1, attr] .* ones(length(rho)))
    elseif attr == "pmse"
      ylab = "log(eMSPE)"
      A = log.(Vector(rst_A[rst_A.gamma .== 1, attr]))
      L = log.(Vector(rst_L[rst_L.gamma .== 1, attr]))
      pA = log.(Vector(rst_pA[rst_pA.gamma .== 1, attr]))
      #spA = log.(Vector(rst_spA[rst_spA.gamma .== 1, attr]))
      uni = log.(Vector(rst_uni[rst_uni.gamma .== 1, attr]))
      adp = log.(rst_adp[1, attr] .* ones(length(rho)))
      las = log.(rst_las[1, attr] .* ones(length(rho)))
    elseif attr == "auc"
      ylab = "AUC"
      A = Vector(rst_A[rst_A.gamma .== 1, attr])
      L = Vector(rst_L[rst_L.gamma .== 1, attr])
      pA = Vector(rst_pA[rst_pA.gamma .== 1, attr])
      #spA = log.(Vector(rst_spA[rst_spA.gamma .== 1, attr]))
      uni = Vector(rst_uni[rst_uni.gamma .== 1, attr])
      adp = rst_adp[1, attr] .* ones(length(rho))
      las = rst_lass[1, attr] .* ones(length(rho))
    end 
    p = plot(rho,
             [A, L, pA, uni, adp, las], 
             #[pA, spA],
             markershape = [:circle :rect :star5 :diamond :hexagon :star4 :octagon], 
             lw=4, markersize = 18,#m=(9, :auto),
             tickfontsize=16, xguidefontsize=18, yguidefontsize=18,
             legendfontsize=14, grid=false, thinkness_scaling=1,
             ls = [:solid :dash :dashdot],
             lab = ["A-OS" "L-OS" "P-OS" "Uni" "A-lasso (full)" "Lasso (full)"],
             #lab = ["P-OS" "sP-OS"],
             xlabel = "sampling rate", ylabel = ylab, #title = "γ=1",
             legend=:topright,
             #legend=false,
             sizes = (600,600))
    path = string(dir, "/plots/", attr, "-", ver, "-", dis, "-", name, ".pdf")
    savefig(p, path)
end

# plotreal() is a function to plot the results of real data
function plotreal(attr, A, L, pA, Un, ylab, file)
  p = plot(rhos,
              [A, L, pA, Un],
              markershape = [:circle :rect :star5 :diamond :hexagon :star4], 
              lw=4, markersize = 15,#m=(9, :auto),
              tickfontsize=16, xguidefontsize=18, yguidefontsize=18,
              legendfontsize=14, grid=false, thinkness_scaling=1,
              ls = [:solid :dash :dashdot],
              lab = ["A-OS" "L-OS" "P-OS" "Uni"],
              xlabel = "sampling rate", ylabel = ylab, #title = "γ=1",
              legend=:topright,
              #legend=:false,
              sizes = (600,600))
      path = string(dir, "/plots/", attr, "-", ver, "-", file, ".pdf")
      savefig(p, path)
end

# The following codes draw plots of simulated data and generate
# tables of results of variable selection 
dir = "../Lasso/results/"
spldir = string(dir, "simulation/subsampling/")
unidir = string(dir, "simulation/uniform/")
adpdir = string(dir, "simulation/fulladapt/")
lasdir = string(dir, "simulation/fulllasso/")
ver = "1"
dis = "nor"
names = ["A", "B", "C"]
# Load data
ttable = DataFrame(case=names,
                   unitime=zeros(5),
                   Atime=zeros(5),Ltime=zeros(5),Ptime=zeros(5),
                   adptime=zeros(5),lastime=zeros(5))

for i in 1:5
  name=names[i]
  println(name)
  file_A = string("Subsampling-", ver, "-", dis, "-", name, "-A.csv")
  file_L = string("Subsampling-", ver, "-", dis, "-", name, "-L.csv")
  file_pA = string("Subsampling-", ver, "-", dis, "-", name, "-pA.csv")
  file_uni = string("Uniform-", ver, "-", dis, "-", name, "-pA.csv")
  file_adp = string("full-adp-", ver, "-", dis, "-", name, "-pA.csv")
  file_las = string("full-las-", ver, "-", dis, "-", name, "-pA.csv")

  rst_A = CSV.read(string(spldir, file_A), DataFrame)
  rst_L = CSV.read(string(spldir, file_L), DataFrame)
  rst_pA = CSV.read(string(spldir, file_pA), DataFrame)
  rst_uni = CSV.read(string(unidir, file_uni), DataFrame)
  rst_adp = CSV.read(string(adpdir, file_adp), DataFrame)
  rst_las = CSV.read(string(lasdir, file_las), DataFrame)

  ttable.Atime[i] = rst_A.ttotal[2] / 499
  ttable.Ltime[i] = rst_L.ttotal[2] / 499
  ttable.Ptime[i] = rst_pA.ttotal[2] / 499
  ttable.unitime[i] = rst_uni.ttotal[2] / 499
  ttable.adptime[i] = rst_adp.ttotal[1] / 499
  ttable.lastime[i] = rst_las.ttotal[1] / 499

  covdf = DataFrame(rho=rst_A.rho, uni=rst_uni.covr,
                    A=rst_A.covr, L=rst_L.covr, P=rst_pA.covr)
  covdf[:,2:5] = round.(covdf[:,2:5],digits=3)
  ovcovdf = DataFrame(rho=rst_A.rho, uni=rst_uni.ovcor,
                      A=rst_A.ovcor, L=rst_L.ovcor, P=rst_pA.ovcor)
  ovcovdf[:,2:5] = round.(ovcovdf[:,2:5],digits=3)
  nvdf = DataFrame(rho=rst_A.rho, fs=rst_A.mfvnum,
                   uni=rst_uni.msvnum, 
                   A=rst_A.msvnum, L=rst_L.msvnum, P=rst_pA.msvnum)
  nvdf[:,2:5] = round.(nvdf[:,2:5],digits=2)
  hds1 = ["rho" "Uni" "A-OS" "L-OS" "P-OS"]
  #hds1 = ["rho" "A-OS" "L-OS" "P-OS"]
  header1 = vec(Symbol.(hds1))
  pretty_table(covdf, header=header1)
  pretty_table(ovcovdf, header=header1)
  hds2 = ["rho" "first-stage" "Uni" "A-OS" "L-OS" "P-OS"]
  #hds2 = ["rho" "first-stage" "A-OS" "L-OS" "P-OS"]
  header2 = vec(Symbol.(hds2))
  pretty_table(nvdf, header=header2)
  
  tx = pretty_table(String, covdf, header=header1, backend=Val(:latex), tf=tf_latex_booktabs)
  #tx = pretty_table(String, df, backend=Val(:latex))
  io = open(string(dir, "tex/cov-",name,ver,".tex"), "w")
  print(io, tx)
  close(io)
  
  tx = pretty_table(String, ovcovdf, header=header1, backend=Val(:latex), tf=tf_latex_booktabs)
  #tx = pretty_table(String, df, backend=Val(:latex))
  io = open(string(dir, "tex/ovcov-",name,ver,".tex"), "w")
  print(io, tx)
  close(io)

  tx = pretty_table(String, nvdf, header=header2, backend=Val(:latex), tf=tf_latex_booktabs)
  #tx = pretty_table(String, df, backend=Val(:latex))
  io = open(string(dir, "tex/nv-",name,ver,".tex"), "w")
  print(io, tx)
  close(io)
end

ttable[:,2:7] = round.(ttable[:,2:7],digits=2)

using PrettyTables
hds = ["Case" "Uni" "A-OS" "L-OS" "P-OS" "A-lasso (full)" "Lasso (full)"]
header = vec(Symbol.(hds))
pretty_table(ttable, header=header)
tx = pretty_table(String, ttable, header=header, backend=Val(:latex), tf=tf_latex_booktabs)
#tx = pretty_table(String, df, backend=Val(:latex))
io = open(string(dir, "tex/time",ver,".tex"), "w")
print(io, tx)
close(io)

# The following codes draw plots of real data experiments
# Parameter "file" can be set to "covtype" and "font" for results
# of covtype data and font data in Section 5
dir = "results/"
ver = "1"
file = "font"
rhos = [0.003, 0.005, 0.007, 0.01, 0.015]
rr = Int(length(rhos))
attrs = ["auc"]

NanA = zeros(rr)
NanL = zeros(rr)
NanpA = zeros(rr)
NanUn = zeros(rr)

for attr in attrs
  A = zeros(rr)
  L = zeros(rr)
  pA = zeros(rr)
  Un = zeros(rr)
  ylab = "ylab"
  #uni = zeros(3)
  for i in ProgressBar(1:rr)
      rho = rhos[i]
      ylab = "AUC"
      Adf = CSV.read(string(dir, "realdata/Subsampling-", attr, "-", file, "-",ver,"-A-", rho, ".csv"), DataFrame)
      Ldf = CSV.read(string(dir, "realdata/Subsampling-", attr, "-", file, "-",ver,"-L-", rho, ".csv"), DataFrame)
      pAdf = CSV.read(string(dir, "realdata/Subsampling-", attr, "-", file, "-",ver,"-pA-", rho, ".csv"), DataFrame)
      Undf = CSV.read(string(dir, "realdata/Uniform-", attr, "-", file, "-",ver, rho, ".csv"), DataFrame)
      Adf = Adf[any.(!isnan, eachrow(Adf)), :]
      Ldf = Ldf[any.(!isnan, eachrow(Ldf)), :]
      pAdf = pAdf[any.(!isnan, eachrow(pAdf)), :]
      Undf = Undf[any.(!isnan, eachrow(Undf)), :]
      Aest = Matrix(Adf)
      Lest = Matrix(Ldf)
      pAest = Matrix(pAdf)
      Unest = Matrix(Undf)
      #uniest = Matrix(unidf)

      NanA[i] = size(Aest, 1)
      NanL[i] = size(Lest, 1)
      NanpA[i] = size(pAest, 1)
      NanUn[i] = size(Unest, 1)
      
      A[i] = median(vec(Aest))
      L[i] = median(vec(Lest))
      pA[i] = median(vec(pAest))
      Un[i] = median(vec(Unest))
  end
  plotreal(attr, A, L, pA, Un, ylab, file)
end



