######################
# ksd using IMQ kernel (code adapted from https://github.com/jgorham/SteinDiscrepancy.jl/blob/master/src/kernels/SteinInverseMultiquadricKernel.jl)
######################
using Base.Threads, Distances, LinearAlgebra, ProgressMeter

function k0(x::AbstractVector{Float64}, 
            y::AbstractVector{Float64}, 
            grd::Function; 
            c2 = 1.0, β= 0.5)
    # compute all the core values only once and store them
    d = length(x)

    z = x - y
    r2 = sum(abs2, z)
    base = c2 + r2
    base_beta = base^(-β)
    base_beta1 = base_beta / base

    gradlogpx, gradlogpy = grd(x), grd(y)

    coeffk = dot(gradlogpx, gradlogpy)
    coeffgrad = -2.0 * β * base_beta1

    kterm = coeffk * base_beta
    gradandgradgradterms = coeffgrad * (
        (dot(gradlogpy, z) - dot(gradlogpx, z)) +
        (-d + 2 * (β + 1) * r2 / base)
    )
    kterm + gradandgradgradterms
end

function ksd(D, grd::Function; c2 = 1.0, β = 0.5)
    #=
    function that used to estimate ksd 
    D: sample matrix (each row is a data vector)
    grd: ∇logp ---score function of target distribution
    =#
    N = size(D,1)
    ksd = Threads.Atomic{Float64}(0.0)
    @info "computing KSD" 
    prog_bar = ProgressMeter.Progress(N^2, dt=0.5, barglyphs=ProgressMeter.BarGlyphs("[=> ]"), barlen=50, color=:yellow)
    @threads for i in 1:N 
        @simd for j in 1:N
            @inbounds kij = k0(@view(D[i,:]), @view(D[j,:]), grd; c2 = c2, β = β)
            Threads.atomic_add!(ksd, kij)
            ProgressMeter.next!(prog_bar)
        end
    end
    return sqrt(ksd[])/N
end

