module GloptiNets

using TestItems

using LinearAlgebra
import Random
using Flux.NNlib: batched_mul
using CUDA

# For optimisation of PolyCheby
using Optim
using ProgressMeter

#= 
  ┌──────────────────────────────────────────────────────────────────────────┐
  │ Sampling with the Bessel function                                        │
  └──────────────────────────────────────────────────────────────────────────┘
 =#

export ApproxBesselSampler, ℕ, ℤ, samplesprobas

abstract type Domain end
struct ℕ <: Domain end
struct ℤ <: Domain end

"""
Approximation of a sampler using the Bessel distribution as pdf. Instead of using
```
p(ω) = I(ω, s), ω ∈ (ℕ or ℤ),
```
where ``I`` is the modified Bessel function of the first kind, it approximates `p` on a _finite_ support. It is defined on `ℤ` by symmetry around `0`. 
"""
struct ApproxBesselSampler{T<:AbstractFloat,Domain}
    scale::Vector{T}
    weights::Vector{Vector{T}}
end

support(proba::ApproxBesselSampler) = length.(proba.weights) .- 1
dim(proba::ApproxBesselSampler) = length(proba.weights)
domain(proba::ApproxBesselSampler{T,D}) where {T,D} = D

using Bessels: besselix, besseli0x, besseli1x
"""
Builds an `ApproxBesselSampler` given a list of scale. 

# Arguments
- `maxsupport::Integer`: the maximum support of the probability. For ``|ω| > maxsupport``, will return ``p(ω) = 0``.
- `tol::AbstractFloat`: the proportion of the support to leave out. The probability will satisfy ``∑ p(ω) = 1 - tol``.
"""
function ApproxBesselSampler(D::Type{<:Domain}, scale::AbstractVector{T};
    maxsupport=1000,
    tol=1e-4) where {T}

    weights = Vector{T}[]
    for i ∈ axes(scale, 1)
        sᵢ = scale[i]
        # Proba of obtaining 0 and |1| (1 or -1 if d=ℤ)
        wᵢ = [besseli0x(sᵢ), 2besseli1x(sᵢ)]
        # Support considered so far: p₀, p₁ if domain is ℕ,
        # otherwise p₀, p₁, p₋₁ if domain is ℤ.
        ∑ = wᵢ[1] + wᵢ[2]
        for ω ∈ 2:maxsupport
            p = besselix(ω, sᵢ)
            push!(wᵢ, 2p)
            ((∑ += 2p) > 1 - tol) && break
        end
        # TODO check that the renomarlization is motivated theoretically 
        wᵢ ./= sum(wᵢ)  # Normalize to 1 

        push!(weights, wᵢ)
    end

    ApproxBesselSampler{T,D}(scale, weights)
end

_ι(D::Type{ℕ}, ω) = 1
_ι(D::Type{ℤ}, ω) = ω ≠ 0 ? 2 : 1
_c(D::Type{ℕ}) = 1
_c(D::Type{ℤ}) = 2
_flipsign(D::Type{ℕ}, T, n) = one(T)
_flipsign(D::Type{ℤ}, T, n) = (rand(Bool, n) .* 2one(T) .- one(T))

using Distributions: Categorical
function samplesprobas(proba::ApproxBesselSampler{T,D}, nfreqs, ::Type{U}) where {T,D,U}
    ωs = zeros(U, dim(proba), nfreqs)
    ps = ones(T, nfreqs)

    samples = zeros(U, nfreqs)
    for i ∈ axes(ωs, 1)
        # Sample from the categorical distribution
        samples .= rand(Categorical(proba.weights[i]), nfreqs)
        # Probabilities are the weights; if d=ℤ we have sampled |ω|; need to choose the sign and divide proba by 2
        ps .*= proba.weights[i][samples] ./ _ι.(D, samples .- 1)
        # The samples are given with -1 (because Categorical gives values starting from 1) and with random sign if D=ℤ
        ωs[i, :] .= (samples .- 1) .* _flipsign(D, T, nfreqs)
    end

    ωs, ps
end
samplesprobas(proba::ApproxBesselSampler, nfreqs) = samplesprobas(proba::ApproxBesselSampler, nfreqs, Int)

# TODO: doc
function samplesprobas_bycat(proba::ApproxBesselSampler{T,D}, nfreqs, ::Type{U}) where {T,D,U}
    ω = zeros(U, dim(proba))
    d_ns = Dict{Vector{U},U}()
    d_ps = Dict{Vector{U},T}()
    for _ ∈ 1:nfreqs
        ω = collect(((rand(Categorical(wᵢ)) for wᵢ ∈ proba.weights) .- 1) .* _flipsign(D, U, dim(proba)))
        ω ∈ keys(d_ns) && (d_ns[ω] += 1; continue)

        d_ns[ω] = 1
        d_ps[ω] = prod(proba.weights[i][abs(ω[i])+1] / _ι.(D, abs(ω[i])) for i ∈ axes(ω, 1))
    end
    reduce(hcat, keys(d_ns)), collect(values(d_ns)), collect(values(d_ps))
end
samplesprobas_bycat(proba::ApproxBesselSampler, nfreqs) = samplesprobas_bycat(proba::ApproxBesselSampler, nfreqs, Int)

# TODO doc
function _get_grid(::Type{U}, proba::ApproxBesselSampler{T,D}) where {T,D,U}  # TODO: not type stable; make dimension of proba a value type 
    ωs = zeros(U, dim(proba), prod(support(proba)))
    ps = ones(T, prod(support(proba)))

    nunique = 0
    p = one(T)
    for ω ∈ Iterators.product((1:support(proba)[i] for i ∈ 1:dim(proba))...)
        p = prod(proba.weights[i][ω[i]] for i ∈ 1:dim(proba))
        p < eps(T) && continue
        nunique += 1
        @views ωs[:, nunique] .= ω
        ps[nunique] = p
    end
    # Allocates to free the rest of the memory
    ωs = ωs[:, 1:nunique]
    ps = ps[1:nunique]
    ps ./= sum(ps)

    ωs, ps
end

function samplesprobas_bycat_wgrid(proba::ApproxBesselSampler{T,ℕ}, nfreqs, ::Type{U}) where {T,U}
    ωs, ps = _get_grid(U, proba)
    rs = rand(Categorical(ps), nfreqs)

    counts = Dict{eltype(rs),eltype(rs)}()
    for ind ∈ axes(rs, 1)
        rs[ind] ∈ keys(counts) ? (counts[rs[ind]] += 1) : (counts[rs[ind]] = 1)
    end
    ind = collect(keys(counts))
    ns = collect(values(counts))

    ωs[:, ind], ns, ps[ind]
end

"""
For each `ω ∈ ωs`, computes the probability of obtaining `ω` with `proba`.
"""
function pdf(proba::ApproxBesselSampler{T,D}, ωs::AbstractMatrix{U}) where {T,D,U<:Integer}
    @assert dim(proba) == size(ωs, 1)
    ps = ones(T, size(ωs, 2))

    for i ∈ axes(ωs, 1)
        for j ∈ axes(ωs, 2)
            ps[j] *= begin
                ω = abs(ωs[i, j])
                if ω < length(proba.weights[i])
                    proba.weights[i][ω+1] / _ι(D, ω)
                else
                    zero(T)
                end
            end
        end
    end

    ps
end

"Distribute `n` samples uniformly across the size of `out`. Perhaps a better way to do that, but this is good enough for now."
_rand_distribute!(out, n) = begin
    for _ ∈ 1:n
        out[rand(1:length(out))] += 1
    end
end

"
Median-of-Mean estimator for weighted samples. `y` are the unique values, and `ns` their respective counts.
"
function mom(y, ns, nbatch, batchsize)
    N = size(ns, 1)
    @assert size(y, 1) == N
    @assert nbatch * batchsize == sum(ns) "Requires `nbatch * batchsize == N` for now"
    @error "Does not make batches of equal size"

    # Distribute the samples across the batches
    batches = zeros(eltype(ns), nbatch, N)
    for i ∈ 1:N
        _rand_distribute!(view(batches, :, i), ns[i])
    end
    @show sum(batches, dims=2)
    # Compute the MoM
    median(mean(y' .* batches, dims=2))
end

include("psdmodels.jl")
include("polynomials.jl")

#= 
  ┌──────────────────────────────────────────────────────────────────────────┐
  │ Train functions and other utilities                                      │
  └──────────────────────────────────────────────────────────────────────────┘
 =#

export Hnorm2_bound

_realfunc(f::PolyCheby, x) = x
_realfunc(f::PolyTrigo, x) = real(x)
_coeff(f::PolyCheby{T}) where {T} = one(T)
_coeff(f::PolyTrigo{T}) where {T} = 2one(T)

"""
Given `f` a Chebychev or trigonometric polynomial, `g` a Bessel model and `proba` a probability distribution, computes the dot product
```
    ⟨f, f - 2g⟩
```
in the RKHS induced by `proba`. This is useful to compute ``‖f - g‖²`` as 
```
‖f - g‖² = ⟨f, f - 2g⟩ + ‖g‖².
```
"""
function dotproduct_bound(f, g, proba) end
function dotproduct_bound(f::AbstractPoly{T,U,D}, g::AbstractBlockPSDModel{T,D}, proba::ApproxBesselSampler{T,D}) where {T,U,D}
    # Without offset c
    ĝ = _basis_proj(g, hcat(zeros(U, dim(f)), frequencies(f)))
    ps = pdf(proba, hcat(zeros(U, dim(f)), frequencies(f)))

    dc_part = _realfunc(f, offset(f) * (offset(f) - 2ĝ[1]) / ps[1])
    ac_part = @views (
        _coeff(f) * _realfunc(f, sum(
            (1 / _coeff(f)) * conj(coefficients(f)) .* ((1 / _coeff(f)) * coefficients(f) .- 2ĝ[2:end]) ./ ps[2:end]
        ))
    )

    dc_part + ac_part
end

"""
A bound on the variance of the estimator defined with
```
X = |û(ω)| / p(ω), u = f - g.
```
The variance is upper bounded by ``‖u‖²`` in the norm induced by the probability ``p``. 
"""
Hnorm2_bound(f, g, proba) = dotproduct_bound(f, g, proba) + HSnorm2(g)

"""
Hilbert norm in the RKHS associated to `proba` for a polynomial `f`.
"""
function Hnorm2(f::AbstractPoly{T,U,D}, proba::ApproxBesselSampler{T,D}) where {T,U,D}
    ps = pdf(proba, hcat(zeros(U, dim(f)), frequencies(f)))

    abs2(offset(f)) / ps[1] + @views (
        _coeff(f) * sum(
            abs2.(coefficients(f)) ./ ps[2:end]
        )
    )
end


_grid(::Type{ℕ}, n) = 0:n
_grid(::Type{ℤ}, n) = -n:n
"""
Approximation of the F-norm and the H̄² norm for `f`, `g`, and `f - g`, by computing their coefficients on a grid of size `n`.
"""
function norms_numapprox(f::AbstractPoly{T,U,D}, g::AbstractBlockPSDModel{T,D}, proba::ApproxBesselSampler{T,D}, n::Int) where {T,U,D}
    ωs = reduce(hcat, map(collect,
        Iterators.product(
            Iterators.repeated(_grid(D, n), dim(f))...
        )
    ))
    ps = pdf(proba, ωs)
    f̂ = _basis_proj(f, ωs)
    ĝ = _basis_proj(g, ωs)
    diff_fnorm = sum(abs.(f̂ .- ĝ))
    diff_hnorm = √sum(abs2.(f̂ .- ĝ) ./ ps)
    f_fnorm = sum(abs.(f̂))
    f_hnorm = √sum(abs2.(f̂) ./ ps)
    g_fnorm = sum(abs.(ĝ))
    g_hnorm = √sum(abs2.(ĝ) ./ ps)
    dotproduct = sum(conj(f̂) .* ĝ ./ ps)
    @show dotproduct
    (;
        diff_fnorm=diff_fnorm,
        diff_hnorm=diff_hnorm,
        f_fnorm=f_fnorm,
        f_hnorm=f_hnorm,
        g_fnorm=g_fnorm,
        g_hnorm=g_hnorm
    )
end

using Statistics: mean, median
_basis_proj(f::Union{AbstractPoly{T,U,ℕ},AbstractBlockPSDModel{T,ℕ}}, ωs) where {T,U} = cheby(f, ωs)
_basis_proj(f::Union{AbstractPoly{T,U,ℤ},AbstractBlockPSDModel{T,ℤ}}, ωs) where {T,U} = fourier(f, ωs)
"""
Approximate the F-norm with a median of mean. Combined with the variance, this allows a bound with confidence `1 - δ` on the mean of the random variable being estimated.

# Parameters
- `nsamples`: number of frequencies sampled to compute the MoM
- `nbatch`: number of batch we take the median on. With `nbatch=32`, the confidence is `delta < 0.03`. Having `rem(nsamples, nbatch) = 0` enables taking exactly `nsamples`. 
"""
function fnorm_approx(f::AbstractPoly{T,U,D}, g::AbstractBlockPSDModel, proba::ApproxBesselSampler{T,D}; nsamples=1024, nbatch=32) where {T,U,D<:Domain}
    (; vals_mom, batchsize, vals_mean) = mom_estimator(f, g, proba; nsamples=nsamples, nbatch=nbatch)
    σ = √Hnorm2_bound(f, g, proba)
    bound = vals_mom + 2σ / √batchsize
    confidence = 1 - exp(-nbatch / 8)
    bound_cheby = vals_mean + σ / √(nbatch * batchsize * (1 - confidence))
    (; bound_mom=bound, bound_cheby=bound_cheby,
        σ=σ, confidence=confidence,
        vals_mom=vals_mom, vals_mean=vals_mean)
end

function mom_estimator(f::AbstractPoly{T,U,D}, g::AbstractBlockPSDModel, proba::ApproxBesselSampler{T,D}; nsamples=1024, nbatch=32) where {T,U,D<:Domain}
    batchsize = div(nsamples, nbatch)
    nsamples_eff = nbatch * batchsize

    ωs, ns, ps = samplesprobas_bycat(proba, nsamples_eff)

    mom_estimator(f, g, ωs, ns, ps; nbatch=nbatch, batchsize=batchsize)
end

function mom_estimator(f::AbstractPoly{T,U,D}, g::AbstractBlockPSDModel, ωs, ns, ps; nbatch, batchsize) where {T,U,D<:Domain}
    r̂ = abs.(
        _basis_proj(f, ωs) .- _basis_proj(g, ωs)
    ) ./ ps
    vals_mom = Inf # 
    # TODO mom(r̂, ns, nbatch, batchsize)
    (; vals_mom=vals_mom, batchsize=batchsize, vals_mean=sum(r̂ .* ns) / sum(ns))
end

using Flux: Flux
using ParameterSchedulers
export get_optimizer, interpolate, NoReg, RegHSNormU, RegHSNormP, RegOrthNorm

"Gives an optimizer which can be passed to Flux's optimise. Allows for writing the optimizer to a config file."
function get_optimizer(type::Symbol, lrdecay::Symbol, lrinit::AbstractFloat, nepochs::Integer)
    opt = if type == :momentum
        Flux.Optimise.Momentum()
    elseif type == :descent
        Flux.Optimise.Descent()
    elseif type == :adam
        Flux.Optimise.Adam()
    else
        error("Unknown optimizer: $(type)")
    end
    if lrdecay == :poly
        ParameterSchedulers.Scheduler(Poly(lrinit, 1, nepochs), opt)
    elseif lrdecay == :cos
        ParameterSchedulers.Scheduler(CosAnneal(; λ0=lrinit, λ1=0.0, period=nepochs), opt)
    elseif lrdecay == :constant
        opt
    else
        error("Unknown scheduler: $(lrdecay)")
    end
end

"Checks that no items in a collection (usually Flux' Params objects) have NaN in them."
hasnan(grads) = any(any(isnan.(v)) for v in values(grads) if typeof(v) <: Array)

_samples!(X, g::PSDBlockBesselCheby) = begin
    Random.rand!(X)
    @. X = acos(X * 2.0 - 1.0) / 2π
end
_samples!(X, g::PSDBlockBesselFourier) = Random.rand!(X)

_evaluate!(r, h::PolyCheby, X) = evaluate_cos!(r, h, X)
_evaluate!(r, h::PolyTrigo, X) = evaluate!(r, h, X)
_evaluate(h::Union{PolyCheby,PSDBlockBesselCheby}, X) = evaluate_cos(h, X)
_evaluate(h::Union{PolyTrigo,PSDBlockBesselFourier}, X) = h(X)

abstract type AbstractReg end
function init end
function update! end
function loss end

struct NoReg <: AbstractReg end
init(::Type{NoReg}, g, params) = NoReg()
update!(reg::NoReg, g) = nothing
loss(reg::NoReg, g) = zero(eltype(g))

struct RegHSNormU{T} <: AbstractReg
    val::T
end
init(::Type{RegHSNormU}, g, params) = RegHSNormU(params.val)
update!(reg::RegHSNormU, g) = nothing
loss(reg::RegHSNormU, g) = reg.val * √HSnorm2_upper(g)

struct RegHSNormP{T} <: AbstractReg
    val::T
end
init(::Type{RegHSNormP}, g, params) = RegHSNormP(params.val)
update!(reg::RegHSNormP, g) = nothing
loss(reg::RegHSNormP, g) = reg.val * √HSnorm2_proxy(g)

struct RegOrthNorm{T,M<:AbstractArray{T,2},F<:Factorization{T}} <: AbstractReg
    Z::M
    C::F
    val::T
end

_kernel_func(x, y, g::PSDBlockBesselCheby) = besselkernel_cheby_cos(x, y; γ=variances(g))
_kernel_func(x, y, g::PSDBlockBesselFourier) = besselkernel_fourier(x, y; γ=variances(g))
function init(::Type{RegOrthNorm}, g::AbstractBlockPSDModel{T}, params) where {T}
    (; nsamples, val) = params
    Z = Array{T}(undef, dim(g), nsamples)
    isgpu(g) && (Z = CuArray(Z))
    C = cholesky(_kernel_func(Z, Z, g) .^ 2 + regularization(g) * I)
    RegOrthNorm(Z, C, val)
end

function update!(reg::RegOrthNorm, g)
    _samples!(reg.Z, g)
    reg.C.U .= cholesky(_kernel_func(reg.Z, reg.Z, g) .^ 2 + regularization(g) * I).U
    nothing
end

function loss(reg::RegOrthNorm, g)
    reg.val * √(HSnorm2_upper(g) - sum(abs2.(reg.C.L \ _evaluate(g, reg.Z))))
end

"Interpolate a *positive* polynomial."
function interpolate(f::AbstractPoly{T,U,D}, g::AbstractBlockPSDModel{T,D}, reg_type, reg_params;
    optimizer_params, nepochs, batchsize,
    lossfunc_symb=:mse,
    lossfunc_param=1.0,
    show_progress=true
) where {T<:AbstractFloat,U,D<:Domain}
    (; optimizer_type, optimizer_lrdecay, optimizer_lrinit) = optimizer_params
    optimizer = get_optimizer(optimizer_type, optimizer_lrdecay, optimizer_lrinit, nepochs)
    @assert !xor(isgpu(f), isgpu(g))

    params = Flux.params(trainable_params(g))
    X = Array{T}(undef, dim(g), batchsize)
    yf = Vector{T}(undef, batchsize)
    if isgpu(f)
        X = CuArray(X)
        yf = CuArray(yf)
    end
    # Not sure if it is advised to have a conditional loss function?
    lossfunc = lossfunc_symb == :mse ? (
        (ŷ, y) -> Flux.mse(ŷ, y)
    ) : (
        (ŷ, y) -> 1 / lossfunc_param * Flux.logsumexp(lossfunc_param * abs.(ŷ - y)) - 1 / lossfunc_param * log(batchsize)  # TODO: good idea? 
    )

    reg = init(reg_type, g, reg_params)

    pbar = Progress(nepochs; enabled=show_progress)
    # r = zero(T)
    for _ ∈ 1:nepochs
        _samples!(X, g)
        _evaluate!(yf, f, X)
        update!(reg, g)

        # r += lossfunc(_evaluate(g, X), yf)
        loss_epoch, grads_epoch = Flux.withgradient(params) do
            (
                lossfunc(_evaluate(g, X), yf) + loss(reg, g)
                # Previous setup: 
                # regfunc_ps * √HSnorm2_upper(g)
            )
        end
        hasnan(grads_epoch) && break
        Flux.Optimise.update!(optimizer, params, grads_epoch)
        next!(pbar; showvalues=[(:loss, loss_epoch)])
        # next!(pbar; showvalues=[(:r, r)])
    end
end

function l∞norm_samples(f::AbstractPoly{T}, g::AbstractBlockPSDModel{T}; nsamples=4096) where {T}
    X = Array{T}(undef, dim(g), nsamples)
    isgpu(f) && (X = CuArray(X))
    _samples!(X, g)
    maximum(abs.(_evaluate(f, X) - _evaluate(g, X)))
end

include("experimental.jl")

end
