
using DrWatson
@quickactivate "BBVIConvergence"

using Plots, StatsPlots
using Random123
using LinearAlgebra
using MAT
using CSV

import MCMCChains
import DataFrames

include("utils.jl")
include("quadratic.jl")
include("linear_regression.jl")
include(srcdir("BBVIConvergence.jl"))

function nonconvex_linear_regression(X, y, α, β)
    d = size(X, 2)

    logπ(z) = begin
        σ  = exp(z[1])
        w  = z[2:end]
        σ² = σ*σ

        logpdf(MvNormal(X*w, σ²*I), y) +
            logpdf(InverseGamma(α, β), σ) +
            logpdf(MvNormal(zeros(d), σ²*I), w)
    end

    d, logπ, nothing#ψ⁻¹
end

function main()
    key  = 1
    seed = (0x97dcb950eaebcfba, 0x741d36b68bef6415)
    prng = Random123.Philox4x(UInt64, seed, 8)
    Random123.set_counter!(prng, key)
    Random.seed!(key)

    γ = 1e-5
    M = 10
    T = 2000
    α = 1.0
    β = 1.0

    param_type = :cholesky

    X, y         = load_dataset("wine.mat")
    d, logπ, ψ⁻¹ = nonconvex_linear_regression(X, y, α, β)
    p            = d + 1

    φ = Normal()
    ϕ(x) = max(x, 1e-5)
    #ϕ = softplus

    #λ_hist = zeros(div(p*(p-1),2) + p, T)
    λ_hist = zeros(90, T)

    function callback!(t, stats, λ, q, ℓπ, elbo, g, g_estimator!)
        t_elbo, hist_elbo = filter_stats(:elbo, stats[1:t])
        λ_hist[:, t]      = λ

        display(Plots.plot(t_elbo, hist_elbo))

        NamedTuple()
    end

    q, stats  = bbvi(p, M, logπ, γ, T; ϕ=ϕ, φ=φ, prng=prng,
                     show_progress = true,
                     param_type    = param_type,
                     callback!     = callback!,
                     C₀            = Matrix{Float64}(I,p,p))

    _, hist_elbo_sgd = filter_stats(:elbo, stats)

    λ_hist_chains = reshape(λ_hist, (size(λ_hist,1), 1, T))
    λ_hist_chains = permutedims(λ_hist_chains, (3, 1, 2))

    chain_early = MCMCChains.Chains(λ_hist_chains[1:1000,:,:])
    chain_later = MCMCChains.Chains(λ_hist_chains[end-1000:end,:,:])
    chain_early, chain_later
end
