using Statistics

function compute_utility(xs, zs, centers, K_ids)
    x0, x1 = xs
    z0, z1 = zs

    # concatenate along rows
    x_full = vcat(x0, x1)
    z_full = vcat(z0, z1)
    n = size(x_full, 1)

    total_distance = 0.0
    for k in K_ids
        idxs = findall(==(k), z_full)
        Xk = x_full[idxs, :]
        Ck = centers[k]
        total_distance += sum(sum((Xk .- Ck).^2, dims=2))
    end

    return total_distance, total_distance / n
end

function compute_fairness(zs, K_ids)
    z0, z1 = zs
    n0, n1 = length(z0), length(z1)
    eps = 1e-8

    balances = Float64[]
    gaps     = Float64[]

    for k in K_ids
        cnt0 = count(==(k), z0)
        cnt1 = count(==(k), z1)
        # balance in cluster k
        prop01 = cnt0 / (cnt1 + eps)
        prop10 = cnt1 / (cnt0 + eps)
        push!(balances, min(prop01, prop10))
        # gap in cluster k
        p0 = cnt0 / (n0 + eps)
        p1 = cnt1 / (n1 + eps)
        push!(gaps, abs(p0 - p1))
    end

    return minimum(balances), mean(balances), sum(gaps)
end