module Felb

using LinearAlgebra
using Random
using CUDA
import Flux
import SparseArrays: issparse, AbstractSparseMatrix

rand_like(x::Matrix{T}, args...) where T = rand(eltype(x), args...)
function enable_gpu(enable = true)
    if enable == true
        @eval gpuenabled() = true
        @eval gpu(a::CUDA.CuArray) = a
        @eval gpu(a) = Flux.gpu(a)
        @eval cpu(a) = Flux.cpu(a)
        @eval rand_like(x::CUDA.CuArray, args...) = rand(eltype(x), args...) |> gpu
        @eval free!(u::CUDA.CuArray) = CUDA.unsafe_free!(u)
    else
        @eval gpuenabled() = false
        @eval gpu(a) = a
        @eval cpu(a) = a
        @eval free!(arg) = nothing
    end
end
copyto_batch!(to, from) = to !== from ? copyto!(to, from) : to
copyto_batch!(to, b::UnitRange, from) = to !== from ? copyto!(view(to, b, :), from) : to
copyto_batch!(to::Array, b::UnitRange, from::CuArray) = copyto!(view(to, b, :), cpu(from))
get(fn::Function, x, i, default = nothing) = x !== nothing ? fn(x[i]) : default
gpu(x, b::UnitRange{Int}) = length(b) >= size(x, 1) ? gpu(x) : gpu(x[b, :])

include("Felb_ipalm.jl")

const felbmf = felbmf_ipalm
const falbmf = falbmf_ipalm
const falbmf_dp = falbmf_dp_ipalm
const felbmf_dp = felbmf_dp_ipalm

export felbmf,
       falbmf,
       felbmf_dp, 
       falbmf_dp,
       enable_gpu, 
       Mechanism

end # module Fedelb
