using Optimisers, Zygote, ProgressBars, Format, Serialization
include("Data.jl"); using .Data; include("Pds.jl"); using .Pds
using Random, PyCall; const np_random = pyimport("numpy.random")

const (nt, nm, N, n, lr, seed) = 1, 1000, 100_000, 10_000, 1e-3, 32

foreach((scurve, moons, swiss, circles, gaussian8)) do data;
   println("Data [$(string(data))]")
   Random.seed!(seed); np_random.seed(seed)
   Pds.merge_mixtures(map(1:nt) do trial;
      println("Trial [$trial]")
      m = Mix(zeros(nm), map(Gaussian, eachcol(data(nm))))
      st = Optimisers.setup(Adam(lr), m)
      pb, fmt = ProgressBar(1:N), "%6.3f"
      lossfn(m, x) = (- n \ logp(m, x))
      for i ∈ pb
         loss, (∇m, _) = withgradient(lossfn, m, data(n))
         st, m = Optimisers.update(st, m, ∇m)
         set_description(pb, "Loss: $(cfmt(fmt, loss))")
      GC.gc() end; m
   end) |> 
   m -> serialize(joinpath(mkpath(
      joinpath(@__DIR__, "save", string(data))), "gt.jls"), m)
end
