using Bijectors, Flux, Zygote, Random

function lrelu_layer(xdims::Int; hdims::Int=20, param_type = Flux.f64)
    """
    # param_type: {f64, f32, f16}
    """
    nn = Chain(Flux.Dense(xdims, hdims, leakyrelu), Flux.Dense(hdims, hdims, leakyrelu), Flux.Dense(hdims, xdims)) |> param_type
    return nn
end

function affine_coupling_layer(shifting_layer, scaling_layer, dims, masking_idx)
    Bijectors.Coupling(θ -> Bijectors.Shift(shifting_layer(θ)) ∘ Bijectors.Scale(scaling_layer(θ)), Bijectors.PartitionMask(dims, masking_idx))
end


function create_rnvp(q0; nlayers = 10, hdims=20, param_type = Flux.f64)
    d = length(q0)
    xdims = Int(d/2)
    # println(xdims)
    scaling_layers = [ lrelu_layer(xdims; hdims = hdims, param_type = param_type) for i in 1:nlayers ]
    shifting_layers = [ lrelu_layer(xdims; hdims = hdims, param_type = param_type) for i in 1:nlayers ]
    ps = Flux.params(shifting_layers[1], scaling_layers[1]) 
    Layers = affine_coupling_layer(shifting_layers[1], scaling_layers[1], d, xdims+1:d)
    # number of affine_coupling_layers with alternating masking scheme
    for i in 2:nlayers
        Flux.params!(ps, (shifting_layers[i], scaling_layers[i]))
        Layers = Layers ∘ affine_coupling_layer(shifting_layers[i], scaling_layers[i], d, (i%2)*xdims+1:(1 + i%2)*xdims) 
    end
    flow = Bijectors.transformed(q0, Layers)
    return flow, Layers, ps
end

function elbo_single_sample(x, flow, logp, logq)
    y, logabsdetjac = with_logabsdet_jacobian(flow.transform, x)
    return logp(y) - logq(x) + logabsdetjac
end

function elbo(xs, flow::Bijectors.MultivariateTransformed, logp, logq)
    n_samples = size(xs, 2)
    elbo_values = map(x -> elbo_single_sample(x, flow, logp, logq), eachcol(xs))
    return sum(elbo_values) / n_samples
end

elbo(rng::AbstractRNG, flow::Bijectors.MultivariateTransformed, logp, logq, n_samples) = elbo(
    rand(rng, flow.dist, n_samples), flow, logp, logq
)

# training function
function train_rnvp!(
    rng::AbstractRNG, 
    flow::Bijectors.MultivariateTransformed,
    logp, logq, 
    ps::Flux.Params; 
    elbo_size::Int = 10, 
    niters::Int = 50000,
    optimizer = Flux.ADAM(1e-3)
)
    losses = zeros(niters)
    #define loss
    loss = () -> begin 
        El = elbo(rng, flow, logp, logq, elbo_size)
        return -El
    end
    @showprogress for iter in 1:niters
        # compute loss, grad simultaneously
        ls, back = pullback(ps)do
            loss()
        end
        grads = back(1.0)
        # update parameters
        Flux.update!(optimizer, ps, grads)
        losses[iter] = ls
    end
    return losses, flow
end

