# Main module for BayesianMixtures package
module BayesianMixtures

include("MFM.jl")
include("RandomNumbers.jl")

include("MVNaaC.jl")
include("utils.jl")

using Random
using Statistics
using SpecialFunctions
using MultivariateStats
using TSne
lgamma_(x) = logabsgamma(x)[1]

SEED = 2025
Random.seed!(SEED)

# ===================================================================
# ===================================================================
# ================== Functions to generate results ==================
# ===================================================================
# ===================================================================

# Create an options object to specify model, data, and MCMC parameters.
function options(
        mode, # "Normal", "MVN", "MVNaaC", "MVNaaN", or "MVNaaRJ" 
        model_type, # "MFM" or "DPM"
        x1, x2, # data
        n1, n2, 
        n_total, # total number of MCMC sweeps to run the sampler
        ot_ratio, # ratio of subsamples used in transport map proposal
        swap_iter, # total number of iterations used in swap proposal
        swap_set_size, 
        temperature, # temperature in transport map
        block_size, # block size for blocked OT
        subset_ratio, # ratio of subsamples that will have matchings
        sub_swap_iter; # total number of iterations used in subset swap proposal
        n_keep=n_total, # number of MCMC sweeps to keep after thinning
        n_burn=round(Int,n_total/10), # number of MCMC sweeps (out of n_total) to discard as burn-in
        verbose=true, # display information or not
        use_hyperprior=true, # update lambda (base distn parameters)
        t_max=40, # a guess at an upper bound on # of clusters that will be encountered during MCMC

        # MFM options:
        gamma=1.0, # Dirichlet_k(gamma,...,gamma)
        log_pk="k -> log(0.1)+(k-1)*log(0.9)", # string representation of log(p(k))
            # (where p(k) is the log of the prior on # of components, K)
            # log_pk="k -> k == 10 ? 0 : -Inf" puts a hard prior on K=10
            # log_pk = "k -> logbeta(a+1, b+k-1) - logbeta(a, b)" puts a marginalized Geometric prior on K with a hierarchical Beta prior
            
        # DPM options:
        alpha_random=true, # put prior on alpha (DPM concentration parameter) or not
        alpha=1.0, # value of alpha (initial value if alpha_random=true)

        # Jain-Neal split-merge options:
        use_splitmerge=true, # use split-merge or not
        n_split=5, # number of intermediate sweeps for split launch state
        n_merge=5,  #                 "         "       merge    "     "  
        
        # RJMCMC options:
        k_max=t_max, # a guess at an upper bound on # of components that will be encountered during MCMC

        # fill remain options:
        fill_with_centers=false # upsampling with k-medoids centers
    )

    # Compute partition distribution values
    # n1 = length(x1); n2 = length(x2); 
    # n = n1+n2
    if model_type=="MFM"
        lpk = eval(Meta.parse(log_pk))
        log_pk_fn(k) = Base.invokelatest(lpk,k)
        log_v = MFM.coefficients(log_pk_fn,gamma,n1,t_max+1)
        a = b = 1.0
    elseif model_type=="DPM"
        log_v = float(1:t_max+1)*log(alpha) .- lgamma_(alpha+n) .+ lgamma_(alpha)
        a,b = 1.,0.
    else
        error("Invalid model_type: $model_type.")
    end
    if mode=="MVNaaRJ"; @assert(model_type=="MFM", "RJMCMC is not implemented for DPMs."); end

    n_keep = min(n_keep,n_total)
    module_ = getfield(BayesianMixtures,Symbol(mode))
    return module_.Options(mode, model_type, x1, x2, n_total, n_keep, n_burn, verbose,
                           use_hyperprior, t_max, gamma, log_pk, alpha_random, alpha,
                           use_splitmerge, n_split, n_merge, k_max, a, b, log_v, ot_ratio, swap_iter, swap_set_size,
                           temperature, block_size, n1, n2, subset_ratio, sub_swap_iter, fill_with_centers)
end


# Run the MCMC sampler with the specified options.
function run_sampler(options)
    o = options
    n1,n2 = o.n1,o.n2
    n_total,n_keep = o.n_total,o.n_keep
    module_ = getfield(BayesianMixtures,Symbol(o.mode))

    # Short run to precompile
    if o.verbose
        println("Precompiling...")
    end
    o_ = module_.Options(
        o.mode, o.model_type, o.x1, o.x2, o.n_total, o.n_keep, o.n_burn, o.verbose,
        o.use_hyperprior, o.t_max, o.gamma, o.log_pk, o.alpha_random, o.alpha,
        o.use_splitmerge, o.n_split, o.n_merge, o.k_max, o.a, o.b, o.log_v, 0.0, o.swap_iter, o.swap_set_size,
        o.temperature, o.block_size, o.n1, o.n2, o.subset_ratio, o.sub_swap_iter, o.fill_with_centers
    )
    module_.sampler(o_,1,1)

    if o.verbose
        println(o.mode, " ", o.model_type)
        println("n = $n1+$n2, n_total = $n_total, n_keep = $n_keep")
        print("Running... ")
    end
    
    # Main run
    elapsed_time = (@elapsed t_r,N_r,z_r,T,ids1_up,theta_r,keepers = module_.sampler(o,n_total,n_keep))
    time_per_step = elapsed_time/(n_total*(n1+n2))

    if o.verbose
        println("complete.")
        println("Elapsed time = $elapsed_time seconds")
        println("Time per step ~ $time_per_step seconds")
    end

    return module_.Result(o,t_r,N_r,z_r,T,ids1_up,theta_r,keepers,elapsed_time,time_per_step)
end


# ===================================================================
# ===================================================================
# ================== Functions to analyze results ===================
# ===================================================================
# ===================================================================

# Generating the cluster label of x2
function gen_z2(z1,T,n2,ids)
    # Validate inputs
    @assert length(T) == n2 "gen_z2: length(T) must equal n2"
    @assert maximum(T) <= n2 && minimum(T) >= 1 "gen_z2: T must contain indices within 1:n2"

    invT = zeros(Int, n2)
    for i in 1:n2
        v = T[i]
        if invT[v] != 0
            error("gen_z2: Duplicate mapping detected in T:$(v).")
        end
        invT[v] = i
    end

    # For stability
    unmapped = findall(==(0), invT)
    if !isempty(unmapped)
        error("gen_z2: Unmapped x2 indices in T: $(unmapped).")
    end

    z2 = similar(z1, n2)
    for j in 1:n2
        i = invT[j]
        z2[j] = z1[ids[i]]
    end
    return z2
end


function calculate_nll(idx, result)
    o = result.options
    module_ = getfield(BayesianMixtures, Symbol(o.mode))
    H = module_.construct_hyperparameters(o)
    log_marginal = module_.log_marginal
    theta = result.theta[idx]

    ll = 0.0
    for c in theta
        ll += log_marginal(c, H)
    end

    return -ll
end


end # module BayesianMixtures
