module RobustDistributedPCA

import Arpack
import LinearAlgebra
import StatsBase

import LinearMaps: LinearMap

const Weights = StatsBase.AnalyticWeights
const mean = StatsBase.mean
const sample = StatsBase.sample

# Operations with the empirical covariance for matrix-valued inputs.
include("emp_cov_ops.jl")

# Procrustes operations
include("procrustes.jl")

"""
  scattermat(X::AbstractArray{Float64, 3}, wts::AbstractWeights)

Compute the scatter matrix of a tensor `X` holding matrix-valued samples
across its second dimension. Sample `i` is weighted by `wts[i]`.
"""
function StatsBase.scattermat(
  X::AbstractArray{Float64, 3},
  wts::StatsBase.AbstractWeights,
)
  result = zeros(size(X, 1), size(X, 1))
  @inbounds for i in 1:size(X, 2)
    result += wts[i] * X[:, i, :] * X[:, i, :]'
  end
  return result
end

"""
  remove_outlier(outlier_scores::Vector{Float64}, randomized::Bool = true)

Remove a candidate outlier from a set of points with outlier scores given by
`outlier_scores`. If `randomized` is set, the point to remove is chosen with
probability proportional to each outlier score; otherwise, the point with the
largest outlier score is removed.
"""
function remove_outlier(outlier_scores::Vector{Float64}, randomized::Bool = true)
  if randomized
    return sample(Weights(outlier_scores ./ sum(outlier_scores)))
  else
    return argmax(outlier_scores)
  end
end

"""
  filter(samples::Array{Float64, 3}, τ::Float64, randomized::Bool = true)

Run the filtering algorithm to compute the robust mean of a set of `samples`.

The tensor `samples` contains individual samples across its 2nd dimension and
`τ` is assumed to be an upper bound on the operator norm of the empirical
covariance of the set of inliers.
"""
function filter(samples::Array{Float64, 3}, τ::Float64, randomized::Bool = true)
  weights = Weights(ones(Int, size(samples, 2)))
  return filter_aux(samples, weights, τ, randomized)
end

"""
  filter_adaptive(samples::Array{Float64, 3}, τ::Float64, error_fun::Function, randomized::Bool = true)

Run an adaptive version of `filter` to compute the robust mean of a set of
`samples`, where the correct `τ` is adjusted based on a user-provided
`error_fun(τ)`.

The tensor `samples` contains individual samples across its 2nd dimension and
`τ` is assumed to be an upper bound on the operator norm of the empirical
covariance of the set of inliers.
"""
function filter_adaptive(
  samples::Array{Float64, 3},
  τ::Float64,
  error_fun::Function,
  randomized::Bool = true,
)
  estimates = []
  while τ > 1e-6
    @debug "Trying τ = $(τ)"
    curr_est = filter(samples, τ, randomized)
    curr_err = error_fun(τ)
    should_stop = false
    for i in 1:length(estimates)
      # Continue decreasing τ as long as d(xᵢ, xⱼ) < f(τᵢ) + f(τⱼ).
      if dist(estimates[i].est, curr_est) > estimates[i].err + curr_err
        should_stop = true
        break
      end
    end
    if should_stop
      @debug "Terminating with τ_opt = $(2 * τ)"
      return estimates[end].est
    else
      push!(estimates, (est = curr_est, err = curr_err))
    end
    # Update estimate of τ.
    τ = τ / 2
  end
  return estimates[end].est
end

# Run the filter algorithm on a set of matrix-valued samples.
function filter_aux(
  samples::Array{Float64, 3},
  weights::Weights,
  τ::Float64,
  randomized::Bool = true,
)
  _, n, _ = size(samples)
  @assert n == length(weights) "length(weights) is not equal to second tensor dimension."
  # Note: samples[:, i, :] ∈ Rᵈˣʳ is the i-th matrix.
  emp_mean = StatsBase.mean(samples, weights, dims=2)
  if n == 1 || sum(weights) == 1
    @debug "Only 1 sample left - terminating."
    return emp_mean[:, 1, :]
  end
  centered = samples .- emp_mean
  cov_matr = (1 / sum(weights)) * StatsBase.scattermat(centered, weights)
  λ, v, _ = Arpack.eigs(cov_matr, nev=1, which=:LR)
  if λ[1] ≤ τ
    @debug "Top eigenvalue: $(λ[1]) - upper bound: $(τ)."
    @debug "Terminating with n = $(sum(weights)) samples remaining."
    return emp_mean[:, 1, :]
  else
    outlier_scores = weights .* _emp_cov_scores(centered, v[:, 1])
    outlier_idx = remove_outlier(outlier_scores, randomized)
    @debug "Removing candidate outlier at idx = $(outlier_idx)"
    weights[outlier_idx] = 0
    return filter_aux(samples, weights, τ, randomized)
  end
end

end # module
