
using DrWatson
@quickactivate "BBVIConvergence"

using DelimitedFiles
using Random, Random123
using Base.GC
#using Folds, FoldsThreads, ThreadPinning
#using Base.Threads

include(srcdir("BBVIConvergence.jl"))
include("utils.jl")

function system_setup(; use_mkl=false, is_hyper=false, start)
    if myid() > 1
       multiplier = is_hyper ? 2 : 1
       run(`taskset -pc $(multiplier*(myid() - 2) + start) $(getpid())`)
    end
    BLAS.set_num_threads(1)
end

function estimate_valid_elbo(rng, logdensityprob, q, b⁻¹, M,
                                         sample_batch!, prepare_full_pass!)
    mapreduce(+, 1:M) do _
        ζₘ  = rand(rng, q)
        zₘ  = b⁻¹(ζₘ)
        n_batch = prepare_full_pass!()
        mapreduce(+, 1:n_batch) do _
            logdensityprob = sample_batch!(logdensityprob)
            LogDensityProblems.logdensity(logdensityprob, zₘ)
        end / n_batch
    end / M
end

function run_config(key, config)
    @unpack problem, dataset, covariance_type, param_type, optimizer, logstepsize = config

    M            = 10
    T, batchsize = if problem == :bradleyterry
      50_000, 500
    else
      50_000, 100
    end

    seed = (0x97dcb950eaebcfba, 0x741d36b68bef6415)
    rng  = Random123.Philox4x(UInt64, seed, 8)
    Random123.set_counter!(rng, key)
    Random.seed!(key)

    ad_type = if covariance_type == :cholesky
        ZygoteAD
    else
        ReverseDiffAD
    end

    prob, b⁻¹, sample_train_batch!, compute_full_elbo!, _ = model_with_dataset(
        Val(problem), dataset, batchsize; rng=rng)

    ϕ = if param_type == :linear || param_type == :linear_narrow
        identity
    else
        StatsFuns.softplus
    end

    ϕ⁻¹ = if param_type == :linear || param_type == :linear_narrow
        identity
    else
        StatsFuns.invsoftplus
    end

    stepsize = 10^logstepsize
    optimizer, estimator = if optimizer == :proximal_adam
        ProxGenAdam(stepsize), ClosedFormEntropy{true}()
    elseif optimizer == :adam
        Optimisers.Adam(stepsize), ClosedFormEntropy{false}()
    elseif optimizer == :proximal_adam_stl
        ProxGenAdam(stepsize), StickingTheLanding{true}()
    elseif optimizer == :adam_stl
        Adam(stepsize), StickingTheLanding{false}()
    else
        nothing
    end

    d  = LogDensityProblems.dimension(prob)
    function callback!(t, stats, λ, q, sub_elbo, g)
        if mod(t, 1000) == 0
            full_elbo = compute_full_elbo!(q, b⁻¹, 1000)
            (full_elbo = full_elbo,)
        else
            NamedTuple()
        end
    end

    m₀ = zeros(d)
    C₀ = if occursin("narrow", string(param_type))
        Diagonal(fill(ϕ⁻¹(1e-3), d))
    else
        Diagonal(fill(ϕ⁻¹(1.0), d))
    end

    q, _, stats = bbvi(prob, M, T, m₀, C₀;
                       rng            = rng,
                       ψ⁻¹            = b⁻¹,
                       ϕ              = ϕ,
                       optimizer      = optimizer,
                       show_progress  = myid() == 2,
                       callback!      = callback!,
                       param_type     = covariance_type,
                       estimator_type = estimator,
                       ad_type        = ad_type,
                       sample_batch   = sample_train_batch!)

    t_sub,  elbo_sub  = filter_stats(:elbo,      stats)
    t_full, elbo_full = filter_stats(:full_elbo, stats)

    t_sub    = t_sub[   100:100:end]
    elbo_sub = elbo_sub[100:100:end]

    df  = DataFrame(:t => t_sub,  :elbo_minibatch => elbo_sub)
    df′ = DataFrame(:t => t_full, :elbo           => elbo_full)
    outerjoin(df, df′, on = :t)
end

function main(n_trials)
    optimization  = [
                     #(param_type = :linear,    optimizer = :proximal_adam),
                     #(param_type = :linear,    optimizer = :adam),
                     (param_type = :nonlinear, optimizer = :adam),
                     #(param_type = :linear,    optimizer = :adam_stl),
                     #(param_type = :linear,    optimizer = :proximal_adam_stl),

                     (param_type = :linear_narrow,    optimizer = :proximal_adam),
                     (param_type = :linear_narrow,    optimizer = :adam),
                     (param_type = :nonlinear_narrow, optimizer = :adam),
                     (param_type = :linear_narrow,    optimizer = :adam_stl),
                     (param_type = :linear_narrow,    optimizer = :proximal_adam_stl),
                     ]

    logstepsizes = [(logstepsize = logstepsize,) for logstepsize ∈ range(-4.5, -0.5, step=0.25)]

    problems = if ENV["TASK"] == "mf_keggu"
        [(problem = :linearreg,    dataset = :keggundirected, covariance_type = :meanfield),]
    elseif ENV["TASK"] == "mf_buzz"
        [(problem = :linearreg,    dataset = :buzz,           covariance_type = :meanfield),]
    elseif ENV["TASK"] == "mf_song"
        [(problem = :linearreg,    dataset = :song,           covariance_type = :meanfield),]
    elseif ENV["TASK"] == "mf_electric"
        [(problem = :linearreg,    dataset = :houseelectric,  covariance_type = :meanfield),]
    elseif ENV["TASK"] == "mf_radon"
        [(problem = :radon,        dataset = :radon,          covariance_type = :meanfield),]
    elseif ENV["TASK"] == "mf_tennis"
        [(problem = :bradleyterry, dataset = :tennis,         covariance_type = :meanfield),]
    elseif ENV["TASK"] == "mf_election"
        [(problem = :election,     dataset = :election,       covariance_type = :meanfield),]
    elseif ENV["TASK"] == "mf_eeg"
        [(problem = :autoregress,  dataset = :eeg,            covariance_type = :meanfield),]

    elseif ENV["TASK"] == "fr_keggu"
        [(problem = :linearreg,    dataset = :keggundirected,  covariance_type = :cholesky),]
    elseif ENV["TASK"] == "fr_buzz"
        [(problem = :linearreg,    dataset = :buzz,            covariance_type = :cholesky),]
    #elseif ENV["TASK"] == "fr_song"
    #    [(problem = :linearreg,    dataset = :song,            covariance_type = :cholesky),]
    #elseif ENV["TASK"] == "fr_electric"
    #    [(problem = :linearreg,    dataset = :houseelectric,   covariance_type = :cholesky),]
    #elseif ENV["TASK"] == "fr_radon"
    #    [(problem = :radon,        dataset = :radon,           covariance_type = :cholesky),]
    elseif ENV["TASK"] == "fr_election"
        [(problem = :election,     dataset = :election,        covariance_type = :cholesky),]
    elseif ENV["TASK"] == "fr_eeg"
        [(problem = :autoregress,  dataset = :eeg,             covariance_type = :cholesky),]
    end

    configs = Iterators.product(logstepsizes, optimization, problems) |> collect
    configs = reshape(configs, :)
    configs = map(x -> merge(x...), configs)

    @showprogress for config ∈ configs
        DrWatson.produce_or_load(datadir("experiment"), config) do _
            dfs = @showprogress pmap(1:n_trials) do key
                run_config(key, config)
            end
            df = vcat(dfs...)
            for (k, v) ∈ pairs(config)
                df[:,k] .= v
            end
            GC.gc()
            Dict(:data => df, :config => config)
        end
    end
end
