module Data

using PyCall
export scurve, moons, swiss, circles, gaussian8
using StatsBase: indicatormat

const skdata = pyimport("sklearn.datasets")
const make_scurve = skdata.make_s_curve
const make_moons = skdata.make_moons
const make_swiss = skdata.make_swiss_roll
const make_circles = skdata.make_circles
const make_blobs = skdata.make_blobs

scurve(n::Integer) = make_scurve(n; noise=0.03)[1]'[1:2:3,:] .* [2.0, 1.3]
moons(n::Integer) = make_moons(n; noise=0.03)[1]' .* [1.5, 2.0] .- [0.75, 0.43]
swiss(n::Integer) = make_swiss(n; noise=0.3)[1]'[1:2:3,:] .* 0.18
circles(n::Integer) = make_circles(n; noise=0.03, factor=0.3)[1]' .* 2.3
gaussian8(n::Integer) = begin
   mode = 2.25 .* reinterpret(
      reshape, Float64, 
      map(θ -> (cos(θ), sin(θ)), (0.25π) .* (0:7)))
   dropdims(sum(
      reshape(indicatormat(rand(1:8, n), 8), 1, 8, n) .* 
      (mode .+ 0.1 .* randn(2, 8, n)); dims=2); dims=2)
end	

end
