
using SpecialFunctions
lgamma_(x) = logabsgamma(x)[1]

logsumexp(a,b) = (m = max(a,b); m == -Inf ? -Inf : log(exp(a-m) + exp(b-m)) + m)

function randp(p,k)
    s = 0.; for j = 1:k; s += p[j]; end
    u = rand()*s
    j = 1
    C = p[1]
    while u > C
        j += 1
        C += p[j]
    end
    @assert(j <= k)
    return j
end

function randlogp!(log_p,k)
    log_s = -Inf; for j = 1:k; log_s = logsumexp(log_s,log_p[j]); end
    p = log_p
    for j = 1:k; p[j] = exp(log_p[j]-log_s); end
    return randp(p,k)
end

function ordered_insert!(index,list,t)
    j = t
    while (j>0) && (list[j]>index)
        list[j+1] = list[j]
        j -= 1
    end
    list[j+1] = index
end

function ordered_remove!(index,list,t)
    for j = 1:t
        if list[j]>=index; list[j] = list[j+1]; end
    end
end

function ordered_next(list)
    j = 1
    while list[j]==j; j += 1; end
    return j
end
                

function restricted_Gibbs!(zsa,zsb,tia,tib,tja,tjb,cia,cib,cja,cjb,ni,nj,i,j,S,ns,x1,x2,ids1_up,T,b,H,active)
# NOTE: The sufficient statistics of tia and tja must be in sync with zsa.
# Also, note that the sufficient statistics of tib and tjb are not updated in this procedure.
    if !active; tia,tja = deepcopy(tia),deepcopy(tja); end
    # precompute upsample index groups for restricted_Gibbs!
    n1 = length(x1)
    idx_groups = [Int[] for _ in 1:n1]
    @inbounds for pos in eachindex(ids1_up)
        push!(idx_groups[ids1_up[pos]], pos)
    end

    log_p = 0.
    for ks = 1:ns
        k = S[ks]
        if k!=i && k!=j
            idx_list = idx_groups[k]
            if zsa[k]==cia
                ni -= (length(idx_list) + 1); Theta_remove!(tia,x1[k])
                for idx in idx_list
                    Theta_remove2!(tia,x2[T[idx]])
                end
            else
                nj -= (length(idx_list) + 1); Theta_remove!(tja,x1[k])
                for idx in idx_list
                    Theta_remove2!(tja,x2[T[idx]])
                end
            end
            x2_list = [x2[T[idx]] for idx in idx_list]
            Li = log_marginal_multiple(x1[k],x2_list,tia,H) - log_marginal(tia,H)
            Lj = log_marginal_multiple(x1[k],x2_list,tja,H) - log_marginal(tja,H)
            Pi = exp(log(ni+b)+Li - logsumexp(log(ni+b)+Li,log(nj+b)+Lj))
            
            if active
                if rand()<Pi
                    zsb[k] = cib
                else
                    zsb[k] = cjb
                end
            end
            if zsb[k]==cib
                ni += (length(idx_list) + 1); Theta_adjoin!(tia,x1[k])
                for idx in idx_list
                    Theta_adjoin2!(tia,x2[T[idx]])
                end
                log_p += log(Pi)
            else
                nj += (length(idx_list) + 1); Theta_adjoin!(tja,x1[k])
                for idx in idx_list
                    Theta_adjoin2!(tja,x2[T[idx]])
                end
                log_p += log(1-Pi)
            end
        end
    end
    return log_p,ni,nj
end


function split_merge!(x1,x2,ids1_up,T,z,zs,S,theta,list,N,t,H,a,b,log_v,n_split,n_merge)
    n1 = length(x1)
    # precompute upsample index groups
    idx_groups = [Int[] for _ in 1:n1]
    @inbounds for pos in eachindex(ids1_up)
        push!(idx_groups[ids1_up[pos]], pos)
    end
    
    # randomly choose a pair of indices
    i = round(Int,ceil(rand()*n1))
    j = round(Int,ceil(rand()*(n1-1))); if j>=i; j += 1; end
    ci0,cj0 = z[i],z[j]
    ti0,tj0 = theta[ci0],theta[cj0]
    
    # set S[1],...,S[ns] to the indices of the points in clusters ci0 and cj0
    ns = 0; n_total = 0
    for k = 1:n1
        if z[k]==ci0 || z[k]==cj0
            ns += 1; S[ns] = k
            n_total += (length(idx_groups[k]) + 1)
        end
    end
    
    # find available cluster IDs for merge and split parameters
    k = 1
    while list[k]==k; k += 1; end; cm = k
    while list[k]==k+1; k += 1; end; ci = k+1
    while list[k]==k+2; k += 1; end; cj = k+2
    tm,ti,tj = theta[cm],theta[ci],theta[cj]
    
    # merge state
    for ks = 1:ns
        k = S[ks]
        idx_list = idx_groups[k]
        Theta_adjoin!(tm,x1[k])
        for idx in idx_list
            Theta_adjoin2!(tm,x2[T[idx]])
        end
    end # get the sufficient statistics
    
    zs[i] = ci; Theta_adjoin!(ti,x1[i])
    idx_list = idx_groups[i]; 
    ni = length(idx_list) + 1
    for idx in idx_list
        Theta_adjoin2!(ti,x2[T[idx]])
    end

    zs[j] = cj; Theta_adjoin!(tj,x1[j])
    idx_list = idx_groups[j]; 
    nj = length(idx_list) + 1
    for idx in idx_list
        Theta_adjoin2!(tj,x2[T[idx]])
    end

    for ks = 1:ns  # start with a uniformly chosen split
        k = S[ks]
        if k!=i && k!=j
            idx_list = idx_groups[k]
            if rand()<0.5
                zs[k] = ci; Theta_adjoin!(ti,x1[k])
                for idx in idx_list
                    Theta_adjoin2!(ti,x2[T[idx]])
                end
                ni += (length(idx_list) + 1)
            else
                zs[k] = cj; Theta_adjoin!(tj,x1[k])
                for idx in idx_list
                    Theta_adjoin2!(tj,x2[T[idx]])
                end
                nj += (length(idx_list) + 1)
            end
        end
    end
    for rep = 1:n_split  # make several moves
        log_p,ni,nj = restricted_Gibbs!(zs,zs,ti,ti,tj,tj,ci,ci,cj,cj,ni,nj,i,j,S,ns,x1,x2,ids1_up,T,b,H,true)
    end
    
    # make proposal
    if ci0==cj0  # propose a split
        # make one final sweep and compute it's probability density
        log_prop_ab,ni,nj = restricted_Gibbs!(zs,zs,ti,ti,tj,tj,ci,ci,cj,cj,ni,nj,i,j,S,ns,x1,x2,ids1_up,T,b,H,true)
        
        # probability of going from merge state to original state
        log_prop_ba = 0.0  # log(1)
        
        # compute acceptance probability
        log_prior_b = log_v[t+1] + lgamma_(ni+b)+lgamma_(nj+b)-2*lgamma_(a)
        log_prior_a = log_v[t] + lgamma_(n_total+b)-lgamma_(a)
        log_lik_ratio = log_marginal(ti,H) + log_marginal(tj,H) - log_marginal(ti0,H)
        p_accept = min(1.0, exp(log_prop_ba-log_prop_ab + log_prior_b-log_prior_a + log_lik_ratio))
        
        # accept or reject
        if rand()<p_accept # accept split
            #  z, list, N, theta, t
            for ks = 1:ns; z[S[ks]] = zs[S[ks]]; end
            ordered_remove!(ci0,list,t)
            ordered_insert!(ci,list,t-1)
            ordered_insert!(cj,list,t)
            N[ci0],N[ci],N[cj] = 0,ni,nj
            t += 1
            Theta_clear!(ti0)
        else # reject split
            Theta_clear!(ti)
            Theta_clear!(tj)
        end
        Theta_clear!(tm)
        
    else  # propose a merge
        # probability of going to merge state
        log_prop_ab = 0.0  # log(1)
        
        # compute probability density of going from split launch state to original state
        log_prop_ba,ni,nj = restricted_Gibbs!(zs,z,ti,ti0,tj,tj0,ci,ci0,cj,cj0,ni,nj,i,j,S,ns,x1,x2,ids1_up,T,b,H,false)
        
        # compute acceptance probability
        log_prior_b = log_v[t-1] + lgamma_(n_total+b)-lgamma_(a)
        log_prior_a = log_v[t] + lgamma_(ni+b)+lgamma_(nj+b)-2*lgamma_(a)
        log_lik_ratio = log_marginal(tm,H) - log_marginal(ti0,H) - log_marginal(tj0,H)
        p_accept = min(1.0, exp(log_prop_ba-log_prop_ab + log_prior_b-log_prior_a + log_lik_ratio))
        
        # accept or reject
        if rand()<p_accept # accept merge
            for ks = 1:ns; z[S[ks]] = cm; end
            ordered_remove!(ci0,list,t)
            ordered_remove!(cj0,list,t-1)
            ordered_insert!(cm,list,t-2)
            N[cm],N[ci0],N[cj0] = n_total,0,0
            t -= 1
            Theta_clear!(ti0)
            Theta_clear!(tj0)
        else # reject merge
            Theta_clear!(tm)
        end
        Theta_clear!(ti)
        Theta_clear!(tj)
    end
    return t
end


function sampler(options,n_total,n_keep)
    x1,x2,n1,n2 = options.x1,options.x2,options.n1,options.n2
    n = n1+n2
    subset_num = round(Int, options.subset_ratio*n1) # number of subset indices
    t_max = options.t_max
    a,b,log_v = options.a,options.b,options.log_v
    use_splitmerge,n_split,n_merge = options.use_splitmerge,options.n_split,options.n_merge
    use_hyperprior = options.use_hyperprior
    fill_with_centers = options.fill_with_centers

    model_type = options.model_type
    alpha_random,alpha = options.alpha_random,options.alpha
    sigma_alpha = 0.1 # scale for MH proposals in alpha move

    ot_ratio = options.ot_ratio
    swap_iter = options.swap_iter
    swap_set_size = options.swap_set_size
    temperature = options.temperature
    block_size = options.block_size

    # Randomly initialize a transport map
    x1_up_current = deepcopy(x1)
    ids1_up_current = deepcopy(1:n2)
    if n1 < n2
        if fill_with_centers
            x1_up_current, ids1_up_current = upsampling_balancing_data_with_centers(x1, x2, subset_num, seed=2026)
        else
            x1_up_current, ids1_up_current = upsampling_balancing_data(x1, x2, subset_num, seed=2026)
        end
    else
        x1_up_current = deepcopy(x1)
        ids1_up_current = deepcopy(1:n2)
    end
    n1_up_current = size(x1_up_current, 1)
    @assert n1_up_current == n2 "n1_up_current should be equal to n2"
    
    # Storage for current T & Initialize with partial ot
    T_current = propose_T(x1_up_current,x2;method="random")
    if ot_ratio > 0
        T_current = propose_T(x1_up_current,x2;T_current=deepcopy(T_current),method="partial_ot",ot_ratio=ot_ratio,block_size=block_size)
    end
    @assert length(unique(T_current)) == n2 "length(unique(T0_current)) should be equal to n2"

    T0_current = propose_T(x1_up_current,x2;method="random_replace")

    # Storage for current E & Initialize with random
    E_current = sample(1:n2, subset_num, replace=false)
    
    # Storage for current TT
    TT_current = similar(T_current)
    if subset_num < n1
        TT_current = aggregate_T_T0(T_current, T0_current, E_current)
    else
        TT_current = deepcopy(T_current)
    end
    
    # Storage for current energy
    current_energy = compute_energy(x1_up_current,x2,p=2,T=T_current,temperature=temperature)

    keepers = zeros(Int,n_keep)
    keepers[:] = round.(Int,range(round(Int,n_total/n_keep),stop=n_total,length=n_keep))
    keep_index = 0

    t = 1  # number of clusters
    z = ones(Int,n1)  # z[i] = the cluster ID associated with data point i
    list = zeros(Int,t_max+3); list[1] = 1  # list[1:t] = the list of active cluster IDs
                                            # list is maintained in increasing order for 1:t, and is 0 after that.
                                            # Further, maximum(list[1:t]) is always <= t_max+3
    c_next = 2  # an available cluster ID to be used next
    N = zeros(Int,t_max+3); N[1] = n1+n2  # N[c] = size of cluster c

    H = construct_hyperparameters(options)
    theta = [new_theta(H)::Theta for c = 1:t_max+3]  # theta[c] = parameters for cluster c

    log_p = zeros(n1+1)
    zs = ones(Int,n1)  # temporary variable used for split-merge assignments
    S = zeros(Int,n1)  # temporary variable used for split-merge indices
    
    log_Nb = log.((1:n) .+ b)
    
    # Record-keeping variables
    t_r = zeros(Int16,n_total); @assert(t_max < 2^15)
    N_r = zeros(Int32,t_max+3,n_total); @assert(n1 < 2^31)
    z_r = zeros(Int16,n1,n_keep); @assert(t_max < 2^15)
    T_r = zeros(Int64,n2,n_keep)
    ids1_up_r = zeros(Int64,n2)
    
    theta_r = []

    log_lik_current = 0.0
    for j in eachindex(x1); Theta_adjoin!(theta[1],x1[j]); end
    for j in eachindex(x2); Theta_adjoin2!(theta[1],x2[j]); end
    for j = 1:t; cc = list[j]
        log_lik_current += log_marginal(theta[cc],H)
    end

    for iteration = 1:n_total
        x1_up = deepcopy(x1_up_current)
        ids1_up = deepcopy(ids1_up_current)
        
        if !fill_with_centers
            # random upsample for eqach seed
            x1_up, ids1_up = upsampling_balancing_data(x1, x2, subset_num, seed=2025+iteration)
        end

        n1_up = size(x1_up, 1)
        @assert n1_up == n2 "n1_up should be equal to n2"
        
        T = deepcopy(T_current)
        T0 = deepcopy(T0_current)
        E = deepcopy(E_current)
        TT = deepcopy(TT_current)

        # proposal of T
        for i in 1:swap_iter
            if i == 1
                T = propose_T(x1_up,x2;T_current=T_current,method="sequential_binary")
            else
                T = propose_T(x1_up,x2;T_current=T,method="sequential_binary")
            end
        end

        # Aggregate
        TT = deepcopy(T)
        @assert length(unique(TT)) == n2 "Aggregated TT must be a valid permutation over x2 indices"
            
        # Calculate new energy
        new_energy = compute_energy(x1_up,x2,p=2,T=T,temperature=temperature)
        
        # Calculate proposal likelihood on a temporary copy of theta
        theta_proposal = [deepcopy(theta[list[c]]) for c in 1:t]
        
        for c in 1:t
            Theta_clear2!(theta_proposal[c])
        end

        idx_groups_proposal = [Int[] for _ in 1:n1]
        @inbounds for pos in eachindex(ids1_up)
            push!(idx_groups_proposal[ids1_up[pos]], pos)
        end

        @inbounds for k = 1:n1
            for idx in idx_groups_proposal[k]
                original_cluster_id = z[ids1_up[idx]]
                list_idx = findfirst(==(original_cluster_id), list)
                if list_idx !== nothing
                    Theta_adjoin2!(theta_proposal[list_idx], x2[TT[idx]])
                end
            end
        end

        log_lik_proposal = 0.0
        for j = 1:t
            log_lik_proposal += log_marginal(theta_proposal[j], H)
        end

        # Acceptance probability
        log_alpha = (log_lik_proposal - log_lik_current) + log(new_energy) - log(current_energy)
        log_α = min(0.0, log_alpha)

        if log(rand()) < log_α
            T_current = deepcopy(T)
            if (@isdefined T0) && (@isdefined E)
                T0_current = deepcopy(T0)
                E_current = deepcopy(E)
            end
            TT_current = deepcopy(TT)
            x1_up_current = deepcopy(x1_up)
            ids1_up_current = deepcopy(ids1_up)
            current_energy = new_energy
            log_lik_current = log_lik_proposal
            
            for c = 1:t
                theta[list[c]] = theta_proposal[c]
            end
        else
            T = deepcopy(T_current)
            if (@isdefined T0) && (@isdefined E)
                T0 = deepcopy(T0_current)
                E = deepcopy(E_current)
            end
            TT = deepcopy(TT_current)
            x1_up = deepcopy(x1_up_current)
            ids1_up = deepcopy(ids1_up_current)
        end
        
        # Update cluster counts N
        N .= 0
        
        for c in z
            N[c] += 1
        end
        
        for id in ids1_up
            c = z[id]
            N[c] += 1
        end
        
        # Validate that list[1:t] contains exactly t non-zero elements
        non_zero_count = count(!iszero, list[1:t])
        @assert non_zero_count == t "Inconsistent state: list[1:t] has $non_zero_count non-zero elements but t=$t"
        t = count(!iszero, list)
        
        c_next = ordered_next(list)

        # Update lambda
        idx_groups = [Int[] for _ in 1:n1]
        @inbounds for pos in eachindex(ids1_up)
            push!(idx_groups[ids1_up[pos]], pos)
        end
        
        if use_hyperprior
            update_hyperparameters!(H,theta,list,t,x1,x2,z)
        end

        # -------------- Resample z's --------------
        for i in 1:n1
            # remove point i from it's cluster
            c = z[i]
            idx_list = idx_groups[i]
            cnt = length(idx_list)
            N[c] -= (cnt + 1)
            Theta_remove!(theta[c],x1[i])
            for idx in idx_list
                Theta_remove2!(theta[c],x2[TT[idx]])
            end
            if N[c]>0
                c_prop = c_next
            else
                c_prop = c
                # remove cluster {i}, keeping list in proper order
                ordered_remove!(c,list,t)
                t -= 1
                # Clear the sufficient statistics for the now-empty cluster to avoid stale state
                Theta_clear!(theta[c])
            end

            # compute probabilities for resampling
            x2_list = [x2[TT[idx]] for idx in idx_list]
            for j = 1:t; cc = list[j]
                log_p[j] = log_Nb[N[cc]] + log_marginal_multiple(x1[i],x2_list,theta[cc],H) - log_marginal(theta[cc],H)
            end
            log_p[t+1] = log_v[t+1] - log_v[t] + log(a) + log_marginal_multiple(x1[i],x2_list,theta[c_prop],H)
            
            # sample a new cluster for it
            j = randlogp!(log_p,t+1)
            
            # add point i to it's new cluster
            if j<=t
                c = list[j]
            else
                c = c_prop
                ordered_insert!(c,list,t)
                t += 1
                c_next = ordered_next(list)
                @assert(t<=t_max, "Sampled t has exceeded t_max. Increase t_max and retry.")
            end
            # update sufficient statistics
            Theta_adjoin!(theta[c],x1[i])
            for idx in idx_list
                Theta_adjoin2!(theta[c],x2[TT[idx]])
            end
            z[i] = c
            N[c] += (cnt + 1)
        end
        
        # -------------- Split/merge move --------------
        if use_splitmerge
            t = split_merge!(x1,x2,ids1_up,TT,z,zs,S,theta,list,N,t,H,a,b,log_v,n_split,n_merge)
            c_next = ordered_next(list)
            @assert(t<=t_max, "Sampled t has exceeded t_max. Increase t_max and retry.")
        end

        # Recalculate current log-likelihood
        log_lik_current = 0.0
        for j = 1:t
            cc = list[j]
            log_lik_current += log_marginal(theta[cc], H)
        end

        # -------------- Record results --------------
        t_r[iteration] = t
        T_r[:, iteration] = TT
        for j = 1:t
            N_r[list[j],iteration] = N[list[j]]
        end
        if iteration==keepers[keep_index+1]
            keep_index += 1
            for i = 1:n1; z_r[i,keep_index] = z[i]; end
        end
        
        theta_r_ = Any[]
        for j = 1:t
            c = list[j]
            if N[c] > 0
                push!(theta_r_, deepcopy(theta[c]))
            end
        end
        push!(theta_r, theta_r_)

        ids1_up_r = ids1_up
    end
    
    return t_r,N_r,z_r,T_r,ids1_up_r,theta_r,keepers
end