using Optimisers, Zygote, ProgressBars, Format, Serialization
include("Data.jl"); using .Data; include("Pds.jl"); using .Pds
using Random, PyCall; const np_random = pyimport("numpy.random")

const (nt, nm, N, itvl, n, lr, baseed) = 5, 100, 30_000, 100, 100, 3e-3, 22
const (η0_cfg, ηT_cfg) = begin
   (2.0, 1.0, 0.5, 0.2, 0.1, 0.05, 0.02, 0.01, 1.0, 1.0, 1.00, 1.00, 1.00, 1.00),
   (2.0, 1.0, 0.5, 0.2, 0.1, 0.05, 0.02, 0.01, 0.2, 0.1, 0.05, 0.01, 0.005, 0.001)
end

foreach(η0_cfg, ηT_cfg) do η0, ηT
   (η0_str, ηT_str) = map((η0, ηT)) do x
      iszero(x) && return "0"
      1 ≤ abs(x) < 100 && return string(trunc(Int, x))
      0.1 ≤ abs(x) < 1 && return string(round(x; digits=1))
      p=floor(Int,log10(abs(x)));b=round(Int,abs(x)/10.0^p)
      return "$(x > 0 ? b : -b)e$(p)"
   end; println("η [$η0_str => $ηT_str]")
   foreach((scurve, moons, swiss, circles, gaussian8)) do data;
      println("Data [$(string(data))]")
      path = mkpath(joinpath(@__DIR__, "save", string(data)))

      foreach(1:nt) do trial; trial_str = lpad(trial, 2, '0')
         println("Trial [$trial]"); seed = baseed + 100trial
         Random.seed!(seed); np_random.seed(seed)
         md = Mix(zeros(nm), map(Gaussian, eachcol(data(nm))))
         ml, mdprev = copy(md), copy(md)
         st = Optimisers.setup(Adam(lr), (md, ml))
         pb, fmt = ProgressBar(1:N), "%6.3f"
         hist_md, hist_ml = [copy(md)], [copy(ml)]
         f(m,l,a,b,x,y1,y2,η)=η*vi(m,a,y1)+(1-η)*vi(m,b,y2)-n\logp(l,x)
         for i ∈ pb; η=1/(1/η0+(((i-1)÷itvl)/(N÷itvl-1))*(1/ηT-1/η0))
            loss,(∇md,∇ml,_,_,_,_)=withgradient(f,
               md,ml,ml,mdprev,data(n),Pds.sample_comp(md,64),Pds.sample_comp(md,64),η)
            st, (md, ml) = Optimisers.update(st, (md, ml), (∇md, ∇ml))
            i % itvl == 0 && begin
               copy!(mdprev, md); push!(hist_md, copy(md)); push!(hist_ml, copy(ml))
            end; set_description(pb, "Loss: $(cfmt(fmt, loss))"); GC.gc()
         end
         file_md = "vmd_$(η0_str)_$(ηT_str)_$(trial_str).jls"
         file_ml = "vml_$(η0_str)_$(ηT_str)_$(trial_str).jls"
         serialize(joinpath(path, file_md), hist_md)
         serialize(joinpath(path, file_ml), hist_ml)
      end
   end
end
