using LinearAlgebra, Statistics
using KernelAbstractions
import LoopVectorization: @tturbo, @turbo

function grab!(sym, ex::Expr)
    ex.head == :. && return union!(sym, [ex])      # grab composite name and return
    start = ex.head == :(call) ? 2 : 1              # don't grab function names
    foreach(a -> grab!(sym, a), ex.args[start:end])   # recurse into args
    ex.args[start:end] = rep.(ex.args[start:end]) # replace composites in args
end
grab!(sym, ex::Symbol) = union!(sym, [ex])        # grab symbol name
grab!(sym, ex) = nothing
rep(ex) = ex
rep(ex::Expr) = ex.head == :. ? Symbol(ex.args[2].value) : ex
macro loop(args...)
    ex, _, itr = args
    _, I, R = itr.args
    sym = []
    grab!(sym, ex)     # get arguments and replace composites in `ex`
    setdiff!(sym, [I]) # don't want to pass I as an argument
    @gensym kern       
    @gensym fn         
    return quote
        @kernel function $kern($(rep.(sym)...)) # replace composite arguments
            $I = @index(Global)
            # @fastmath @inbounds 
            $ex
        end
        function $fn(::CPU)
            @fastmath @inbounds @tturbo for $I in 1:length($R) # |> eachindex
                $ex
            end
        end
        function $fn(::GPU)
            $kern(get_backend($(sym[1])), 64)($(sym...), ndrange = size($R))
        end
        $fn(get_backend($(sym[1])))
    end |> esc
end

struct Batchwise
    len::Int
    batchsize::Int
    partial::Bool
end
function Batchwise(n::Int, batchsize::Int=min(2^10, n), partial=false)
    new(n, batchsize, partial)
end
batchwise(args...) = Batchwise(args...)
function Base.length(A::Batchwise) 
    k, r = divrem(A.len, A.batchsize) 
    k + (r != 0)
end
Base.eltype(::Batchwise) = UnitRange{Int64}
function Base.iterate(A::Batchwise, state = 1) 
    if state >= A.len
        nothing
    elseif state+A.batchsize > A.len && !A.partial && A.batchsize < A.len 
        state = A.len - A.batchsize + 1
        next = state + A.batchsize
        (state:min(next-1, A.len), next)
    else
        next = state+A.batchsize
        (state:min(next-1, A.len), next)
    end
end

function felbmf_ipalm(X, n_components; batchsize = 2^15, kws...)
    if batchsize >= sum(size(x, 1) for x in X) || !gpuenabled()
        felbmf_ipalm_impl(gpu.(X), n_components; kws...)
    elseif batchsize >= maximum(size(x, 1) for x in X)
        felbmf_ipalm_impl(X, n_components; kws...)
    else
        felbmf_ipalm_batched(X, n_components; batchsize = batchsize, kws...)
    end
end

function felbmf_ipalm_impl(X, n_components; κ = 0.01, λ = 0.01, tau = t -> 1.005^t,
                           maxiter = 100, tol = 1e-5, beta = 0.01, rounds = 10,
                           callback = nothing, with_rounding = true)
    U, V, μ = felb_init(X, n_components)
    ℓ = typemax(tol)
    if beta == 0
        palm = PALM()
        for i in 1:maxiter
            t = tau(i - 1)
            ℓ, ℓ0 = 0.0, ℓ
            for j in eachindex(X)
                x, u, v = gpu.((X[j], U[j], V[j]))
                for _ in 1:rounds
                    η = reduce!(palm, x, u, v)
                    elb!(u, κ * η, λ * t * η)
                    η = reduce!(palm, x', v', u')
                    elb!(v, κ * η, λ * t * η)
                end
                ℓ += norm(x - u * v)^2
                copyto_batch!(U[j], u)
                copyto_batch!(V[j], v)
            end
            μ = mean(V)
            elb!(μ, κ, λ * t)
            for v in V
                copyto!(v, μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    else
        ipalm = iPALM(beta)
        U_ = deepcopy(U)
        Vt_ = permutedims.(V)
        for i in 1:maxiter
            ℓ, ℓ0 = 0.0, ℓ
            t = tau(i - 1)
            for j in eachindex(X)
                x, u, v, u_, vt_ = gpu.((X[j], U[j], V[j], U_[j], Vt_[j]))
                for _ in 1:rounds
                    η = reduce!(ipalm, x, u, v, u_)
                    elb!(u, κ * η, λ * t * η)
                    η = reduce!(ipalm, x', v', u', vt_)
                    elb!(v, κ * η, λ * t * η)
                end
                ℓ += norm(x - u * v)^2
                copyto_batch!(U[j], u)
                copyto_batch!(V[j], v)
                copyto_batch!(U_[j], u_)
                copyto_batch!(Vt_[j], vt_)
            end
            μ = mean(V)
            elb!(μ, κ, λ * t)
            for i in eachindex(V)
                copyto!(V[i], μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    end
    μ = cpu(μ)
    with_rounding && round_felb!(X, U, V, μ)
    U, V, μ
end

function felbmf_ipalm_batched(X, n_components; κ = 0.01, λ = 0.01, tau = t -> 1.005^t,
                              maxiter = 100, tol = 1e-5, beta = 0.01, rounds = 10,
                              callback = nothing, with_rounding = true, batchsize = 2^15)
    U, V, μ = felb_init(X, n_components)
    ℓ = typemax(tol)
    if beta == 0
        palm = PALM()
        for i in 1:maxiter
            t = tau(i - 1)
            ℓ, ℓ0 = 0.0, ℓ
            for j in eachindex(X)
                v = gpu(V[j])
                for _ in 1:rounds, b in batchwise(size(X[j], 1), batchsize)
                    x, u = gpu(X[j], b), gpu(U[j], b)
                    η = reduce!(palm, x, u, v)
                    elb!(u, κ * η, λ * t * η)
                    η = reduce!(palm, x', v', u')
                    elb!(v, κ * η, λ * t * η)
                    ℓ += norm(x - u * v)^2
                    copyto_batch!(U[j], u)
                end
                copyto_batch!(V[j], v)
            end
            μ = mean(V)
            elb!(μ, κ, λ * t)
            for v in V
                copyto!(v, μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    else
        ipalm = iPALM(beta)
        U_ = deepcopy(U)
        Vt_ = permutedims.(V)
        for i in 1:maxiter
            ℓ, ℓ0 = 0.0, ℓ
            t = tau(i - 1)
            for j in eachindex(X)
                v, vt_ = gpu(V[j]), gpu(Vt_[j])
                for _ in 1:rounds, b in batchwise(size(X[j], 1), batchsize)
                    x, u, u_ = gpu(X[j], b), gpu(U[j], b), gpu(U_[j], b)
                    η = reduce!(ipalm, x, u, v, u_)
                    elb!(u, κ * η, λ * t * η)
                    η = reduce!(ipalm, x', v', u', vt_)
                    elb!(v, κ * η, λ * t * η)
                    ℓ += norm(x - u * v)^2
                    copyto_batch!(U[j], b, u)
                    copyto_batch!(U_[j], b, u_)
                end
                copyto_batch!(V[j], v)
                copyto_batch!(Vt_[j], vt_)
            end
            μ = mean(V)
            elb!(μ, κ, λ * t)
            for i in eachindex(V)
                copyto!(V[i], μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    end
    μ = cpu(μ)
    with_rounding && round_felb!(X, U, V, μ)
    U, V, μ
end

function elb!(x, k, l)
    _0, _1 = zero(eltype(x)), one(eltype(x))
    @inline @fastmath elb0(x) = (x - k * sign(x)) / (_1 + l)
    @inline @fastmath elb1(x) = (x - k * sign(x - _1) + l) / (_1 + l)
    @loop x[I]=(x[I] <= 0.5) ? elb0(x[I]) : elb1(x[I]) over I in eachindex(x)
    clamp!(x, _0, _1)
end

function felb_init(X, K)
    U = [rand_like(x, size(x, 1), K) for x in X]
    V = [rand_like(x, K, size(x, 2)) for x in X]
    U, V, mean(V)
end

struct PALM end
@fastmath @inline function reduce!(::PALM, A, U, V)
    VVt, AVt = V * V', A * V'
    L = max(norm(VVt), 1e-4)
    η = 1 / (1.1 * L)
    U .= U - η * (U * VVt - AVt)
    η
end

struct iPALM{T}
    beta::T
end
@fastmath @inline function reduce!(opt::iPALM, A, U, V, U_)
    VVt, AVt = V * V', A * V'
    @. U = U + opt.beta * (U - U_)
    @. U_ = U
    L = 1.05 * max(norm(VVt), 1e-4)
    η = 2 * (1 - opt.beta) / (1 + 2 * opt.beta) / L
    U .= U - η * (U * VVt - AVt)
    η
end

using Distributions

@kwdef struct LaplacianMechanism
    eps::Float64 = 0.001
    sensitivity::Float64 = 1
end
@kwdef struct GaussianMechanism
    eps::Float64 = 0.001
    delta::Float64 = 0.01
    sensitivity::Float64 = 1
end

@kwdef struct LaplacianMechanismClipped
    eps::Float64 = 0.001
    sensitivity::Float64 = 1
end
@kwdef struct GaussianMechanismClipped
    eps::Float64 = 0.001
    delta::Float64 = 0.01
    sensitivity::Float64 = 1
end

function noise_distribution(mech::Union{LaplacianMechanism, LaplacianMechanismClipped}, T = Float64)
    Laplace(zero(T), mech.sensitivity / mech.eps)
end
function noise_distribution(mech::Union{GaussianMechanism, GaussianMechanismClipped}, T = Float64)
    Normal(zero(T), 2 * log(T(1.25) / mech.delta) * mech.sensitivity / mech.eps^2)
end
function dp(mech::Union{LaplacianMechanism, GaussianMechanism}, v)
    v .+ rand(noise_distribution(mech), size(v))
end
function dp(mech::Union{LaplacianMechanismClipped, GaussianMechanismClipped}, v)
    clamp!(v .+ rand(noise_distribution(mech), size(v)), 0, 1)
end
function dp(mech::Union{LaplacianMechanism, GaussianMechanism}, v::CUDA.CuArray)
    v .+ gpu(rand(noise_distribution(mech), size(v)))
end
function dp(mech::Union{LaplacianMechanismClipped, GaussianMechanismClipped}, v::CUDA.CuArray)
    clamp!(v .+ gpu(rand(noise_distribution(mech), size(v))), 0, 1)
end
function Mechanism(ϵ, mechanism, sens)
    if mechanism == :gaussian_clipped
        GaussianMechanismClipped(ϵ, 0.05, sens)
    elseif mechanism == :gaussian
        GaussianMechanism(ϵ, 0.05, sens)
    elseif mechanism == :laplacian_clipped
        LaplacianMechanismClipped(ϵ, sens)
    elseif mechanism == :laplacian
        LaplacianMechanism(ϵ, sens)
    else
        ϵ
    end
end

function felbmf_dp_ipalm(X, n_components; ϵ = Mechanism(1/0.0025, :laplacian_clipped, maximum(sum, eachrow(X))), batchsize = 2^15, kws...)
    if batchsize >= sum(size(x, 1) for x in X) || !gpuenabled()
        felbmf_dp_ipalm_impl(gpu.(X), n_components; ϵ = ϵ, kws...)
    elseif batchsize >= maximum(size(x, 1) for x in X)
        felbmf_dp_ipalm_impl(X, n_components; ϵ = ϵ, kws...)
    else
        felbmf_dp_ipalm_batched(X, n_components; batchsize = batchsize, ϵ = ϵ, kws...)
    end
end

function felbmf_dp_ipalm_impl(X, n_components; κ = 0.01, λ = 0.01, tau = t -> 1.005^t,
                              maxiter = 100, tol = 1e-5, beta = 0.01, rounds = 10, ϵ = 0.01,
                              callback = nothing, with_rounding = true)
    U, V, μ = felb_init(X, n_components)
    ℓ = typemax(tol)
    if beta == 0
        palm = PALM()
        for i in 1:maxiter
            t = tau(i - 1)
            ℓ, ℓ0 = 0.0, ℓ
            for j in eachindex(X)
                x, u, v = gpu.((X[j], U[j], V[j]))
                for _ in 1:rounds
                    η = reduce!(palm, x, u, v)
                    elb!(u, κ * η, λ * t * η)
                    η = reduce!(palm, x', v', u')
                    elb!(v, κ * η, λ * t * η)
                end
                ℓ += norm(x - u * v)^2
                copyto_batch!(U[j], u)
                copyto_batch!(V[j], v)
            end
            # μ = mean(V)
            μ = mean(dp(ϵ, v) for v in V)
            elb!(μ, κ, λ * t)
            for v in V
                copyto!(v, μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    else
        ipalm = iPALM(beta)
        U_ = deepcopy(U)
        Vt_ = permutedims.(V)
        for i in 1:maxiter
            ℓ, ℓ0 = 0.0, ℓ
            t = tau(i - 1)
            for j in eachindex(X)
                x, u, v, u_, vt_ = gpu.((X[j], U[j], V[j], U_[j], Vt_[j]))
                for _ in 1:rounds
                    η = reduce!(ipalm, x, u, v, u_)
                    elb!(u, κ * η, λ * t * η)
                    η = reduce!(ipalm, x', v', u', vt_)
                    elb!(v, κ * η, λ * t * η)
                end
                ℓ += norm(x - u * v)^2
                copyto_batch!(U[j], u)
                copyto_batch!(V[j], v)
                copyto_batch!(U_[j], u_)
                copyto_batch!(Vt_[j], vt_)
            end
            μ = mean(dp(ϵ, v) for v in V)
            elb!(μ, κ, λ * t)
            for i in eachindex(V)
                copyto!(V[i], μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    end
    μ = cpu(μ)
    with_rounding && round_felb!(X, U, V, μ)
    U, V, μ
end

function felbmf_dp_ipalm_batched(X, n_components; κ = 0.01, λ = 0.01, tau = t -> 1.005^t,
                                 maxiter = 100, tol = 1e-5, beta = 0.01, rounds = 10, ϵ = 0.0,
                                 callback = nothing, with_rounding = true, batchsize = 2^15)
    U, V, μ = felb_init(X, n_components)
    ℓ = typemax(tol)
    if beta == 0
        palm = PALM()
        for i in 1:maxiter
            t = tau(i - 1)
            ℓ, ℓ0 = 0.0, ℓ
            for j in eachindex(X)
                v = gpu(V[j])
                for _ in 1:rounds, b in batchwise(size(X[j], 1), batchsize)
                    x, u = gpu(X[j], b), gpu(U[j], b)
                    η = reduce!(palm, x, u, v)
                    elb!(u, κ * η, λ * t * η)
                    η = reduce!(palm, x', v', u')
                    elb!(v, κ * η, λ * t * η)
                    ℓ += norm(x - u * v)^2
                    copyto_batch!(U[j], u)
                end
                copyto_batch!(V[j], v)
            end
            μ = mean(dp(ϵ, v) for v in V)
            elb!(μ, κ, λ * t)
            for v in V
                copyto!(v, μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    else
        ipalm = iPALM(beta)
        U_ = deepcopy(U)
        Vt_ = permutedims.(V)
        for i in 1:maxiter
            ℓ, ℓ0 = 0.0, ℓ
            t = tau(i - 1)
            for j in eachindex(X)
                v, vt_ = gpu(V[j]), gpu(Vt_[j])
                for _ in 1:rounds, b in batchwise(size(X[j], 1), batchsize)
                    x, u, u_ = gpu(X[j], b), gpu(U[j], b), gpu(U_[j], b)
                    η = reduce!(ipalm, x, u, v, u_)
                    elb!(u, κ * η, λ * t * η)
                    η = reduce!(ipalm, x', v', u', vt_)
                    elb!(v, κ * η, λ * t * η)
                    ℓ += norm(x - u * v)^2
                    copyto_batch!(U[j], b, u)
                    copyto_batch!(U_[j], b, u_)
                end
                copyto_batch!(V[j], v)
                copyto_batch!(Vt_[j], vt_)
            end
            μ = mean(dp(ϵ, v) for v in V)
            elb!(μ, κ, λ * t)
            for i in eachindex(V)
                copyto!(V[i], μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    end
    μ = cpu(μ)
    with_rounding && round_felb!(X, U, V, μ)
    U, V, μ
end


function falbmf_ipalm(X, n_components; batchsize = 2^15, kws...)
    if batchsize >= sum(size(x, 1) for x in X) || !gpuenabled()
        falbmf_ipalm_impl(gpu.(X), n_components; kws...)
    elseif batchsize >= maximum(size(x, 1) for x in X)
        falbmf_ipalm_impl(X, n_components; kws...)
    else
        falbmf_ipalm_batched(X, n_components; batchsize = batchsize, kws...)
    end
end

function falbmf_ipalm_impl(X, n_components; κ = 0.01, λ = 0.01, tau = t -> 1.005^t,
                           maxiter = 100,
                           tol = 1e-5, beta = 0.01, rounds = 10, callback = nothing,
                           with_rounding = true)
    U, V, μ = felb_init(X, n_components)
    ℓ = typemax(tol)
    if beta == 0
        palm = PALM()
        for i in 1:maxiter
            ℓ, ℓ0 = 0.0, ℓ
            t = tau(i - 1)
            for j in eachindex(X)
                x, u, v = gpu.((X[j], U[j], V[j]))
                for _ in 1:rounds
                    η = reduce!(palm, x, u, v)
                    alb!(u, κ * η, λ * t * η)
                    η = reduce!(palm, x', v', u')
                    alb!(v, κ * η, λ * t * η)
                end
                ℓ += norm(x - u * v)^2
                copyto_batch!(U[j], u)
                copyto_batch!(V[j], v)
            end
            μ = mean(V)
            alb!(μ, κ, λ * t)
            for v in V
                copyto!(v, μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    else
        ipalm = iPALM(beta)
        U_ = deepcopy(U)
        Vt_ = permutedims.(V)
        for i in 1:maxiter
            ℓ, ℓ0 = 0.0, ℓ
            t = tau(i - 1)
            for j in eachindex(X)
                x, u, v, u_, vt_ = gpu.((X[j], U[j], V[j], U_[j], Vt_[j]))
                for _ in 1:rounds
                    η = reduce!(ipalm, x, u, v, u_)
                    alb!(u, κ * η, λ * t * η)
                    η = reduce!(ipalm, x', v', u', vt_)
                    alb!(v, κ * η, λ * t * η)
                end
                ℓ += norm(x - u * v)^2
                copyto_batch!(U[j], u)
                copyto_batch!(V[j], v)
                copyto_batch!(U_[j], u_)
                copyto_batch!(Vt_[j], vt_)
            end
            μ = mean(V)
            alb!(μ, κ, λ * t)
            for i in eachindex(V)
                copyto!(V[i], μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    end
    with_rounding && round_felb!(X, U, V, μ)
    U, V, μ
end

function falbmf_ipalm_batched(X, n_components; κ = 0.01, λ = 0.01, tau = t -> 1.005^t,
                              maxiter = 100, tol = 1e-5, beta = 0.01, rounds = 10,
                              callback = nothing, with_rounding = true, batchsize = 2^15)
    U, V, μ = felb_init(X, n_components)
    ℓ = typemax(tol)
    if beta == 0
        palm = PALM()
        for i in 1:maxiter
            t = tau(i - 1)
            ℓ, ℓ0 = 0.0, ℓ
            for j in eachindex(X)
                v = gpu(V[j])
                for _ in 1:rounds, b in batchwise(size(X[j], 1), batchsize)
                    x, u = gpu(X[j], b), gpu(U[j], b)

                    η = reduce!(palm, x, u, v)
                    alb!(u, κ * η, λ * t * η)
                    η = reduce!(palm, x', v', u')
                    alb!(v, κ * η, λ * t * η)
                    ℓ += norm(x - u * v)^2
                    copyto_batch!(U[j], b, u)
                end
                copyto_batch!(V[j], v)
            end
            μ = mean(V)
            alb!(μ, κ, λ * t)
            for v in V
                copyto!(v, μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    else
        ipalm = iPALM(beta)
        U_ = deepcopy(U)
        Vt_ = permutedims.(V)
        for i in 1:maxiter
            ℓ, ℓ0 = 0.0, ℓ
            t = tau(i - 1)
            for j in eachindex(X)
                v, vt_ = gpu(V[j]), gpu(Vt_[j])
                for _ in 1:rounds, b in batchwise(size(X[j], 1), batchsize)
                    # x, u, u_ = gpu.((X[j][b, :], U[j][b, :], U_[j][b, :]))
                    x, u, u_ = gpu(X[j], b), gpu(U[j], b), gpu(U_[j], b)
                    η = reduce!(ipalm, x, u, v, u_)
                    alb!(u, κ * η, λ * t * η)
                    η = reduce!(ipalm, x', v', u', vt_)
                    alb!(v, κ * η, λ * t * η)
                    ℓ += norm(x - u * v)^2
                    copyto_batch!(U[j], b, u)
                    copyto_batch!(U_[j], b, u_)
                end
                copyto_batch!(V[j], v)
                copyto_batch!(Vt_[j], vt_)
            end
            μ = mean(V)
            alb!(μ, κ, λ * t)
            for i in eachindex(V)
                copyto!(V[i], μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    end
    μ = cpu(μ)
    with_rounding && round_felb!(X, U, V, μ)
    U, V, μ
end

function alb!(x, κ_, λ_)
    T = eltype(x)
    κ, λ, σ = T(κ_), T(λ_), T(10)
    c = eps(T) / 2
    @inline alb_λ(x) = λ / (1 - exp(-σ * (x + c)))
    function alb0(x)
        (x - κ * sign(x)) / (1 + alb_λ(x))
    end
    function alb1(x)
        λ1 = alb_λ(1 - x)
        (x - κ * sign(x - 1) + λ1) / (1 + λ1)
    end
    @loop x[I]=(x[I] <= 0.5) ? alb0(x[I]) : alb1(x[I]) over I in eachindex(x)
    clamp!(x, zero(T), one(T))
end

function falbmf_dp_ipalm(X, n_components; ϵ = Mechanism(1 / 0.0025, :laplacian_clipped, maximum(sum, eachrow(X))),
                         batchsize = 2^15, kws...)
    if batchsize >= sum(size(x, 1) for x in X) || !gpuenabled()
        falbmf_dp_ipalm_impl(gpu.(X), n_components; ϵ = ϵ, kws...)
    elseif batchsize >= maximum(size(x, 1) for x in X)
        falbmf_dp_ipalm_impl(X, n_components; ϵ = ϵ, kws...)
    else
        falbmf_dp_ipalm_batched(X, n_components; batchsize = batchsize, ϵ = ϵ, kws...)
    end
end

function falbmf_dp_ipalm_impl(X, n_components; κ = 0.01, λ = 0.01, tau = t -> 1.005^t,
                              maxiter = 100, tol = 1e-5, beta = 0.01, rounds = 10, ϵ = 0.01,
                              callback = nothing, with_rounding = true)
    U, V, μ = felb_init(X, n_components)
    ℓ = typemax(tol)
    if beta == 0
        palm = PALM()
        for i in 1:maxiter
            t = tau(i - 1)
            ℓ, ℓ0 = 0.0, ℓ
            for j in eachindex(X)
                x, u, v = gpu.((X[j], U[j], V[j]))
                for _ in 1:rounds
                    η = reduce!(palm, x, u, v)
                    alb!(u, κ * η, λ * t * η)
                    η = reduce!(palm, x', v', u')
                    alb!(v, κ * η, λ * t * η)
                end
                ℓ += norm(x - u * v)^2
                copyto_batch!(U[j], u)
                copyto_batch!(V[j], v)
            end
            # μ = mean(V)
            μ = mean(dp(ϵ, v) for v in V)
            alb!(μ, κ, λ * t)
            for v in V
                copyto!(v, μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    else
        ipalm = iPALM(beta)
        U_ = deepcopy(U)
        Vt_ = permutedims.(V)
        for i in 1:maxiter
            ℓ, ℓ0 = 0.0, ℓ
            t = tau(i - 1)
            for j in eachindex(X)
                x, u, v, u_, vt_ = gpu.((X[j], U[j], V[j], U_[j], Vt_[j]))
                for _ in 1:rounds
                    η = reduce!(ipalm, x, u, v, u_)
                    alb!(u, κ * η, λ * t * η)
                    η = reduce!(ipalm, x', v', u', vt_)
                    alb!(v, κ * η, λ * t * η)
                end
                ℓ += norm(x - u * v)^2
                copyto_batch!(U[j], u)
                copyto_batch!(V[j], v)
                copyto_batch!(U_[j], u_)
                copyto_batch!(Vt_[j], vt_)
            end
            μ = mean(dp(ϵ, v) for v in V)
            alb!(μ, κ, λ * t)
            for i in eachindex(V)
                copyto!(V[i], μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    end
    μ = cpu(μ)
    with_rounding && round_felb!(X, U, V, μ)
    U, V, μ
end

function falbmf_dp_ipalm_batched(X, n_components; κ = 0.01, λ = 0.01, tau = t -> 1.005^t,
                                 maxiter = 100, tol = 1e-5, beta = 0.01, rounds = 10, ϵ = 0.0,
                                 callback = nothing, with_rounding = true, batchsize = 2^15)
    U, V, μ = felb_init(X, n_components)
    ℓ = typemax(tol)
    if beta == 0
        palm = PALM()
        for i in 1:maxiter
            t = tau(i - 1)
            ℓ, ℓ0 = 0.0, ℓ
            for j in eachindex(X)
                v = gpu(V[j])
                for _ in 1:rounds, b in batchwise(size(X[j], 1), batchsize)
                    x, u = gpu(X[j], b), gpu(U[j], b)
                    η = reduce!(palm, x, u, v)
                    alb!(u, κ * η, λ * t * η)
                    η = reduce!(palm, x', v', u')
                    alb!(v, κ * η, λ * t * η)
                    ℓ += norm(x - u * v)^2
                    copyto_batch!(U[j], u)
                end
                copyto_batch!(V[j], v)
            end
            # μ = mean(V)
            μ = mean(dp(ϵ, v) for v in V)
            alb!(μ, κ, λ * t)
            for v in V
                copyto!(v, μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    else
        ipalm = iPALM(beta)
        U_ = deepcopy(U)
        Vt_ = permutedims.(V)
        for i in 1:maxiter
            ℓ, ℓ0 = 0.0, ℓ
            t = tau(i - 1)
            for j in eachindex(X)
                v, vt_ = gpu(V[j]), gpu(Vt_[j])
                for _ in 1:rounds, b in batchwise(size(X[j], 1), batchsize)
                    x, u, u_ = gpu(X[j], b), gpu(U[j], b), gpu(U_[j], b)
                    η = reduce!(ipalm, x, u, v, u_)
                    alb!(u, κ * η, λ * t * η)
                    η = reduce!(ipalm, x', v', u', vt_)
                    alb!(v, κ * η, λ * t * η)
                    ℓ += norm(x - u * v)^2
                    copyto_batch!(U[j], b, u)
                    copyto_batch!(U_[j], b, u_)
                end
                copyto_batch!(V[j], v)
                copyto_batch!(Vt_[j], vt_)
            end
            μ = mean(dp(ϵ, v) for v in V)
            alb!(μ, κ, λ * t)
            for i in eachindex(V)
                copyto!(V[i], μ)
            end
            (callback !== nothing) && callback(X, U, V, μ)
            isapprox(ℓ, ℓ0; atol = tol) && break
        end
    end
    μ = cpu(μ)
    with_rounding && round_felb!(X, U, V, μ)
    U, V, μ
end