using Base: _ensure_array
using Serialization, CairoMakie, Random, ColorSchemes

Random.seed!(0)

include("Data.jl"); using .Data; include("Pds.jl"); using .Pds

const labels = [L"\eta=2.0", L"\eta=1.0", L"\eta=0.5", L"\eta=0.1", L"\eta=0.01", L"\eta=1.0\to0.05"]
const suffixes = ["2_2", "1_1", "0.5_0.5", "0.1_0.1", "1e-2_1e-2", "1_5e-2"]
const markers = [:circle, :rect, :diamond, :cross, :xcross, :star5]
const colors = [collect(cgrad(:Blues, length(suffixes)+1; categorical=true))[3:end]..., RGBf(246/255, 175/255, 49/255)] # ColorSchemes.tab10[5]] # RGBAf(0, 0, 0, 1)]


const maxiter = 300
const itvl = 20
const num = maxiter ÷ itvl + 1
const nt = 5 # 2
const sample_size = 100_000 # 4000
const kl_color = let
   red_tab10 = ColorSchemes.tab10[4]
   RGBAf(red_tab10.r, red_tab10.g, red_tab10.b, 0.85)
end

data_tup = (scurve, moons, swiss)
titles = ("S-curve", "Moons", "Swiss Roll")

fig = Figure(size = (1210,560))

g = GridLayout()
fig.layout[1, 1] = g
axes = []
for (i, (data, title)) in enumerate(zip(data_tup, titles))
   println("[$data]")
   savepath = mkpath(joinpath(@__DIR__, "save", string(data)))
   figpath = mkpath(joinpath(@__DIR__, "fig", string(data)))

   gt = deserialize(joinpath(savepath, "gt.jls"))
   cost(m) = Pds.kl_mc(m, gt; n=sample_size)
   
   ax = Axis(g[1,i]; title=title, titlesize=36, #24titlesize=32, 
      xticks = (0:100:maxiter, ["0", "10K", "20K", "30K"]),
      xtickwidth=3, ytickwidth=3, aspect = nothing, yscale=log10, limits=((0, 310), nothing), 
      xgridvisible=false, ygridvisible=false, spinewidth=3)
   push!(axes, ax)

   x = LinRange(0, maxiter, num)
   zz = zeros(num)
   for (label, suffix, marker, color) ∈ zip(labels, suffixes, markers, colors)
      println(label)
      y = zeros(num)
      z = zeros(num)
      for trial ∈ 1:nt # 5
         trial_str = lpad(trial, 2, '0')
         println("[Trial $trial_str]")
         mds = deserialize(joinpath(savepath, "vmd_$(suffix)_$(trial_str).jls"))
         mls = deserialize(joinpath(savepath, "vml_$(suffix)_$(trial_str).jls"))
         for (i, m) ∈ enumerate(mds[1:itvl:maxiter+1])
            y[i] += cost(m)
         end
         for (i, m) ∈ enumerate(mls[1:itvl:maxiter+1])
            z[i] += cost(m)
         end
      end
      y ./= nt
      z ./= nt
      scatterlines!(x, y; color=color, marker=marker, label=label, linewidth=6, markersize=24, strokewidth=3)
      # 6 24
      zz .+= z
   end
   zz ./= length(suffixes)

   lines!(x[2:end], zz[2:end]; color=kl_color, label="Direct KL", linewidth=12, linestyle=Linestyle([0.5, 1.0, 1.5, 2.5]))
   # axislegend(ax, merge = true, unique = true)
   # let filename = joinpath(figpath, "plot.pdf")
   #    save(filename, fig)
   # end
end

# println(g[1,1] |> typeof)

Legend(fig[2,1], axes[1],  orientation = :horizontal, labelsize=28, patchsize=(36, 36),
padding=(28, 28, 8, 8), framewidth=3)
#        left right  bottom, top
save("fig/plot.pdf", fig)
run(`pdfcrop fig/plot.pdf fig/plot.pdf`)
