
∑ = sum
# implementation of the 2-Wasserstein distance
function quantile(samples)
    # return a function that computes the quantile of a given sample
    samples_sorted = sort(samples)
    return p -> samples_sorted[floor(Int, p * length(samples_sorted))+1]
end
function W₂²(u_samples, v_samples)
    # adapted from https://github.com/nklb/wasserstein-distance
    u_samples_sorted = sort(u_samples);
    v_samples_sorted = sort(v_samples);    
    u_icdf_grids = [i / length(u_samples) for i in 0:length(u_samples)]
    v_icdf_grids = [i / length(v_samples) for i in 0:length(v_samples)]
    grids = unique([u_icdf_grids; v_icdf_grids]) |> sort
    U_icdf = quantile(u_samples).(grids[1:end-1])
    V_icdf = quantile(v_samples).(grids[1:end-1])
    return sum((U_icdf - V_icdf).^2 .* diff(grids))
end


function loss_avg_std(p′;var_weight = 1.0)
    U′s = map(u0 -> sde_construct_solve(u0, p′), U0s)
    U′ = [vcat([U′s[i][j]  for i in eachindex(U′s)]...) for j in 1:length(U′s[1])]
    𝔼u = mean.(U)
    varu = var.(U)
    𝔼u′ = mean.(U′)
    varu′ = var.(U′)
    return (∑(abs2,𝔼u - 𝔼u′) + ∑(abs,varu - varu′) * var_weight) / length(U)
end

function loss_W₂²(p′)
    prob = SDEProblem(SDEFunction(f!,σ!), σ!, U0, tspan, p′)
    # prob = ODEProblem(f!, u0, tspan, p′)
    sol = solve(prob, LambaEM(), saveat=times)
    U′ = sol.u
    return ∑([W₂²(uₜ′, uₜ) for (uₜ′, uₜ) in zip(U′, U)]) / length(U)
end

function loss_mse(p′)
    prob = SDEProblem(SDEFunction(f!,σ!), σ!, u0, tspan, p′)
    # prob = ODEProblem(f!, u0, tspan, p′)
    sol = solve(prob,LambaEM(),  saveat=times)
    U′ = sol.u
    return ∑([∑((uₜ′ - uₜ).^2) for (uₜ′, uₜ) in zip(U′, U)]) / length(U)
end

function loss_avg(p′)
    prob = SDEProblem(SDEFunction(f!,σ!), σ!, U0, tspan, p′)
    # prob = ODEProblem(f!, u0, tspan, p′)
    sol = solve(prob,LambaEM(),  saveat=times)
    U′ = sol.u
    𝔼u = mean.(U)
    stdu = std.(U)
    𝔼u′ = mean.(U′)
    stdu′ = std.(U′)
    return (∑(abs2,𝔼u - 𝔼u′)) / length(U)
end

function train!(loss_func)
    Random.seed!(56789)
    p′ = rand(2)
    @info "Using $(loss_func) as loss function"
    @show ∑(loss_func(p′) for i in 1:1)

    # constructed a threaded version 
    opt = ADAMW(0.01)

    prog_bar = ProgressBar(1:1000)
    # store the best parameters
    # best_p′ = copy(p′)
    # best_loss = ∑(loss_func(p′) for i in 1:1)
    @show p′
    for i in prog_bar
        ∇p′ = gradient(Flux.params(p′)) do
            ∑(loss_func(p′) for i in 1:1)
        end
        Flux.update!(opt, Flux.params(p′), ∇p′)
        current_loss = ∑(loss_func(p′) for i in 1:1)
        # if current_loss < best_loss
            # best_p′ = copy(p′)
            # best_loss = current_loss
        # end
        set_multiline_postfix(prog_bar,
        "loss=$(current_loss)\np₁=$(p′[1])\np₂=$(p′[2])")
    end
    return p′
end
function train!(loss_func, p′)
    @info "Using $(loss_func) as loss function"
    @show ∑(loss_func(p′) for i in 1:1)

    # constructed a threaded version 
    opt = ADAM(0.01)
    params = []
    prog_bar = ProgressBar(1:1000)
    # store the best parameters
    # best_p′ = copy(p′)
    # best_loss = ∑(loss_func(p′) for i in 1:1)
    @show p′
    for i in prog_bar
        ∇p′ = gradient(Flux.params(p′)) do
            ∑(loss_func(p′) for i in 1:1)
        end
        Flux.update!(opt, Flux.params(p′), ∇p′)
        current_loss = ∑(loss_func(p′) for i in 1:1)
        # if current_loss < best_loss
            # best_p′ = copy(p′)
            # best_loss = current_loss
        # end
        set_multiline_postfix(prog_bar,
        "loss=$(current_loss)\np₁=$(p′[1])\np₂=$(p′[2])")
        push!(params, copy(p′))
        GC.gc()
    end
    return params
end


# define approximate loglikelihood function given a sequence of observations U at 
# regular intervals `times` and functioin (f!, σ!) defining the dynamics of the SDE 
# du = f!(du, u, p, t) + σ!(du, u, p, t) dW
# where p is a vector of parameters and dW is a Wiener process
# the basic idea of approximation is to assume that for sufficiently small dt,
# the transition probability density function is approximately Gaussian with mean
# u + f!(du, u, p, t) dt and covariance σ!(du, u, p, t) dt
# the loglikelihood is then approximated by the sum of log of the Gaussian density
# evaluated at the observed values U0
# the function returns the loglikelihood and the predicted values U
# U = [U₀, U₁, U₂, ...], each Uᵢ is a vector of iid observations at time times[i]
# times = 0.0:0.1:20.0
using LinearAlgebra
# let uₜ represent a trajectory of the SDE
function ℓ̂(U, times, f, σ, p)
    # must use out-of-place version of f and σ
    Uₜ = [[U[t][i] for t in eachindex(times)] for i in 1:length(U[1])]
    i = 1
    ℓ = 0.0
    for uₜ ∈ Uₜ
        for i ∈ 1:length(times)-1
            dt = times[i+1] - times[i]
            n = typeof(uₜ[1]) <: AbstractArray ? length(uₜ[1]) : 1
            dx╱dt = f(uₜ[i], p, times[i])
            dσ╱dt = σ(uₜ[i], p, times[i])
            if n == 1
                x = uₜ[i+1] - uₜ[i] - dx╱dt[1] * dt
                Σ = dσ╱dt[1]^2 * dt * I(n)
            else
                x = uₜ[i+1] - uₜ[i] - dx╱dt * dt
                Σ = dσ╱dt.^2 .* I(n) * dt
            end
            # remark: the implementation of covariance matrix Σ from σ is subtle
            # they assume diagonal noises by default 
            ℓ += -0.5 * n * log(2π) - 0.5 * logdet(Σ) - 0.5 * (x' * inv(Σ) * x)[]
        end
    end
    return ℓ
end

function loss_likelihood(p′)
    -ℓ̂(U, times, f, σ, p′)
end

# loss_likelihood([0.5,1.0])



# # alternative implementation by vectorizing the computation
# function logLik(U, times, f, σ, p)
#     ℓ = 0.0
#     for i ∈ 1:length(times) - 1
#         dt = times[i+1] - times[i]
#         X = U[i+1] - U[i] - f(U[i], p, times[i]) * dt
#         Σ = σ(U[i], p, times[i]) * dt * I(length(X))
#         n = typeof(X) <: AbstractArray ? length(X) : 1
#         ℓ += -0.5 * n * log(2π) - 0.5 * logdet(Σ) - 0.5 * sum(X .* (Σ \ X))
#     end
#     return ℓ
# end

# using BenchmarkTools
# @benchmark ℓ̂(U, times, f, σ, p)
# BenchmarkTools.Trial: 10000 samples with 1 evaluation.
#  Range (min … max):  84.383 μs …  1.648 ms  ┊ GC (min … max): 0.00% … 93.95%
#  Time  (median):     85.029 μs              ┊ GC (median):    0.00%
#  Time  (mean ± σ):   87.885 μs ± 46.880 μs  ┊ GC (mean ± σ):  2.01% ±  3.56%

#   ▅██▆▃        ▁▁                             ▁▁              ▁
#   █████▇███▇▇▆▇██▇▅▄▅▄▄▃▃▃▅▄▃▂▂▂▃▄▂▄▃▂▂▄██▇█████▇▇▅▅▆▅▆▆▆▆▆▆▇ █
#   84.4 μs      Histogram: log(frequency) by time       102 μs <

#  Memory estimate: 71.00 KiB, allocs estimate: 42.

# @benchmark logLik(U, times, f, σ, p)
# BenchmarkTools.Trial: 10000 samples with 1 evaluation.
#  Range (min … max):  120.298 μs …   1.885 ms  ┊ GC (min … max):  0.00% … 91.86%
#  Time  (median):     128.209 μs               ┊ GC (median):     0.00%
#  Time  (mean ± σ):   149.556 μs ± 157.140 μs  ┊ GC (mean ± σ):  13.13% ± 11.21%

#   █▁                                                            ▁
#   ███▃▄▁▄█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆█ █
#   120 μs        Histogram: log(frequency) by time       1.32 ms <

#  Memory estimate: 721.89 KiB, allocs estimate: 2001.
