using Serialization, GLMakie, ColorSchemes # CairoMakie 
GLMakie.activate!()
GLMakie.closeall() # close any open screen

include("Data.jl"); using .Data; include("Pds.jl"); using .Pds
const data = scurve
const savepath = mkpath(joinpath(@__DIR__, "save", string(data)))
const figpath = mkpath(joinpath(@__DIR__, "fig", string(data)))

const labels = [L"\eta=2.0", L"\eta=1.0", L"\eta=0.5", L"\eta=0.1", L"\eta=0.01"]
const suffixes = ["2_2", "1_1", "0.5_0.5", "0.1_0.1", "1e-2_1e-2"]
const suffix = "1_5e-2"

const md_trial = 1
const ml_trial = 1
const md_trial_str = lpad(md_trial, 2, '0')
const ml_trial_str = lpad(ml_trial, 2, '0')
const md_idx = 301 # 301
const ml_idx = 301 # 301

const gt = deserialize(joinpath(savepath, "gt.jls"))
const mls = deserialize(joinpath(savepath, "vml_$(suffix)_$(ml_trial_str).jls"))
const mds = deserialize(joinpath(savepath, "vmd_$(suffix)_$(md_trial_str).jls"))

const datalim = 3.5
const plotlim = 3.5
const ticklim = 3.5
const tickitvl = 1.75
const num = 601 # 401
const α = 0.2
# :coolwarm
const cmap = cgrad(:nuuk; scale = :log) # :ice # :inferno # :ice # :grays # :darkrainbow # :heat # :blues # :amp # :winter # :tokyo # :imola # :broc # :berlin # :batlowK # :bamako # # parula # :cividis # parula # :coolwarm # :jet # :plasma # :inferno # :cool # :Blues # :viridis

const x = LinRange(-datalim, datalim, num)
const y = LinRange(-datalim, datalim, num)
const xy = copy([repeat(x, outer=length(y)) repeat(y, inner=length(x))]')
const z_gt = reshape(exp.(Pds.logp_each(gt, xy)), num, num)
const (z_gt_min, z_gt_max) = minimum(z_gt), maximum(z_gt)
const crop = "256x256+256+556"
const gw = 0
const size_small = (800, 800)

f = Figure(size = (1000, 800), fontsize = 22)

const
ax = 
   Axis3(f[1:2, 1]; 
      aspect = (1.0, 1.0, 0.1), perspectiveness = 0.5,
      elevation = π / 3, azimuth= -2π / 5, # -π / 4,
      viewmode = :stretch,
      # xzpanelcolor = :black, yzpanelcolor = :black,
      # xzpanelcolor = :black, yzpanelcolor = :black,
      protrusions = (0, 0, 0, 0),
      xzpanelvisible=false, 
      yzpanelvisible=false,
      xticks=-ticklim:tickitvl:ticklim,
      yticks=-ticklim:tickitvl:ticklim,
      zticks=WilkinsonTicks(4; k_min = 2),
      xlabelvisible=false,
      xticksvisible=false,
      xticklabelsvisible=false,
      ylabelvisible=false,
      yticksvisible=false,
      yticklabelsvisible=false,
      zlabelvisible=false,
      zticksvisible=false,
      zticklabelsvisible=false,
      zgridcolor=:white, 
      ygridcolor=:white, xgridcolor=:white,
      limits=((-plotlim, plotlim), (-plotlim, plotlim), (z_gt_min, z_gt_max)),
      xgridwidth=4, ygridwidth=4, zgridwidth=4,
      xspinewidth=0, yspinewidth=0, zspinewidth=0)

contour!(ax, x, y, z_gt;
   levels=50, colormap=cmap, linewidth=2,
   colorrange=(z_gt_min, z_gt_max), transformation=(:xy, 0), transparency=true)

surface!(ax, x, y, z_gt;
   colormap = cmap, colorrange = (z_gt_min, z_gt_max), transparency = true)

sample = data(256)
# scatter!(ax, sample[1,:], sample[2,:], 0.02 .* ones(size(sample, 2)), color=ColorSchemes.cool[end])

let filename = joinpath(figpath, "gt.png")
   save(filename, f)
   run(`convert $filename -trim $filename`)
end
# save(joinpath(figpath, "gt.png"), f)


# foreach([1, 12, 76, 151, 226, 301]) do step 
foreach([301]) do step 
   datalim = 3.2 #.2
   plotlim = 3.2 #.2
   ticklim = 3.2 #.2
   tickitvl = 1.6
   x = LinRange(-datalim, datalim, num)
   y = LinRange(-datalim, datalim, num)
   xy = copy([repeat(x, outer=length(y)) repeat(y, inner=length(x))]')

   step_str = lpad(step, 3, '0')
   z_ml = reshape(exp.(Pds.logp_each(mls[step], xy)), num, num)
   f1 = Figure(size = size_small, fontsize = 22)
   ax1 = Axis3(f1[1, 1]; 
      aspect = (1.0, 1.0, 0.08), perspectiveness = 0.2,
      elevation = π / 4, azimuth= -π / 2, # -π / 4,
      # viewmode = :stretch,
      # xzpanelcolor = :black, yzpanelcolor = :black,
      # xzpanelcolor = :black, yzpanelcolor = :black,
      protrusions = (0, 0, 0, 0),
      xzpanelvisible=false, 
      yzpanelvisible=false,
      xticks=-ticklim:tickitvl:ticklim,
      yticks=-ticklim:tickitvl:ticklim,
      zticks=WilkinsonTicks(4; k_min = 2),
      xlabelvisible=false,
      xticksvisible=false,
      xticklabelsvisible=false,
      ylabelvisible=false,
      yticksvisible=false,
      yticklabelsvisible=false,
      zlabelvisible=false,
      zticksvisible=false,
      zticklabelsvisible=false,
      zgridcolor=:black, 
      ygridcolor=:black, xgridcolor=:black,
      limits=((-plotlim, plotlim), (-plotlim, plotlim), (z_gt_min, z_gt_max)),
      xgridwidth=gw, ygridwidth=gw, zgridwidth=gw,
      xspinewidth=0, yspinewidth=0, zspinewidth=0)


   f1_highres = Figure(size = (1024, 1024), fontsize = 22)
   ax1_highres = Axis3(f1_highres[1, 1]; 
      aspect = (1.0, 1.0, 0.08), perspectiveness = 0.2,
      elevation = π / 4, azimuth= -π / 2, # -π / 4,
      # viewmode = :stretch,
      # xzpanelcolor = :black, yzpanelcolor = :black,
      # xzpanelcolor = :black, yzpanelcolor = :black,
      protrusions = (0, 0, 0, 0),
      xzpanelvisible=false, 
      yzpanelvisible=false,
      xticks=-ticklim:tickitvl:ticklim,
      yticks=-ticklim:tickitvl:ticklim,
      zticks=WilkinsonTicks(4; k_min = 2),
      xlabelvisible=false,
      xticksvisible=false,
      xticklabelsvisible=false,
      ylabelvisible=false,
      yticksvisible=false,
      yticklabelsvisible=false,
      zlabelvisible=false,
      zticksvisible=false,
      zticklabelsvisible=false,
      zgridcolor=:black, 
      ygridcolor=:black, xgridcolor=:black,
      limits=((-plotlim, plotlim), (-plotlim, plotlim), (z_gt_min, z_gt_max)),
      xgridwidth=gw, ygridwidth=gw, zgridwidth=gw,
      xspinewidth=0, yspinewidth=0, zspinewidth=0)

      
    # contour!(ax1, x, y, z_ml;
    #     levels=50, colormap=cmap, linewidth=2,
    #     colorrange=(z_gt_min, z_gt_max), transformation=(:xy, 0), transparency=true)
    surface!(ax1, x, y, z_ml;
       colormap = cmap, colorrange = (z_gt_min, z_gt_max)) #, transparency = true)    
    # surface!(ax1, x, y, z_ml .- z_gt;
    #     colormap = cmap) # , colorrange = (z_gt_min, z_gt_max), transparency = true)
    surface!(ax1_highres, x, y, z_ml;
       colormap = cmap, colorrange = (z_gt_min, z_gt_max)) #, transparency = true)    
    let filename = joinpath(figpath, "ml_$step_str.png")
       save(filename, f1)
       run(`convert $filename -trim $filename`)
    end
    let filename = joinpath(figpath, "ml_highres_$step_str.png")
       save(filename, f1_highres)
       # run(`convert $filename -crop $crop $filename`)
       run(`convert $filename -trim $filename`)
    end
    # save(joinpath(figpath, "ml_$step_str.png"), f1)


    z_md = reshape(exp.(Pds.logp_each(mds[step], xy)), num, num)
    f2 = Figure(size = size_small, fontsize = 22)
    ax2 = Axis3(f2[1, 1]; 
       aspect = (1.0, 1.0, 0.08), perspectiveness = 0.2,
       elevation = π / 4, azimuth= -π / 2, # -π / 4,
       # viewmode = :stretch,
       # xzpanelcolor = :black, yzpanelcolor = :black,
       # xzpanelcolor = :black, yzpanelcolor = :black,
       protrusions = (0, 0, 0, 0),
       xzpanelvisible=false, 
       yzpanelvisible=false,
       xticks=-ticklim:tickitvl:ticklim,
       yticks=-ticklim:tickitvl:ticklim,
       zticks=WilkinsonTicks(4; k_min = 2),
       xlabelvisible=false,
       xticksvisible=false,
       xticklabelsvisible=false,
       ylabelvisible=false,
       yticksvisible=false,
       yticklabelsvisible=false,
       zlabelvisible=false,
       zticksvisible=false,
       zticklabelsvisible=false,
       zgridcolor=:black, 
       ygridcolor=:black, xgridcolor=:black,
       limits=((-plotlim, plotlim), (-plotlim, plotlim), (z_gt_min, z_gt_max)),
       xgridwidth=gw, ygridwidth=gw, zgridwidth=gw,
       xspinewidth=0, yspinewidth=0, zspinewidth=0)

    f2_highres = Figure(size = (1024, 1024), fontsize = 22)
    ax2_highres = Axis3(f2_highres[1, 1]; 
       aspect = (1.0, 1.0, 0.08), perspectiveness = 0.2,
       elevation = π / 4, azimuth= -π / 2, # -π / 4,
       # viewmode = :stretch,
       # xzpanelcolor = :black, yzpanelcolor = :black,
       # xzpanelcolor = :black, yzpanelcolor = :black,
       protrusions = (0, 0, 0, 0),
       xzpanelvisible=false, 
       yzpanelvisible=false,
       xticks=-ticklim:tickitvl:ticklim,
       yticks=-ticklim:tickitvl:ticklim,
       zticks=WilkinsonTicks(4; k_min = 2),
       xlabelvisible=false,
       xticksvisible=false,
       xticklabelsvisible=false,
       ylabelvisible=false,
       yticksvisible=false,
       yticklabelsvisible=false,
       zlabelvisible=false,
       zticksvisible=false,
       zticklabelsvisible=false,
       zgridcolor=:black, 
       ygridcolor=:black, xgridcolor=:black,
       limits=((-plotlim, plotlim), (-plotlim, plotlim), (z_gt_min, z_gt_max)),
       xgridwidth=gw, ygridwidth=gw, zgridwidth=gw,
       xspinewidth=0, yspinewidth=0, zspinewidth=0)

    surface!(ax2, x, y, z_md;
       colormap = cmap, colorrange = (z_gt_min, z_gt_max)) #, transparency = true)    

    surface!(ax2_highres, x, y, z_md;
       colormap = cmap, colorrange = (z_gt_min, z_gt_max)) #, transparency = true)    
    let filename = joinpath(figpath, "md_$step_str.png")
       save(filename, f2)
       run(`convert $filename -trim $filename`)
    end
    let filename = joinpath(figpath, "md_highres_$step_str.png")
       save(filename, f2_highres)
       # run(`convert $filename -crop $crop $filename`)
       run(`convert $filename -trim $filename`)
    end
    # save(joinpath(figpath, "md_$step_str.png"), f2)
end


    



