using Bijectors, Flux

function planarflow(q0, nlayers, d; param_type = Flux.f32)
    F = ∘([PlanarLayer(d) for i in 1:nlayers]...) |> param_type
    flow = transformed(q0, F)
    ps = Flux.params(flow)
    return flow, F, ps
end


function radialflow(q0, nlayers, d; param_type = Flux.f32)
    F = ∘([PlanarLayer(d) for i in 1:nlayers]...) |> param_type
    flow = transformed(q0, F)
    ps = Flux.params(flow)
    return flow, F, ps
end

function trainsnf(q0, logp, logq, d::Int, niters::Int; 
                nlayers = 5, flow_type::String = "Planar", param_type = Flux.f64,  
                elbo_size::Int = 10, optimizer = Flux.ADAM(1e-3), kwargs...)
    
    if flow_type == "Planar"
        flow, Layers, ps = planarflow(q0, nlayers, d; param_type = param_type)
    elseif flow_type == "Radial"
        flow, Layers, ps = radialflow(q0, nlayers, d; param_type = param_type)
    end
    #define loss
    loss = () -> begin 
        elbo = nf_ELBO(flow, logp, logq; elbo_size = elbo_size)
        return -elbo
    end

    ls_log, _ = vi_train!(niters, loss, ps, optimizer; logging_ps = false, kwargs...)
    
    return flow, Layers, -ls_log, ps
end
