#=
NFs are coded using Bijectors.jl interface
=#
using LinearAlgebra, Distributions, Random, StatsBase, SpecialFunctions, Parameters
using ProgressMeter, Flux, Bijectors, TickTock
using Flux: Params
include("../util/train.jl")


function single_elbo(flow::Bijectors.MultivariateTransformed, logp, logq)
    x, y, logjac, logpdf_y = Bijectors.forward(flow)
    el = logp(y) -logq(x) + logjac
    # el = logp(y) - logpdf_y
    return el
end

function nf_ELBO(flow::Bijectors.MultivariateTransformed, logp, logq; elbo_size = 1)
    el = 0.0    
    @simd for i in 1:elbo_size
        el += 1/elbo_size*single_elbo(flow, logp, logq)
    end
    return el
end

function nf_train!(flow, Layers, ps::Flux.Params, logp, logq; elbo_size::Int = 10)
"""
mutate flow, Layers, ps
"""
    #define loss
    loss = () -> begin 
        elbo = nf_ELBO(flow, logp, logq; elbo_size = elbo_size)
        return -elbo
    end

    ls_log, _ = vi_train!(niters, loss, ps, optimizer; logging_ps = false, kwargs...)
    return flow, Layers, -ls_log
end


include("simpleflow.jl")  # planar flow and radial flow
# include("realnvp.jl")     # real NVP with ReLu activation

