"""
Sparse K-means clustering exploiting sparsity in input data.

Overloads Distances.pairwise! to efficiently compute squared Euclidean
distances between sparse samples and dense centroids using:

    ‖xᵢ - μₖ‖² = ‖xᵢ‖² - 2⟨xᵢ, μₖ⟩ + ‖μₖ‖²

This avoids the O(d) dense computation per sample-centroid pair,
replacing it with O(nnz) sparse-dense operations.
"""
module SparseKMeans

using Clustering
using Distances
using SparseArrays
using LinearAlgebra
using ThreadedDenseSparseMul

# Optimized pairwise squared Euclidean: dense centers × sparse samples
# Called as: pairwise!(metric, r, centers, X, dims=2)
function Distances.pairwise!(::SqEuclidean, r::AbstractMatrix{T},
                             a::AbstractMatrix{T}, b::SparseMatrixCSC;
                             dims::Int=2) where T<:Real
    if dims != 2
        error("Only dims=2 supported for sparse pairwise")
    end

    # a: d × k_clusters (dense centers)
    # b: d × n_samples (sparse samples)
    # r: k_clusters × n_samples (output distances)
    d, k_clusters = size(a)
    d_b, n_samples = size(b)
    @assert d == d_b "Dimension mismatch: a has $d rows, b has $d_b rows"
    @assert size(r) == (k_clusters, n_samples) "Output size mismatch: r is $(size(r)), expected ($k_clusters, $n_samples)"

    # Precompute squared norms
    a_sqnorms = vec(sum(abs2, a, dims=1))  # length k_clusters, O(d × k)
    b_sqnorms = vec(sum(abs2, b, dims=1))  # length n_samples, O(nnz)

    # r[k,i] = ‖aₖ‖² - 2⟨aₖ, bᵢ⟩ + ‖bᵢ‖²
    # Compute -2 * aᵀb via dense-sparse matmul
    # mul!(r, a', b)  # r = aᵀb, O(nnz × k_clusters)
    fastdensesparsemul!(r, copy(a'), b, true, false)  # r = aᵀb, O(nnz × k_clusters)
    r .*= -2

    # Add squared norms (broadcasting)
    r .+= a_sqnorms           # add ‖aₖ‖² to each row
    r .+= b_sqnorms'          # add ‖bᵢ‖² to each column

    return r
end

"""
    sparse_kmeans(X, k; maxiter=100, tol=1e-4, verbose=false)

Run K-means on sparse data matrix X (d × n) with k clusters.
Returns Clustering.KmeansResult with assignments and centers.
"""
function sparse_kmeans(X::SparseMatrixCSC, k::Int;
                       maxiter::Int=100, tol::Float64=1e-4, verbose::Bool=false)
    display = verbose ? :iter : :none
    return kmeans(X, k; maxiter=maxiter, tol=tol, display=display)
end

# Dispatch for dense/any matrix: convert to sparse
function sparse_kmeans(X::AbstractMatrix, k::Int; kwargs...)
    return sparse_kmeans(sparse(X), k; kwargs...)
end

"""
    sparse_kmeans_labels(X, k; kwargs...)

Convenience function that returns just the cluster assignments (0-indexed for Python).
"""
function sparse_kmeans_labels(X::AbstractMatrix, k::Int; kwargs...)
    result = sparse_kmeans(X, k; kwargs...)
    return result.assignments .- 1  # 0-indexed for Python
end

export sparse_kmeans, sparse_kmeans_labels

end # module
